Cone Beam FDK Reconstruction

This example demonstrates 3D cone beam FDK (Feldkamp-Davis-Kress) reconstruction using the ConeProjectorFunction and ConeBackprojectorFunction from diffct.

Overview

The FDK algorithm is the standard analytical method for 3D cone beam CT reconstruction. This example shows how to:

  • Configure 3D cone beam geometry with 2D detector array

  • Generate cone beam projections from a 3D phantom

  • Apply cosine weighting and ramp filtering for FDK reconstruction

  • Perform 3D backprojection to reconstruct the volume

Mathematical Background

Cone Beam Geometry

Cone beam CT extends fan beam to 3D using a point X-ray source and 2D detector array. Key parameters:

  • SDD \(D_s\): Source-to-Detector Distance (distance from X-ray source to detector plane)

  • SID \(D_{sid}\): Source-to-Isocenter Distance (distance from X-ray source to rotation center)

  • Detector coordinates \((u, v)\): Horizontal and vertical detector positions

  • Cone angles \((\alpha, \beta)\): Horizontal and vertical beam divergence

3D Forward Projection

The cone beam projection at source angle \(\phi\) and detector position \((u, v)\) is:

\[p(\phi, u, v) = \int_0^{\infty} f\left(\vec{r}_s(\phi) + t \cdot \vec{d}(\phi, u, v)\right) dt\]

where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.

FDK Algorithm

The Feldkamp-Davis-Kress algorithm performs approximate 3D reconstruction in three steps:

  1. Cosine Weighting: Compensate for ray divergence and \(1/r^2\) intensity falloff:

    \[p_w(\phi, u, v) = p(\phi, u, v) \cdot \frac{D_s}{\sqrt{D_s^2 + u^2 + v^2}}\]
  2. Row-wise Ramp Filtering: Apply 1D ramp filter along detector rows (u-direction):

    \[p_f(\phi, u, v) = \mathcal{F}_u^{-1}\{|\omega_u| \cdot \mathcal{F}_u\{p_w(\phi, u, v)\}\}\]

    Each detector row is filtered independently.

  3. 3D Cone Beam Backprojection: Reconstruct volume using weighted backprojection:

    \[f(x,y,z) = \int_0^{2\pi} \frac{D_s^2}{(D_s + x\cos\phi + y\sin\phi)^2} p_f(\phi, u_{xyz}, v_{xyz}) d\phi\]

    where detector coordinates \((u_{xyz}, v_{xyz})\) for voxel \((x,y,z)\) are:

    \[u_{xyz} = D_s \frac{-x\sin\phi + y\cos\phi}{D_s + x\cos\phi + y\sin\phi}\]
    \[v_{xyz} = D_s \frac{z}{D_s + x\cos\phi + y\sin\phi}\]

Implementation Steps

  1. 3D Phantom Generation: Create 3D Shepp-Logan phantom with 10 ellipsoids

  2. Cone Beam Projection: Generate 2D projections using ConeProjectorFunction

  3. Cosine Weighting: Apply distance-dependent weights

  4. Row-wise Filtering: Apply ramp filter to each detector row

  5. 3D Backprojection: Reconstruct volume using ConeBackprojectorFunction

  6. Normalization: Scale by \(\frac{\pi}{N_{\text{angles}}}\) factor

3D Shepp-Logan Phantom

The 3D phantom extends the 2D version with 10 ellipsoids representing anatomical structures:

  • Outer skull: Large ellipsoid encompassing the head

  • Brain tissue: Medium ellipsoids for different brain regions

  • Ventricles: Small ellipsoids representing fluid-filled cavities

  • Lesions: High-contrast features for reconstruction assessment

Each ellipsoid is defined by center position \((x_0, y_0, z_0)\), semi-axes \((a, b, c)\), rotation angles, and attenuation coefficient.

FDK Approximations and Limitations

The FDK algorithm makes several approximations:

  • Circular orbit: Assumes circular source trajectory

  • Row-wise filtering: Ramp filtering only along detector rows

  • Small cone angle: Most accurate for limited cone angles

These approximations introduce cone beam artifacts for large cone angles, but FDK remains widely used due to computational efficiency.

3D Cone Beam FDK Example
  1import math
  2import numpy as np
  3import torch
  4import matplotlib.pyplot as plt
  5import torch.nn.functional as F
  6from diffct.differentiable import (
  7    ConeProjectorFunction,
  8    angular_integration_weights,
  9    cone_cosine_weights,
 10    cone_weighted_backproject,
 11    ramp_filter_1d,
 12)
 13
 14
 15def shepp_logan_3d(shape):
 16    zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
 17    xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
 18    yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
 19    zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
 20    el_params = np.array([
 21        [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
 22        [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
 23        [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
 24        [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
 25        [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
 26        [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 27        [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 28        [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
 29        [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
 30        [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
 31    ], dtype=np.float32)
 32
 33    # Extract parameters for vectorization
 34    x_pos = el_params[:, 0][:, None, None, None]
 35    y_pos = el_params[:, 1][:, None, None, None]
 36    z_pos = el_params[:, 2][:, None, None, None]
 37    a_axis = el_params[:, 3][:, None, None, None]
 38    b_axis = el_params[:, 4][:, None, None, None]
 39    c_axis = el_params[:, 5][:, None, None, None]
 40    phi = el_params[:, 6][:, None, None, None]
 41    val = el_params[:, 9][:, None, None, None]
 42
 43    # Broadcast grid to ellipsoid axis
 44    xc = xx[None, ...] - x_pos
 45    yc = yy[None, ...] - y_pos
 46    zc = zz[None, ...] - z_pos
 47
 48    c = np.cos(phi)
 49    s = np.sin(phi)
 50
 51    # Only rotation around z, so can vectorize:
 52    xp = c * xc - s * yc
 53    yp = s * xc + c * yc
 54    zp = zc
 55
 56    mask = (
 57        (xp ** 2) / (a_axis ** 2)
 58        + (yp ** 2) / (b_axis ** 2)
 59        + (zp ** 2) / (c_axis ** 2)
 60        <= 1.0
 61    )
 62
 63    # Use broadcasting to sum all ellipsoid contributions
 64    shepp_logan = np.sum(mask * val, axis=0)
 65    shepp_logan = np.clip(shepp_logan, 0, 1)
 66    return shepp_logan
 67
 68def main():
 69    Nx, Ny, Nz = 128, 128, 128
 70    phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
 71
 72    num_views = 360
 73    angles_np = np.linspace(0, 2*math.pi, num_views, endpoint=False).astype(np.float32)
 74
 75    det_u, det_v = 256, 256
 76    du, dv = 1.0, 1.0
 77    detector_offset_u = 0.0
 78    detector_offset_v = 0.0
 79    sdd = 900.0
 80    sid = 600.0
 81
 82    voxel_spacing = 1.0
 83
 84    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 85    phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32).contiguous()
 86    angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
 87
 88    sinogram = ConeProjectorFunction.apply(phantom_torch, angles_torch,
 89                                           det_u, det_v, du, dv,
 90                                           sdd, sid, voxel_spacing)
 91
 92    # --- FDK weighting and filtering ---
 93    # 1) FDK cosine pre-weighting
 94    weights = cone_cosine_weights(
 95        det_u,
 96        det_v,
 97        du,
 98        dv,
 99        sdd,
100        detector_offset_u=detector_offset_u,
101        detector_offset_v=detector_offset_v,
102        device=device,
103        dtype=phantom_torch.dtype,
104    ).unsqueeze(0)
105    sino_weighted = sinogram * weights
106
107    # 2) Ramp filter along detector-u rows
108    sinogram_filt = ramp_filter_1d(sino_weighted, dim=1).contiguous()
109
110    # 3) Angle-integration weights
111    d_beta = angular_integration_weights(angles_torch, redundant_full_scan=True).view(-1, 1, 1)
112    sinogram_filt = sinogram_filt * d_beta
113
114    # 4) Weighted cone-beam backprojection
115    reconstruction = F.relu(
116        cone_weighted_backproject(
117            sinogram_filt,
118            angles_torch,
119            Nz,
120            Ny,
121            Nx,
122            du,
123            dv,
124            sdd,
125            sid,
126            voxel_spacing=voxel_spacing,
127            detector_offset_u=detector_offset_u,
128            detector_offset_v=detector_offset_v,
129        )
130    )
131
132    loss = torch.mean((reconstruction - phantom_torch)**2)
133
134    print("Cone Beam Example with user-defined geometry:")
135    print("Loss:", loss.item())
136    print("Reconstruction shape:", reconstruction.shape)
137
138    reconstruction_cpu = reconstruction.detach().cpu().numpy()
139    sinogram_cpu = sinogram.detach().cpu().numpy()
140    mid_slice = Nz // 2
141
142    plt.figure(figsize=(12,4))
143    plt.subplot(1,3,1)
144    plt.imshow(phantom_cpu[mid_slice, :,:], cmap='gray')
145    plt.title("Phantom mid-slice")
146    plt.axis('off')
147    plt.subplot(1,3,2)
148    plt.imshow(sinogram_cpu[num_views//2].T, cmap='gray', origin='lower') # Transpose for correct orientation
149    plt.title("Sinogram mid-view")
150    plt.axis('off')
151    plt.subplot(1,3,3)
152    plt.imshow(reconstruction_cpu[mid_slice, :, :], cmap='gray')
153    plt.title("Recon mid-slice")
154    plt.axis('off')
155    plt.tight_layout()
156    plt.show()
157
158    # print data range of the phantom and reco
159    print("Phantom data range:", phantom_cpu.min(), phantom_cpu.max())
160    print("Reco data range:", reconstruction_cpu.min(), reconstruction_cpu.max())
161
162if __name__ == "__main__":
163    main()