Parallel Beam Iterative Reconstruction
This example demonstrates 2D parallel beam iterative reconstruction using the differentiable ParallelProjectorFunction and ParallelBackprojectorFunction from diffct.
Overview
Parallel beam iterative reconstruction solves the CT inverse problem through optimization, offering advantages over analytical methods like FBP. This example shows how to:
Formulate parallel beam CT reconstruction as an optimization problem
Use automatic differentiation for gradient computation
Apply gradient-based optimization with parallel beam operators
Monitor convergence and reconstruction quality
Mathematical Background
Parallel Beam Iterative Formulation
The parallel beam reconstruction problem is formulated as:
where: - \(f\) is the unknown 2D image - \(A_{\text{parallel}}\) is the parallel beam forward projection operator (Radon transform) - \(p\) is the measured sinogram data - \(R(f)\) is an optional regularization term - \(\lambda\) is the regularization parameter
Gradient-Based Optimization
The gradient of the data fidelity term is computed using the adjoint operator:
where \(A_{\text{parallel}}^T\) is the parallel beam backprojection operator (adjoint of the forward projector).
Automatic Differentiation
PyTorch’s automatic differentiation computes gradients through the differentiable operators:
This enables seamless integration with advanced optimizers like Adam.
Adam Optimizer
The Adam optimizer adapts learning rates using gradient statistics:
where \(m^{(k)}\) and \(v^{(k)}\) are biased first and second moment estimates.
Implementation Steps
Problem Setup: Define parameterized 2D image as learnable tensor
Forward Model: Compute predicted sinogram using ParallelProjectorFunction
Loss Computation: Calculate L2 distance between predicted and measured data
Gradient Computation: Use automatic differentiation for gradient calculation
Parameter Update: Apply Adam optimizer for iterative improvement
Convergence Monitoring: Track loss and reconstruction quality
Model Architecture
The reconstruction model consists of:
Parameterized Image: Learnable 2D tensor representing the unknown image
Forward Projection: ParallelProjectorFunction for sinogram prediction
Loss Function: Mean squared error between predicted and measured sinograms
Advantages of Iterative Methods
Noise Robustness: Superior handling of noisy measurements
Regularization: Natural incorporation of prior knowledge
Incomplete Data: Effective with limited-angle or sparse-view acquisitions
Flexibility: Easy modification of cost functions and constraints
Artifact Reduction: Better control over reconstruction artifacts
Convergence Characteristics
Typical convergence behavior:
Initial Phase (0-100 iterations): Rapid loss decrease, basic structure emerges
Refinement Phase (100-500 iterations): Fine details develop, slower convergence
Convergence Phase (500+ iterations): Minimal improvement, potential overfitting
1import math
2import torch
3import numpy as np
4import matplotlib.pyplot as plt
5import torch.nn as nn
6import torch.optim as optim
7from diffct.differentiable import ParallelProjectorFunction
8
9
10def shepp_logan_2d(Nx, Ny):
11 Nx = int(Nx)
12 Ny = int(Ny)
13 phantom = np.zeros((Ny, Nx), dtype=np.float32)
14 ellipses = [
15 (0.0, 0.0, 0.69, 0.92, 0, 1.0),
16 (0.0, -0.0184, 0.6624, 0.8740, 0, -0.8),
17 (0.22, 0.0, 0.11, 0.31, -18.0, -0.8),
18 (-0.22, 0.0, 0.16, 0.41, 18.0, -0.8),
19 (0.0, 0.35, 0.21, 0.25, 0, 0.7),
20 ]
21 cx = (Nx - 1)*0.5
22 cy = (Ny - 1)*0.5
23 for ix in range(Nx):
24 for iy in range(Ny):
25 xnorm = (ix - cx)/(Nx/2)
26 ynorm = (iy - cy)/(Ny/2)
27 val = 0.0
28 for (x0, y0, a, b, angdeg, ampl) in ellipses:
29 th = np.deg2rad(angdeg)
30 xprime = (xnorm - x0)*np.cos(th) + (ynorm - y0)*np.sin(th)
31 yprime = -(xnorm - x0)*np.sin(th) + (ynorm - y0)*np.cos(th)
32 if xprime*xprime/(a*a) + yprime*yprime/(b*b) <= 1.0:
33 val += ampl
34 phantom[iy, ix] = val
35 phantom = np.clip(phantom, 0.0, 1.0)
36 return phantom
37
38class IterativeRecoModel(nn.Module):
39 def __init__(self, volume_shape, angles, num_detectors, detector_spacing, voxel_spacing):
40 super().__init__()
41 self.reco = nn.Parameter(torch.zeros(volume_shape))
42 self.angles = angles
43 self.num_detectors = num_detectors
44 self.detector_spacing = detector_spacing
45 self.voxel_spacing = voxel_spacing
46
47 def forward(self, x):
48 updated_reco = x + self.reco
49 current_sino = ParallelProjectorFunction.apply(updated_reco, self.angles,
50 self.num_detectors, self.detector_spacing, self.voxel_spacing)
51 return current_sino, updated_reco
52
53class Pipeline:
54 def __init__(self, lr, volume_shape, angles, num_detectors, detector_spacing,
55 voxel_spacing, device, epoches=1000):
56
57 self.epoches = epoches
58 self.model = IterativeRecoModel(volume_shape, angles, num_detectors,
59 detector_spacing, voxel_spacing).to(device)
60
61 self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
62 self.loss = nn.MSELoss()
63
64 def train(self, input, label):
65 loss_values = []
66 for epoch in range(self.epoches):
67 self.optimizer.zero_grad()
68 predictions, current_reco = self.model(input)
69 loss_value = self.loss(predictions, label)
70 loss_value.backward()
71 self.optimizer.step()
72 with torch.no_grad():
73 self.model.reco.clamp_(min=0.0)
74 loss_values.append(loss_value.item())
75
76 if epoch % 10 == 0:
77 print(f"Epoch {epoch}, Loss: {loss_value.item()}")
78
79 return loss_values, self.model
80
81def main():
82 Nx, Ny = 128, 128
83 phantom_cpu = shepp_logan_2d(Nx, Ny)
84
85 num_views = 360
86 angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
87
88 num_detectors = 256
89 detector_spacing = 0.5
90 voxel_spacing = 1.0
91
92 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
93 phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32)
94 angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
95
96 # Generate the "real" sinogram
97 real_sinogram = ParallelProjectorFunction.apply(phantom_torch, angles_torch,
98 num_detectors, detector_spacing, voxel_spacing)
99
100 pipeline_instance = Pipeline(lr=1e-1,
101 volume_shape=(Ny, Nx),
102 angles=angles_torch,
103 num_detectors=num_detectors,
104 detector_spacing=detector_spacing,
105 voxel_spacing=voxel_spacing,
106 device=device, epoches=1000)
107
108 ini_guess = torch.zeros_like(phantom_torch)
109
110 loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
111
112 reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
113
114 plt.figure()
115 plt.plot(loss_values)
116 plt.title("Loss Curve")
117 plt.xlabel("Epoch")
118 plt.ylabel("Loss")
119 plt.show()
120
121 plt.figure(figsize=(12, 6))
122 plt.subplot(1, 2, 1)
123 plt.imshow(phantom_cpu, cmap="gray")
124 plt.title("Original Phantom")
125 plt.axis("off")
126
127 plt.subplot(1, 2, 2)
128 plt.imshow(reco, cmap="gray")
129 plt.title("Reconstructed")
130 plt.axis("off")
131 plt.show()
132
133if __name__ == "__main__":
134 main()