Cone Beam Iterative Reconstruction

This example demonstrates gradient-based iterative reconstruction for 3D cone beam CT using the differentiable ConeProjectorFunction from diffct.

Overview

3D cone beam iterative reconstruction extends optimization methods to full volumetric reconstruction. This example shows how to:

  • Formulate 3D cone beam CT reconstruction as a large-scale optimization problem

  • Handle the computational complexity of 3D forward and backward projections

  • Apply memory-efficient optimization strategies for volumetric data

  • Monitor convergence in high-dimensional parameter space

Mathematical Background

3D Cone Beam Iterative Formulation

The 3D reconstruction problem is formulated as:

\[\hat{f} = \arg\min_f \|A_{\text{cone}}(f) - p\|_2^2 + \lambda R(f)\]

where: - \(f(x,y,z)\) is the unknown 3D volume - \(A_{\text{cone}}\) is the cone beam forward projection operator - \(p(\phi, u, v)\) is the measured 2D projection data - \(R(f)\) is an optional 3D regularization term

3D Forward Projection

The cone beam forward projection integrates along rays through the 3D volume:

\[p(\phi, u, v) = \int_0^{\infty} f\left(\vec{r}_s(\phi) + t \cdot \vec{d}(\phi, u, v)\right) dt\]

where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.

3D Gradient Computation

The gradient of the 3D loss function uses the cone beam backprojection operator:

\[\frac{\partial L}{\partial f} = 2A_{\text{cone}}^T(A_{\text{cone}}(f) - p_{\text{measured}})\]

where \(A_{\text{cone}}^T\) is the 3D cone beam backprojection operator (adjoint).

Computational Complexity

3D reconstruction presents significant computational challenges:

  • Memory Requirements: \(O(N^3)\) for volume storage vs \(O(N^2)\) for 2D images

  • Projection Data: \(O(N_{\phi} \times N_u \times N_v)\) 2D projections

  • Forward/Backward Operations: \(O(N^3 \times N_{\phi})\) computational complexity

  • Gradient Storage: Additional memory for automatic differentiation

Implementation Steps

  1. 3D Problem Setup: Define parameterized 3D volume as learnable tensor

  2. Cone Beam Forward Model: Use ConeProjectorFunction for 2D projection prediction

  3. Loss Computation: Calculate L2 distance between predicted and measured projections

  4. 3D Gradient Computation: Use automatic differentiation through cone beam operators

  5. Memory-Efficient Optimization: Apply strategies to handle large 3D parameter space

  6. Convergence Monitoring: Track loss and 3D reconstruction quality

Model Architecture

The 3D reconstruction model consists of:

  • Parameterized Volume: Learnable 3D tensor representing the unknown volume

  • Cone Beam Forward Model: ConeProjectorFunction with 3D geometry parameters

  • Loss Function: Mean squared error between predicted and measured 2D projections

3D Regularization Options

Common 3D regularization terms:

  1. 3D Total Variation: \(R_{\text{TV}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2\)

  2. 3D Smoothness: \(R_{\text{smooth}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2^2\)

  3. L1 Sparsity: \(R_{\text{L1}}(f) = \sum_{x,y,z} |f(x,y,z)|\)

Memory Management Strategies

3D reconstruction requires careful memory management:

  • Gradient Checkpointing: Trade computation for memory in backpropagation

  • Mixed Precision: Use float16 when possible to reduce memory usage

  • Batch Processing: Process volume slices when memory is extremely limited

  • Efficient Data Layout: Optimize tensor storage and access patterns

Convergence Characteristics

3D cone beam reconstruction typically exhibits:

  1. Initial Convergence (0-100 iterations): Rapid loss decrease, basic 3D structure emerges

  2. Detail Refinement (100-500 iterations): Fine 3D features develop progressively

  3. Final Convergence (500+ iterations): Slow improvement, potential overfitting risk

Challenges in 3D Reconstruction

  • Cone Beam Artifacts: Increased artifacts for large cone angles in 3D

  • Incomplete Sampling: Missing data in certain regions of 3D Fourier space

  • Computational Cost: Orders of magnitude higher than 2D reconstruction

  • Memory Limitations: Large volumes may exceed available GPU memory

  • Convergence Complexity: Higher-dimensional optimization landscape

Applications

3D cone beam iterative reconstruction is essential for:

  • Medical CBCT: Dental, orthopedic, and interventional imaging

  • Industrial CT: Non-destructive testing and quality control

  • Micro-CT: High-resolution imaging of small specimens and materials

  • Security Screening: Advanced baggage and cargo inspection systems

Code Example

3D Cone Beam Iterative Example
  1import math
  2import torch
  3import numpy as np
  4import matplotlib.pyplot as plt
  5import torch.nn as nn
  6import torch.optim as optim
  7from diffct.differentiable import ConeProjectorFunction
  8
  9def shepp_logan_3d(shape):
 10    zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
 11    xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
 12    yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
 13    zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
 14    el_params = np.array([
 15        [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
 16        [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
 17        [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
 18        [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
 19        [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
 20        [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 21        [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 22        [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
 23        [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
 24        [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
 25    ], dtype=np.float32)
 26
 27    # Extract parameters for vectorization
 28    x_pos = el_params[:, 0][:, None, None, None]
 29    y_pos = el_params[:, 1][:, None, None, None]
 30    z_pos = el_params[:, 2][:, None, None, None]
 31    a_axis = el_params[:, 3][:, None, None, None]
 32    b_axis = el_params[:, 4][:, None, None, None]
 33    c_axis = el_params[:, 5][:, None, None, None]
 34    phi = el_params[:, 6][:, None, None, None]
 35    val = el_params[:, 9][:, None, None, None]
 36
 37    # Broadcast grid to ellipsoid axis
 38    xc = xx[None, ...] - x_pos
 39    yc = yy[None, ...] - y_pos
 40    zc = zz[None, ...] - z_pos
 41
 42    c = np.cos(phi)
 43    s = np.sin(phi)
 44
 45    # Only rotation around z, so can vectorize:
 46    xp = c * xc - s * yc
 47    yp = s * xc + c * yc
 48    zp = zc
 49
 50    mask = (
 51        (xp ** 2) / (a_axis ** 2)
 52        + (yp ** 2) / (b_axis ** 2)
 53        + (zp ** 2) / (c_axis ** 2)
 54        <= 1.0
 55    )
 56
 57    # Use broadcasting to sum all ellipsoid contributions
 58    shepp_logan = np.sum(mask * val, axis=0)
 59    shepp_logan = np.clip(shepp_logan, 0, 1)
 60    return shepp_logan
 61
 62class IterativeRecoModel(nn.Module):
 63    def __init__(self, volume_shape, angles, det_u, det_v, du, dv, sdd, sid,
 64                 voxel_spacing, backend="siddon"):
 65        super().__init__()
 66        self.reco = nn.Parameter(torch.zeros(volume_shape))
 67        self.angles = angles
 68        self.det_u = det_u
 69        self.det_v = det_v
 70        self.du = du
 71        self.dv = dv
 72        self.sdd = sdd
 73        self.sid = sid
 74        self.relu = nn.ReLU() # non negative constraint
 75        self.voxel_spacing = voxel_spacing
 76        self.backend = backend
 77
 78    def forward(self, x):
 79        updated_reco = x + self.reco
 80        # ``backend`` is the last positional arg to
 81        # ``ConeProjectorFunction.apply`` so we also pass the five default
 82        # offsets (two detector offsets + three centre offsets) to line
 83        # up with the signature.
 84        current_sino = ConeProjectorFunction.apply(
 85            updated_reco,
 86            self.angles,
 87            self.det_u, self.det_v,
 88            self.du, self.dv,
 89            self.sdd, self.sid,
 90            self.voxel_spacing,
 91            0.0, 0.0,             # detector_offset_u, detector_offset_v
 92            0.0, 0.0, 0.0,        # center_offset_x, y, z
 93            self.backend,
 94        )
 95        return current_sino, self.relu(updated_reco)
 96
 97class Pipeline:
 98    def __init__(self, lr, volume_shape, angles,
 99                 det_u, det_v, du, dv,
100                 sdd, sid, voxel_spacing,
101                 device, epoches=1000, backend="siddon"):
102        self.epoches = epoches
103        self.model = IterativeRecoModel(volume_shape, angles,
104                                        det_u, det_v, du, dv,
105                                        sdd, sid, voxel_spacing,
106                                        backend=backend).to(device)
107
108        self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
109        self.loss = nn.MSELoss()
110
111    def train(self, input, label):
112        loss_values = []
113        for epoch in range(self.epoches):
114            self.optimizer.zero_grad()
115            predictions, current_reco = self.model(input)
116            loss_value = self.loss(predictions, label)
117            loss_value.backward()
118            self.optimizer.step()
119            loss_values.append(loss_value.item())
120
121            if epoch % 10 == 0:
122                print(f"Epoch {epoch}, Loss: {loss_value.item()}")
123
124        return loss_values, self.model
125
126def main():
127    Nx, Ny, Nz = 64, 64, 64
128    phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
129
130    num_views = 180
131    angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
132
133    det_u, det_v = 128, 128
134    du, dv = 1.0, 1.0
135    voxel_spacing = 1.0
136    sdd = 600.0
137    sid = 400.0
138
139    # Forward projector backend used for BOTH the ground-truth sinogram
140    # and the inner iterative loop. Using the same backend on both sides
141    # guarantees the loop is solving its own consistent inverse problem
142    # and that the adjoint returned by autograd matches the forward
143    # byte-for-byte (matched scatter/gather kernel pair, verified by
144    # tests/test_adjoint_inner_product.py). Options:
145    #
146    #   "siddon"           - 3D ray-driven Siddon with trilinear
147    #                        interpolation. Fastest per-iteration step.
148    #                        Good default when iteration count is the
149    #                        bottleneck.
150    #   "sf_tr"            - 3D SF with trapezoidal transaxial and
151    #                        rectangular axial footprint. Mass-
152    #                        conserving per voxel, closed-form cell
153    #                        integral. ~2x slower forward than siddon.
154    #   "sf_tt"            - 3D SF with trapezoidal footprint in BOTH
155    #                        directions; the axial trapezoid captures
156    #                        the variation of axial magnification across
157    #                        the voxel by using ``U_near`` and ``U_far``
158    #                        corner projections. Strictly more expressive
159    #                        than SF-TR at ~1.4x the SF-TR cost. Useful
160    #                        for large cone angles and for research into
161    #                        the full Long et al. separable-footprint
162    #                        model.
163    projector_backend = "sf_tr"
164
165    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166    phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32).contiguous()
167
168    # Generate the "real" sinogram with the same backend as the inner
169    # loop, so reconstruction targets what the loop can actually produce.
170    angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
171    real_sinogram = ConeProjectorFunction.apply(
172        phantom_torch, angles_torch,
173        det_u, det_v, du, dv,
174        sdd, sid, voxel_spacing,
175        0.0, 0.0,                      # detector_offset_u, detector_offset_v
176        0.0, 0.0, 0.0,                 # center_offset_x, y, z
177        projector_backend,
178    )
179
180    pipeline_instance = Pipeline(lr=1e-1,
181                                 volume_shape=(Nz,Ny,Nx),
182                                 angles=angles_torch,
183                                 det_u=det_u, det_v=det_v,
184                                 du=du, dv=dv, voxel_spacing=voxel_spacing,
185                                 sdd=sdd,
186                                 sid=sid,
187                                 device=device, epoches=1000,
188                                 backend=projector_backend)
189    
190    ini_guess = torch.zeros_like(phantom_torch)
191    
192    loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
193    
194    reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
195
196    plt.figure()
197    plt.plot(loss_values)
198    plt.title("Loss Curve")
199    plt.xlabel("Epoch")
200    plt.ylabel("Loss")
201    plt.show()
202
203    mid_slice = Nz // 2
204    plt.figure(figsize=(12, 6))
205    plt.subplot(1, 2, 1)
206    plt.imshow(phantom_cpu[mid_slice, :, :], cmap="gray")
207    plt.title("Original Phantom Mid-Slice")
208    plt.axis("off")
209
210    plt.subplot(1, 2, 2)
211    plt.imshow(reco[mid_slice, :, :], cmap="gray")
212    plt.title("Reconstructed Mid-Slice")
213    plt.axis("off")
214    plt.show()
215
216if __name__ == "__main__":
217    main()