Cone Beam Iterative Reconstruction
This example demonstrates gradient-based iterative reconstruction for 3D cone beam CT using the differentiable ConeProjectorFunction from diffct.
Overview
3D cone beam iterative reconstruction extends optimization methods to full volumetric reconstruction. This example shows how to:
Formulate 3D cone beam CT reconstruction as a large-scale optimization problem
Handle the computational complexity of 3D forward and backward projections
Apply memory-efficient optimization strategies for volumetric data
Monitor convergence in high-dimensional parameter space
Mathematical Background
3D Cone Beam Iterative Formulation
The 3D reconstruction problem is formulated as:
where: - \(f(x,y,z)\) is the unknown 3D volume - \(A_{\text{cone}}\) is the cone beam forward projection operator - \(p(\phi, u, v)\) is the measured 2D projection data - \(R(f)\) is an optional 3D regularization term
3D Forward Projection
The cone beam forward projection integrates along rays through the 3D volume:
where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.
3D Gradient Computation
The gradient of the 3D loss function uses the cone beam backprojection operator:
where \(A_{\text{cone}}^T\) is the 3D cone beam backprojection operator (adjoint).
Computational Complexity
3D reconstruction presents significant computational challenges:
Memory Requirements: \(O(N^3)\) for volume storage vs \(O(N^2)\) for 2D images
Projection Data: \(O(N_{\phi} \times N_u \times N_v)\) 2D projections
Forward/Backward Operations: \(O(N^3 \times N_{\phi})\) computational complexity
Gradient Storage: Additional memory for automatic differentiation
Implementation Steps
3D Problem Setup: Define parameterized 3D volume as learnable tensor
Cone Beam Forward Model: Use ConeProjectorFunction for 2D projection prediction
Loss Computation: Calculate L2 distance between predicted and measured projections
3D Gradient Computation: Use automatic differentiation through cone beam operators
Memory-Efficient Optimization: Apply strategies to handle large 3D parameter space
Convergence Monitoring: Track loss and 3D reconstruction quality
Model Architecture
The 3D reconstruction model consists of:
Parameterized Volume: Learnable 3D tensor representing the unknown volume
Cone Beam Forward Model: ConeProjectorFunction with 3D geometry parameters
Loss Function: Mean squared error between predicted and measured 2D projections
3D Regularization Options
Common 3D regularization terms:
3D Total Variation: \(R_{\text{TV}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2\)
3D Smoothness: \(R_{\text{smooth}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2^2\)
L1 Sparsity: \(R_{\text{L1}}(f) = \sum_{x,y,z} |f(x,y,z)|\)
Memory Management Strategies
3D reconstruction requires careful memory management:
Gradient Checkpointing: Trade computation for memory in backpropagation
Mixed Precision: Use float16 when possible to reduce memory usage
Batch Processing: Process volume slices when memory is extremely limited
Efficient Data Layout: Optimize tensor storage and access patterns
Convergence Characteristics
3D cone beam reconstruction typically exhibits:
Initial Convergence (0-100 iterations): Rapid loss decrease, basic 3D structure emerges
Detail Refinement (100-500 iterations): Fine 3D features develop progressively
Final Convergence (500+ iterations): Slow improvement, potential overfitting risk
Challenges in 3D Reconstruction
Cone Beam Artifacts: Increased artifacts for large cone angles in 3D
Incomplete Sampling: Missing data in certain regions of 3D Fourier space
Computational Cost: Orders of magnitude higher than 2D reconstruction
Memory Limitations: Large volumes may exceed available GPU memory
Convergence Complexity: Higher-dimensional optimization landscape
Applications
3D cone beam iterative reconstruction is essential for:
Medical CBCT: Dental, orthopedic, and interventional imaging
Industrial CT: Non-destructive testing and quality control
Micro-CT: High-resolution imaging of small specimens and materials
Security Screening: Advanced baggage and cargo inspection systems
Code Example
1import math
2import torch
3import numpy as np
4import matplotlib.pyplot as plt
5import torch.nn as nn
6import torch.optim as optim
7from diffct.differentiable import ConeProjectorFunction
8
9def shepp_logan_3d(shape):
10 zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
11 xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
12 yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
13 zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
14 el_params = np.array([
15 [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
16 [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
17 [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
18 [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
19 [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
20 [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
21 [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
22 [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
23 [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
24 [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
25 ], dtype=np.float32)
26
27 # Extract parameters for vectorization
28 x_pos = el_params[:, 0][:, None, None, None]
29 y_pos = el_params[:, 1][:, None, None, None]
30 z_pos = el_params[:, 2][:, None, None, None]
31 a_axis = el_params[:, 3][:, None, None, None]
32 b_axis = el_params[:, 4][:, None, None, None]
33 c_axis = el_params[:, 5][:, None, None, None]
34 phi = el_params[:, 6][:, None, None, None]
35 val = el_params[:, 9][:, None, None, None]
36
37 # Broadcast grid to ellipsoid axis
38 xc = xx[None, ...] - x_pos
39 yc = yy[None, ...] - y_pos
40 zc = zz[None, ...] - z_pos
41
42 c = np.cos(phi)
43 s = np.sin(phi)
44
45 # Only rotation around z, so can vectorize:
46 xp = c * xc - s * yc
47 yp = s * xc + c * yc
48 zp = zc
49
50 mask = (
51 (xp ** 2) / (a_axis ** 2)
52 + (yp ** 2) / (b_axis ** 2)
53 + (zp ** 2) / (c_axis ** 2)
54 <= 1.0
55 )
56
57 # Use broadcasting to sum all ellipsoid contributions
58 shepp_logan = np.sum(mask * val, axis=0)
59 shepp_logan = np.clip(shepp_logan, 0, 1)
60 return shepp_logan
61
62class IterativeRecoModel(nn.Module):
63 def __init__(self, volume_shape, angles, det_u, det_v, du, dv, sdd, sid, voxel_spacing):
64 super().__init__()
65 self.reco = nn.Parameter(torch.zeros(volume_shape))
66 self.angles = angles
67 self.det_u = det_u
68 self.det_v = det_v
69 self.du = du
70 self.dv = dv
71 self.sdd = sdd
72 self.sid = sid
73 self.relu = nn.ReLU() # non negative constraint
74 self.voxel_spacing = voxel_spacing
75
76 def forward(self, x):
77 updated_reco = x + self.reco
78 current_sino = ConeProjectorFunction.apply(updated_reco,
79 self.angles,
80 self.det_u, self.det_v,
81 self.du, self.dv,
82 self.sdd, self.sid, self.voxel_spacing)
83 return current_sino, self.relu(updated_reco)
84
85class Pipeline:
86 def __init__(self, lr, volume_shape, angles,
87 det_u, det_v, du, dv,
88 sdd, sid, voxel_spacing,
89 device, epoches=1000):
90
91 self.epoches = epoches
92 self.model = IterativeRecoModel(volume_shape, angles,
93 det_u, det_v, du, dv,
94 sdd, sid, voxel_spacing).to(device)
95
96 self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
97 self.loss = nn.MSELoss()
98
99 def train(self, input, label):
100 loss_values = []
101 for epoch in range(self.epoches):
102 self.optimizer.zero_grad()
103 predictions, current_reco = self.model(input)
104 loss_value = self.loss(predictions, label)
105 loss_value.backward()
106 self.optimizer.step()
107 loss_values.append(loss_value.item())
108
109 if epoch % 10 == 0:
110 print(f"Epoch {epoch}, Loss: {loss_value.item()}")
111
112 return loss_values, self.model
113
114def main():
115 Nx, Ny, Nz = 64, 64, 64
116 phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
117
118 num_views = 180
119 angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
120
121 det_u, det_v = 128, 128
122 du, dv = 1.0, 1.0
123 voxel_spacing = 1.0
124 sdd = 600.0
125 sid = 400.0
126
127 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
128 phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32).contiguous()
129
130 # Generate the "real" sinogram
131 angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
132 real_sinogram = ConeProjectorFunction.apply(phantom_torch, angles_torch,
133 det_u, det_v, du, dv,
134 sdd, sid, voxel_spacing)
135
136 pipeline_instance = Pipeline(lr=1e-1,
137 volume_shape=(Nz,Ny,Nx),
138 angles=angles_torch,
139 det_u=det_u, det_v=det_v,
140 du=du, dv=dv, voxel_spacing=voxel_spacing,
141 sdd=sdd,
142 sid=sid,
143 device=device, epoches=1000)
144
145 ini_guess = torch.zeros_like(phantom_torch)
146
147 loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
148
149 reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
150
151 plt.figure()
152 plt.plot(loss_values)
153 plt.title("Loss Curve")
154 plt.xlabel("Epoch")
155 plt.ylabel("Loss")
156 plt.show()
157
158 mid_slice = Nz // 2
159 plt.figure(figsize=(12, 6))
160 plt.subplot(1, 2, 1)
161 plt.imshow(phantom_cpu[mid_slice, :, :], cmap="gray")
162 plt.title("Original Phantom Mid-Slice")
163 plt.axis("off")
164
165 plt.subplot(1, 2, 2)
166 plt.imshow(reco[mid_slice, :, :], cmap="gray")
167 plt.title("Reconstructed Mid-Slice")
168 plt.axis("off")
169 plt.show()
170
171if __name__ == "__main__":
172 main()