Parallel Beam Iterative Reconstruction
This example demonstrates 2D parallel beam iterative reconstruction using the differentiable ParallelProjectorFunction and ParallelBackprojectorFunction from diffct.
Overview
Parallel beam iterative reconstruction solves the CT inverse problem through optimization, offering advantages over analytical methods like FBP. This example shows how to:
Formulate parallel beam CT reconstruction as an optimization problem
Use automatic differentiation for gradient computation
Apply gradient-based optimization with parallel beam operators
Monitor convergence and reconstruction quality
Mathematical Background
Parallel Beam Iterative Formulation
The parallel beam reconstruction problem is formulated as:
where: - \(f\) is the unknown 2D image - \(A_{\text{parallel}}\) is the parallel beam forward projection operator (Radon transform) - \(p\) is the measured sinogram data - \(R(f)\) is an optional regularization term - \(\lambda\) is the regularization parameter
Gradient-Based Optimization
The gradient of the data fidelity term is computed using the adjoint operator:
where \(A_{\text{parallel}}^T\) is the parallel beam backprojection operator (adjoint of the forward projector).
Automatic Differentiation
PyTorch’s automatic differentiation computes gradients through the differentiable operators:
This enables seamless integration with advanced optimizers like Adam.
Adam Optimizer
The Adam optimizer adapts learning rates using gradient statistics:
where \(m^{(k)}\) and \(v^{(k)}\) are biased first and second moment estimates.
Implementation Steps
Problem Setup: Define parameterized 2D image as learnable tensor
Forward Model: Compute predicted sinogram using ParallelProjectorFunction
Loss Computation: Calculate L2 distance between predicted and measured data
Gradient Computation: Use automatic differentiation for gradient calculation
Parameter Update: Apply Adam optimizer for iterative improvement
Convergence Monitoring: Track loss and reconstruction quality
Model Architecture
The reconstruction model consists of:
Parameterized Image: Learnable 2D tensor representing the unknown image
Forward Projection: ParallelProjectorFunction for sinogram prediction
Loss Function: Mean squared error between predicted and measured sinograms
Advantages of Iterative Methods
Noise Robustness: Superior handling of noisy measurements
Regularization: Natural incorporation of prior knowledge
Incomplete Data: Effective with limited-angle or sparse-view acquisitions
Flexibility: Easy modification of cost functions and constraints
Artifact Reduction: Better control over reconstruction artifacts
Convergence Characteristics
Typical convergence behavior:
Initial Phase (0-100 iterations): Rapid loss decrease, basic structure emerges
Refinement Phase (100-500 iterations): Fine details develop, slower convergence
Convergence Phase (500+ iterations): Minimal improvement, potential overfitting
1import math
2import torch
3import numpy as np
4import matplotlib.pyplot as plt
5import torch.nn as nn
6import torch.optim as optim
7from diffct.differentiable import ParallelProjectorFunction
8
9
10def shepp_logan_2d(Nx, Ny):
11 Nx = int(Nx)
12 Ny = int(Ny)
13 phantom = np.zeros((Ny, Nx), dtype=np.float32)
14 ellipses = [
15 (0.0, 0.0, 0.69, 0.92, 0, 1.0),
16 (0.0, -0.0184, 0.6624, 0.8740, 0, -0.8),
17 (0.22, 0.0, 0.11, 0.31, -18.0, -0.8),
18 (-0.22, 0.0, 0.16, 0.41, 18.0, -0.8),
19 (0.0, 0.35, 0.21, 0.25, 0, 0.7),
20 ]
21 cx = (Nx - 1)*0.5
22 cy = (Ny - 1)*0.5
23 for ix in range(Nx):
24 for iy in range(Ny):
25 xnorm = (ix - cx)/(Nx/2)
26 ynorm = (iy - cy)/(Ny/2)
27 val = 0.0
28 for (x0, y0, a, b, angdeg, ampl) in ellipses:
29 th = np.deg2rad(angdeg)
30 xprime = (xnorm - x0)*np.cos(th) + (ynorm - y0)*np.sin(th)
31 yprime = -(xnorm - x0)*np.sin(th) + (ynorm - y0)*np.cos(th)
32 if xprime*xprime/(a*a) + yprime*yprime/(b*b) <= 1.0:
33 val += ampl
34 phantom[iy, ix] = val
35 phantom = np.clip(phantom, 0.0, 1.0)
36 return phantom
37
38class IterativeRecoModel(nn.Module):
39 def __init__(self, volume_shape, angles, num_detectors, detector_spacing, voxel_spacing):
40 super().__init__()
41 self.reco = nn.Parameter(torch.zeros(volume_shape))
42 self.angles = angles
43 self.num_detectors = num_detectors
44 self.detector_spacing = detector_spacing
45 self.voxel_spacing = voxel_spacing
46 self.relu = nn.ReLU() # non negative constraint
47
48 def forward(self, x):
49 updated_reco = x + self.reco
50 current_sino = ParallelProjectorFunction.apply(updated_reco, self.angles,
51 self.num_detectors, self.detector_spacing, self.voxel_spacing)
52 return current_sino, self.relu(updated_reco)
53
54class Pipeline:
55 def __init__(self, lr, volume_shape, angles, num_detectors, detector_spacing,
56 voxel_spacing, device, epoches=1000):
57
58 self.epoches = epoches
59 self.model = IterativeRecoModel(volume_shape, angles, num_detectors,
60 detector_spacing, voxel_spacing).to(device)
61
62 self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
63 self.loss = nn.MSELoss()
64
65 def train(self, input, label):
66 loss_values = []
67 for epoch in range(self.epoches):
68 self.optimizer.zero_grad()
69 predictions, current_reco = self.model(input)
70 loss_value = self.loss(predictions, label)
71 loss_value.backward()
72 self.optimizer.step()
73 loss_values.append(loss_value.item())
74
75 if epoch % 10 == 0:
76 print(f"Epoch {epoch}, Loss: {loss_value.item()}")
77
78 return loss_values, self.model
79
80def main():
81 Nx, Ny = 128, 128
82 phantom_cpu = shepp_logan_2d(Nx, Ny)
83
84 num_views = 360
85 angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
86
87 num_detectors = 256
88 detector_spacing = 0.5
89 voxel_spacing = 1.0
90
91 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92 phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32)
93 angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
94
95 # Generate the "real" sinogram
96 real_sinogram = ParallelProjectorFunction.apply(phantom_torch, angles_torch,
97 num_detectors, detector_spacing, voxel_spacing)
98
99 pipeline_instance = Pipeline(lr=1e-1,
100 volume_shape=(Ny, Nx),
101 angles=angles_torch,
102 num_detectors=num_detectors,
103 detector_spacing=detector_spacing,
104 voxel_spacing=voxel_spacing,
105 device=device, epoches=1000)
106
107 ini_guess = torch.zeros_like(phantom_torch)
108
109 loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
110
111 reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
112
113 plt.figure()
114 plt.plot(loss_values)
115 plt.title("Loss Curve")
116 plt.xlabel("Epoch")
117 plt.ylabel("Loss")
118 plt.show()
119
120 plt.figure(figsize=(12, 6))
121 plt.subplot(1, 2, 1)
122 plt.imshow(phantom_cpu, cmap="gray")
123 plt.title("Original Phantom")
124 plt.axis("off")
125
126 plt.subplot(1, 2, 2)
127 plt.imshow(reco, cmap="gray")
128 plt.title("Reconstructed")
129 plt.axis("off")
130 plt.show()
131
132if __name__ == "__main__":
133 main()