Cone Beam FDK Reconstruction

This example demonstrates 3D cone beam FDK (Feldkamp-Davis-Kress) reconstruction using the ConeProjectorFunction and ConeBackprojectorFunction from diffct.

Overview

The FDK algorithm is the standard analytical method for 3D cone beam CT reconstruction. This example shows how to:

  • Configure 3D cone beam geometry with 2D detector array

  • Generate cone beam projections from a 3D phantom

  • Apply cosine weighting and ramp filtering for FDK reconstruction

  • Perform 3D backprojection to reconstruct the volume

Mathematical Background

Cone Beam Geometry

Cone beam CT extends fan beam to 3D using a point X-ray source and 2D detector array. Key parameters:

  • SDD \(D_s\): Source-to-Detector Distance (distance from X-ray source to detector plane)

  • SID \(D_{sid}\): Source-to-Isocenter Distance (distance from X-ray source to rotation center)

  • Detector coordinates \((u, v)\): Horizontal and vertical detector positions

  • Cone angles \((\alpha, \beta)\): Horizontal and vertical beam divergence

3D Forward Projection

The cone beam projection at source angle \(\phi\) and detector position \((u, v)\) is:

\[p(\phi, u, v) = \int_0^{\infty} f\left(\vec{r}_s(\phi) + t \cdot \vec{d}(\phi, u, v)\right) dt\]

where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.

FDK Algorithm

The Feldkamp-Davis-Kress algorithm performs approximate 3D reconstruction in three steps:

  1. Cosine Weighting: Compensate for ray divergence and \(1/r^2\) intensity falloff:

    \[p_w(\phi, u, v) = p(\phi, u, v) \cdot \frac{D_s}{\sqrt{D_s^2 + u^2 + v^2}}\]
  2. Row-wise Ramp Filtering: Apply 1D ramp filter along detector rows (u-direction):

    \[p_f(\phi, u, v) = \mathcal{F}_u^{-1}\{|\omega_u| \cdot \mathcal{F}_u\{p_w(\phi, u, v)\}\}\]

    Each detector row is filtered independently.

  3. 3D Cone Beam Backprojection: Reconstruct volume using weighted backprojection:

    \[f(x,y,z) = \int_0^{2\pi} \frac{D_s^2}{(D_s + x\cos\phi + y\sin\phi)^2} p_f(\phi, u_{xyz}, v_{xyz}) d\phi\]

    where detector coordinates \((u_{xyz}, v_{xyz})\) for voxel \((x,y,z)\) are:

    \[u_{xyz} = D_s \frac{-x\sin\phi + y\cos\phi}{D_s + x\cos\phi + y\sin\phi}\]
    \[v_{xyz} = D_s \frac{z}{D_s + x\cos\phi + y\sin\phi}\]

Implementation Steps

  1. 3D Phantom Generation: Create 3D Shepp-Logan phantom with 10 ellipsoids

  2. Cone Beam Projection: Generate 2D projections using ConeProjectorFunction

  3. Cosine Weighting: Apply distance-dependent weights

  4. Row-wise Filtering: Apply ramp filter to each detector row

  5. 3D Backprojection: Reconstruct volume using ConeBackprojectorFunction

  6. Normalization: Scale by \(\frac{\pi}{N_{\text{angles}}}\) factor

3D Shepp-Logan Phantom

The 3D phantom extends the 2D version with 10 ellipsoids representing anatomical structures:

  • Outer skull: Large ellipsoid encompassing the head

  • Brain tissue: Medium ellipsoids for different brain regions

  • Ventricles: Small ellipsoids representing fluid-filled cavities

  • Lesions: High-contrast features for reconstruction assessment

Each ellipsoid is defined by center position \((x_0, y_0, z_0)\), semi-axes \((a, b, c)\), rotation angles, and attenuation coefficient.

FDK Approximations and Limitations

The FDK algorithm makes several approximations:

  • Circular orbit: Assumes circular source trajectory

  • Row-wise filtering: Ramp filtering only along detector rows

  • Small cone angle: Most accurate for limited cone angles

These approximations introduce cone beam artifacts for large cone angles, but FDK remains widely used due to computational efficiency.

3D Cone Beam FDK Example
  1import math
  2import numpy as np
  3import torch
  4import matplotlib.pyplot as plt
  5import torch.nn.functional as F
  6from diffct.differentiable import ConeProjectorFunction, ConeBackprojectorFunction
  7
  8
  9def shepp_logan_3d(shape):
 10    zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
 11    xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
 12    yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
 13    zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
 14    el_params = np.array([
 15        [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
 16        [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
 17        [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
 18        [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
 19        [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
 20        [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 21        [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
 22        [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
 23        [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
 24        [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
 25    ], dtype=np.float32)
 26
 27    # Extract parameters for vectorization
 28    x_pos = el_params[:, 0][:, None, None, None]
 29    y_pos = el_params[:, 1][:, None, None, None]
 30    z_pos = el_params[:, 2][:, None, None, None]
 31    a_axis = el_params[:, 3][:, None, None, None]
 32    b_axis = el_params[:, 4][:, None, None, None]
 33    c_axis = el_params[:, 5][:, None, None, None]
 34    phi = el_params[:, 6][:, None, None, None]
 35    val = el_params[:, 9][:, None, None, None]
 36
 37    # Broadcast grid to ellipsoid axis
 38    xc = xx[None, ...] - x_pos
 39    yc = yy[None, ...] - y_pos
 40    zc = zz[None, ...] - z_pos
 41
 42    c = np.cos(phi)
 43    s = np.sin(phi)
 44
 45    # Only rotation around z, so can vectorize:
 46    xp = c * xc - s * yc
 47    yp = s * xc + c * yc
 48    zp = zc
 49
 50    mask = (
 51        (xp ** 2) / (a_axis ** 2)
 52        + (yp ** 2) / (b_axis ** 2)
 53        + (zp ** 2) / (c_axis ** 2)
 54        <= 1.0
 55    )
 56
 57    # Use broadcasting to sum all ellipsoid contributions
 58    shepp_logan = np.sum(mask * val, axis=0)
 59    shepp_logan = np.clip(shepp_logan, 0, 1)
 60    return shepp_logan
 61
 62def ramp_filter_3d(sinogram_tensor):
 63    device = sinogram_tensor.device
 64    num_views, num_det_u, num_det_v = sinogram_tensor.shape
 65    freqs = torch.fft.fftfreq(num_det_u, device=device)
 66    omega = 2.0 * torch.pi * freqs
 67    ramp = torch.abs(omega)
 68    ramp_3d = ramp.reshape(1, num_det_u, 1)
 69    sino_fft = torch.fft.fft(sinogram_tensor, dim=1)
 70    filtered_fft = sino_fft * ramp_3d
 71    filtered = torch.real(torch.fft.ifft(filtered_fft, dim=1))
 72    
 73    return filtered
 74
 75def main():
 76    Nx, Ny, Nz = 128, 128, 128
 77    phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
 78
 79    num_views = 360
 80    angles_np = np.linspace(0, 2*math.pi, num_views, endpoint=False).astype(np.float32)
 81
 82    det_u, det_v = 256, 256
 83    du, dv = 1.0, 1.0
 84    sdd = 900.0
 85    sid = 600.0
 86
 87    voxel_spacing = 1.0
 88
 89    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 90    phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32, requires_grad=True).contiguous()
 91    angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
 92
 93    sinogram = ConeProjectorFunction.apply(phantom_torch, angles_torch,
 94                                           det_u, det_v, du, dv,
 95                                           sdd, sid, voxel_spacing)
 96
 97    # --- FDK weighting and filtering ---
 98    # For FDK, projections must be weighted before filtering.
 99    # Weight = D / sqrt(D^2 + u^2 + v^2), where D is source_distance
100    # and (u,v) are detector coordinates.
101    u_coords = (torch.arange(det_u, dtype=phantom_torch.dtype, device=device) - (det_u - 1) / 2) * du
102    v_coords = (torch.arange(det_v, dtype=phantom_torch.dtype, device=device) - (det_v - 1) / 2) * dv
103
104    # Reshape for broadcasting over sinogram of shape (views, u, v)
105    u_coords = u_coords.view(1, det_u, 1)
106    v_coords = v_coords.view(1, 1, det_v)
107    
108    weights = sdd / torch.sqrt(sdd**2 + u_coords**2 + v_coords**2)
109    
110    # Apply weights and then filter
111    sino_weighted = sinogram * weights
112    sinogram_filt = ramp_filter_3d(sino_weighted).contiguous()
113
114    reconstruction = F.relu(ConeBackprojectorFunction.apply(sinogram_filt, angles_torch, Nz, Ny, Nx,
115                                                    du, dv, sdd, sid, voxel_spacing)) # ReLU to ensure non-negativity
116    
117    # --- FDK normalization ---
118    # The backprojection is a sum over all angles. To approximate the integral,
119    # we need to multiply by the angular step d_beta.
120    # The FDK formula also includes a factor of 1/2 when integrating over [0, 2*pi].
121    # d_beta = 2 * pi / num_views
122    # Normalization factor = (1/2) * d_beta = pi / num_views
123    reconstruction = reconstruction * (math.pi / num_views)
124
125    loss = torch.mean((reconstruction - phantom_torch)**2)
126    loss.backward()
127
128    print("Cone Beam Example with user-defined geometry:")
129    print("Loss:", loss.item())
130    print("Volume center voxel gradient:", phantom_torch.grad[Nz//2, Ny//2, Nx//2].item())
131    print("Reconstruction shape:", reconstruction.shape)
132
133    reconstruction_cpu = reconstruction.detach().cpu().numpy()
134    sinogram_cpu = sinogram.detach().cpu().numpy()
135    mid_slice = Nz // 2
136
137    plt.figure(figsize=(12,4))
138    plt.subplot(1,3,1)
139    plt.imshow(phantom_cpu[mid_slice, :,:], cmap='gray')
140    plt.title("Phantom mid-slice")
141    plt.axis('off')
142    plt.subplot(1,3,2)
143    plt.imshow(sinogram_cpu[num_views//2].T, cmap='gray', origin='lower') # Transpose for correct orientation
144    plt.title("Sinogram mid-view")
145    plt.axis('off')
146    plt.subplot(1,3,3)
147    plt.imshow(reconstruction_cpu[mid_slice, :, :], cmap='gray')
148    plt.title("Recon mid-slice")
149    plt.axis('off')
150    plt.tight_layout()
151    plt.show()
152
153    # print data range of the phantom and reco
154    print("Phantom data range:", phantom_cpu.min(), phantom_cpu.max())
155    print("Reco data range:", reconstruction_cpu.min(), reconstruction_cpu.max())
156
157if __name__ == "__main__":
158    main()