Cone Beam FDK Reconstruction
This example demonstrates 3D cone beam FDK (Feldkamp-Davis-Kress) reconstruction using the ConeProjectorFunction and ConeBackprojectorFunction from diffct.
Overview
The FDK algorithm is the standard analytical method for 3D cone beam CT reconstruction. This example shows how to:
Configure 3D cone beam geometry with 2D detector array
Generate cone beam projections from a 3D phantom
Apply cosine weighting and ramp filtering for FDK reconstruction
Perform 3D backprojection to reconstruct the volume
Mathematical Background
Cone Beam Geometry
Cone beam CT extends fan beam to 3D using a point X-ray source and 2D detector array. Key parameters:
SDD \(D_s\): Source-to-Detector Distance (distance from X-ray source to detector plane)
SID \(D_{sid}\): Source-to-Isocenter Distance (distance from X-ray source to rotation center)
Detector coordinates \((u, v)\): Horizontal and vertical detector positions
Cone angles \((\alpha, \beta)\): Horizontal and vertical beam divergence
3D Forward Projection
The cone beam projection at source angle \(\phi\) and detector position \((u, v)\) is:
where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.
FDK Algorithm
The Feldkamp-Davis-Kress algorithm performs approximate 3D reconstruction in three steps:
Cosine Weighting: Compensate for ray divergence and \(1/r^2\) intensity falloff:
\[p_w(\phi, u, v) = p(\phi, u, v) \cdot \frac{D_s}{\sqrt{D_s^2 + u^2 + v^2}}\]Row-wise Ramp Filtering: Apply 1D ramp filter along detector rows (u-direction):
\[p_f(\phi, u, v) = \mathcal{F}_u^{-1}\{|\omega_u| \cdot \mathcal{F}_u\{p_w(\phi, u, v)\}\}\]Each detector row is filtered independently.
3D Cone Beam Backprojection: Reconstruct volume using weighted backprojection:
\[f(x,y,z) = \int_0^{2\pi} \frac{D_s^2}{(D_s + x\cos\phi + y\sin\phi)^2} p_f(\phi, u_{xyz}, v_{xyz}) d\phi\]where detector coordinates \((u_{xyz}, v_{xyz})\) for voxel \((x,y,z)\) are:
\[u_{xyz} = D_s \frac{-x\sin\phi + y\cos\phi}{D_s + x\cos\phi + y\sin\phi}\]\[v_{xyz} = D_s \frac{z}{D_s + x\cos\phi + y\sin\phi}\]
Implementation Steps
3D Phantom Generation: Create 3D Shepp-Logan phantom with 10 ellipsoids
Cone Beam Projection: Generate 2D projections using ConeProjectorFunction
Cosine Weighting: Apply distance-dependent weights
Row-wise Filtering: Apply ramp filter to each detector row
3D Backprojection: Reconstruct volume using ConeBackprojectorFunction
Normalization: Scale by \(\frac{\pi}{N_{\text{angles}}}\) factor
3D Shepp-Logan Phantom
The 3D phantom extends the 2D version with 10 ellipsoids representing anatomical structures:
Outer skull: Large ellipsoid encompassing the head
Brain tissue: Medium ellipsoids for different brain regions
Ventricles: Small ellipsoids representing fluid-filled cavities
Lesions: High-contrast features for reconstruction assessment
Each ellipsoid is defined by center position \((x_0, y_0, z_0)\), semi-axes \((a, b, c)\), rotation angles, and attenuation coefficient.
FDK Approximations and Limitations
The FDK algorithm makes several approximations:
Circular orbit: Assumes circular source trajectory
Row-wise filtering: Ramp filtering only along detector rows
Small cone angle: Most accurate for limited cone angles
These approximations introduce cone beam artifacts for large cone angles, but FDK remains widely used due to computational efficiency.
1import math
2import numpy as np
3import torch
4import matplotlib.pyplot as plt
5import torch.nn.functional as F
6from diffct.differentiable import ConeProjectorFunction, ConeBackprojectorFunction
7
8
9def shepp_logan_3d(shape):
10 zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
11 xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
12 yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
13 zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
14 el_params = np.array([
15 [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
16 [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
17 [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
18 [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
19 [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
20 [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
21 [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
22 [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
23 [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
24 [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
25 ], dtype=np.float32)
26
27 # Extract parameters for vectorization
28 x_pos = el_params[:, 0][:, None, None, None]
29 y_pos = el_params[:, 1][:, None, None, None]
30 z_pos = el_params[:, 2][:, None, None, None]
31 a_axis = el_params[:, 3][:, None, None, None]
32 b_axis = el_params[:, 4][:, None, None, None]
33 c_axis = el_params[:, 5][:, None, None, None]
34 phi = el_params[:, 6][:, None, None, None]
35 val = el_params[:, 9][:, None, None, None]
36
37 # Broadcast grid to ellipsoid axis
38 xc = xx[None, ...] - x_pos
39 yc = yy[None, ...] - y_pos
40 zc = zz[None, ...] - z_pos
41
42 c = np.cos(phi)
43 s = np.sin(phi)
44
45 # Only rotation around z, so can vectorize:
46 xp = c * xc - s * yc
47 yp = s * xc + c * yc
48 zp = zc
49
50 mask = (
51 (xp ** 2) / (a_axis ** 2)
52 + (yp ** 2) / (b_axis ** 2)
53 + (zp ** 2) / (c_axis ** 2)
54 <= 1.0
55 )
56
57 # Use broadcasting to sum all ellipsoid contributions
58 shepp_logan = np.sum(mask * val, axis=0)
59 shepp_logan = np.clip(shepp_logan, 0, 1)
60 return shepp_logan
61
62def ramp_filter_3d(sinogram_tensor):
63 device = sinogram_tensor.device
64 num_views, num_det_u, num_det_v = sinogram_tensor.shape
65 freqs = torch.fft.fftfreq(num_det_u, device=device)
66 omega = 2.0 * torch.pi * freqs
67 ramp = torch.abs(omega)
68 ramp_3d = ramp.reshape(1, num_det_u, 1)
69 sino_fft = torch.fft.fft(sinogram_tensor, dim=1)
70 filtered_fft = sino_fft * ramp_3d
71 filtered = torch.real(torch.fft.ifft(filtered_fft, dim=1))
72
73 return filtered
74
75def main():
76 Nx, Ny, Nz = 128, 128, 128
77 phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
78
79 num_views = 360
80 angles_np = np.linspace(0, 2*math.pi, num_views, endpoint=False).astype(np.float32)
81
82 det_u, det_v = 256, 256
83 du, dv = 1.0, 1.0
84 sdd = 900.0
85 sid = 600.0
86
87 voxel_spacing = 1.0
88
89 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90 phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32, requires_grad=True).contiguous()
91 angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
92
93 sinogram = ConeProjectorFunction.apply(phantom_torch, angles_torch,
94 det_u, det_v, du, dv,
95 sdd, sid, voxel_spacing)
96
97 # --- FDK weighting and filtering ---
98 # For FDK, projections must be weighted before filtering.
99 # Weight = D / sqrt(D^2 + u^2 + v^2), where D is source_distance
100 # and (u,v) are detector coordinates.
101 u_coords = (torch.arange(det_u, dtype=phantom_torch.dtype, device=device) - (det_u - 1) / 2) * du
102 v_coords = (torch.arange(det_v, dtype=phantom_torch.dtype, device=device) - (det_v - 1) / 2) * dv
103
104 # Reshape for broadcasting over sinogram of shape (views, u, v)
105 u_coords = u_coords.view(1, det_u, 1)
106 v_coords = v_coords.view(1, 1, det_v)
107
108 weights = sdd / torch.sqrt(sdd**2 + u_coords**2 + v_coords**2)
109
110 # Apply weights and then filter
111 sino_weighted = sinogram * weights
112 sinogram_filt = ramp_filter_3d(sino_weighted).contiguous()
113
114 reconstruction = F.relu(ConeBackprojectorFunction.apply(sinogram_filt, angles_torch, Nz, Ny, Nx,
115 du, dv, sdd, sid, voxel_spacing)) # ReLU to ensure non-negativity
116
117 # --- FDK normalization ---
118 # The backprojection is a sum over all angles. To approximate the integral,
119 # we need to multiply by the angular step d_beta.
120 # The FDK formula also includes a factor of 1/2 when integrating over [0, 2*pi].
121 # d_beta = 2 * pi / num_views
122 # Normalization factor = (1/2) * d_beta = pi / num_views
123 reconstruction = reconstruction * (math.pi / num_views)
124
125 loss = torch.mean((reconstruction - phantom_torch)**2)
126 loss.backward()
127
128 print("Cone Beam Example with user-defined geometry:")
129 print("Loss:", loss.item())
130 print("Volume center voxel gradient:", phantom_torch.grad[Nz//2, Ny//2, Nx//2].item())
131 print("Reconstruction shape:", reconstruction.shape)
132
133 reconstruction_cpu = reconstruction.detach().cpu().numpy()
134 sinogram_cpu = sinogram.detach().cpu().numpy()
135 mid_slice = Nz // 2
136
137 plt.figure(figsize=(12,4))
138 plt.subplot(1,3,1)
139 plt.imshow(phantom_cpu[mid_slice, :,:], cmap='gray')
140 plt.title("Phantom mid-slice")
141 plt.axis('off')
142 plt.subplot(1,3,2)
143 plt.imshow(sinogram_cpu[num_views//2].T, cmap='gray', origin='lower') # Transpose for correct orientation
144 plt.title("Sinogram mid-view")
145 plt.axis('off')
146 plt.subplot(1,3,3)
147 plt.imshow(reconstruction_cpu[mid_slice, :, :], cmap='gray')
148 plt.title("Recon mid-slice")
149 plt.axis('off')
150 plt.tight_layout()
151 plt.show()
152
153 # print data range of the phantom and reco
154 print("Phantom data range:", phantom_cpu.min(), phantom_cpu.max())
155 print("Reco data range:", reconstruction_cpu.min(), reconstruction_cpu.max())
156
157if __name__ == "__main__":
158 main()