Cone Beam Iterative Reconstruction
This example demonstrates gradient-based iterative reconstruction for 3D cone beam CT using the differentiable ConeProjectorFunction from diffct.
Overview
3D cone beam iterative reconstruction extends optimization methods to full volumetric reconstruction. This example shows how to:
Formulate 3D cone beam CT reconstruction as a large-scale optimization problem
Handle the computational complexity of 3D forward and backward projections
Apply memory-efficient optimization strategies for volumetric data
Monitor convergence in high-dimensional parameter space
Mathematical Background
3D Cone Beam Iterative Formulation
The 3D reconstruction problem is formulated as:
where: - \(f(x,y,z)\) is the unknown 3D volume - \(A_{\text{cone}}\) is the cone beam forward projection operator - \(p(\phi, u, v)\) is the measured 2D projection data - \(R(f)\) is an optional 3D regularization term
3D Forward Projection
The cone beam forward projection integrates along rays through the 3D volume:
where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.
3D Gradient Computation
The gradient of the 3D loss function uses the cone beam backprojection operator:
where \(A_{\text{cone}}^T\) is the 3D cone beam backprojection operator (adjoint).
Computational Complexity
3D reconstruction presents significant computational challenges:
Memory Requirements: \(O(N^3)\) for volume storage vs \(O(N^2)\) for 2D images
Projection Data: \(O(N_{\phi} \times N_u \times N_v)\) 2D projections
Forward/Backward Operations: \(O(N^3 \times N_{\phi})\) computational complexity
Gradient Storage: Additional memory for automatic differentiation
Implementation Steps
3D Problem Setup: Define parameterized 3D volume as learnable tensor
Cone Beam Forward Model: Use ConeProjectorFunction for 2D projection prediction
Loss Computation: Calculate L2 distance between predicted and measured projections
3D Gradient Computation: Use automatic differentiation through cone beam operators
Memory-Efficient Optimization: Apply strategies to handle large 3D parameter space
Convergence Monitoring: Track loss and 3D reconstruction quality
Model Architecture
The 3D reconstruction model consists of:
Parameterized Volume: Learnable 3D tensor representing the unknown volume
Cone Beam Forward Model: ConeProjectorFunction with 3D geometry parameters
Loss Function: Mean squared error between predicted and measured 2D projections
3D Regularization Options
Common 3D regularization terms:
3D Total Variation: \(R_{\text{TV}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2\)
3D Smoothness: \(R_{\text{smooth}}(f) = \sum_{x,y,z} \|\nabla f(x,y,z)\|_2^2\)
L1 Sparsity: \(R_{\text{L1}}(f) = \sum_{x,y,z} |f(x,y,z)|\)
Memory Management Strategies
3D reconstruction requires careful memory management:
Gradient Checkpointing: Trade computation for memory in backpropagation
Mixed Precision: Use float16 when possible to reduce memory usage
Batch Processing: Process volume slices when memory is extremely limited
Efficient Data Layout: Optimize tensor storage and access patterns
Convergence Characteristics
3D cone beam reconstruction typically exhibits:
Initial Convergence (0-100 iterations): Rapid loss decrease, basic 3D structure emerges
Detail Refinement (100-500 iterations): Fine 3D features develop progressively
Final Convergence (500+ iterations): Slow improvement, potential overfitting risk
Challenges in 3D Reconstruction
Cone Beam Artifacts: Increased artifacts for large cone angles in 3D
Incomplete Sampling: Missing data in certain regions of 3D Fourier space
Computational Cost: Orders of magnitude higher than 2D reconstruction
Memory Limitations: Large volumes may exceed available GPU memory
Convergence Complexity: Higher-dimensional optimization landscape
Applications
3D cone beam iterative reconstruction is essential for:
Medical CBCT: Dental, orthopedic, and interventional imaging
Industrial CT: Non-destructive testing and quality control
Micro-CT: High-resolution imaging of small specimens and materials
Security Screening: Advanced baggage and cargo inspection systems
Code Example
1import math
2import torch
3import numpy as np
4import matplotlib.pyplot as plt
5import torch.nn as nn
6import torch.optim as optim
7from diffct.differentiable import ConeProjectorFunction
8
9def shepp_logan_3d(shape):
10 zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
11 xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
12 yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
13 zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
14 el_params = np.array([
15 [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
16 [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
17 [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
18 [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
19 [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
20 [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
21 [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
22 [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
23 [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
24 [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
25 ], dtype=np.float32)
26
27 # Extract parameters for vectorization
28 x_pos = el_params[:, 0][:, None, None, None]
29 y_pos = el_params[:, 1][:, None, None, None]
30 z_pos = el_params[:, 2][:, None, None, None]
31 a_axis = el_params[:, 3][:, None, None, None]
32 b_axis = el_params[:, 4][:, None, None, None]
33 c_axis = el_params[:, 5][:, None, None, None]
34 phi = el_params[:, 6][:, None, None, None]
35 val = el_params[:, 9][:, None, None, None]
36
37 # Broadcast grid to ellipsoid axis
38 xc = xx[None, ...] - x_pos
39 yc = yy[None, ...] - y_pos
40 zc = zz[None, ...] - z_pos
41
42 c = np.cos(phi)
43 s = np.sin(phi)
44
45 # Only rotation around z, so can vectorize:
46 xp = c * xc - s * yc
47 yp = s * xc + c * yc
48 zp = zc
49
50 mask = (
51 (xp ** 2) / (a_axis ** 2)
52 + (yp ** 2) / (b_axis ** 2)
53 + (zp ** 2) / (c_axis ** 2)
54 <= 1.0
55 )
56
57 # Use broadcasting to sum all ellipsoid contributions
58 shepp_logan = np.sum(mask * val, axis=0)
59 shepp_logan = np.clip(shepp_logan, 0, 1)
60 return shepp_logan
61
62class IterativeRecoModel(nn.Module):
63 def __init__(self, volume_shape, angles, det_u, det_v, du, dv, sdd, sid,
64 voxel_spacing, backend="siddon"):
65 super().__init__()
66 self.reco = nn.Parameter(torch.zeros(volume_shape))
67 self.angles = angles
68 self.det_u = det_u
69 self.det_v = det_v
70 self.du = du
71 self.dv = dv
72 self.sdd = sdd
73 self.sid = sid
74 self.relu = nn.ReLU() # non negative constraint
75 self.voxel_spacing = voxel_spacing
76 self.backend = backend
77
78 def forward(self, x):
79 updated_reco = x + self.reco
80 # ``backend`` is the last positional arg to
81 # ``ConeProjectorFunction.apply`` so we also pass the five default
82 # offsets (two detector offsets + three centre offsets) to line
83 # up with the signature.
84 current_sino = ConeProjectorFunction.apply(
85 updated_reco,
86 self.angles,
87 self.det_u, self.det_v,
88 self.du, self.dv,
89 self.sdd, self.sid,
90 self.voxel_spacing,
91 0.0, 0.0, # detector_offset_u, detector_offset_v
92 0.0, 0.0, 0.0, # center_offset_x, y, z
93 self.backend,
94 )
95 return current_sino, self.relu(updated_reco)
96
97class Pipeline:
98 def __init__(self, lr, volume_shape, angles,
99 det_u, det_v, du, dv,
100 sdd, sid, voxel_spacing,
101 device, epoches=1000, backend="siddon"):
102 self.epoches = epoches
103 self.model = IterativeRecoModel(volume_shape, angles,
104 det_u, det_v, du, dv,
105 sdd, sid, voxel_spacing,
106 backend=backend).to(device)
107
108 self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
109 self.loss = nn.MSELoss()
110
111 def train(self, input, label):
112 loss_values = []
113 for epoch in range(self.epoches):
114 self.optimizer.zero_grad()
115 predictions, current_reco = self.model(input)
116 loss_value = self.loss(predictions, label)
117 loss_value.backward()
118 self.optimizer.step()
119 loss_values.append(loss_value.item())
120
121 if epoch % 10 == 0:
122 print(f"Epoch {epoch}, Loss: {loss_value.item()}")
123
124 return loss_values, self.model
125
126def main():
127 Nx, Ny, Nz = 64, 64, 64
128 phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
129
130 num_views = 180
131 angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
132
133 det_u, det_v = 128, 128
134 du, dv = 1.0, 1.0
135 voxel_spacing = 1.0
136 sdd = 600.0
137 sid = 400.0
138
139 # Forward projector backend used for BOTH the ground-truth sinogram
140 # and the inner iterative loop. Using the same backend on both sides
141 # guarantees the loop is solving its own consistent inverse problem
142 # and that the adjoint returned by autograd matches the forward
143 # byte-for-byte (matched scatter/gather kernel pair, verified by
144 # tests/test_adjoint_inner_product.py). Options:
145 #
146 # "siddon" - 3D ray-driven Siddon with trilinear
147 # interpolation. Fastest per-iteration step.
148 # Good default when iteration count is the
149 # bottleneck.
150 # "sf_tr" - 3D SF with trapezoidal transaxial and
151 # rectangular axial footprint. Mass-
152 # conserving per voxel, closed-form cell
153 # integral. ~2x slower forward than siddon.
154 # "sf_tt" - 3D SF with trapezoidal footprint in BOTH
155 # directions; the axial trapezoid captures
156 # the variation of axial magnification across
157 # the voxel by using ``U_near`` and ``U_far``
158 # corner projections. Strictly more expressive
159 # than SF-TR at ~1.4x the SF-TR cost. Useful
160 # for large cone angles and for research into
161 # the full Long et al. separable-footprint
162 # model.
163 projector_backend = "sf_tr"
164
165 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166 phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32).contiguous()
167
168 # Generate the "real" sinogram with the same backend as the inner
169 # loop, so reconstruction targets what the loop can actually produce.
170 angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
171 real_sinogram = ConeProjectorFunction.apply(
172 phantom_torch, angles_torch,
173 det_u, det_v, du, dv,
174 sdd, sid, voxel_spacing,
175 0.0, 0.0, # detector_offset_u, detector_offset_v
176 0.0, 0.0, 0.0, # center_offset_x, y, z
177 projector_backend,
178 )
179
180 pipeline_instance = Pipeline(lr=1e-1,
181 volume_shape=(Nz,Ny,Nx),
182 angles=angles_torch,
183 det_u=det_u, det_v=det_v,
184 du=du, dv=dv, voxel_spacing=voxel_spacing,
185 sdd=sdd,
186 sid=sid,
187 device=device, epoches=1000,
188 backend=projector_backend)
189
190 ini_guess = torch.zeros_like(phantom_torch)
191
192 loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
193
194 reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
195
196 plt.figure()
197 plt.plot(loss_values)
198 plt.title("Loss Curve")
199 plt.xlabel("Epoch")
200 plt.ylabel("Loss")
201 plt.show()
202
203 mid_slice = Nz // 2
204 plt.figure(figsize=(12, 6))
205 plt.subplot(1, 2, 1)
206 plt.imshow(phantom_cpu[mid_slice, :, :], cmap="gray")
207 plt.title("Original Phantom Mid-Slice")
208 plt.axis("off")
209
210 plt.subplot(1, 2, 2)
211 plt.imshow(reco[mid_slice, :, :], cmap="gray")
212 plt.title("Reconstructed Mid-Slice")
213 plt.axis("off")
214 plt.show()
215
216if __name__ == "__main__":
217 main()