Cone Beam Iterative Reconstruction

This example demonstrates gradient-based iterative reconstruction for 3D cone beam CT using the differentiable ConeProjectorFunction from diffct.

Overview

3D cone beam iterative reconstruction extends optimization methods to full volumetric reconstruction. This example shows how to:

  • Formulate 3D cone beam CT reconstruction as a large-scale optimization problem

  • Handle the computational complexity of 3D forward and backward projections

  • Apply memory-efficient optimization strategies for volumetric data

  • Monitor convergence in high-dimensional parameter space

Mathematical Background

3D Cone Beam Iterative Formulation

The 3D reconstruction problem is formulated as:

\[\hat{f} = \arg\min_f \|A_{\text{cone}}(f) - p\|_2^2 + \lambda R(f)\]

where: - \(f(x,y,z)\) is the unknown 3D volume - \(A_{\text{cone}}\) is the cone beam forward projection operator - \(p(\phi, u, v)\) is the measured 2D projection data - \(R(f)\) is an optional 3D regularization term

3D Forward Projection

The cone beam forward projection integrates along rays through the 3D volume:

\[p(\phi, u, v) = \int_0^{\infty} f\left(\vec{r}_s(\phi) + t \cdot \vec{d}(\phi, u, v)\right) dt\]

where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.

3D Gradient Computation

The gradient of the 3D loss function uses the cone beam backprojection operator:

\[\frac{\partial L}{\partial f} = 2A_{\text{cone}}^T(A_{\text{cone}}(f) - p_{\text{measured}})\]

where \(A_{\text{cone}}^T\) is the 3D cone beam backprojection operator (adjoint).

Computational Complexity

3D reconstruction presents significant computational challenges:

  • Memory Requirements: \(O(N^3)\) for volume storage vs \(O(N^2)\) for 2D images

  • Projection Data: \(O(N_{\phi} \times N_u \times N_v)\) 2D projections

  • Forward/Backward Operations: \(O(N^3 \times N_{\phi})\) computational complexity

  • Gradient Storage: Additional memory for automatic differentiation

Implementation Steps

  1. 3D Problem Setup: Define parameterized 3D volume as learnable tensor

  2. Cone Beam Forward Model: Use ConeProjectorFunction for 2D projection prediction

  3. Loss Computation: Calculate L2 distance between predicted and measured projections

  4. 3D Gradient Computation: Use automatic differentiation through cone beam operators

  5. Memory-Efficient Optimization: Apply strategies to handle large 3D parameter space

  6. Convergence Monitoring: Track loss and 3D reconstruction quality

Model Architecture

The 3D reconstruction model consists of:

  • Parameterized Volume: Learnable 3D tensor representing the unknown volume

  • Cone Beam Forward Model: ConeProjectorFunction with 3D geometry parameters

  • Loss Function: Mean squared error between predicted and measured 2D projections

3D Regularization Options

Common 3D regularization terms:

  1. 3D Total Variation: \(R_{\text{TV}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2\)

  2. 3D Smoothness: \(R_{\text{smooth}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2^2\)

  3. L1 Sparsity: \(R_{\text{L1}}(f) = \sum_{x,y,z} |f(x,y,z)|\)

Memory Management Strategies

3D reconstruction requires careful memory management:

  • Gradient Checkpointing: Trade computation for memory in backpropagation

  • Mixed Precision: Use float16 when possible to reduce memory usage

  • Batch Processing: Process volume slices when memory is extremely limited

  • Efficient Data Layout: Optimize tensor storage and access patterns

Convergence Characteristics

3D cone beam reconstruction typically exhibits:

  1. Initial Convergence (0-100 iterations): Rapid loss decrease, basic 3D structure emerges

  2. Detail Refinement (100-500 iterations): Fine 3D features develop progressively

  3. Final Convergence (500+ iterations): Slow improvement, potential overfitting risk

Challenges in 3D Reconstruction

  • Cone Beam Artifacts: Increased artifacts for large cone angles in 3D

  • Incomplete Sampling: Missing data in certain regions of 3D Fourier space

  • Computational Cost: Orders of magnitude higher than 2D reconstruction

  • Memory Limitations: Large volumes may exceed available GPU memory

  • Convergence Complexity: Higher-dimensional optimization landscape

Applications

3D cone beam iterative reconstruction is essential for:

  • Medical CBCT: Dental, orthopedic, and interventional imaging

  • Industrial CT: Non-destructive testing and quality control

  • Micro-CT: High-resolution imaging of small specimens and materials

  • Security Screening: Advanced baggage and cargo inspection systems

Code Example

3D Cone Beam Iterative Example
  1import math
  2import torch
  3import numpy as np
  4import matplotlib.pyplot as plt
  5import torch.nn as nn
  6import torch.optim as optim
  7from diffct.differentiable import ConeProjectorFunction
  8
  9def shepp_logan_3d(shape):
 10    zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
 11    xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
 12    yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
 13    zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
 14    el_params = np.array([
 15        [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
 16        [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
 17        [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
 18        [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
 19        [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
 20        [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 21        [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 22        [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
 23        [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
 24        [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
 25    ], dtype=np.float32)
 26
 27    # Extract parameters for vectorization
 28    x_pos = el_params[:, 0][:, None, None, None]
 29    y_pos = el_params[:, 1][:, None, None, None]
 30    z_pos = el_params[:, 2][:, None, None, None]
 31    a_axis = el_params[:, 3][:, None, None, None]
 32    b_axis = el_params[:, 4][:, None, None, None]
 33    c_axis = el_params[:, 5][:, None, None, None]
 34    phi = el_params[:, 6][:, None, None, None]
 35    val = el_params[:, 9][:, None, None, None]
 36
 37    # Broadcast grid to ellipsoid axis
 38    xc = xx[None, ...] - x_pos
 39    yc = yy[None, ...] - y_pos
 40    zc = zz[None, ...] - z_pos
 41
 42    c = np.cos(phi)
 43    s = np.sin(phi)
 44
 45    # Only rotation around z, so can vectorize:
 46    xp = c * xc - s * yc
 47    yp = s * xc + c * yc
 48    zp = zc
 49
 50    mask = (
 51        (xp ** 2) / (a_axis ** 2)
 52        + (yp ** 2) / (b_axis ** 2)
 53        + (zp ** 2) / (c_axis ** 2)
 54        <= 1.0
 55    )
 56
 57    # Use broadcasting to sum all ellipsoid contributions
 58    shepp_logan = np.sum(mask * val, axis=0)
 59    shepp_logan = np.clip(shepp_logan, 0, 1)
 60    return shepp_logan
 61
 62class IterativeRecoModel(nn.Module):
 63    def __init__(self, volume_shape, angles, det_u, det_v, du, dv, sdd, sid,
 64                 voxel_spacing, backend="siddon"):
 65        super().__init__()
 66        self.reco = nn.Parameter(torch.zeros(volume_shape))
 67        self.angles = angles
 68        self.det_u = det_u
 69        self.det_v = det_v
 70        self.du = du
 71        self.dv = dv
 72        self.sdd = sdd
 73        self.sid = sid
 74        self.voxel_spacing = voxel_spacing
 75        self.backend = backend
 76
 77    def forward(self, x):
 78        updated_reco = x + self.reco
 79        # ``backend`` is the last positional arg to
 80        # ``ConeProjectorFunction.apply`` so we also pass the five default
 81        # offsets (two detector offsets + three centre offsets) to line
 82        # up with the signature.
 83        current_sino = ConeProjectorFunction.apply(
 84            updated_reco,
 85            self.angles,
 86            self.det_u, self.det_v,
 87            self.du, self.dv,
 88            self.sdd, self.sid,
 89            self.voxel_spacing,
 90            0.0, 0.0,             # detector_offset_u, detector_offset_v
 91            0.0, 0.0, 0.0,        # center_offset_x, y, z
 92            self.backend,
 93        )
 94        return current_sino, updated_reco
 95
 96class Pipeline:
 97    def __init__(self, lr, volume_shape, angles,
 98                 det_u, det_v, du, dv,
 99                 sdd, sid, voxel_spacing,
100                 device, epoches=1000, backend="siddon"):
101        self.epoches = epoches
102        self.model = IterativeRecoModel(volume_shape, angles,
103                                        det_u, det_v, du, dv,
104                                        sdd, sid, voxel_spacing,
105                                        backend=backend).to(device)
106
107        self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
108        self.loss = nn.MSELoss()
109
110    def train(self, input, label):
111        loss_values = []
112        for epoch in range(self.epoches):
113            self.optimizer.zero_grad()
114            predictions, current_reco = self.model(input)
115            loss_value = self.loss(predictions, label)
116            loss_value.backward()
117            self.optimizer.step()
118            with torch.no_grad():
119                self.model.reco.clamp_(min=0.0)
120            loss_values.append(loss_value.item())
121
122            if epoch % 10 == 0:
123                print(f"Epoch {epoch}, Loss: {loss_value.item()}")
124
125        return loss_values, self.model
126
127def main():
128    Nx, Ny, Nz = 64, 64, 64
129    phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
130
131    num_views = 180
132    angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
133
134    det_u, det_v = 128, 128
135    du, dv = 1.0, 1.0
136    voxel_spacing = 1.0
137    sdd = 600.0
138    sid = 400.0
139
140    # Forward projector backend used for BOTH the ground-truth sinogram
141    # and the inner iterative loop. Using the same backend on both sides
142    # guarantees the loop is solving its own consistent inverse problem
143    # and that the adjoint returned by autograd matches the forward
144    # byte-for-byte (matched scatter/gather kernel pair, verified by
145    # tests/test_adjoint_inner_product.py). Options:
146    #
147    #   "siddon"           - 3D ray-driven cell-constant Siddon. Fastest
148    #                        per-iteration step. Good default when
149    #                        iteration count is the bottleneck.
150    #   "sf_tr"            - 3D SF with trapezoidal transaxial and
151    #                        rectangular axial footprint. Mass-
152    #                        conserving per voxel, closed-form cell
153    #                        integral. ~2x slower forward than siddon.
154    #   "sf_tt"            - 3D SF with trapezoidal footprint in BOTH
155    #                        directions; the axial trapezoid captures
156    #                        the variation of axial magnification across
157    #                        the voxel by using ``U_near`` and ``U_far``
158    #                        corner projections. Strictly more expressive
159    #                        than SF-TR at ~1.4x the SF-TR cost. Useful
160    #                        for large cone angles and for research into
161    #                        the full Long et al. separable-footprint
162    #                        model.
163    projector_backend = "sf_tr"
164
165    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166    phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32).contiguous()
167
168    # Generate the "real" sinogram with the same backend as the inner
169    # loop, so reconstruction targets what the loop can actually produce.
170    angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
171    real_sinogram = ConeProjectorFunction.apply(
172        phantom_torch, angles_torch,
173        det_u, det_v, du, dv,
174        sdd, sid, voxel_spacing,
175        0.0, 0.0,                      # detector_offset_u, detector_offset_v
176        0.0, 0.0, 0.0,                 # center_offset_x, y, z
177        projector_backend,
178    )
179
180    pipeline_instance = Pipeline(lr=1e-1,
181                                 volume_shape=(Nz,Ny,Nx),
182                                 angles=angles_torch,
183                                 det_u=det_u, det_v=det_v,
184                                 du=du, dv=dv, voxel_spacing=voxel_spacing,
185                                 sdd=sdd,
186                                 sid=sid,
187                                 device=device, epoches=1000,
188                                 backend=projector_backend)
189    
190    ini_guess = torch.zeros_like(phantom_torch)
191    
192    loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
193    
194    reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
195
196    plt.figure()
197    plt.plot(loss_values)
198    plt.title("Loss Curve")
199    plt.xlabel("Epoch")
200    plt.ylabel("Loss")
201    plt.show()
202
203    mid_slice = Nz // 2
204    plt.figure(figsize=(12, 6))
205    plt.subplot(1, 2, 1)
206    plt.imshow(phantom_cpu[mid_slice, :, :], cmap="gray")
207    plt.title("Original Phantom Mid-Slice")
208    plt.axis("off")
209
210    plt.subplot(1, 2, 2)
211    plt.imshow(reco[mid_slice, :, :], cmap="gray")
212    plt.title("Reconstructed Mid-Slice")
213    plt.axis("off")
214    plt.show()
215
216if __name__ == "__main__":
217    main()