Skip to main content

qsm_core/bgremove/
sharp.rs

1//! SHARP background field removal
2//!
3//! Sophisticated Harmonic Artifact Reduction for Phase data.
4//! Uses the spherical mean value property of harmonic functions to
5//! separate local from background fields.
6//!
7//! Reference:
8//! Schweser, F., Deistung, A., Lehr, B.W., Reichenbach, J.R. (2011).
9//! "Quantitative imaging of intrinsic magnetic tissue properties using MRI signal phase."
10//! NeuroImage, 54(4):2789-2807. https://doi.org/10.1016/j.neuroimage.2010.10.070
11//!
12//! Reference implementation: https://github.com/kamesy/QSM.jl
13
14use num_complex::Complex64;
15use crate::fft::{fft3d, ifft3d};
16use crate::kernels::smv::smv_kernel;
17
18/// SHARP background field removal
19///
20/// Uses spherical mean value (SMV) filtering to remove background field.
21/// The local field is obtained by deconvolving the SMV-filtered field.
22///
23/// # Arguments
24/// * `field` - Unwrapped total field (nx * ny * nz)
25/// * `mask` - Binary mask (nx * ny * nz), 1 = inside ROI
26/// * `nx`, `ny`, `nz` - Array dimensions
27/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
28/// * `radius` - SMV kernel radius in mm
29/// * `threshold` - High-pass filter threshold (typically 0.05)
30///
31/// SHARP algorithm parameters
32#[derive(Clone, Debug)]
33pub struct SharpParams {
34    /// Deconvolution threshold
35    pub threshold: f64,
36    /// Kernel radius factor (multiplied by min voxel size to get radius in mm; default: 18.0)
37    pub radius_factor: f64,
38}
39
40impl Default for SharpParams {
41    fn default() -> Self {
42        Self {
43            threshold: 0.05,
44            radius_factor: 18.0,
45        }
46    }
47}
48
49/// # Returns
50/// (local_field, eroded_mask)
51pub fn sharp(
52    field: &[f64],
53    mask: &[u8],
54    nx: usize, ny: usize, nz: usize,
55    vsx: f64, vsy: f64, vsz: f64,
56    radius: f64,
57    threshold: f64,
58) -> (Vec<f64>, Vec<u8>) {
59    let n_total = nx * ny * nz;
60
61    // Generate SMV kernel and FFT it
62    let s_kernel = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
63
64    // FFT of SMV kernel
65    let mut s_complex: Vec<Complex64> = s_kernel.iter()
66        .map(|&x| Complex64::new(x, 0.0))
67        .collect();
68    fft3d(&mut s_complex, nx, ny, nz);
69
70    // S is the real part of FFT(smv_kernel)
71    // 1-S is the high-pass kernel
72    let s_fft: Vec<f64> = s_complex.iter().map(|c| c.re).collect();
73
74    // Erode mask: convolve mask with SMV kernel
75    // Voxels where convolution result < 1 are near boundary
76    let mask_f64: Vec<f64> = mask.iter().map(|&m| m as f64).collect();
77    let mut mask_complex: Vec<Complex64> = mask_f64.iter()
78        .map(|&x| Complex64::new(x, 0.0))
79        .collect();
80
81    fft3d(&mut mask_complex, nx, ny, nz);
82
83    // Convolve mask with SMV kernel
84    for i in 0..n_total {
85        mask_complex[i] *= s_complex[i].re;  // S is real
86    }
87
88    ifft3d(&mut mask_complex, nx, ny, nz);
89
90    // Eroded mask: values close to 1 are fully inside
91    let delta = 1.0 - 1e-7_f64.sqrt();  // ~1 - eps
92    let eroded_mask: Vec<u8> = mask_complex.iter()
93        .map(|c| if c.re > delta { 1 } else { 0 })
94        .collect();
95
96    // Apply SHARP:
97    // 1. Multiply field by (1-S) in k-space (high-pass filter)
98    // 2. Apply eroded mask
99    // 3. Divide by (1-S) with threshold (deconvolution)
100    // 4. Apply eroded mask
101
102    // FFT of field
103    let mut field_complex: Vec<Complex64> = field.iter()
104        .map(|&x| Complex64::new(x, 0.0))
105        .collect();
106    fft3d(&mut field_complex, nx, ny, nz);
107
108    // High-pass filter: multiply by (1-S)
109    for i in 0..n_total {
110        field_complex[i] *= 1.0 - s_fft[i];
111    }
112
113    // IFFT
114    ifft3d(&mut field_complex, nx, ny, nz);
115
116    // Apply eroded mask
117    for i in 0..n_total {
118        if eroded_mask[i] == 0 {
119            field_complex[i] = Complex64::new(0.0, 0.0);
120        }
121    }
122
123    // FFT again for deconvolution
124    fft3d(&mut field_complex, nx, ny, nz);
125
126    // Deconvolution: divide by (1-S) with threshold
127    for i in 0..n_total {
128        let one_minus_s = 1.0 - s_fft[i];
129        if one_minus_s.abs() < threshold {
130            field_complex[i] = Complex64::new(0.0, 0.0);
131        } else {
132            field_complex[i] /= one_minus_s;
133        }
134    }
135
136    // Final IFFT
137    ifft3d(&mut field_complex, nx, ny, nz);
138
139    // Apply eroded mask and extract real part
140    let local_field: Vec<f64> = field_complex.iter()
141        .enumerate()
142        .map(|(i, c)| if eroded_mask[i] == 1 { c.re } else { 0.0 })
143        .collect();
144
145    (local_field, eroded_mask)
146}
147
148/// SHARP with default parameters
149pub fn sharp_default(
150    field: &[f64],
151    mask: &[u8],
152    nx: usize, ny: usize, nz: usize,
153    vsx: f64, vsy: f64, vsz: f64,
154) -> (Vec<f64>, Vec<u8>) {
155    // Default radius: 18 * minimum voxel size
156    let radius = 18.0 * vsx.min(vsy).min(vsz);
157    sharp(field, mask, nx, ny, nz, vsx, vsy, vsz, radius, 0.05)
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_sharp_zero_field() {
166        // Zero field should give zero local field
167        let n = 16;
168        let field = vec![0.0; n * n * n];
169        let mask = vec![1u8; n * n * n];
170
171        // Use small radius for small test array
172        let (local, _) = sharp(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 2.0, 0.05);
173
174        for &val in local.iter() {
175            assert!(val.abs() < 1e-8, "Zero field should give zero local field, got {}", val);
176        }
177    }
178
179    #[test]
180    fn test_sharp_finite() {
181        // Result should be finite (no NaN or Inf)
182        let n = 16;
183        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.01).collect();
184        let mask = vec![1u8; n * n * n];
185
186        let (local, eroded) = sharp(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 2.0, 0.05);
187
188        for (i, &val) in local.iter().enumerate() {
189            assert!(val.is_finite(), "Local field should be finite at index {}", i);
190        }
191
192        // Eroded mask should have at least some voxels
193        let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
194        assert!(eroded_count > 0, "Eroded mask should have some voxels");
195    }
196}