Cone Beam Iterative Reconstruction

This example demonstrates gradient-based iterative reconstruction for 3D cone beam CT using the differentiable ConeProjectorFunction from diffct.

Overview

3D cone beam iterative reconstruction extends optimization methods to full volumetric reconstruction. This example shows how to:

  • Formulate 3D cone beam CT reconstruction as a large-scale optimization problem

  • Handle the computational complexity of 3D forward and backward projections

  • Apply memory-efficient optimization strategies for volumetric data

  • Monitor convergence in high-dimensional parameter space

Mathematical Background

3D Cone Beam Iterative Formulation

The 3D reconstruction problem is formulated as:

\[\hat{f} = \arg\min_f \|A_{\text{cone}}(f) - p\|_2^2 + \lambda R(f)\]

where: - \(f(x,y,z)\) is the unknown 3D volume - \(A_{\text{cone}}\) is the cone beam forward projection operator - \(p(\phi, u, v)\) is the measured 2D projection data - \(R(f)\) is an optional 3D regularization term

3D Forward Projection

The cone beam forward projection integrates along rays through the 3D volume:

\[p(\phi, u, v) = \int_0^{\infty} f\left(\vec{r}_s(\phi) + t \cdot \vec{d}(\phi, u, v)\right) dt\]

where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.

3D Gradient Computation

The gradient of the 3D loss function uses the cone beam backprojection operator:

\[\frac{\partial L}{\partial f} = 2A_{\text{cone}}^T(A_{\text{cone}}(f) - p_{\text{measured}})\]

where \(A_{\text{cone}}^T\) is the 3D cone beam backprojection operator (adjoint).

Computational Complexity

3D reconstruction presents significant computational challenges:

  • Memory Requirements: \(O(N^3)\) for volume storage vs \(O(N^2)\) for 2D images

  • Projection Data: \(O(N_{\phi} \times N_u \times N_v)\) 2D projections

  • Forward/Backward Operations: \(O(N^3 \times N_{\phi})\) computational complexity

  • Gradient Storage: Additional memory for automatic differentiation

Implementation Steps

  1. 3D Problem Setup: Define parameterized 3D volume as learnable tensor

  2. Cone Beam Forward Model: Use ConeProjectorFunction for 2D projection prediction

  3. Loss Computation: Calculate L2 distance between predicted and measured projections

  4. 3D Gradient Computation: Use automatic differentiation through cone beam operators

  5. Memory-Efficient Optimization: Apply strategies to handle large 3D parameter space

  6. Convergence Monitoring: Track loss and 3D reconstruction quality

Model Architecture

The 3D reconstruction model consists of:

  • Parameterized Volume: Learnable 3D tensor representing the unknown volume

  • Cone Beam Forward Model: ConeProjectorFunction with 3D geometry parameters

  • Loss Function: Mean squared error between predicted and measured 2D projections

3D Regularization Options

Common 3D regularization terms:

  1. 3D Total Variation: \(R_{\text{TV}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2\)

  2. 3D Smoothness: \(R_{\text{smooth}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2^2\)

  3. L1 Sparsity: \(R_{\text{L1}}(f) = \sum_{x,y,z} |f(x,y,z)|\)

Memory Management Strategies

3D reconstruction requires careful memory management:

  • Gradient Checkpointing: Trade computation for memory in backpropagation

  • Mixed Precision: Use float16 when possible to reduce memory usage

  • Batch Processing: Process volume slices when memory is extremely limited

  • Efficient Data Layout: Optimize tensor storage and access patterns

Convergence Characteristics

3D cone beam reconstruction typically exhibits:

  1. Initial Convergence (0-100 iterations): Rapid loss decrease, basic 3D structure emerges

  2. Detail Refinement (100-500 iterations): Fine 3D features develop progressively

  3. Final Convergence (500+ iterations): Slow improvement, potential overfitting risk

Challenges in 3D Reconstruction

  • Cone Beam Artifacts: Increased artifacts for large cone angles in 3D

  • Incomplete Sampling: Missing data in certain regions of 3D Fourier space

  • Computational Cost: Orders of magnitude higher than 2D reconstruction

  • Memory Limitations: Large volumes may exceed available GPU memory

  • Convergence Complexity: Higher-dimensional optimization landscape

Applications

3D cone beam iterative reconstruction is essential for:

  • Medical CBCT: Dental, orthopedic, and interventional imaging

  • Industrial CT: Non-destructive testing and quality control

  • Micro-CT: High-resolution imaging of small specimens and materials

  • Security Screening: Advanced baggage and cargo inspection systems

Code Example

3D Cone Beam Iterative Example
  1import math
  2import torch
  3import numpy as np
  4import matplotlib.pyplot as plt
  5import torch.nn as nn
  6import torch.optim as optim
  7from diffct.differentiable import ConeProjectorFunction
  8
  9def shepp_logan_3d(shape):
 10    zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
 11    xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
 12    yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
 13    zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
 14    el_params = np.array([
 15        [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
 16        [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
 17        [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
 18        [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
 19        [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
 20        [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 21        [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 22        [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
 23        [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
 24        [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
 25    ], dtype=np.float32)
 26
 27    # Extract parameters for vectorization
 28    x_pos = el_params[:, 0][:, None, None, None]
 29    y_pos = el_params[:, 1][:, None, None, None]
 30    z_pos = el_params[:, 2][:, None, None, None]
 31    a_axis = el_params[:, 3][:, None, None, None]
 32    b_axis = el_params[:, 4][:, None, None, None]
 33    c_axis = el_params[:, 5][:, None, None, None]
 34    phi = el_params[:, 6][:, None, None, None]
 35    val = el_params[:, 9][:, None, None, None]
 36
 37    # Broadcast grid to ellipsoid axis
 38    xc = xx[None, ...] - x_pos
 39    yc = yy[None, ...] - y_pos
 40    zc = zz[None, ...] - z_pos
 41
 42    c = np.cos(phi)
 43    s = np.sin(phi)
 44
 45    # Only rotation around z, so can vectorize:
 46    xp = c * xc - s * yc
 47    yp = s * xc + c * yc
 48    zp = zc
 49
 50    mask = (
 51        (xp ** 2) / (a_axis ** 2)
 52        + (yp ** 2) / (b_axis ** 2)
 53        + (zp ** 2) / (c_axis ** 2)
 54        <= 1.0
 55    )
 56
 57    # Use broadcasting to sum all ellipsoid contributions
 58    shepp_logan = np.sum(mask * val, axis=0)
 59    shepp_logan = np.clip(shepp_logan, 0, 1)
 60    return shepp_logan
 61
 62class IterativeRecoModel(nn.Module):
 63    def __init__(self, volume_shape, angles, det_u, det_v, du, dv, sdd, sid, voxel_spacing):
 64        super().__init__()
 65        self.reco = nn.Parameter(torch.zeros(volume_shape))
 66        self.angles = angles
 67        self.det_u = det_u
 68        self.det_v = det_v
 69        self.du = du
 70        self.dv = dv
 71        self.sdd = sdd
 72        self.sid = sid
 73        self.relu = nn.ReLU() # non negative constraint
 74        self.voxel_spacing = voxel_spacing
 75
 76    def forward(self, x):
 77        updated_reco = x + self.reco
 78        current_sino = ConeProjectorFunction.apply(updated_reco, 
 79                                                   self.angles, 
 80                                                   self.det_u, self.det_v, 
 81                                                   self.du, self.dv, 
 82                                                   self.sdd, self.sid, self.voxel_spacing)
 83        return current_sino, self.relu(updated_reco)
 84
 85class Pipeline:
 86    def __init__(self, lr, volume_shape, angles, 
 87                 det_u, det_v, du, dv, 
 88                 sdd, sid, voxel_spacing,
 89                 device, epoches=1000):
 90        
 91        self.epoches = epoches
 92        self.model = IterativeRecoModel(volume_shape, angles,
 93                                        det_u, det_v, du, dv, 
 94                                        sdd, sid, voxel_spacing).to(device)
 95        
 96        self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
 97        self.loss = nn.MSELoss()
 98
 99    def train(self, input, label):
100        loss_values = []
101        for epoch in range(self.epoches):
102            self.optimizer.zero_grad()
103            predictions, current_reco = self.model(input)
104            loss_value = self.loss(predictions, label)
105            loss_value.backward()
106            self.optimizer.step()
107            loss_values.append(loss_value.item())
108
109            if epoch % 10 == 0:
110                print(f"Epoch {epoch}, Loss: {loss_value.item()}")
111
112        return loss_values, self.model
113
114def main():
115    Nx, Ny, Nz = 64, 64, 64
116    phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
117
118    num_views = 180
119    angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
120
121    det_u, det_v = 128, 128
122    du, dv = 1.0, 1.0
123    voxel_spacing = 1.0
124    sdd = 600.0
125    sid = 400.0
126
127    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
128    phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32).contiguous()
129
130    # Generate the "real" sinogram
131    angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
132    real_sinogram = ConeProjectorFunction.apply(phantom_torch, angles_torch,
133                                               det_u, det_v, du, dv,
134                                               sdd, sid, voxel_spacing)
135
136    pipeline_instance = Pipeline(lr=1e-1, 
137                                 volume_shape=(Nz,Ny,Nx),
138                                 angles=angles_torch,
139                                 det_u=det_u, det_v=det_v,
140                                 du=du, dv=dv, voxel_spacing=voxel_spacing,
141                                 sdd=sdd,
142                                 sid=sid,
143                                 device=device, epoches=1000)
144    
145    ini_guess = torch.zeros_like(phantom_torch)
146    
147    loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
148    
149    reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
150
151    plt.figure()
152    plt.plot(loss_values)
153    plt.title("Loss Curve")
154    plt.xlabel("Epoch")
155    plt.ylabel("Loss")
156    plt.show()
157
158    mid_slice = Nz // 2
159    plt.figure(figsize=(12, 6))
160    plt.subplot(1, 2, 1)
161    plt.imshow(phantom_cpu[mid_slice, :, :], cmap="gray")
162    plt.title("Original Phantom Mid-Slice")
163    plt.axis("off")
164
165    plt.subplot(1, 2, 2)
166    plt.imshow(reco[mid_slice, :, :], cmap="gray")
167    plt.title("Reconstructed Mid-Slice")
168    plt.axis("off")
169    plt.show()
170
171if __name__ == "__main__":
172    main()