Parallel Beam Iterative Reconstruction

This example demonstrates 2D parallel beam iterative reconstruction using the differentiable ParallelProjectorFunction and ParallelBackprojectorFunction from diffct.

Overview

Parallel beam iterative reconstruction solves the CT inverse problem through optimization, offering advantages over analytical methods like FBP. This example shows how to:

  • Formulate parallel beam CT reconstruction as an optimization problem

  • Use automatic differentiation for gradient computation

  • Apply gradient-based optimization with parallel beam operators

  • Monitor convergence and reconstruction quality

Mathematical Background

Parallel Beam Iterative Formulation

The parallel beam reconstruction problem is formulated as:

\[\hat{f} = \arg\min_f \|A_{\text{parallel}}(f) - p\|_2^2 + \lambda R(f)\]

where: - \(f\) is the unknown 2D image - \(A_{\text{parallel}}\) is the parallel beam forward projection operator (Radon transform) - \(p\) is the measured sinogram data - \(R(f)\) is an optional regularization term - \(\lambda\) is the regularization parameter

Gradient-Based Optimization

The gradient of the data fidelity term is computed using the adjoint operator:

\[\nabla_f \|A_{\text{parallel}}(f) - p\|_2^2 = 2A_{\text{parallel}}^T(A_{\text{parallel}}(f) - p)\]

where \(A_{\text{parallel}}^T\) is the parallel beam backprojection operator (adjoint of the forward projector).

Automatic Differentiation

PyTorch’s automatic differentiation computes gradients through the differentiable operators:

\[\frac{\partial L}{\partial f} = \frac{\partial}{\partial f} \|A_{\text{parallel}}(f) - p_{\text{measured}}\|_2^2\]

This enables seamless integration with advanced optimizers like Adam.

Adam Optimizer

The Adam optimizer adapts learning rates using gradient statistics:

\[f^{(k+1)} = f^{(k)} - \alpha \cdot \frac{m^{(k)}}{1-\beta_1^k} \cdot \frac{1}{\sqrt{v^{(k)}/(1-\beta_2^k)} + \epsilon}\]

where \(m^{(k)}\) and \(v^{(k)}\) are biased first and second moment estimates.

Implementation Steps

  1. Problem Setup: Define parameterized 2D image as learnable tensor

  2. Forward Model: Compute predicted sinogram using ParallelProjectorFunction

  3. Loss Computation: Calculate L2 distance between predicted and measured data

  4. Gradient Computation: Use automatic differentiation for gradient calculation

  5. Parameter Update: Apply Adam optimizer for iterative improvement

  6. Convergence Monitoring: Track loss and reconstruction quality

Model Architecture

The reconstruction model consists of:

  • Parameterized Image: Learnable 2D tensor representing the unknown image

  • Forward Projection: ParallelProjectorFunction for sinogram prediction

  • Loss Function: Mean squared error between predicted and measured sinograms

Advantages of Iterative Methods

  • Noise Robustness: Superior handling of noisy measurements

  • Regularization: Natural incorporation of prior knowledge

  • Incomplete Data: Effective with limited-angle or sparse-view acquisitions

  • Flexibility: Easy modification of cost functions and constraints

  • Artifact Reduction: Better control over reconstruction artifacts

Convergence Characteristics

Typical convergence behavior:

  1. Initial Phase (0-100 iterations): Rapid loss decrease, basic structure emerges

  2. Refinement Phase (100-500 iterations): Fine details develop, slower convergence

  3. Convergence Phase (500+ iterations): Minimal improvement, potential overfitting

2D Parallel Beam Iterative Example
  1import math
  2import torch
  3import numpy as np
  4import matplotlib.pyplot as plt
  5import torch.nn as nn
  6import torch.optim as optim
  7from diffct.differentiable import ParallelProjectorFunction
  8
  9
 10def shepp_logan_2d(Nx, Ny):
 11    Nx = int(Nx)
 12    Ny = int(Ny)
 13    phantom = np.zeros((Ny, Nx), dtype=np.float32)
 14    ellipses = [
 15        (0.0, 0.0, 0.69, 0.92, 0, 1.0),
 16        (0.0, -0.0184, 0.6624, 0.8740, 0, -0.8),
 17        (0.22, 0.0, 0.11, 0.31, -18.0, -0.8),
 18        (-0.22, 0.0, 0.16, 0.41, 18.0, -0.8),
 19        (0.0, 0.35, 0.21, 0.25, 0, 0.7),
 20    ]
 21    cx = (Nx - 1)*0.5
 22    cy = (Ny - 1)*0.5
 23    for ix in range(Nx):
 24        for iy in range(Ny):
 25            xnorm = (ix - cx)/(Nx/2)
 26            ynorm = (iy - cy)/(Ny/2)
 27            val = 0.0
 28            for (x0, y0, a, b, angdeg, ampl) in ellipses:
 29                th = np.deg2rad(angdeg)
 30                xprime = (xnorm - x0)*np.cos(th) + (ynorm - y0)*np.sin(th)
 31                yprime = -(xnorm - x0)*np.sin(th) + (ynorm - y0)*np.cos(th)
 32                if xprime*xprime/(a*a) + yprime*yprime/(b*b) <= 1.0:
 33                    val += ampl
 34            phantom[iy, ix] = val
 35    phantom = np.clip(phantom, 0.0, 1.0)
 36    return phantom
 37
 38class IterativeRecoModel(nn.Module):
 39    def __init__(self, volume_shape, angles, num_detectors, detector_spacing, voxel_spacing):
 40        super().__init__()
 41        self.reco = nn.Parameter(torch.zeros(volume_shape))
 42        self.angles = angles
 43        self.num_detectors = num_detectors
 44        self.detector_spacing = detector_spacing
 45        self.voxel_spacing = voxel_spacing
 46
 47    def forward(self, x):
 48        updated_reco = x + self.reco
 49        current_sino = ParallelProjectorFunction.apply(updated_reco, self.angles, 
 50                                                       self.num_detectors, self.detector_spacing, self.voxel_spacing)
 51        return current_sino, updated_reco
 52
 53class Pipeline:
 54    def __init__(self, lr, volume_shape, angles, num_detectors, detector_spacing, 
 55                 voxel_spacing, device, epoches=1000):
 56        
 57        self.epoches = epoches
 58        self.model = IterativeRecoModel(volume_shape, angles, num_detectors, 
 59                                        detector_spacing, voxel_spacing).to(device)
 60
 61        self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
 62        self.loss = nn.MSELoss()
 63
 64    def train(self, input, label):
 65        loss_values = []
 66        for epoch in range(self.epoches):
 67            self.optimizer.zero_grad()
 68            predictions, current_reco = self.model(input)
 69            loss_value = self.loss(predictions, label)
 70            loss_value.backward()
 71            self.optimizer.step()
 72            with torch.no_grad():
 73                self.model.reco.clamp_(min=0.0)
 74            loss_values.append(loss_value.item())
 75
 76            if epoch % 10 == 0:
 77                print(f"Epoch {epoch}, Loss: {loss_value.item()}")
 78                
 79        return loss_values, self.model
 80
 81def main():
 82    Nx, Ny = 128, 128
 83    phantom_cpu = shepp_logan_2d(Nx, Ny)
 84
 85    num_views = 360
 86    angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
 87
 88    num_detectors = 256
 89    detector_spacing = 0.5
 90    voxel_spacing = 1.0
 91
 92    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 93    phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32)
 94    angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
 95
 96    # Generate the "real" sinogram
 97    real_sinogram = ParallelProjectorFunction.apply(phantom_torch, angles_torch,
 98                                                    num_detectors, detector_spacing, voxel_spacing)
 99
100    pipeline_instance = Pipeline(lr=1e-1,
101                                 volume_shape=(Ny, Nx),
102                                 angles=angles_torch,
103                                 num_detectors=num_detectors,
104                                 detector_spacing=detector_spacing,
105                                 voxel_spacing=voxel_spacing,
106                                 device=device, epoches=1000)
107
108    ini_guess = torch.zeros_like(phantom_torch)
109
110    loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
111
112    reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
113
114    plt.figure()
115    plt.plot(loss_values)
116    plt.title("Loss Curve")
117    plt.xlabel("Epoch")
118    plt.ylabel("Loss")
119    plt.show()
120
121    plt.figure(figsize=(12, 6))
122    plt.subplot(1, 2, 1)
123    plt.imshow(phantom_cpu, cmap="gray")
124    plt.title("Original Phantom")
125    plt.axis("off")
126
127    plt.subplot(1, 2, 2)
128    plt.imshow(reco, cmap="gray")
129    plt.title("Reconstructed")
130    plt.axis("off")
131    plt.show()
132
133if __name__ == "__main__":
134    main()