# Copyright 2014-2025 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""Backend for ASTRA using CUDA."""
import warnings
from multiprocessing import Lock
import numpy as np
from packaging.version import parse as parse_version
from odl.core.discr import DiscretizedSpace
from odl.applications.tomo.backends.astra_setup import (
ASTRA_VERSION, astra_projection_geometry,
astra_projector, astra_supports, astra_versions_supporting,
astra_volume_geometry)
from odl.applications.tomo.backends.util import _add_default_complex_impl
from odl.applications.tomo.geometry import (
ConeBeamGeometry, FanBeamGeometry, Geometry, Parallel2dGeometry,
Parallel3dAxisGeometry)
from odl.core.discr.discr_space import DiscretizedSpaceElement
from odl.core.array_API_support import empty, get_array_and_backend
try:
import astra
# This is important, although not use explicitely.
# If not imported, astra.experimental is not "visible"
import astra.experimental
ASTRA_CUDA_AVAILABLE = astra.astra.use_cuda()
except ImportError:
ASTRA_CUDA_AVAILABLE = False
__all__ = (
'ASTRA_CUDA_AVAILABLE',
)
[docs]
def index_of_cuda_device(device: "torch.device"):
"""Gets the integer index of a cuda device"""
if device == 'cpu':
return None
else:
return int(str(device).split(':')[-1])
[docs]
class AstraCudaImpl:
"""`RayTransform` implementation for CUDA algorithms in ASTRA."""
projector_id = None
[docs]
def __init__(self, geometry, vol_space, proj_space):
"""Initialize a new instance.
Parameters
----------
geometry : `Geometry`
Geometry defining the tomographic setup.
vol_space : `DiscretizedSpace`
Reconstruction space, the space of the images to be forward
projected.
proj_space : `DiscretizedSpace`
Projection space, the space of the result.
"""
if not isinstance(geometry, Geometry):
raise TypeError(f"`geometry` must be a `Geometry` instance, got {geometry}")
if not isinstance(vol_space, DiscretizedSpace):
raise TypeError(
f"`vol_space` must be a `DiscretizedSpace` instance, got {vol_space}"
)
if not isinstance(proj_space, DiscretizedSpace):
raise TypeError(
f"`proj_space` must be a `DiscretizedSpace` instance, got {proj_space}"
)
# Print a warning if the detector midpoint normal vector at any
# angle is perpendicular to the geometry axis in parallel 3d
# single-axis geometry -- this is broken in some ASTRA versions
if (
isinstance(geometry, Parallel3dAxisGeometry)
and not astra_supports('par3d_det_mid_pt_perp_to_axis')
):
req_ver = astra_versions_supporting('par3d_det_mid_pt_perp_to_axis')
axis = geometry.axis
mid_pt = geometry.det_params.mid_pt
for i, angle in enumerate(geometry.angles):
if abs(np.dot(axis, geometry.det_to_src(angle, mid_pt))) < 1e-4:
warnings.warn(
f"angle {i}: detector midpoint normal {geometry.det_to_src(angle, mid_pt)}"
+ f" is perpendicular to the geometry axis {axis} in `Parallel3dAxisGeometry`;"
+ f" this is broken in ASTRA {ASTRA_VERSION}, please upgrade to ASTRA {req_ver}",
RuntimeWarning,
)
break
self.geometry = geometry
self._vol_space = vol_space
self._proj_space = proj_space
self.create_ids()
# ASTRA projectors are not thread-safe, thus we need to lock manually
self._mutex = Lock()
assert (
vol_space.impl == proj_space.impl
), f"Volume space ({vol_space.impl}) != Projection space ({proj_space.impl})"
if self.geometry.ndim == 3:
if vol_space.impl == 'numpy':
self.transpose_tuple = (1,0,2)
elif vol_space.impl == 'pytorch':
self.transpose_tuple = (1,0)
else:
raise NotImplementedError("Not implemented for another backend")
self.fp_scaling_factor = astra_cuda_fp_scaling_factor(self.geometry)
self.bp_scaling_factor = astra_cuda_bp_scaling_factor(
self.proj_space, self.vol_space, self.geometry)
@property
def vol_space(self):
"""Volume Space of the Ray Transform"""
return self._vol_space
@property
def proj_space(self):
"""Projection Space of the Ray Transform"""
return self._proj_space
[docs]
def create_ids(self):
"""Create ASTRA objects."""
# Create input and output arrays
if self.geometry.motion_partition.ndim == 1:
motion_shape = self.geometry.motion_partition.shape
else:
# Need to flatten 2- or 3-dimensional angles into one axis
motion_shape = (np.prod(self.geometry.motion_partition.shape),)
proj_shape = motion_shape + self.geometry.det_partition.shape
self.proj_ndim = len(proj_shape)
# Create ASTRA data structures
self.vol_geom = astra_volume_geometry(self.vol_space, 'cuda')
self.proj_geom = astra_projection_geometry(self.geometry, 'cuda')
self.projector_id = astra_projector(
astra_proj_type = 'cuda3d',
astra_vol_geom = self.vol_geom,
astra_proj_geom = self.proj_geom,
ndim = 3,
override_2D = bool(self.geometry.ndim == 2)
)
[docs]
@_add_default_complex_impl
def call_forward(self, x, out=None, **kwargs):
"""Run an ASTRA forward projection on the given data using the GPU.
Parameters
----------
vol_data : ``vol_space.real_space`` element
Volume data to which the projector is applied. Although
``vol_space`` may be complex, this element needs to be real.
out : ``proj_space`` element, optional
Element of the projection space to which the result is written. If
``None``, an element in `proj_space` is created.
Returns
-------
out : ``proj_space`` element
Projection data resulting from the application of the projector.
If ``out`` was provided, the returned object is a reference to it.
"""
return self._call_forward_real(x, out, **kwargs)
def _call_forward_real(self, vol_data:DiscretizedSpaceElement, out=None, **kwargs):
"""Run an ASTRA forward projection on the given data using the GPU.
Parameters
----------
vol_data : ``vol_space.real_space`` element
Volume data to which the projector is applied. Although
``vol_space`` may be complex, this element needs to be real.
out : ``proj_space`` element, optional
Element of the projection space to which the result is written. If
``None``, an element in `proj_space` is created.
Returns
-------
out : ``proj_space`` element
Projection data resulting from the application of the projector.
If ``out`` was provided, the returned object is a reference to it.
"""
with self._mutex:
assert vol_data in self.vol_space.real_space
if out is not None:
assert (
out in self.proj_space.real_space
), ("The out argument provided is a {type(out)}, which is not an element"
+ f" of the projection space {self.proj_space.real_space}")
if self.vol_space.impl == "pytorch":
warnings.warn(
"You requested an out-of-place transform with PyTorch."
+ " This will require cloning the data and will allocate extra memory",
RuntimeWarning,
)
proj_data = out.data[None] if self.proj_ndim == 2 else out.data
if self.geometry.ndim == 3:
proj_data = proj_data.transpose(*self.transpose_tuple)
else:
proj_data = empty(
impl = self.proj_space.impl,
shape = astra.geom_size(self.proj_geom),
dtype = self.proj_space.dtype,
device = self.proj_space.device
)
if self.proj_ndim == 2:
volume_data = vol_data.data[None]
elif self.proj_ndim == 3:
volume_data = vol_data.data
else:
raise NotImplementedError
volume_data, _ = get_array_and_backend(volume_data, must_be_contiguous=True)
proj_data, _ = get_array_and_backend(proj_data, must_be_contiguous=True)
if self.proj_space.impl == 'pytorch':
device_index = index_of_cuda_device(
self.proj_space.tspace.device) #type:ignore
if device_index is not None:
astra.set_gpu_index(device_index)
astra.experimental.direct_FP3D( #type:ignore
self.projector_id,
volume_data,
proj_data
)
proj_data *= self.fp_scaling_factor
proj_data = (
proj_data[0]
if self.geometry.ndim == 2
else proj_data.transpose(*self.transpose_tuple)
)
if out is not None:
out.data[:] = proj_data if self.proj_space.impl == 'numpy' else proj_data.clone()
else:
return self.proj_space.element(proj_data)
[docs]
@_add_default_complex_impl
def call_backward(self, x, out=None, **kwargs):
"""Run an ASTRA back-projection on the given data using the GPU.
Parameters
----------
proj_data : ``proj_space.real_space`` element
Projection data to which the back-projector is applied. Although
``proj_space`` may be complex, this element needs to be real.
out : ``vol_space`` element, optional
Element of the reconstruction space to which the result is written.
If ``None``, an element in ``vol_space`` is created.
Returns
-------
out : ``vol_space`` element
Reconstruction data resulting from the application of the
back-projector. If ``out`` was provided, the returned object is a
reference to it.
"""
return self._call_backward_real(x, out, **kwargs)
def _call_backward_real(self, proj_data:DiscretizedSpaceElement, out=None, **kwargs):
"""Run an ASTRA back-projection on the given data using the GPU.
Parameters
----------
proj_data : ``proj_space.real_space`` element
Projection data to which the back-projector is applied. Although
``proj_space`` may be complex, this element needs to be real.
out : ``vol_space`` element, optional
Element of the reconstruction space to which the result is written.
If ``None``, an element in ``vol_space`` is created.
Returns
-------
out : ``vol_space`` element
Reconstruction data resulting from the application of the
back-projector. If ``out`` was provided, the returned object is a
reference to it.
"""
with self._mutex:
assert proj_data in self.proj_space.real_space
if out is not None:
assert (
out in self.vol_space.real_space
), f"The out argument provided is a {type(out)}, which is not an element of the projection space {self.vol_space.real_space}"
if self.vol_space.impl == 'pytorch':
warnings.warn(
"You requested an out-of-place transform with PyTorch. \
This will require cloning the data and will allocate extra memory",
RuntimeWarning)
volume_data = out.data[None] if self.geometry.ndim==2 else out.data
else:
volume_data = empty(
self.vol_space.impl,
astra.geom_size(self.vol_geom),
dtype = self.vol_space.dtype,
device = self.vol_space.device
)
### Transpose projection tensor
if self.proj_ndim == 2:
proj_data = proj_data.data[None]
elif self.proj_ndim == 3:
proj_data = proj_data.data.transpose(*self.transpose_tuple)
else:
raise NotImplementedError
# Ensure data is contiguous otherwise astra will throw an error
volume_data, _ = get_array_and_backend(volume_data, must_be_contiguous=True)
proj_data, _ = get_array_and_backend(proj_data, must_be_contiguous=True)
if self.vol_space.tspace.impl == 'pytorch':
device_index = index_of_cuda_device(self.vol_space.tspace.device) #type:ignore
if device_index is not None:
astra.set_gpu_index(device_index)
### Call the backprojection
astra.experimental.direct_BP3D( #type:ignore
self.projector_id,
volume_data,
proj_data
)
volume_data *= self.bp_scaling_factor
volume_data = volume_data[0] if self.geometry.ndim == 2 else volume_data
if out is not None:
out[:] = volume_data if self.vol_space.impl == 'numpy' else volume_data.clone()
return out
else:
return self.vol_space.element(volume_data)
[docs]
def astra_cuda_fp_scaling_factor(geometry):
"""Volume scaling accounting for differing adjoint definitions.
ASTRA defines the adjoint operator in terms of a fully discrete
setting (transposed "projection matrix") without any relation to
physical dimensions, which makes a re-scaling necessary to
translate it to spaces with physical dimensions.
Behavior of ASTRA changes slightly between versions, so we keep
track of it and adapt the scaling accordingly.
"""
if (
isinstance(geometry, Parallel2dGeometry)
and parse_version(ASTRA_VERSION) < parse_version('1.9.9.dev')
):
# parallel2d scales with pixel stride
return 1 / float(geometry.det_partition.cell_volume)
else:
return 1
[docs]
def astra_cuda_bp_scaling_factor(proj_space, vol_space, geometry):
"""Volume scaling accounting for differing adjoint definitions.
ASTRA defines the adjoint operator in terms of a fully discrete
setting (transposed "projection matrix") without any relation to
physical dimensions, which makes a re-scaling necessary to
translate it to spaces with physical dimensions.
Behavior of ASTRA changes slightly between versions, so we keep
track of it and adapt the scaling accordingly.
"""
# Angular integration weighting factor
# angle interval weight by approximate cell volume
angle_extent = geometry.motion_partition.extent
num_angles = geometry.motion_partition.shape
# TODO: this gives the wrong factor for Parallel3dEulerGeometry with
# 2 angles
scaling_factor = (angle_extent / num_angles).prod()
# Correct in case of non-weighted spaces
proj_extent = float(proj_space.partition.extent.prod())
proj_size = float(proj_space.partition.size)
proj_weighting = proj_extent / proj_size
scaling_factor *= proj_space.weighting.const / proj_weighting
scaling_factor /= vol_space.weighting.const / vol_space.cell_volume
if parse_version(ASTRA_VERSION) < parse_version('1.8rc1'):
# Scaling for the old, pre-1.8 behaviour
if isinstance(geometry, Parallel2dGeometry):
# Scales with 1 / cell_volume
scaling_factor *= float(vol_space.cell_volume)
elif (isinstance(geometry, FanBeamGeometry)
and geometry.det_curvature_radius is None):
# Scales with 1 / cell_volume
scaling_factor *= float(vol_space.cell_volume)
# Additional magnification correction
src_radius = geometry.src_radius
det_radius = geometry.det_radius
scaling_factor *= ((src_radius + det_radius) / src_radius)
elif isinstance(geometry, Parallel3dAxisGeometry):
# Scales with voxel stride
# In 1.7, only cubic voxels are supported
voxel_stride = vol_space.cell_sides[0]
scaling_factor /= float(voxel_stride)
elif (isinstance(geometry, ConeBeamGeometry)
and geometry.det_curvature_radius is None):
# Scales with 1 / cell_volume
# In 1.7, only cubic voxels are supported
voxel_stride = vol_space.cell_sides[0]
scaling_factor /= float(voxel_stride)
# Magnification correction
src_radius = geometry.src_radius
det_radius = geometry.det_radius
scaling_factor *= ((src_radius + det_radius) / src_radius) ** 2
elif parse_version(ASTRA_VERSION) < parse_version('1.9.0dev'):
# Scaling for the 1.8.x releases
if isinstance(geometry, Parallel2dGeometry):
# Scales with 1 / cell_volume
scaling_factor *= float(vol_space.cell_volume)
elif (isinstance(geometry, FanBeamGeometry)
and geometry.det_curvature_radius is None):
# Scales with 1 / cell_volume
scaling_factor *= float(vol_space.cell_volume)
# Magnification correction
src_radius = geometry.src_radius
det_radius = geometry.det_radius
scaling_factor *= ((src_radius + det_radius) / src_radius)
elif isinstance(geometry, Parallel3dAxisGeometry):
# Scales with cell volume
# currently only square voxels are supported
scaling_factor /= vol_space.cell_volume
elif (isinstance(geometry, ConeBeamGeometry)
and geometry.det_curvature_radius is None):
# Scales with cell volume
scaling_factor /= vol_space.cell_volume
# Magnification correction (scaling = 1 / magnification ** 2)
src_radius = geometry.src_radius
det_radius = geometry.det_radius
scaling_factor *= ((src_radius + det_radius) / src_radius) ** 2
# Correction for scaled 1/r^2 factor in ASTRA's density weighting.
# This compensates for scaled voxels and pixels, as well as a
# missing factor src_radius ** 2 in the ASTRA BP with
# density weighting.
det_px_area = geometry.det_partition.cell_volume
scaling_factor *= (
src_radius ** 2 * det_px_area ** 2 / vol_space.cell_volume ** 2
)
elif parse_version(ASTRA_VERSION) < parse_version('1.9.9.dev'):
# Scaling for intermediate dev releases between 1.8.3 and 1.9.9.dev
if isinstance(geometry, Parallel2dGeometry):
# Scales with 1 / cell_volume
scaling_factor *= float(vol_space.cell_volume)
elif (isinstance(geometry, FanBeamGeometry)
and geometry.det_curvature_radius is None):
# Scales with 1 / cell_volume
scaling_factor *= float(vol_space.cell_volume)
# Magnification correction
src_radius = geometry.src_radius
det_radius = geometry.det_radius
scaling_factor *= ((src_radius + det_radius) / src_radius)
elif isinstance(geometry, Parallel3dAxisGeometry):
# Scales with cell volume
# currently only square voxels are supported
scaling_factor /= vol_space.cell_volume
elif (isinstance(geometry, ConeBeamGeometry)
and geometry.det_curvature_radius is None):
# Scales with cell volume
scaling_factor /= vol_space.cell_volume
# Magnification correction (scaling = 1 / magnification ** 2)
src_radius = geometry.src_radius
det_radius = geometry.det_radius
scaling_factor *= ((src_radius + det_radius) / src_radius) ** 2
# Correction for scaled 1/r^2 factor in ASTRA's density weighting.
# This compensates for scaled voxels and pixels, as well as a
# missing factor src_radius ** 2 in the ASTRA BP with
# density weighting.
det_px_area = geometry.det_partition.cell_volume
scaling_factor *= (src_radius ** 2 * det_px_area ** 2)
else:
# Scaling for versions since 1.9.9.dev
scaling_factor /= float(vol_space.cell_volume)
scaling_factor *= float(geometry.det_partition.cell_volume)
return scaling_factor
if __name__ == '__main__':
from odl.core.util.testutils import run_doctests
run_doctests()