Skip to main content

qsm_core/bgremove/
sdf.rs

1//! Spatially Dependent Filtering (SDF) for QSMART
2//!
3//! SDF is the background field removal method used in QSMART. It uses variable-radius
4//! Gaussian filtering where the kernel size depends on the proximity to the brain boundary.
5//! This allows for aggressive filtering in the brain interior while preserving details
6//! near the surface.
7//!
8//! The algorithm includes optional curvature-based weighting to further reduce artifacts
9//! at highly curved brain regions.
10//!
11//! Reference:
12//! Yaghmaie, N., Syeda, W., et al. (2021).
13//! "QSMART: Quantitative Susceptibility Mapping Artifact Reduction Technique."
14//! NeuroImage, 231:117701. https://doi.org/10.1016/j.neuroimage.2020.117701
15//!
16//! Reference implementation: https://github.com/wtsyeda/QSMART
17
18use crate::utils::curvature::calculate_curvature_proximity;
19
20/// Parameters for SDF background field removal
21#[derive(Clone, Debug)]
22pub struct SdfParams {
23    /// Sigma parameter for initial proximity map (default 10 for stage 1, 8 for stage 2)
24    pub sigma1: f64,
25    /// Sigma parameter for vasculature proximity (default 0 for stage 1, 2 for stage 2)
26    pub sigma2: f64,
27    /// Spatial radius for morphological closing of indents (default 8 voxels)
28    pub spatial_radius: i32,
29    /// Lower limit for clamping proximity values (default 0.6)
30    pub lower_lim: f64,
31    /// Scaling constant for curvature (default 500)
32    pub curv_constant: f64,
33    /// Whether to use curvature-based edge refinement
34    pub use_curvature: bool,
35}
36
37impl Default for SdfParams {
38    fn default() -> Self {
39        Self {
40            sigma1: 10.0,
41            // MATLAB passes sigma1 for both sigma args in stage 1.
42            // sigma2 affects the combined sigma even though vasculature
43            // proximity is skipped when vasc_only is all-ones.
44            sigma2: 10.0,
45            spatial_radius: 8,
46            lower_lim: 0.6,
47            curv_constant: 500.0,
48            use_curvature: true,
49        }
50    }
51}
52
53impl SdfParams {
54    /// Create parameters for QSMART Stage 1
55    /// MATLAB passes sigma1 for both sigma args in stage 1.
56    pub fn stage1() -> Self {
57        Self {
58            sigma1: 10.0,
59            sigma2: 10.0,
60            spatial_radius: 8,
61            lower_lim: 0.6,
62            curv_constant: 500.0,
63            use_curvature: true,
64        }
65    }
66
67    /// Create parameters for QSMART Stage 2
68    pub fn stage2() -> Self {
69        Self {
70            sigma1: 8.0,
71            sigma2: 2.0,
72            spatial_radius: 8,
73            lower_lim: 0.6,
74            curv_constant: 500.0,
75            use_curvature: true,
76        }
77    }
78}
79
80/// SDF background field removal
81///
82/// Removes background field from total field shift using spatially dependent filtering.
83///
84/// # Arguments
85/// * `tfs` - Total field shift (unwrapped phase / ppm)
86/// * `mask` - Binary brain mask (weighted by reliability if R_0 is incorporated)
87/// * `vasc_only` - Vasculature-only mask (1 = tissue, 0 = vessel). Pass all-ones for stage 1.
88/// * `nx`, `ny`, `nz` - Volume dimensions
89/// * `params` - SDF parameters
90///
91/// # Returns
92/// Local field shift (background removed)
93pub fn sdf(
94    tfs: &[f64],
95    mask: &[f64],
96    vasc_only: &[f64],
97    nx: usize, ny: usize, nz: usize,
98    params: &SdfParams,
99) -> Vec<f64> {
100    sdf_with_progress(tfs, mask, vasc_only, nx, ny, nz, params, |_, _| {})
101}
102
103/// SDF with progress callback
104pub fn sdf_with_progress<F>(
105    tfs: &[f64],
106    mask: &[f64],
107    vasc_only: &[f64],
108    nx: usize, ny: usize, nz: usize,
109    params: &SdfParams,
110    progress_callback: F,
111) -> Vec<f64>
112where
113    F: Fn(usize, usize),
114{
115    let n_total = nx * ny * nz;
116
117    // Convert mask to binary for morphological operations
118    let mask_binary: Vec<u8> = mask.iter().map(|&v| if v > 0.0 { 1 } else { 0 }).collect();
119
120    // Combined sigma for n calculation
121    let sigma = (params.sigma1 * params.sigma1 + params.sigma2 * params.sigma2).sqrt();
122    let n = if sigma > 0.0 { -sigma.ln() / 0.5_f64.ln() } else { 0.0 };
123
124    // Calculate initial proximity map (prox1)
125    // Gaussian smoothing of mask with anisotropic kernel [sigma1, 2*sigma1, 2*sigma1]
126    let prox1 = if params.sigma1 > 0.0 {
127        gaussian_smooth_3d_masked_f64(mask, mask, nx, ny, nz, &[params.sigma1, 2.0 * params.sigma1, 2.0 * params.sigma1])
128    } else {
129        mask.to_vec()
130    };
131
132    // Calculate curvature-based proximity if enabled
133    let prox = if params.use_curvature {
134        let (prox_curv, _curv_i) = calculate_curvature_proximity(
135            &mask_binary,
136            &prox1,
137            params.lower_lim,
138            params.curv_constant,
139            params.sigma1,
140            nx, ny, nz,
141        );
142        prox_curv
143    } else {
144        // Even without curvature, clamp proximity to lower_lim to prevent
145        // filter sizes from getting too small at the edges
146        // (matching calculate_curvature.m line 45: prox(prox < lowerLim & prox ~= 0) = lowerLim)
147        prox1.iter()
148            .zip(mask.iter())
149            .map(|(&p, &m)| {
150                if m > 0.0 && p > 0.0 && p < params.lower_lim {
151                    params.lower_lim
152                } else {
153                    p
154                }
155            })
156            .collect()
157    };
158
159    // Calculate vasculature proximity (prox2) for stage 2
160    let prox_final = if params.sigma2 > 0.0 {
161        let prox2 = gaussian_smooth_3d_masked_f64(vasc_only, mask, nx, ny, nz, &[params.sigma2, params.sigma2, params.sigma2]);
162        // Multiply prox * prox2
163        prox.iter().zip(prox2.iter()).map(|(&p, &p2)| p * p2).collect()
164    } else {
165        prox
166    };
167
168    // Calculate alpha = sigma * round(prox^n, 2)
169    // Alpha determines the local smoothing kernel size
170    let mut alpha: Vec<f64> = prox_final.iter()
171        .zip(mask.iter())
172        .map(|(&p, &m)| {
173            if m > 0.0 {
174                let val = sigma * (p.powf(n) * 100.0).round() / 100.0;
175                val
176            } else {
177                0.0
178            }
179        })
180        .collect();
181
182    // Set alpha=1 for vessel regions within mask
183    // (vasc_only=0 means vessel, matching sdf_curvature.m line 27)
184    for i in 0..n_total {
185        if mask[i] > 0.0 && vasc_only[i] == 0.0 {
186            alpha[i] = 1.0;
187        }
188    }
189
190    // Get unique alpha values and sort
191    let mut unique_alphas: Vec<f64> = alpha.iter()
192        .filter(|&&a| a > 0.0)
193        .copied()
194        .collect();
195    unique_alphas.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
196    unique_alphas.dedup();
197
198    let total_alphas = unique_alphas.len();
199
200    // Create index map: for each voxel, which alpha index does it belong to?
201    let mut alpha_index: Vec<usize> = vec![0; n_total];
202    for i in 0..n_total {
203        if alpha[i] > 0.0 {
204            // Find index in unique_alphas
205            let idx = unique_alphas.iter().position(|&a| (a - alpha[i]).abs() < 1e-10).unwrap_or(0);
206            alpha_index[i] = idx + 1; // 1-indexed to distinguish from background
207        }
208    }
209
210    // Apply spatially dependent filtering
211    // For each unique alpha, smooth and assign to corresponding voxels
212    let mut background = vec![0.0f64; n_total];
213
214    // Pre-compute filter size
215    let filter_size = 2 * (2.0 * sigma).ceil() as usize + 1;
216
217    for (alpha_idx, &current_alpha) in unique_alphas.iter().enumerate() {
218        progress_callback(alpha_idx, total_alphas);
219
220        // Compute smoothed field for this alpha
221        let smoothed: Vec<f64> = if current_alpha > 0.0 {
222            // Smooth tfs * mask with Gaussian kernel of size current_alpha
223            let weighted_tfs: Vec<f64> = tfs.iter()
224                .zip(mask.iter())
225                .map(|(&t, &m)| t * m)
226                .collect();
227
228            let num = gaussian_smooth_3d_with_filter_size(&weighted_tfs, nx, ny, nz, current_alpha, filter_size);
229            let denom = gaussian_smooth_3d_with_filter_size(&mask.to_vec(), nx, ny, nz, current_alpha, filter_size);
230
231            // Divide: num / denom
232            num.iter()
233                .zip(denom.iter())
234                .map(|(&n, &d)| if d.abs() > 1e-10 { n / d } else { 0.0 })
235                .collect()
236        } else {
237            // alpha=0: just use tfs*mask
238            tfs.iter().zip(mask.iter()).map(|(&t, &m)| t * m).collect()
239        };
240
241        // Assign to voxels with this alpha
242        for i in 0..n_total {
243            if alpha_index[i] == alpha_idx + 1 {
244                background[i] = smoothed[i];
245            }
246        }
247    }
248
249    progress_callback(total_alphas, total_alphas);
250
251    // Compute local field: (tfs - background) * mask
252    let local_field: Vec<f64> = tfs.iter()
253        .zip(background.iter())
254        .zip(mask.iter())
255        .map(|((&t, &b), &m)| (t - b) * m)
256        .collect();
257
258    local_field
259}
260
261/// SDF with curvature-based weighting (full QSMART pipeline)
262///
263/// This is the main entry point matching QSMART's sdf_curvature function.
264pub fn sdf_curvature(
265    tfs: &[f64],
266    mask: &[f64],
267    vasc_only: &[f64],
268    nx: usize, ny: usize, nz: usize,
269    params: &SdfParams,
270) -> Vec<f64> {
271    // Ensure curvature is enabled
272    let params_with_curv = SdfParams {
273        use_curvature: true,
274        ..params.clone()
275    };
276
277    sdf(tfs, mask, vasc_only, nx, ny, nz, &params_with_curv)
278}
279
280/// 3D Gaussian smoothing with specified filter size
281fn gaussian_smooth_3d_with_filter_size(
282    data: &[f64],
283    nx: usize, ny: usize, nz: usize,
284    sigma: f64,
285    filter_size: usize,
286) -> Vec<f64> {
287    if sigma <= 0.0 {
288        return data.to_vec();
289    }
290
291    // Create 1D Gaussian kernel
292    let kernel_radius = (filter_size - 1) / 2;
293    let mut kernel = vec![0.0f64; filter_size];
294
295    let mut sum = 0.0;
296    for i in 0..filter_size {
297        let x = i as f64 - kernel_radius as f64;
298        kernel[i] = (-x * x / (2.0 * sigma * sigma)).exp();
299        sum += kernel[i];
300    }
301
302    // Normalize
303    for k in kernel.iter_mut() {
304        *k /= sum;
305    }
306
307    // Apply separable convolution
308    let smoothed_x = convolve_1d_direction(data, nx, ny, nz, &kernel, 'x');
309    let smoothed_xy = convolve_1d_direction(&smoothed_x, nx, ny, nz, &kernel, 'y');
310    let smoothed_xyz = convolve_1d_direction(&smoothed_xy, nx, ny, nz, &kernel, 'z');
311
312    smoothed_xyz
313}
314
315/// Gaussian smoothing with anisotropic sigma and mask
316fn gaussian_smooth_3d_masked_f64(
317    data: &[f64],
318    mask: &[f64],
319    nx: usize, ny: usize, nz: usize,
320    sigmas: &[f64; 3],
321) -> Vec<f64> {
322    // Apply separable 1D convolutions
323    let smoothed_x = convolve_1d_direction_sigma(data, nx, ny, nz, sigmas[0], 'x');
324    let smoothed_xy = convolve_1d_direction_sigma(&smoothed_x, nx, ny, nz, sigmas[1], 'y');
325    let smoothed_xyz = convolve_1d_direction_sigma(&smoothed_xy, nx, ny, nz, sigmas[2], 'z');
326
327    // Apply mask
328    smoothed_xyz.iter()
329        .zip(mask.iter())
330        .map(|(&v, &m)| if m > 0.0 { v } else { 0.0 })
331        .collect()
332}
333
334/// 1D convolution along specified axis with replicate padding
335/// Matches MATLAB's imgaussfilt3 default behavior
336fn convolve_1d_direction(
337    data: &[f64],
338    nx: usize, ny: usize, nz: usize,
339    kernel: &[f64],
340    direction: char,
341) -> Vec<f64> {
342    let n_total = nx * ny * nz;
343    let mut result = vec![0.0f64; n_total];
344    let kernel_radius = (kernel.len() - 1) / 2;
345
346    let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
347
348    // Helper to clamp index for replicate padding
349    let clamp_x = |x: isize| -> usize { x.max(0).min(nx as isize - 1) as usize };
350    let clamp_y = |y: isize| -> usize { y.max(0).min(ny as isize - 1) as usize };
351    let clamp_z = |z: isize| -> usize { z.max(0).min(nz as isize - 1) as usize };
352
353    match direction {
354        'x' => {
355            for k in 0..nz {
356                for j in 0..ny {
357                    for i in 0..nx {
358                        let mut sum = 0.0;
359
360                        for ki in 0..kernel.len() {
361                            let offset = ki as isize - kernel_radius as isize;
362                            let ni = clamp_x(i as isize + offset);
363                            sum += data[idx(ni, j, k)] * kernel[ki];
364                        }
365
366                        result[idx(i, j, k)] = sum;
367                    }
368                }
369            }
370        }
371        'y' => {
372            for k in 0..nz {
373                for j in 0..ny {
374                    for i in 0..nx {
375                        let mut sum = 0.0;
376
377                        for ki in 0..kernel.len() {
378                            let offset = ki as isize - kernel_radius as isize;
379                            let nj = clamp_y(j as isize + offset);
380                            sum += data[idx(i, nj, k)] * kernel[ki];
381                        }
382
383                        result[idx(i, j, k)] = sum;
384                    }
385                }
386            }
387        }
388        'z' => {
389            for k in 0..nz {
390                for j in 0..ny {
391                    for i in 0..nx {
392                        let mut sum = 0.0;
393
394                        for ki in 0..kernel.len() {
395                            let offset = ki as isize - kernel_radius as isize;
396                            let nk = clamp_z(k as isize + offset);
397                            sum += data[idx(i, j, nk)] * kernel[ki];
398                        }
399
400                        result[idx(i, j, k)] = sum;
401                    }
402                }
403            }
404        }
405        _ => panic!("Invalid convolution direction"),
406    }
407
408    result
409}
410
411/// 1D convolution with specified sigma
412fn convolve_1d_direction_sigma(
413    data: &[f64],
414    nx: usize, ny: usize, nz: usize,
415    sigma: f64,
416    direction: char,
417) -> Vec<f64> {
418    if sigma <= 0.0 {
419        return data.to_vec();
420    }
421
422    // Create 1D Gaussian kernel
423    // Match MATLAB's imgaussfilt3 default: filterSize = 2*ceil(2*sigma)+1
424    let kernel_radius = (2.0 * sigma).ceil() as usize;
425    let kernel_size = 2 * kernel_radius + 1;
426    let mut kernel = vec![0.0f64; kernel_size];
427
428    let mut sum = 0.0;
429    for i in 0..kernel_size {
430        let x = i as f64 - kernel_radius as f64;
431        kernel[i] = (-x * x / (2.0 * sigma * sigma)).exp();
432        sum += kernel[i];
433    }
434
435    // Normalize
436    for k in kernel.iter_mut() {
437        *k /= sum;
438    }
439
440    convolve_1d_direction(data, nx, ny, nz, &kernel, direction)
441}
442
443/// Simple SDF without curvature (faster, for testing)
444pub fn sdf_simple(
445    tfs: &[f64],
446    mask: &[f64],
447    nx: usize, ny: usize, nz: usize,
448    sigma1: f64,
449) -> Vec<f64> {
450    let vasc_only = vec![1.0f64; mask.len()];
451    let params = SdfParams {
452        sigma1,
453        sigma2: 0.0,
454        use_curvature: false,
455        ..Default::default()
456    };
457
458    sdf(tfs, mask, &vasc_only, nx, ny, nz, &params)
459}
460
461/// Default SDF parameters for stage 1
462pub fn sdf_default_stage1(
463    tfs: &[f64],
464    mask: &[f64],
465    nx: usize, ny: usize, nz: usize,
466) -> Vec<f64> {
467    let vasc_only = vec![1.0f64; mask.len()];
468    sdf(tfs, mask, &vasc_only, nx, ny, nz, &SdfParams::stage1())
469}
470
471/// Default SDF parameters for stage 2
472pub fn sdf_default_stage2(
473    tfs: &[f64],
474    mask: &[f64],
475    vasc_only: &[f64],
476    nx: usize, ny: usize, nz: usize,
477) -> Vec<f64> {
478    sdf(tfs, mask, vasc_only, nx, ny, nz, &SdfParams::stage2())
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[test]
486    fn test_sdf_simple() {
487        // Simple test: constant field should give zero local field
488        let n = 10;
489        let n_total = n * n * n;
490
491        let tfs = vec![1.0f64; n_total];
492        let mask = vec![1.0f64; n_total];
493
494        let lfs = sdf_simple(&tfs, &mask, n, n, n, 2.0);
495
496        // Local field should be near zero for constant total field
497        let max_lfs = lfs.iter().fold(0.0f64, |a, &b| a.max(b.abs()));
498        assert!(max_lfs < 0.1, "Max LFS was {}", max_lfs);
499    }
500
501    #[test]
502    fn test_gaussian_smooth_constant() {
503        // Smoothing constant field should give same constant
504        let data = vec![5.0f64; 27];
505        let smoothed = gaussian_smooth_3d_with_filter_size(&data, 3, 3, 3, 1.0, 5);
506
507        for &v in &smoothed {
508            assert!((v - 5.0).abs() < 0.1);
509        }
510    }
511}