Cone Beam FDK Reconstruction
This example demonstrates 3D cone beam FDK (Feldkamp-Davis-Kress) reconstruction using the ConeProjectorFunction and ConeBackprojectorFunction from diffct.
Overview
The FDK algorithm is the standard analytical method for 3D cone beam CT reconstruction. This example shows how to:
Configure 3D cone beam geometry with 2D detector array
Generate cone beam projections from a 3D phantom
Apply cosine weighting and ramp filtering for FDK reconstruction
Perform 3D backprojection to reconstruct the volume
Mathematical Background
Cone Beam Geometry
Cone beam CT extends fan beam to 3D using a point X-ray source and 2D detector array. Key parameters:
SDD \(D_s\): Source-to-Detector Distance (distance from X-ray source to detector plane)
SID \(D_{sid}\): Source-to-Isocenter Distance (distance from X-ray source to rotation center)
Detector coordinates \((u, v)\): Horizontal and vertical detector positions
Cone angles \((\alpha, \beta)\): Horizontal and vertical beam divergence
3D Forward Projection
The cone beam projection at source angle \(\phi\) and detector position \((u, v)\) is:
where \(\vec{r}_s(\phi)\) is the source position and \(\vec{d}(\phi, u, v)\) is the ray direction vector.
FDK Algorithm
The Feldkamp-Davis-Kress algorithm performs approximate 3D reconstruction in three steps:
Cosine Weighting: Compensate for ray divergence and \(1/r^2\) intensity falloff:
\[p_w(\phi, u, v) = p(\phi, u, v) \cdot \frac{D_s}{\sqrt{D_s^2 + u^2 + v^2}}\]Row-wise Ramp Filtering: Apply 1D ramp filter along detector rows (u-direction):
\[p_f(\phi, u, v) = \mathcal{F}_u^{-1}\{|\omega_u| \cdot \mathcal{F}_u\{p_w(\phi, u, v)\}\}\]Each detector row is filtered independently.
3D Cone Beam Backprojection: Reconstruct volume using weighted backprojection:
\[f(x,y,z) = \int_0^{2\pi} \frac{D_s^2}{(D_s + x\cos\phi + y\sin\phi)^2} p_f(\phi, u_{xyz}, v_{xyz}) d\phi\]where detector coordinates \((u_{xyz}, v_{xyz})\) for voxel \((x,y,z)\) are:
\[u_{xyz} = D_s \frac{-x\sin\phi + y\cos\phi}{D_s + x\cos\phi + y\sin\phi}\]\[v_{xyz} = D_s \frac{z}{D_s + x\cos\phi + y\sin\phi}\]
Implementation Steps
3D Phantom Generation: Create 3D Shepp-Logan phantom with 10 ellipsoids
Cone Beam Projection: Generate 2D projections using ConeProjectorFunction
Cosine Weighting: Apply distance-dependent weights
Row-wise Filtering: Apply ramp filter to each detector row
3D Backprojection: Reconstruct volume using ConeBackprojectorFunction
Normalization: Scale by \(\frac{\pi}{N_{\text{angles}}}\) factor
3D Shepp-Logan Phantom
The 3D phantom extends the 2D version with 10 ellipsoids representing anatomical structures:
Outer skull: Large ellipsoid encompassing the head
Brain tissue: Medium ellipsoids for different brain regions
Ventricles: Small ellipsoids representing fluid-filled cavities
Lesions: High-contrast features for reconstruction assessment
Each ellipsoid is defined by center position \((x_0, y_0, z_0)\), semi-axes \((a, b, c)\), rotation angles, and attenuation coefficient.
FDK Approximations and Limitations
The FDK algorithm makes several approximations:
Circular orbit: Assumes circular source trajectory
Row-wise filtering: Ramp filtering only along detector rows
Small cone angle: Most accurate for limited cone angles
These approximations introduce cone beam artifacts for large cone angles, but FDK remains widely used due to computational efficiency.
1import math
2import numpy as np
3import torch
4import matplotlib.pyplot as plt
5import torch.nn.functional as F
6from diffct.differentiable import (
7 ConeProjectorFunction,
8 angular_integration_weights,
9 cone_cosine_weights,
10 cone_weighted_backproject,
11 ramp_filter_1d,
12)
13
14
15def shepp_logan_3d(shape):
16 zz, yy, xx = np.mgrid[:shape[0], :shape[1], :shape[2]]
17 xx = (xx - (shape[2] - 1) / 2) / ((shape[2] - 1) / 2)
18 yy = (yy - (shape[1] - 1) / 2) / ((shape[1] - 1) / 2)
19 zz = (zz - (shape[0] - 1) / 2) / ((shape[0] - 1) / 2)
20 el_params = np.array([
21 [0, 0, 0, 0.69, 0.92, 0.81, 0, 0, 0, 1],
22 [0, -0.0184, 0, 0.6624, 0.874, 0.78, 0, 0, 0, -0.8],
23 [0.22, 0, 0, 0.11, 0.31, 0.22, -np.pi/10.0, 0, 0, -0.2],
24 [-0.22, 0, 0, 0.16, 0.41, 0.28, np.pi/10.0, 0, 0, -0.2],
25 [0, 0.35, -0.15, 0.21, 0.25, 0.41, 0, 0, 0, 0.1],
26 [0, 0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
27 [0, -0.1, 0.25, 0.046, 0.046, 0.05, 0, 0, 0, 0.1],
28 [-0.08, -0.605, 0, 0.046, 0.023, 0.05, 0, 0, 0, 0.1],
29 [0, -0.605, 0, 0.023, 0.023, 0.02, 0, 0, 0, 0.1],
30 [0.06, -0.605, 0, 0.023, 0.046, 0.02, 0, 0, 0, 0.1],
31 ], dtype=np.float32)
32
33 # Extract parameters for vectorization
34 x_pos = el_params[:, 0][:, None, None, None]
35 y_pos = el_params[:, 1][:, None, None, None]
36 z_pos = el_params[:, 2][:, None, None, None]
37 a_axis = el_params[:, 3][:, None, None, None]
38 b_axis = el_params[:, 4][:, None, None, None]
39 c_axis = el_params[:, 5][:, None, None, None]
40 phi = el_params[:, 6][:, None, None, None]
41 val = el_params[:, 9][:, None, None, None]
42
43 # Broadcast grid to ellipsoid axis
44 xc = xx[None, ...] - x_pos
45 yc = yy[None, ...] - y_pos
46 zc = zz[None, ...] - z_pos
47
48 c = np.cos(phi)
49 s = np.sin(phi)
50
51 # Only rotation around z, so can vectorize:
52 xp = c * xc - s * yc
53 yp = s * xc + c * yc
54 zp = zc
55
56 mask = (
57 (xp ** 2) / (a_axis ** 2)
58 + (yp ** 2) / (b_axis ** 2)
59 + (zp ** 2) / (c_axis ** 2)
60 <= 1.0
61 )
62
63 # Use broadcasting to sum all ellipsoid contributions
64 shepp_logan = np.sum(mask * val, axis=0)
65 shepp_logan = np.clip(shepp_logan, 0, 1)
66 return shepp_logan
67
68def main():
69 Nx, Ny, Nz = 128, 128, 128
70 phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
71
72 num_views = 360
73 angles_np = np.linspace(0, 2*math.pi, num_views, endpoint=False).astype(np.float32)
74
75 det_u, det_v = 256, 256
76 du, dv = 1.0, 1.0
77 detector_offset_u = 0.0
78 detector_offset_v = 0.0
79 sdd = 900.0
80 sid = 600.0
81
82 voxel_spacing = 1.0
83
84 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85 phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32).contiguous()
86 angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
87
88 sinogram = ConeProjectorFunction.apply(phantom_torch, angles_torch,
89 det_u, det_v, du, dv,
90 sdd, sid, voxel_spacing)
91
92 # --- FDK weighting and filtering ---
93 # 1) FDK cosine pre-weighting
94 weights = cone_cosine_weights(
95 det_u,
96 det_v,
97 du,
98 dv,
99 sdd,
100 detector_offset_u=detector_offset_u,
101 detector_offset_v=detector_offset_v,
102 device=device,
103 dtype=phantom_torch.dtype,
104 ).unsqueeze(0)
105 sino_weighted = sinogram * weights
106
107 # 2) Ramp filter along detector-u rows
108 sinogram_filt = ramp_filter_1d(sino_weighted, dim=1).contiguous()
109
110 # 3) Angle-integration weights
111 d_beta = angular_integration_weights(angles_torch, redundant_full_scan=True).view(-1, 1, 1)
112 sinogram_filt = sinogram_filt * d_beta
113
114 # 4) Weighted cone-beam backprojection
115 reconstruction = F.relu(
116 cone_weighted_backproject(
117 sinogram_filt,
118 angles_torch,
119 Nz,
120 Ny,
121 Nx,
122 du,
123 dv,
124 sdd,
125 sid,
126 voxel_spacing=voxel_spacing,
127 detector_offset_u=detector_offset_u,
128 detector_offset_v=detector_offset_v,
129 )
130 )
131
132 loss = torch.mean((reconstruction - phantom_torch)**2)
133
134 print("Cone Beam Example with user-defined geometry:")
135 print("Loss:", loss.item())
136 print("Reconstruction shape:", reconstruction.shape)
137
138 reconstruction_cpu = reconstruction.detach().cpu().numpy()
139 sinogram_cpu = sinogram.detach().cpu().numpy()
140 mid_slice = Nz // 2
141
142 plt.figure(figsize=(12,4))
143 plt.subplot(1,3,1)
144 plt.imshow(phantom_cpu[mid_slice, :,:], cmap='gray')
145 plt.title("Phantom mid-slice")
146 plt.axis('off')
147 plt.subplot(1,3,2)
148 plt.imshow(sinogram_cpu[num_views//2].T, cmap='gray', origin='lower') # Transpose for correct orientation
149 plt.title("Sinogram mid-view")
150 plt.axis('off')
151 plt.subplot(1,3,3)
152 plt.imshow(reconstruction_cpu[mid_slice, :, :], cmap='gray')
153 plt.title("Recon mid-slice")
154 plt.axis('off')
155 plt.tight_layout()
156 plt.show()
157
158 # print data range of the phantom and reco
159 print("Phantom data range:", phantom_cpu.min(), phantom_cpu.max())
160 print("Reco data range:", reconstruction_cpu.min(), reconstruction_cpu.max())
161
162if __name__ == "__main__":
163 main()