Skip to main content

qsm_core/bgremove/
smv.rs

1//! Simple Spherical Mean Value (SMV) background field removal
2//!
3//! Basic SMV filtering: subtracts the spherical mean of the field.
4//! Simpler than SHARP (no deconvolution step).
5//!
6//! local_field = field - SMV(field)
7//!
8//! Reference:
9//! Schweser, F., Deistung, A., Lehr, B.W., Reichenbach, J.R. (2011).
10//! "Quantitative imaging of intrinsic magnetic tissue properties using MRI signal phase."
11//! NeuroImage, 54(4):2789-2807. https://doi.org/10.1016/j.neuroimage.2010.10.070
12//!
13//! Reference implementation: https://github.com/kamesy/QSM.jl
14
15use num_complex::Complex64;
16use crate::fft::{fft3d, ifft3d};
17use crate::kernels::smv::smv_kernel;
18
19/// Simple SMV background field removal
20///
21/// Computes: local_field = field - SMV(field)
22///
23/// # Arguments
24/// * `field` - Unwrapped total field (nx * ny * nz)
25/// * `mask` - Binary mask (nx * ny * nz), 1 = inside ROI
26/// * `nx`, `ny`, `nz` - Array dimensions
27/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
28/// * `radius` - SMV kernel radius in mm
29///
30/// # Returns
31/// (local_field, eroded_mask)
32pub fn smv(
33    field: &[f64],
34    mask: &[u8],
35    nx: usize, ny: usize, nz: usize,
36    vsx: f64, vsy: f64, vsz: f64,
37    radius: f64,
38) -> (Vec<f64>, Vec<u8>) {
39    let n_total = nx * ny * nz;
40
41    // Generate SMV kernel and FFT it
42    let s_kernel = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
43
44    // FFT of SMV kernel
45    let mut s_complex: Vec<Complex64> = s_kernel.iter()
46        .map(|&x| Complex64::new(x, 0.0))
47        .collect();
48    fft3d(&mut s_complex, nx, ny, nz);
49    let s_fft = s_complex;
50
51    // Erode mask: convolve mask with SMV kernel
52    let mask_f64: Vec<f64> = mask.iter().map(|&m| m as f64).collect();
53    let mut mask_complex: Vec<Complex64> = mask_f64.iter()
54        .map(|&x| Complex64::new(x, 0.0))
55        .collect();
56
57    fft3d(&mut mask_complex, nx, ny, nz);
58
59    // Convolve mask with SMV kernel
60    for i in 0..n_total {
61        mask_complex[i] *= s_fft[i].re;
62    }
63
64    ifft3d(&mut mask_complex, nx, ny, nz);
65
66    // Eroded mask: values close to 1 are fully inside
67    let delta = 1.0 - 1e-10;
68    let eroded_mask: Vec<u8> = mask_complex.iter()
69        .map(|c| if c.re > delta { 1 } else { 0 })
70        .collect();
71
72    // Compute SMV(field) = background field estimate
73    let mut field_complex: Vec<Complex64> = field.iter()
74        .map(|&x| Complex64::new(x, 0.0))
75        .collect();
76    fft3d(&mut field_complex, nx, ny, nz);
77
78    // Multiply by SMV kernel in k-space
79    for i in 0..n_total {
80        field_complex[i] *= s_fft[i].re;
81    }
82
83    ifft3d(&mut field_complex, nx, ny, nz);
84
85    // Local field = field - SMV(field), within eroded mask
86    let local_field: Vec<f64> = field.iter()
87        .zip(field_complex.iter())
88        .enumerate()
89        .map(|(i, (&f, smv_f))| {
90            if eroded_mask[i] == 1 {
91                f - smv_f.re
92            } else {
93                0.0
94            }
95        })
96        .collect();
97
98    (local_field, eroded_mask)
99}
100
101/// Simple SMV with default parameters
102pub fn smv_default(
103    field: &[f64],
104    mask: &[u8],
105    nx: usize, ny: usize, nz: usize,
106    vsx: f64, vsy: f64, vsz: f64,
107) -> (Vec<f64>, Vec<u8>) {
108    // Default radius: 5mm (typical for brain imaging)
109    let radius = 5.0;
110    smv(field, mask, nx, ny, nz, vsx, vsy, vsz, radius)
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn test_smv_zero_field() {
119        let n = 16;
120        let field = vec![0.0; n * n * n];
121        let mask = vec![1u8; n * n * n];
122
123        let (local, _) = smv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 2.0);
124
125        for &val in local.iter() {
126            assert!(val.abs() < 1e-8, "Zero field should give zero local field, got {}", val);
127        }
128    }
129
130    #[test]
131    fn test_smv_finite() {
132        let n = 16;
133        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.01).collect();
134        let mask = vec![1u8; n * n * n];
135
136        let (local, eroded) = smv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 2.0);
137
138        for (i, &val) in local.iter().enumerate() {
139            assert!(val.is_finite(), "Local field should be finite at index {}", i);
140        }
141
142        let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
143        assert!(eroded_count > 0, "Eroded mask should have some voxels");
144    }
145}