Fan Beam Iterative Reconstruction

This example demonstrates 2D fan beam iterative reconstruction using the differentiable FanProjectorFunction and FanBackprojectorFunction from diffct.

Overview

Fan beam iterative reconstruction extends the optimization approach to the more realistic fan beam geometry. This example shows how to:

  • Formulate fan beam CT reconstruction as an optimization problem

  • Handle geometric complexities of divergent ray geometry

  • Apply gradient-based optimization with fan beam operators

  • Monitor convergence and reconstruction quality

Mathematical Background

Fan Beam Iterative Formulation

The fan beam reconstruction problem is formulated as:

\[\hat{f} = \arg\min_f \|A_{\text{fan}}(f) - p\|_2^2 + \lambda R(f)\]

where \(A_{\text{fan}}\) is the fan beam forward projection operator accounting for divergent ray geometry.

Fan Beam Forward Model

The fan beam projection operator maps 2D image \(f(x,y)\) to sinogram \(p(\beta, u)\):

\[p(\beta, u) = \int_{\text{ray}} f(x,y) \, dl\]

where integration follows the ray from point source to detector element \(u\) at source angle \(\beta\).

Gradient Computation

The gradient involves the fan beam backprojection operator (adjoint):

\[\frac{\partial L}{\partial f} = 2A_{\text{fan}}^T(A_{\text{fan}}(f) - p_{\text{measured}})\]

where \(A_{\text{fan}}^T\) is the fan beam backprojection operator.

Geometric Considerations

Fan beam geometry introduces complexities compared to parallel beam:

  • Ray Divergence: Non-parallel rays affect sampling density and conditioning

  • Magnification Effects: Variable magnification across the field of view

  • Non-uniform Resolution: Spatial resolution varies with distance from rotation center

  • Geometric Distortion: Requires careful handling of coordinate transformations

Implementation Steps

  1. Geometry Setup: Configure fan beam parameters (SID, SDD)

  2. Problem Formulation: Define parameterized image and fan beam forward model

  3. Loss Computation: Calculate L2 distance using FanProjectorFunction

  4. Gradient Computation: Use automatic differentiation through fan beam operators

  5. Optimization: Apply Adam optimizer with appropriate learning rate

  6. Convergence Monitoring: Track reconstruction quality and loss evolution

Model Architecture

The fan beam reconstruction model consists of:

  • Parameterized Image: Learnable 2D tensor representing the unknown image

  • Fan Beam Forward Model: FanProjectorFunction with geometric parameters

  • Loss Function: Mean squared error between predicted and measured sinograms

Convergence Characteristics

Fan beam reconstruction typically exhibits:

  1. Initial Convergence (0-100 iterations): Rapid loss decrease, basic structure

  2. Detail Refinement (100-500 iterations): Fine features develop, slower progress

  3. Final Convergence (500+ iterations): Minimal improvement, convergence plateau

Challenges and Solutions

  • Conditioning: Fan beam system matrix may have different conditioning properties

  • Geometric Artifacts: Proper weighting and filtering help reduce artifacts

  • Parameter Tuning: Learning rate may need adjustment for optimal convergence

  • Memory Usage: Similar to parallel beam but with additional geometric computations

2D Fan Beam Iterative Example
  1import math
  2import torch
  3import numpy as np
  4import matplotlib.pyplot as plt
  5import torch.nn as nn
  6import torch.optim as optim
  7from diffct.differentiable import FanProjectorFunction
  8
  9
 10def shepp_logan_2d(Nx, Ny):
 11    Nx = int(Nx)
 12    Ny = int(Ny)
 13    phantom = np.zeros((Ny, Nx), dtype=np.float32)
 14    ellipses = [
 15        (0.0, 0.0, 0.69, 0.92, 0, 1.0),
 16        (0.0, -0.0184, 0.6624, 0.8740, 0, -0.8),
 17        (0.22, 0.0, 0.11, 0.31, -18.0, -0.8),
 18        (-0.22, 0.0, 0.16, 0.41, 18.0, -0.8),
 19        (0.0, 0.35, 0.21, 0.25, 0, 0.7),
 20    ]
 21    cx = (Nx - 1)*0.5
 22    cy = (Ny - 1)*0.5
 23    for ix in range(Nx):
 24        for iy in range(Ny):
 25            xnorm = (ix - cx)/(Nx/2)
 26            ynorm = (iy - cy)/(Ny/2)
 27            val = 0.0
 28            for (x0, y0, a, b, angdeg, ampl) in ellipses:
 29                th = np.deg2rad(angdeg)
 30                xprime = (xnorm - x0)*np.cos(th) + (ynorm - y0)*np.sin(th)
 31                yprime = -(xnorm - x0)*np.sin(th) + (ynorm - y0)*np.cos(th)
 32                if xprime*xprime/(a*a) + yprime*yprime/(b*b) <= 1.0:
 33                    val += ampl
 34            phantom[iy, ix] = val
 35    phantom = np.clip(phantom, 0.0, 1.0)
 36    return phantom
 37
 38class IterativeRecoModel(nn.Module):
 39    def __init__(self, volume_shape, angles,
 40                 num_detectors, detector_spacing,
 41                 sdd, sid, voxel_spacing,
 42                 backend="siddon"):
 43        super().__init__()
 44        self.reco = nn.Parameter(torch.zeros(volume_shape))
 45        self.angles = angles
 46        self.num_detectors = num_detectors
 47        self.detector_spacing = detector_spacing
 48        self.sdd = sdd
 49        self.sid = sid
 50        self.voxel_spacing = voxel_spacing
 51        self.backend = backend
 52
 53    def forward(self, x):
 54        updated_reco = x + self.reco
 55        # ``backend`` is the last positional argument to
 56        # ``FanProjectorFunction.apply`` so we must also pass the three
 57        # default offsets (detector_offset / center_offset_x / center_offset_y)
 58        # to line up with the signature.
 59        current_sino = FanProjectorFunction.apply(
 60            updated_reco,
 61            self.angles,
 62            self.num_detectors,
 63            self.detector_spacing,
 64            self.sdd,
 65            self.sid,
 66            self.voxel_spacing,
 67            0.0,              # detector_offset
 68            0.0,              # center_offset_x
 69            0.0,              # center_offset_y
 70            self.backend,
 71        )
 72        return current_sino, updated_reco
 73
 74class Pipeline:
 75    def __init__(self, lr, volume_shape, angles,
 76                 num_detectors, detector_spacing,
 77                 sdd, sid, voxel_spacing,
 78                 device, epoches=1000, backend="siddon"):
 79        self.epoches = epoches
 80        self.model = IterativeRecoModel(volume_shape, angles,
 81                                        num_detectors, detector_spacing,
 82                                        sdd, sid, voxel_spacing,
 83                                        backend=backend).to(device)
 84
 85        self.optimizer = optim.AdamW(list(self.model.parameters()), lr=lr)
 86        self.loss = nn.MSELoss()
 87
 88    def train(self, input, label):
 89        loss_values = []
 90        for epoch in range(self.epoches):
 91            self.optimizer.zero_grad()
 92            predictions, current_reco = self.model(input)
 93            loss_value = self.loss(predictions, label)
 94            loss_value.backward()
 95            self.optimizer.step()
 96            with torch.no_grad():
 97                self.model.reco.clamp_(min=0.0)
 98            loss_values.append(loss_value.item())
 99
100            if epoch % 10 == 0:
101                print(f"Epoch {epoch}, Loss: {loss_value.item()}")
102                
103        return loss_values, self.model
104
105def main():
106    Nx, Ny = 128, 128
107    phantom_cpu = shepp_logan_2d(Nx, Ny)
108
109    num_views = 360
110    angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
111
112    num_detectors = 256
113    detector_spacing = 1.0
114    voxel_spacing = 1.0
115    sdd = 600.0
116    sid = 400.0
117
118    # Forward projector backend used for BOTH the ground-truth sinogram
119    # and the inner iterative loop. Using the same backend on both sides
120    # is the cleanest setup: the loop then solves its own exact inverse
121    # problem, and the adjoint returned by autograd matches the forward
122    # byte-for-byte (guaranteed by the matched scatter/gather kernel
123    # pair, verified by tests/test_adjoint_inner_product.py). Options:
124    #
125    #   "siddon"           - ray-driven cell-constant Siddon. Fastest.
126    #                        Good default when you want the shortest
127    #                        iteration step.
128    #   "sf"               - voxel-driven separable-footprint projector
129    #                        (Long et al. SF-TR). Mass-conserving per
130    #                        voxel, closed-form cell integral, ~3x
131    #                        slower than siddon. Worth trying when you
132    #                        want a physically-principled cell-
133    #                        integrated forward model and don't mind
134    #                        the per-iteration cost.
135    projector_backend = "siddon"
136
137    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
138    phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32)
139    angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
140
141    # Generate the "real" sinogram with the same backend as the inner
142    # loop, so reconstruction targets what the loop can actually produce.
143    real_sinogram = FanProjectorFunction.apply(
144        phantom_torch, angles_torch,
145        num_detectors, detector_spacing,
146        sdd, sid, voxel_spacing,
147        0.0, 0.0, 0.0,                  # detector_offset, center_offset_x, center_offset_y
148        projector_backend,
149    )
150
151    pipeline_instance = Pipeline(lr=1e-1,
152                                 volume_shape=(Ny,Nx),
153                                 angles=angles_torch,
154                                 num_detectors=num_detectors,
155                                 detector_spacing=detector_spacing,
156                                 sdd=sdd, voxel_spacing=voxel_spacing,
157                                 sid=sid,
158                                 device=device, epoches=1000,
159                                 backend=projector_backend)
160
161    ini_guess = torch.zeros_like(phantom_torch)
162
163    loss_values, trained_model = pipeline_instance.train(ini_guess, real_sinogram)
164
165    reco = trained_model(ini_guess)[1].squeeze().cpu().detach().numpy()
166
167    plt.figure()
168    plt.plot(loss_values)
169    plt.title("Loss Curve")
170    plt.xlabel("Epoch")
171    plt.ylabel("Loss")
172    plt.show()
173
174    plt.figure(figsize=(12, 6))
175    plt.subplot(1, 2, 1)
176    plt.imshow(phantom_cpu, cmap="gray")
177    plt.title("Original Phantom")
178    plt.axis("off")
179
180    plt.subplot(1, 2, 2)
181    plt.imshow(reco, cmap="gray")
182    plt.title("Reconstructed")
183    plt.axis("off")
184    plt.show()
185
186if __name__ == "__main__":
187    main()