Skip to main content

qsm_core/utils/
qsmart.rs

1//! QSMART (Quantitative Susceptibility Mapping Artifact Reduction Technique)
2//!
3//! This module provides the complete QSMART pipeline including:
4//! - Two-stage QSM reconstruction (whole ROI + tissue only)
5//! - Offset adjustment for combining the two stages
6//! - Integration of SDF background removal with iLSQR inversion
7//!
8//! Reference:
9//! Yaghmaie, N., Syeda, W., et al. (2021).
10//! "QSMART: Quantitative Susceptibility Mapping Artifact Reduction Technique."
11//! NeuroImage, 231:117701. https://doi.org/10.1016/j.neuroimage.2020.117701
12//!
13//! Reference implementation: https://github.com/wtsyeda/QSMART
14
15use crate::fft::{fft3d, ifft3d};
16use crate::kernels::dipole::dipole_kernel;
17use num_complex::Complex64;
18
19/// Adjust offset between two-stage QSMART susceptibility maps
20///
21/// Combines chi_1 (whole ROI) and chi_2 (tissue only) with offset adjustment
22/// to ensure consistency with the original field data.
23///
24/// # Arguments
25/// * `removed_voxels` - Mask of removed voxels (mask * R_0 - vasc_only), indicates where stage 1 but not stage 2 was applied
26/// * `lfs_sdf` - Local field shift from stage 1 SDF (in ppm, will be scaled back)
27/// * `chi_1` - Susceptibility from stage 1 (whole ROI)
28/// * `chi_2` - Susceptibility from stage 2 (tissue only)
29/// * `nx`, `ny`, `nz` - Volume dimensions
30/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
31/// * `b0_dir` - B0 field direction (unit vector)
32/// * `ppm` - PPM conversion factor (to scale lfs_sdf back)
33///
34/// # Returns
35/// Combined and offset-adjusted susceptibility map
36pub fn adjust_offset(
37    removed_voxels: &[f64],
38    lfs_sdf: &[f64],
39    chi_1: &[f64],
40    chi_2: &[f64],
41    nx: usize, ny: usize, nz: usize,
42    vsx: f64, vsy: f64, vsz: f64,
43    b0_dir: (f64, f64, f64),
44    ppm: f64,
45) -> Vec<f64> {
46    let n_total = nx * ny * nz;
47
48    // Scale lfs_sdf back (it was multiplied by ppm in the QSMART pipeline)
49    let lfs_scaled: Vec<f64> = lfs_sdf.iter().map(|&v| v / ppm).collect();
50
51    // Clean removed_voxels (clamp negatives to 0)
52    let removed_clean: Vec<f64> = removed_voxels.iter()
53        .map(|&v| if v < 0.0 { 0.0 } else { v })
54        .collect();
55
56    // Zero out chi_1 where removed_voxels is 0
57    let chi_1_masked: Vec<f64> = chi_1.iter()
58        .zip(removed_clean.iter())
59        .map(|(&c, &r)| if r > 0.0 { c } else { 0.0 })
60        .collect();
61
62    // Combined chi = chi_1_masked + chi_2
63    let combined_chi: Vec<f64> = chi_1_masked.iter()
64        .zip(chi_2.iter())
65        .map(|(&c1, &c2)| c1 + c2)
66        .collect();
67
68    // Get dipole kernel
69    let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, b0_dir);
70
71    // Compute offset using Fourier-space relationship
72    // x1 = ifft(D * fft(removed_voxels))
73    // x2 = lfs_sdf - ifft(D * fft(combined_chi))
74    // offset = real(x1' * x2) / real(x1' * x1)
75
76    // Convert to complex for FFT
77    let mut removed_complex: Vec<Complex64> = removed_clean.iter()
78        .map(|&v| Complex64::new(v, 0.0))
79        .collect();
80    let mut combined_complex: Vec<Complex64> = combined_chi.iter()
81        .map(|&v| Complex64::new(v, 0.0))
82        .collect();
83
84    // FFT of removed_voxels and combined_chi
85    fft3d(&mut removed_complex, nx, ny, nz);
86    fft3d(&mut combined_complex, nx, ny, nz);
87
88    // Multiply by dipole kernel
89    let mut d_removed: Vec<Complex64> = removed_complex.iter()
90        .zip(d_kernel.iter())
91        .map(|(&c, &d)| c * d)
92        .collect();
93
94    let mut d_combined: Vec<Complex64> = combined_complex.iter()
95        .zip(d_kernel.iter())
96        .map(|(&c, &d)| c * d)
97        .collect();
98
99    // Inverse FFT
100    ifft3d(&mut d_removed, nx, ny, nz);
101    ifft3d(&mut d_combined, nx, ny, nz);
102
103    // x1 = real part of ifft(D * fft(removed))
104    let x1: Vec<f64> = d_removed.iter().map(|c| c.re).collect();
105
106    // x2 = lfs_sdf - real(ifft(D * fft(combined_chi)))
107    let x2: Vec<f64> = lfs_scaled.iter()
108        .zip(d_combined.iter())
109        .map(|(&lfs, c)| lfs - c.re)
110        .collect();
111
112    // Compute offset: o = (x1' * x2) / (x1' * x1)
113    let x1_dot_x2: f64 = x1.iter().zip(x2.iter()).map(|(&a, &b)| a * b).sum();
114    let x1_dot_x1: f64 = x1.iter().map(|&a| a * a).sum();
115
116    let offset = if x1_dot_x1.abs() > 1e-10 {
117        x1_dot_x2 / x1_dot_x1
118    } else {
119        0.0
120    };
121
122    // Adjusted combined chi = combined_chi + offset * removed_voxels
123    let adjusted: Vec<f64> = combined_chi.iter()
124        .zip(removed_clean.iter())
125        .map(|(&c, &r)| c + offset * r)
126        .collect();
127
128    adjusted
129}
130
131/// Complete QSMART pipeline parameters
132#[derive(Clone, Debug)]
133pub struct QsmartParams {
134    /// PPM conversion factor: (gyro * field) / 1e6
135    pub ppm: f64,
136    /// SDF parameters for stage 1
137    pub sdf_sigma1_stage1: f64,
138    pub sdf_sigma2_stage1: f64,
139    /// SDF parameters for stage 2
140    pub sdf_sigma1_stage2: f64,
141    pub sdf_sigma2_stage2: f64,
142    /// SDF spatial radius for morphological closing
143    pub sdf_spatial_radius: i32,
144    /// SDF lower limit for proximity clamping
145    pub sdf_lower_lim: f64,
146    /// SDF curvature constant
147    pub sdf_curv_constant: f64,
148    /// Vasculature sphere radius for bottom-hat
149    pub vasc_sphere_radius: i32,
150    /// Frangi scale range for vessel detection
151    pub frangi_scale_range: [f64; 2],
152    /// Frangi scale ratio
153    pub frangi_scale_ratio: f64,
154    /// Frangi C parameter
155    pub frangi_c: f64,
156    /// iLSQR tolerance
157    pub ilsqr_tol: f64,
158    /// iLSQR max iterations
159    pub ilsqr_max_iter: usize,
160    /// B0 field direction
161    pub b0_dir: (f64, f64, f64),
162}
163
164impl Default for QsmartParams {
165    fn default() -> Self {
166        // Default values for 7T human brain
167        Self {
168            ppm: 2.675e8 * 7.0 / 1e6, // gyro * field / 1e6
169            sdf_sigma1_stage1: 10.0,
170            sdf_sigma2_stage1: 10.0,
171            sdf_sigma1_stage2: 8.0,
172            sdf_sigma2_stage2: 2.0,
173            sdf_spatial_radius: 8,
174            sdf_lower_lim: 0.6,
175            sdf_curv_constant: 500.0,
176            vasc_sphere_radius: 8,
177            // QSMART Demo defaults: FrangiScaleRange=[0.5,6], FrangiScaleRatio=0.5
178            frangi_scale_range: [0.5, 6.0],
179            frangi_scale_ratio: 0.5,
180            frangi_c: 500.0,
181            ilsqr_tol: 0.01,
182            ilsqr_max_iter: 50,
183            b0_dir: (0.0, 0.0, 1.0),
184        }
185    }
186}
187
188impl QsmartParams {
189    /// Create parameters for specific field strength
190    pub fn for_field_strength(field_tesla: f64) -> Self {
191        let gyro = 2.675e8; // Proton gyromagnetic ratio
192        Self {
193            ppm: gyro * field_tesla / 1e6,
194            ..Default::default()
195        }
196    }
197}
198
199/// Result of QSMART pipeline
200pub struct QsmartResult {
201    /// Final combined and offset-adjusted susceptibility map
202    pub chi_qsmart: Vec<f64>,
203    /// Stage 1 susceptibility (whole ROI)
204    pub chi_stage1: Vec<f64>,
205    /// Stage 2 susceptibility (tissue only)
206    pub chi_stage2: Vec<f64>,
207    /// Local field from stage 1
208    pub lfs_stage1: Vec<f64>,
209    /// Local field from stage 2
210    pub lfs_stage2: Vec<f64>,
211    /// Vasculature mask (1 = tissue, 0 = vessel)
212    pub vasc_mask: Vec<f64>,
213}
214
215/// Compute removed voxels mask for offset adjustment
216///
217/// removed_voxels = (mask * R_0) - vasc_only
218/// This represents voxels processed in stage 1 but not in stage 2
219pub fn compute_removed_voxels(
220    mask: &[f64],
221    r_0: &[f64],
222    vasc_only: &[f64],
223) -> Vec<f64> {
224    mask.iter()
225        .zip(r_0.iter())
226        .zip(vasc_only.iter())
227        .map(|((&m, &r), &v)| m * r - v)
228        .collect()
229}
230
231/// Compute weighted mask for iLSQR
232///
233/// For stage 1: mask * R_0
234/// For stage 2: mask * vasc_only * R_0
235pub fn compute_weighted_mask_stage1(mask: &[f64], r_0: &[f64]) -> Vec<f64> {
236    mask.iter().zip(r_0.iter()).map(|(&m, &r)| m * r).collect()
237}
238
239pub fn compute_weighted_mask_stage2(mask: &[f64], r_0: &[f64], vasc_only: &[f64]) -> Vec<f64> {
240    mask.iter()
241        .zip(r_0.iter())
242        .zip(vasc_only.iter())
243        .map(|((&m, &r), &v)| m * v * r)
244        .collect()
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_offset_adjustment_basic() {
253        // Basic test: should run without panic
254        let n = 8;
255        let n_total = n * n * n;
256
257        let removed = vec![1.0f64; n_total];
258        let lfs = vec![0.0f64; n_total];
259        let chi_1 = vec![0.1f64; n_total];
260        let chi_2 = vec![0.1f64; n_total];
261
262        let result = adjust_offset(
263            &removed, &lfs, &chi_1, &chi_2,
264            n, n, n, 1.0, 1.0, 1.0,
265            (0.0, 0.0, 1.0), 1.0
266        );
267
268        assert_eq!(result.len(), n_total);
269    }
270
271    #[test]
272    fn test_params_for_field_strength() {
273        let params_7t = QsmartParams::for_field_strength(7.0);
274        let params_3t = QsmartParams::for_field_strength(3.0);
275
276        // 7T should have higher PPM than 3T
277        assert!(params_7t.ppm > params_3t.ppm);
278    }
279}