Cone Beam Iterative Reconstruction
This example demonstrates gradient-based iterative reconstruction for 3D cone beam CT using the differentiable ConeProjectorFunction from diffct.
Overview
3D cone beam iterative reconstruction extends optimization methods to full volumetric reconstruction. This example shows how to:
Formulate 3D cone beam CT reconstruction as a large-scale optimization problem
Handle the computational complexity of 3D forward and backward projections
Apply memory-efficient optimization strategies for volumetric data
Monitor convergence in high-dimensional parameter space
Mathematical Background
3D Cone Beam Iterative Formulation
The 3D reconstruction problem is formulated as:
where: - \(f(x,y,z)\) is the unknown 3D volume - \(A_{\text{cone}}\) is the cone beam forward projection operator - \(p(\phi, u, v)\) is the measured 2D projection data - \(R(f)\) is an optional 3D regularization term
3D Forward Projection
The cone beam forward projection integrates along rays through the 3D volume:
where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.
3D Gradient Computation
The gradient of the 3D loss function uses the cone beam backprojection operator:
where \(A_{\text{cone}}^T\) is the 3D cone beam backprojection operator (adjoint).
Computational Complexity
3D reconstruction presents significant computational challenges:
Memory Requirements: \(O(N^3)\) for volume storage vs \(O(N^2)\) for 2D images
Projection Data: \(O(N_{\phi} \times N_u \times N_v)\) 2D projections
Forward/Backward Operations: \(O(N^3 \times N_{\phi})\) computational complexity
Gradient Storage: Additional memory for automatic differentiation
Implementation Steps
3D Problem Setup: Define parameterized 3D volume as learnable tensor
Cone Beam Forward Model: Use ConeProjectorFunction for 2D projection prediction
Loss Computation: Calculate L2 distance between predicted and measured projections
3D Gradient Computation: Use automatic differentiation through cone beam operators
Memory-Efficient Optimization: Apply strategies to handle large 3D parameter space
Convergence Monitoring: Track loss and 3D reconstruction quality
Model Architecture
The 3D reconstruction model consists of:
Parameterized Volume: Learnable 3D tensor representing the unknown volume
Cone Beam Forward Model: ConeProjectorFunction with 3D geometry parameters
Loss Function: Mean squared error between predicted and measured 2D projections
3D Regularization Options
Common 3D regularization terms:
3D Total Variation: \(R_{\text{TV}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2\)
3D Smoothness: \(R_{\text{smooth}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2^2\)
L1 Sparsity: \(R_{\text{L1}}(f) = \sum_{x,y,z} |f(x,y,z)|\)
Memory Management Strategies
3D reconstruction requires careful memory management:
Gradient Checkpointing: Trade computation for memory in backpropagation
Mixed Precision: Use float16 when possible to reduce memory usage
Batch Processing: Process volume slices when memory is extremely limited
Efficient Data Layout: Optimize tensor storage and access patterns
Convergence Characteristics
3D cone beam reconstruction typically exhibits:
Initial Convergence (0-100 iterations): Rapid loss decrease, basic 3D structure emerges
Detail Refinement (100-500 iterations): Fine 3D features develop progressively
Final Convergence (500+ iterations): Slow improvement, potential overfitting risk
Challenges in 3D Reconstruction
Cone Beam Artifacts: Increased artifacts for large cone angles in 3D
Incomplete Sampling: Missing data in certain regions of 3D Fourier space
Computational Cost: Orders of magnitude higher than 2D reconstruction
Memory Limitations: Large volumes may exceed available GPU memory
Convergence Complexity: Higher-dimensional optimization landscape
Applications
3D cone beam iterative reconstruction is essential for:
Medical CBCT: Dental, orthopedic, and interventional imaging
Industrial CT: Non-destructive testing and quality control
Micro-CT: High-resolution imaging of small specimens and materials
Security Screening: Advanced baggage and cargo inspection systems
Code Example
1import math
2import torch
3import numpy as np
4import matplotlib.pyplot as plt
5import torch.nn as nn
6import torch.optim as optim
7from diffct.differentiable import ConeProjectorFunction
8
9def shepp_logan_3d(shape):
10 zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
11 xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
12 yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
13 zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
14 el_params = np.array([
15 [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
16 [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
17 [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
18 [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
19 [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
20 [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
21 [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
22 [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
23 [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
24 [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
25 ], dtype=np.float32)
26
27 # Extract parameters for vectorization
28 x_pos = el_params[:, 0][:, None, None, None]
29 y_pos = el_params[:, 1][:, None, None, None]
30 z_pos = el_params[:, 2][:, None, None, None]
31 a_axis = el_params[:, 3][:, None, None, None]
32 b_axis = el_params[:, 4][:, None, None, None]
33 c_axis = el_params[:, 5][:, None, None, None]
34 phi = el_params[:, 6][:, None, None, None]
35 val = el_params[:, 9][:, None, None, None]
36
37 # Broadcast grid to ellipsoid axis
38 xc = xx[None, ...] - x_pos
39 yc = yy[None, ...] - y_pos
40 zc = zz[None, ...] - z_pos
41
42 c = np.cos(phi)
43 s = np.sin(phi)
44
45 # Only rotation around z, so can vectorize:
46 xp = c * xc - s * yc
47 yp = s * xc + c * yc
48 zp = zc
49
50 mask = (
51 (xp ** 2) / (a_axis ** 2)
52 + (yp ** 2) / (b_axis ** 2)
53 + (zp ** 2) / (c_axis ** 2)
54 <= 1.0
55 )
56
57 # Use broadcasting to sum all ellipsoid contributions
58 shepp_logan = np.sum(mask * val, axis=0)
59 shepp_logan = np.clip(shepp_logan, 0, 1)
60 return shepp_logan
61
62class IterativeRecoModel(nn.Module):
63 def __init__(self, volume_shape, angles, det_u, det_v, du, dv, sdd, sid,
64 voxel_spacing, backend="siddon"):
65 super().__init__()
66 self.reco = nn.Parameter(torch.zeros(volume_shape))
67 self.angles = angles
68 self.det_u = det_u
69 self.det_v = det_v
70 self.du = du
71 self.dv = dv
72 self.sdd = sdd
73 self.sid = sid
74 self.voxel_spacing = voxel_spacing
75 self.backend = backend
76
77 def forward(self, x):
78 updated_reco = x + self.reco
79 # ``backend`` is the last positional arg to
80 # ``ConeProjectorFunction.apply`` so we also pass the five default
81 # offsets (two detector offsets + three centre offsets) to line
82 # up with the signature.
83 current_sino = ConeProjectorFunction.apply(
84 updated_reco,
85 self.angles,
86 self.det_u, self.det_v,
87 self.du, self.dv,
88 self.sdd, self.sid,
89 self.voxel_spacing,
90 0.0, 0.0, # detector_offset_u, detector_offset_v
91 0.0, 0.0, 0.0, # center_offset_x, y, z
92 self.backend,
93 )
94 return current_sino, updated_reco
95
96class Pipeline:
97 def __init__(self, lr, volume_shape, angles,
98 det_u, det_v, du, dv,
99 sdd, sid, voxel_spacing,
100 device, epoches=1000, backend="siddon"):
101 self.epoches = epoches
102 self.model = IterativeRecoModel(volume_shape, angles,
103 det_u, det_v, du, dv,
104 sdd, sid, voxel_spacing,
105 backend=backend).to(device)
106
107 self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
108 self.loss = nn.MSELoss()
109
110 def train(self, input, label):
111 loss_values = []
112 for epoch in range(self.epoches):
113 self.optimizer.zero_grad()
114 predictions, current_reco = self.model(input)
115 loss_value = self.loss(predictions, label)
116 loss_value.backward()
117 self.optimizer.step()
118 with torch.no_grad():
119 self.model.reco.clamp_(min=0.0)
120 loss_values.append(loss_value.item())
121
122 if epoch % 10 == 0:
123 print(f"Epoch {epoch}, Loss: {loss_value.item()}")
124
125 return loss_values, self.model
126
127def main():
128 Nx, Ny, Nz = 64, 64, 64
129 phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
130
131 num_views = 180
132 angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
133
134 det_u, det_v = 128, 128
135 du, dv = 1.0, 1.0
136 voxel_spacing = 1.0
137 sdd = 600.0
138 sid = 400.0
139
140 # Forward projector backend used for BOTH the ground-truth sinogram
141 # and the inner iterative loop. Using the same backend on both sides
142 # guarantees the loop is solving its own consistent inverse problem
143 # and that the adjoint returned by autograd matches the forward
144 # byte-for-byte (matched scatter/gather kernel pair, verified by
145 # tests/test_adjoint_inner_product.py). Options:
146 #
147 # "siddon" - 3D ray-driven cell-constant Siddon. Fastest
148 # per-iteration step. Good default when
149 # iteration count is the bottleneck.
150 # "sf_tr" - 3D SF with trapezoidal transaxial and
151 # rectangular axial footprint. Mass-
152 # conserving per voxel, closed-form cell
153 # integral. ~2x slower forward than siddon.
154 # "sf_tt" - 3D SF with trapezoidal footprint in BOTH
155 # directions; the axial trapezoid captures
156 # the variation of axial magnification across
157 # the voxel by using ``U_near`` and ``U_far``
158 # corner projections. Strictly more expressive
159 # than SF-TR at ~1.4x the SF-TR cost. Useful
160 # for large cone angles and for research into
161 # the full Long et al. separable-footprint
162 # model.
163 projector_backend = "sf_tr"
164
165 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166 phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32).contiguous()
167
168 # Generate the "real" sinogram with the same backend as the inner
169 # loop, so reconstruction targets what the loop can actually produce.
170 angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
171 real_sinogram = ConeProjectorFunction.apply(
172 phantom_torch, angles_torch,
173 det_u, det_v, du, dv,
174 sdd, sid, voxel_spacing,
175 0.0, 0.0, # detector_offset_u, detector_offset_v
176 0.0, 0.0, 0.0, # center_offset_x, y, z
177 projector_backend,
178 )
179
180 pipeline_instance = Pipeline(lr=1e-1,
181 volume_shape=(Nz,Ny,Nx),
182 angles=angles_torch,
183 det_u=det_u, det_v=det_v,
184 du=du, dv=dv, voxel_spacing=voxel_spacing,
185 sdd=sdd,
186 sid=sid,
187 device=device, epoches=1000,
188 backend=projector_backend)
189
190 ini_guess = torch.zeros_like(phantom_torch)
191
192 loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
193
194 reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
195
196 plt.figure()
197 plt.plot(loss_values)
198 plt.title("Loss Curve")
199 plt.xlabel("Epoch")
200 plt.ylabel("Loss")
201 plt.show()
202
203 mid_slice = Nz // 2
204 plt.figure(figsize=(12, 6))
205 plt.subplot(1, 2, 1)
206 plt.imshow(phantom_cpu[mid_slice, :, :], cmap="gray")
207 plt.title("Original Phantom Mid-Slice")
208 plt.axis("off")
209
210 plt.subplot(1, 2, 2)
211 plt.imshow(reco[mid_slice, :, :], cmap="gray")
212 plt.title("Reconstructed Mid-Slice")
213 plt.axis("off")
214 plt.show()
215
216if __name__ == "__main__":
217 main()