Parallel Beam Iterative Reconstruction

This example demonstrates 2D parallel beam iterative reconstruction using the differentiable ParallelProjectorFunction and ParallelBackprojectorFunction from diffct.

Overview

Parallel beam iterative reconstruction solves the CT inverse problem through optimization, offering advantages over analytical methods like FBP. This example shows how to:

  • Formulate parallel beam CT reconstruction as an optimization problem

  • Use automatic differentiation for gradient computation

  • Apply gradient-based optimization with parallel beam operators

  • Monitor convergence and reconstruction quality

Mathematical Background

Parallel Beam Iterative Formulation

The parallel beam reconstruction problem is formulated as:

\[\hat{f} = \arg\min_f \|A_{\text{parallel}}(f) - p\|_2^2 + \lambda R(f)\]

where: - \(f\) is the unknown 2D image - \(A_{\text{parallel}}\) is the parallel beam forward projection operator (Radon transform) - \(p\) is the measured sinogram data - \(R(f)\) is an optional regularization term - \(\lambda\) is the regularization parameter

Gradient-Based Optimization

The gradient of the data fidelity term is computed using the adjoint operator:

\[\nabla_f \|A_{\text{parallel}}(f) - p\|_2^2 = 2A_{\text{parallel}}^T(A_{\text{parallel}}(f) - p)\]

where \(A_{\text{parallel}}^T\) is the parallel beam backprojection operator (adjoint of the forward projector).

Automatic Differentiation

PyTorch’s automatic differentiation computes gradients through the differentiable operators:

\[\frac{\partial L}{\partial f} = \frac{\partial}{\partial f} \|A_{\text{parallel}}(f) - p_{\text{measured}}\|_2^2\]

This enables seamless integration with advanced optimizers like Adam.

Adam Optimizer

The Adam optimizer adapts learning rates using gradient statistics:

\[f^{(k+1)} = f^{(k)} - \alpha \cdot \frac{m^{(k)}}{1-\beta_1^k} \cdot \frac{1}{\sqrt{v^{(k)}/(1-\beta_2^k)} + \epsilon}\]

where \(m^{(k)}\) and \(v^{(k)}\) are biased first and second moment estimates.

Implementation Steps

  1. Problem Setup: Define parameterized 2D image as learnable tensor

  2. Forward Model: Compute predicted sinogram using ParallelProjectorFunction

  3. Loss Computation: Calculate L2 distance between predicted and measured data

  4. Gradient Computation: Use automatic differentiation for gradient calculation

  5. Parameter Update: Apply Adam optimizer for iterative improvement

  6. Convergence Monitoring: Track loss and reconstruction quality

Model Architecture

The reconstruction model consists of:

  • Parameterized Image: Learnable 2D tensor representing the unknown image

  • Forward Projection: ParallelProjectorFunction for sinogram prediction

  • Loss Function: Mean squared error between predicted and measured sinograms

Advantages of Iterative Methods

  • Noise Robustness: Superior handling of noisy measurements

  • Regularization: Natural incorporation of prior knowledge

  • Incomplete Data: Effective with limited-angle or sparse-view acquisitions

  • Flexibility: Easy modification of cost functions and constraints

  • Artifact Reduction: Better control over reconstruction artifacts

Convergence Characteristics

Typical convergence behavior:

  1. Initial Phase (0-100 iterations): Rapid loss decrease, basic structure emerges

  2. Refinement Phase (100-500 iterations): Fine details develop, slower convergence

  3. Convergence Phase (500+ iterations): Minimal improvement, potential overfitting

2D Parallel Beam Iterative Example
  1import math
  2import torch
  3import numpy as np
  4import matplotlib.pyplot as plt
  5import torch.nn as nn
  6import torch.optim as optim
  7from diffct.differentiable import ParallelProjectorFunction
  8
  9
 10def shepp_logan_2d(Nx, Ny):
 11    Nx = int(Nx)
 12    Ny = int(Ny)
 13    phantom = np.zeros((Ny, Nx), dtype=np.float32)
 14    ellipses = [
 15        (0.0, 0.0, 0.69, 0.92, 0, 1.0),
 16        (0.0, -0.0184, 0.6624, 0.8740, 0, -0.8),
 17        (0.22, 0.0, 0.11, 0.31, -18.0, -0.8),
 18        (-0.22, 0.0, 0.16, 0.41, 18.0, -0.8),
 19        (0.0, 0.35, 0.21, 0.25, 0, 0.7),
 20    ]
 21    cx = (Nx - 1)*0.5
 22    cy = (Ny - 1)*0.5
 23    for ix in range(Nx):
 24        for iy in range(Ny):
 25            xnorm = (ix - cx)/(Nx/2)
 26            ynorm = (iy - cy)/(Ny/2)
 27            val = 0.0
 28            for (x0, y0, a, b, angdeg, ampl) in ellipses:
 29                th = np.deg2rad(angdeg)
 30                xprime = (xnorm - x0)*np.cos(th) + (ynorm - y0)*np.sin(th)
 31                yprime = -(xnorm - x0)*np.sin(th) + (ynorm - y0)*np.cos(th)
 32                if xprime*xprime/(a*a) + yprime*yprime/(b*b) <= 1.0:
 33                    val += ampl
 34            phantom[iy, ix] = val
 35    phantom = np.clip(phantom, 0.0, 1.0)
 36    return phantom
 37
 38class IterativeRecoModel(nn.Module):
 39    def __init__(self, volume_shape, angles, num_detectors, detector_spacing, voxel_spacing):
 40        super().__init__()
 41        self.reco = nn.Parameter(torch.zeros(volume_shape))
 42        self.angles = angles
 43        self.num_detectors = num_detectors
 44        self.detector_spacing = detector_spacing
 45        self.voxel_spacing = voxel_spacing
 46        self.relu = nn.ReLU() # non negative constraint
 47
 48    def forward(self, x):
 49        updated_reco = x + self.reco
 50        current_sino = ParallelProjectorFunction.apply(updated_reco, self.angles, 
 51                                                       self.num_detectors, self.detector_spacing, self.voxel_spacing)
 52        return current_sino, self.relu(updated_reco)
 53
 54class Pipeline:
 55    def __init__(self, lr, volume_shape, angles, num_detectors, detector_spacing, 
 56                 voxel_spacing, device, epoches=1000):
 57        
 58        self.epoches = epoches
 59        self.model = IterativeRecoModel(volume_shape, angles, num_detectors, 
 60                                        detector_spacing, voxel_spacing).to(device)
 61
 62        self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
 63        self.loss = nn.MSELoss()
 64
 65    def train(self, input, label):
 66        loss_values = []
 67        for epoch in range(self.epoches):
 68            self.optimizer.zero_grad()
 69            predictions, current_reco = self.model(input)
 70            loss_value = self.loss(predictions, label)
 71            loss_value.backward()
 72            self.optimizer.step()
 73            loss_values.append(loss_value.item())
 74
 75            if epoch % 10 == 0:
 76                print(f"Epoch {epoch}, Loss: {loss_value.item()}")
 77                
 78        return loss_values, self.model
 79
 80def main():
 81    Nx, Ny = 128, 128
 82    phantom_cpu = shepp_logan_2d(Nx, Ny)
 83
 84    num_views = 360
 85    angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
 86
 87    num_detectors = 256
 88    detector_spacing = 0.5
 89    voxel_spacing = 1.0
 90
 91    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 92    phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32)
 93    angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
 94
 95    # Generate the "real" sinogram
 96    real_sinogram = ParallelProjectorFunction.apply(phantom_torch, angles_torch,
 97                                                    num_detectors, detector_spacing, voxel_spacing)
 98
 99    pipeline_instance = Pipeline(lr=1e-1,
100                                 volume_shape=(Ny, Nx),
101                                 angles=angles_torch,
102                                 num_detectors=num_detectors,
103                                 detector_spacing=detector_spacing,
104                                 voxel_spacing=voxel_spacing,
105                                 device=device, epoches=1000)
106
107    ini_guess = torch.zeros_like(phantom_torch)
108
109    loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
110
111    reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
112
113    plt.figure()
114    plt.plot(loss_values)
115    plt.title("Loss Curve")
116    plt.xlabel("Epoch")
117    plt.ylabel("Loss")
118    plt.show()
119
120    plt.figure(figsize=(12, 6))
121    plt.subplot(1, 2, 1)
122    plt.imshow(phantom_cpu, cmap="gray")
123    plt.title("Original Phantom")
124    plt.axis("off")
125
126    plt.subplot(1, 2, 2)
127    plt.imshow(reco, cmap="gray")
128    plt.title("Reconstructed")
129    plt.axis("off")
130    plt.show()
131
132if __name__ == "__main__":
133    main()