Fan Beam Iterative Reconstruction
This example demonstrates 2D fan beam iterative reconstruction using the differentiable FanProjectorFunction and FanBackprojectorFunction from diffct.
Overview
Fan beam iterative reconstruction extends the optimization approach to the more realistic fan beam geometry. This example shows how to:
Formulate fan beam CT reconstruction as an optimization problem
Handle geometric complexities of divergent ray geometry
Apply gradient-based optimization with fan beam operators
Monitor convergence and reconstruction quality
Mathematical Background
Fan Beam Iterative Formulation
The fan beam reconstruction problem is formulated as:
where \(A_{\text{fan}}\) is the fan beam forward projection operator accounting for divergent ray geometry.
Fan Beam Forward Model
The fan beam projection operator maps 2D image \(f(x,y)\) to sinogram \(p(\beta, u)\):
where integration follows the ray from point source to detector element \(u\) at source angle \(\beta\).
Gradient Computation
The gradient involves the fan beam backprojection operator (adjoint):
where \(A_{\text{fan}}^T\) is the fan beam backprojection operator.
Geometric Considerations
Fan beam geometry introduces complexities compared to parallel beam:
Ray Divergence: Non-parallel rays affect sampling density and conditioning
Magnification Effects: Variable magnification across the field of view
Non-uniform Resolution: Spatial resolution varies with distance from rotation center
Geometric Distortion: Requires careful handling of coordinate transformations
Implementation Steps
Geometry Setup: Configure fan beam parameters (SID, SDD)
Problem Formulation: Define parameterized image and fan beam forward model
Loss Computation: Calculate L2 distance using FanProjectorFunction
Gradient Computation: Use automatic differentiation through fan beam operators
Optimization: Apply Adam optimizer with appropriate learning rate
Convergence Monitoring: Track reconstruction quality and loss evolution
Model Architecture
The fan beam reconstruction model consists of:
Parameterized Image: Learnable 2D tensor representing the unknown image
Fan Beam Forward Model: FanProjectorFunction with geometric parameters
Loss Function: Mean squared error between predicted and measured sinograms
Convergence Characteristics
Fan beam reconstruction typically exhibits:
Initial Convergence (0-100 iterations): Rapid loss decrease, basic structure
Detail Refinement (100-500 iterations): Fine features develop, slower progress
Final Convergence (500+ iterations): Minimal improvement, convergence plateau
Challenges and Solutions
Conditioning: Fan beam system matrix may have different conditioning properties
Geometric Artifacts: Proper weighting and filtering help reduce artifacts
Parameter Tuning: Learning rate may need adjustment for optimal convergence
Memory Usage: Similar to parallel beam but with additional geometric computations
1import math
2import torch
3import numpy as np
4import matplotlib.pyplot as plt
5import torch.nn as nn
6import torch.optim as optim
7from diffct.differentiable import FanProjectorFunction
8
9
10def shepp_logan_2d(Nx, Ny):
11 Nx = int(Nx)
12 Ny = int(Ny)
13 phantom = np.zeros((Ny, Nx), dtype=np.float32)
14 ellipses = [
15 (0.0, 0.0, 0.69, 0.92, 0, 1.0),
16 (0.0, -0.0184, 0.6624, 0.8740, 0, -0.8),
17 (0.22, 0.0, 0.11, 0.31, -18.0, -0.8),
18 (-0.22, 0.0, 0.16, 0.41, 18.0, -0.8),
19 (0.0, 0.35, 0.21, 0.25, 0, 0.7),
20 ]
21 cx = (Nx - 1)*0.5
22 cy = (Ny - 1)*0.5
23 for ix in range(Nx):
24 for iy in range(Ny):
25 xnorm = (ix - cx)/(Nx/2)
26 ynorm = (iy - cy)/(Ny/2)
27 val = 0.0
28 for (x0, y0, a, b, angdeg, ampl) in ellipses:
29 th = np.deg2rad(angdeg)
30 xprime = (xnorm - x0)*np.cos(th) + (ynorm - y0)*np.sin(th)
31 yprime = -(xnorm - x0)*np.sin(th) + (ynorm - y0)*np.cos(th)
32 if xprime*xprime/(a*a) + yprime*yprime/(b*b) <= 1.0:
33 val += ampl
34 phantom[iy, ix] = val
35 phantom = np.clip(phantom, 0.0, 1.0)
36 return phantom
37
38class IterativeRecoModel(nn.Module):
39 def __init__(self, volume_shape, angles,
40 num_detectors, detector_spacing,
41 sdd, sid, voxel_spacing,
42 backend="siddon"):
43 super().__init__()
44 self.reco = nn.Parameter(torch.zeros(volume_shape))
45 self.angles = angles
46 self.num_detectors = num_detectors
47 self.detector_spacing = detector_spacing
48 self.sdd = sdd
49 self.sid = sid
50 self.voxel_spacing = voxel_spacing
51 self.backend = backend
52
53 def forward(self, x):
54 updated_reco = x + self.reco
55 # ``backend`` is the last positional argument to
56 # ``FanProjectorFunction.apply`` so we must also pass the three
57 # default offsets (detector_offset / center_offset_x / center_offset_y)
58 # to line up with the signature.
59 current_sino = FanProjectorFunction.apply(
60 updated_reco,
61 self.angles,
62 self.num_detectors,
63 self.detector_spacing,
64 self.sdd,
65 self.sid,
66 self.voxel_spacing,
67 0.0, # detector_offset
68 0.0, # center_offset_x
69 0.0, # center_offset_y
70 self.backend,
71 )
72 return current_sino, updated_reco
73
74class Pipeline:
75 def __init__(self, lr, volume_shape, angles,
76 num_detectors, detector_spacing,
77 sdd, sid, voxel_spacing,
78 device, epoches=1000, backend="siddon"):
79 self.epoches = epoches
80 self.model = IterativeRecoModel(volume_shape, angles,
81 num_detectors, detector_spacing,
82 sdd, sid, voxel_spacing,
83 backend=backend).to(device)
84
85 self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
86 self.loss = nn.MSELoss()
87
88 def train(self, input, label):
89 loss_values = []
90 for epoch in range(self.epoches):
91 self.optimizer.zero_grad()
92 predictions, current_reco = self.model(input)
93 loss_value = self.loss(predictions, label)
94 loss_value.backward()
95 self.optimizer.step()
96 with torch.no_grad():
97 self.model.reco.clamp_(min=0.0)
98 loss_values.append(loss_value.item())
99
100 if epoch % 10 == 0:
101 print(f"Epoch {epoch}, Loss: {loss_value.item()}")
102
103 return loss_values, self.model
104
105def main():
106 Nx, Ny = 128, 128
107 phantom_cpu = shepp_logan_2d(Nx, Ny)
108
109 num_views = 360
110 angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
111
112 num_detectors = 256
113 detector_spacing = 1.0
114 voxel_spacing = 1.0
115 sdd = 600.0
116 sid = 400.0
117
118 # Forward projector backend used for BOTH the ground-truth sinogram
119 # and the inner iterative loop. Using the same backend on both sides
120 # is the cleanest setup: the loop then solves its own exact inverse
121 # problem, and the adjoint returned by autograd matches the forward
122 # byte-for-byte (guaranteed by the matched scatter/gather kernel
123 # pair, verified by tests/test_adjoint_inner_product.py). Options:
124 #
125 # "siddon" - ray-driven cell-constant Siddon. Fastest.
126 # Good default when you want the shortest
127 # iteration step.
128 # "sf" - voxel-driven separable-footprint projector
129 # (Long et al. SF-TR). Mass-conserving per
130 # voxel, closed-form cell integral, ~3x
131 # slower than siddon. Worth trying when you
132 # want a physically-principled cell-
133 # integrated forward model and don't mind
134 # the per-iteration cost.
135 projector_backend = "siddon"
136
137 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
138 phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32)
139 angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
140
141 # Generate the "real" sinogram with the same backend as the inner
142 # loop, so reconstruction targets what the loop can actually produce.
143 real_sinogram = FanProjectorFunction.apply(
144 phantom_torch, angles_torch,
145 num_detectors, detector_spacing,
146 sdd, sid, voxel_spacing,
147 0.0, 0.0, 0.0, # detector_offset, center_offset_x, center_offset_y
148 projector_backend,
149 )
150
151 pipeline_instance = Pipeline(lr=1e-1,
152 volume_shape=(Ny,Nx),
153 angles=angles_torch,
154 num_detectors=num_detectors,
155 detector_spacing=detector_spacing,
156 sdd=sdd, voxel_spacing=voxel_spacing,
157 sid=sid,
158 device=device, epoches=1000,
159 backend=projector_backend)
160
161 ini_guess = torch.zeros_like(phantom_torch)
162
163 loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
164
165 reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
166
167 plt.figure()
168 plt.plot(loss_values)
169 plt.title("Loss Curve")
170 plt.xlabel("Epoch")
171 plt.ylabel("Loss")
172 plt.show()
173
174 plt.figure(figsize=(12, 6))
175 plt.subplot(1, 2, 1)
176 plt.imshow(phantom_cpu, cmap="gray")
177 plt.title("Original Phantom")
178 plt.axis("off")
179
180 plt.subplot(1, 2, 2)
181 plt.imshow(reco, cmap="gray")
182 plt.title("Reconstructed")
183 plt.axis("off")
184 plt.show()
185
186if __name__ == "__main__":
187 main()