Skip to main content

qsm_core/bgremove/
vsharp.rs

1//! V-SHARP background field removal
2//!
3//! Variable kernel SHARP uses multiple SMV kernel radii
4//! to preserve more brain tissue at edges while still
5//! removing background fields.
6//!
7//! Reference:
8//! Wu, B., Li, W., Guidon, A., Liu, C. (2012).
9//! "Whole brain susceptibility mapping using compressed sensing."
10//! Magnetic Resonance in Medicine, 67(1):137-147. https://doi.org/10.1002/mrm.23000
11//!
12//! Reference implementation: https://github.com/kamesy/QSM.jl
13
14use num_complex::Complex64;
15use crate::fft::{fft3d, ifft3d};
16use crate::kernels::smv::smv_kernel;
17
18/// V-SHARP background field removal
19///
20/// Uses multiple SMV kernel radii, starting from largest and decreasing.
21/// At each voxel, uses the smallest radius that doesn't touch the boundary.
22///
23/// # Arguments
24/// * `field` - Unwrapped total field (nx * ny * nz)
25/// * `mask` - Binary mask (nx * ny * nz), 1 = inside ROI
26/// * `nx`, `ny`, `nz` - Array dimensions
27/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
28/// * `radii` - SMV kernel radii in mm (should be sorted large to small)
29/// * `threshold` - High-pass filter threshold (typically 0.05)
30///
31/// V-SHARP algorithm parameters
32#[derive(Clone, Debug)]
33pub struct VsharpParams {
34    /// Deconvolution threshold
35    pub threshold: f64,
36    /// Maximum kernel radius in mm (as multiple of min voxel size; default: 18.0)
37    pub max_radius_factor: f64,
38    /// Minimum kernel radius in mm (as multiple of max voxel size; default: 2.0)
39    pub min_radius_factor: f64,
40}
41
42impl Default for VsharpParams {
43    fn default() -> Self {
44        Self {
45            threshold: 0.001,
46            max_radius_factor: 18.0,
47            min_radius_factor: 2.0,
48        }
49    }
50}
51
52/// # Returns
53/// (local_field, eroded_mask)
54pub fn vsharp(
55    field: &[f64],
56    mask: &[u8],
57    nx: usize, ny: usize, nz: usize,
58    vsx: f64, vsy: f64, vsz: f64,
59    radii: &[f64],
60    threshold: f64,
61) -> (Vec<f64>, Vec<u8>) {
62    if radii.is_empty() {
63        return (vec![0.0; nx * ny * nz], mask.to_vec());
64    }
65
66    // If only one radius, use regular SHARP
67    if radii.len() == 1 {
68        return crate::bgremove::sharp::sharp(
69            field, mask, nx, ny, nz, vsx, vsy, vsz, radii[0], threshold
70        );
71    }
72
73    let n_total = nx * ny * nz;
74
75    // Sort radii from largest to smallest
76    let mut sorted_radii = radii.to_vec();
77    sorted_radii.sort_by(|a, b| b.partial_cmp(a).unwrap());
78
79    // FFT of field
80    let mut field_complex: Vec<Complex64> = field.iter()
81        .map(|&x| Complex64::new(x, 0.0))
82        .collect();
83    fft3d(&mut field_complex, nx, ny, nz);
84    let field_fft = field_complex.clone();
85
86    // Track which voxels have been processed and final mask
87    let mut processed = vec![false; n_total];
88    let mut local_field = vec![0.0; n_total];
89    let mut final_mask = vec![0u8; n_total];
90
91    // Process each radius (large to small)
92    let delta = 1.0 - 1e-7_f64.sqrt();
93
94    // Store the first (largest) kernel's inverse for deconvolution
95    let mut inverse_kernel: Option<Vec<f64>> = None;
96
97    for &radius in &sorted_radii {
98        // Generate SMV kernel
99        let s_kernel = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
100
101        // FFT of SMV kernel
102        let mut s_complex: Vec<Complex64> = s_kernel.iter()
103            .map(|&x| Complex64::new(x, 0.0))
104            .collect();
105        fft3d(&mut s_complex, nx, ny, nz);
106        let s_fft: Vec<f64> = s_complex.iter().map(|c| c.re).collect();
107
108        // Store inverse of first (largest) kernel
109        if inverse_kernel.is_none() {
110            inverse_kernel = Some(s_fft.iter().map(|&s| {
111                let one_minus_s = 1.0 - s;
112                if one_minus_s.abs() < threshold {
113                    0.0
114                } else {
115                    1.0 / one_minus_s
116                }
117            }).collect());
118        }
119
120        // Erode mask for this radius
121        let mask_f64: Vec<f64> = mask.iter().map(|&m| m as f64).collect();
122        let mut mask_complex: Vec<Complex64> = mask_f64.iter()
123            .map(|&x| Complex64::new(x, 0.0))
124            .collect();
125
126        fft3d(&mut mask_complex, nx, ny, nz);
127
128        // Convolve mask with SMV kernel
129        for i in 0..n_total {
130            mask_complex[i] *= s_fft[i];
131        }
132
133        ifft3d(&mut mask_complex, nx, ny, nz);
134
135        // Current eroded mask
136        let current_mask: Vec<bool> = mask_complex.iter()
137            .map(|c| c.re > delta)
138            .collect();
139
140        // Apply high-pass filter: multiply by (1-S)
141        let mut filtered = field_fft.clone();
142        for i in 0..n_total {
143            filtered[i] *= 1.0 - s_fft[i];
144        }
145
146        ifft3d(&mut filtered, nx, ny, nz);
147
148        // For voxels that are in current mask but not yet processed,
149        // store the filtered value
150        for i in 0..n_total {
151            if current_mask[i] && !processed[i] {
152                local_field[i] = filtered[i].re;
153                processed[i] = true;
154                final_mask[i] = 1;
155            }
156        }
157    }
158
159    // Deconvolution with largest kernel's inverse
160    if let Some(inv_kernel) = inverse_kernel {
161        let mut local_complex: Vec<Complex64> = local_field.iter()
162            .map(|&x| Complex64::new(x, 0.0))
163            .collect();
164
165        fft3d(&mut local_complex, nx, ny, nz);
166
167        for i in 0..n_total {
168            local_complex[i] *= inv_kernel[i];
169        }
170
171        ifft3d(&mut local_complex, nx, ny, nz);
172
173        // Apply final mask
174        for i in 0..n_total {
175            local_field[i] = if final_mask[i] == 1 { local_complex[i].re } else { 0.0 };
176        }
177    }
178
179    (local_field, final_mask)
180}
181
182/// V-SHARP with progress callback
183///
184/// Same as `vsharp` but calls `progress_callback(radius_index, total_radii)` for each radius.
185pub fn vsharp_with_progress<F>(
186    field: &[f64],
187    mask: &[u8],
188    nx: usize, ny: usize, nz: usize,
189    vsx: f64, vsy: f64, vsz: f64,
190    radii: &[f64],
191    threshold: f64,
192    mut progress_callback: F,
193) -> (Vec<f64>, Vec<u8>)
194where
195    F: FnMut(usize, usize),
196{
197    if radii.is_empty() {
198        return (vec![0.0; nx * ny * nz], mask.to_vec());
199    }
200
201    // If only one radius, use regular SHARP
202    if radii.len() == 1 {
203        progress_callback(1, 1);
204        return crate::bgremove::sharp::sharp(
205            field, mask, nx, ny, nz, vsx, vsy, vsz, radii[0], threshold
206        );
207    }
208
209    let n_total = nx * ny * nz;
210    let n_radii = radii.len();
211
212    // Sort radii from largest to smallest
213    let mut sorted_radii = radii.to_vec();
214    sorted_radii.sort_by(|a, b| b.partial_cmp(a).unwrap());
215
216    // FFT of field
217    let mut field_complex: Vec<Complex64> = field.iter()
218        .map(|&x| Complex64::new(x, 0.0))
219        .collect();
220    fft3d(&mut field_complex, nx, ny, nz);
221    let field_fft = field_complex.clone();
222
223    // Track which voxels have been processed and final mask
224    let mut processed = vec![false; n_total];
225    let mut local_field = vec![0.0; n_total];
226    let mut final_mask = vec![0u8; n_total];
227
228    let delta = 1.0 - 1e-7_f64.sqrt();
229    let mut inverse_kernel: Option<Vec<f64>> = None;
230
231    for (idx, &radius) in sorted_radii.iter().enumerate() {
232        // Report progress
233        progress_callback(idx + 1, n_radii);
234
235        // Generate SMV kernel
236        let s_kernel = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
237
238        // FFT of SMV kernel
239        let mut s_complex: Vec<Complex64> = s_kernel.iter()
240            .map(|&x| Complex64::new(x, 0.0))
241            .collect();
242        fft3d(&mut s_complex, nx, ny, nz);
243        let s_fft: Vec<f64> = s_complex.iter().map(|c| c.re).collect();
244
245        // Store inverse of first (largest) kernel
246        if inverse_kernel.is_none() {
247            inverse_kernel = Some(s_fft.iter().map(|&s| {
248                let one_minus_s = 1.0 - s;
249                if one_minus_s.abs() < threshold {
250                    0.0
251                } else {
252                    1.0 / one_minus_s
253                }
254            }).collect());
255        }
256
257        // Erode mask for this radius
258        let mask_f64: Vec<f64> = mask.iter().map(|&m| m as f64).collect();
259        let mut mask_complex: Vec<Complex64> = mask_f64.iter()
260            .map(|&x| Complex64::new(x, 0.0))
261            .collect();
262
263        fft3d(&mut mask_complex, nx, ny, nz);
264
265        for i in 0..n_total {
266            mask_complex[i] *= s_fft[i];
267        }
268
269        ifft3d(&mut mask_complex, nx, ny, nz);
270
271        let current_mask: Vec<bool> = mask_complex.iter()
272            .map(|c| c.re > delta)
273            .collect();
274
275        // Apply high-pass filter
276        let mut filtered = field_fft.clone();
277        for i in 0..n_total {
278            filtered[i] *= 1.0 - s_fft[i];
279        }
280
281        ifft3d(&mut filtered, nx, ny, nz);
282
283        for i in 0..n_total {
284            if current_mask[i] && !processed[i] {
285                local_field[i] = filtered[i].re;
286                processed[i] = true;
287                final_mask[i] = 1;
288            }
289        }
290    }
291
292    // Deconvolution
293    if let Some(inv_kernel) = inverse_kernel {
294        let mut local_complex: Vec<Complex64> = local_field.iter()
295            .map(|&x| Complex64::new(x, 0.0))
296            .collect();
297
298        fft3d(&mut local_complex, nx, ny, nz);
299
300        for i in 0..n_total {
301            local_complex[i] *= inv_kernel[i];
302        }
303
304        ifft3d(&mut local_complex, nx, ny, nz);
305
306        for i in 0..n_total {
307            local_field[i] = if final_mask[i] == 1 { local_complex[i].re } else { 0.0 };
308        }
309    }
310
311    (local_field, final_mask)
312}
313
314/// V-SHARP with default parameters
315pub fn vsharp_default(
316    field: &[f64],
317    mask: &[u8],
318    nx: usize, ny: usize, nz: usize,
319    vsx: f64, vsy: f64, vsz: f64,
320) -> (Vec<f64>, Vec<u8>) {
321    let min_vox = vsx.min(vsy).min(vsz);
322    let max_vox = vsx.max(vsy).max(vsz);
323
324    // Default radii: 18*min_vox down to 2*max_vox in steps of 2*max_vox
325    let mut radii = Vec::new();
326    let mut r = 18.0 * min_vox;
327    while r >= 2.0 * max_vox {
328        radii.push(r);
329        r -= 2.0 * max_vox;
330    }
331
332    if radii.is_empty() {
333        radii.push(18.0 * min_vox);
334    }
335
336    vsharp(field, mask, nx, ny, nz, vsx, vsy, vsz, &radii, 0.05)
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_vsharp_zero_field() {
345        let n = 8;
346        let field = vec![0.0; n * n * n];
347        let mask = vec![1u8; n * n * n];
348
349        let radii = vec![4.0, 3.0, 2.0];
350        let (local, _) = vsharp(&field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05);
351
352        for &val in local.iter() {
353            assert!(val.abs() < 1e-10);
354        }
355    }
356
357    #[test]
358    fn test_vsharp_preserves_more_than_sharp() {
359        let n = 16;
360        let field = vec![0.0; n * n * n];
361        let mask = vec![1u8; n * n * n];
362
363        // V-SHARP with multiple radii
364        let radii = vec![5.0, 4.0, 3.0, 2.0];
365        let (_, vsharp_mask) = vsharp(&field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05);
366
367        // SHARP with single large radius
368        let (_, sharp_mask) = crate::bgremove::sharp::sharp(
369            &field, &mask, n, n, n, 1.0, 1.0, 1.0, 5.0, 0.05
370        );
371
372        let vsharp_count: usize = vsharp_mask.iter().map(|&m| m as usize).sum();
373        let sharp_count: usize = sharp_mask.iter().map(|&m| m as usize).sum();
374
375        // V-SHARP should preserve at least as many voxels as SHARP
376        assert!(vsharp_count >= sharp_count,
377            "V-SHARP {} should preserve at least as many as SHARP {}",
378            vsharp_count, sharp_count);
379    }
380
381    #[test]
382    fn test_vsharp_nonuniform_voxels() {
383        let n = 8;
384        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
385        let mask = vec![1u8; n * n * n];
386
387        // Anisotropic voxel sizes
388        let radii = vec![4.0, 3.0, 2.0];
389        let (local, final_mask) = vsharp(
390            &field, &mask, n, n, n, 0.5, 1.0, 2.0, &radii, 0.05
391        );
392
393        // All values should be finite
394        for (i, &val) in local.iter().enumerate() {
395            assert!(val.is_finite(), "V-SHARP nonuniform voxels: finite at index {}", i);
396        }
397
398        // Final mask should have some voxels
399        let mask_count: usize = final_mask.iter().map(|&m| m as usize).sum();
400        assert!(mask_count > 0, "V-SHARP nonuniform: final mask should have some voxels");
401    }
402
403    #[test]
404    fn test_vsharp_single_radius() {
405        let n = 8;
406        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
407        let mask = vec![1u8; n * n * n];
408
409        // Single radius should delegate to SHARP
410        let radii = vec![3.0];
411        let (local, final_mask) = vsharp(
412            &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05
413        );
414
415        // All values should be finite
416        for (i, &val) in local.iter().enumerate() {
417            assert!(val.is_finite(), "V-SHARP single radius: finite at index {}", i);
418        }
419
420        // Result should match SHARP with same radius
421        let (sharp_local, sharp_mask) = crate::bgremove::sharp::sharp(
422            &field, &mask, n, n, n, 1.0, 1.0, 1.0, 3.0, 0.05
423        );
424
425        for i in 0..n*n*n {
426            assert!(
427                (local[i] - sharp_local[i]).abs() < 1e-10,
428                "Single-radius V-SHARP should match SHARP at index {}", i
429            );
430        }
431
432        assert_eq!(final_mask, sharp_mask, "Single-radius V-SHARP mask should match SHARP mask");
433    }
434
435    #[test]
436    fn test_vsharp_empty_radii() {
437        // Empty radii should return zeros and the original mask
438        let n = 8;
439        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
440        let mask = vec![1u8; n * n * n];
441
442        let (local, returned_mask) = vsharp(
443            &field, &mask, n, n, n, 1.0, 1.0, 1.0, &[], 0.05
444        );
445
446        for &val in &local {
447            assert_eq!(val, 0.0, "Empty radii should return zero local field");
448        }
449        assert_eq!(returned_mask, mask, "Empty radii should return original mask");
450    }
451
452    #[test]
453    fn test_vsharp_larger_volume() {
454        // 16x16x16 volume with spherical mask
455        let n = 16;
456        let mut field = vec![0.0; n * n * n];
457        // Linear background field in z
458        for z in 0..n {
459            for y in 0..n {
460                for x in 0..n {
461                    field[x + y * n + z * n * n] = (z as f64) * 0.1;
462                }
463            }
464        }
465
466        // Spherical mask
467        let mut mask = vec![0u8; n * n * n];
468        let center = n / 2;
469        let radius = n / 3;
470        for z in 0..n {
471            for y in 0..n {
472                for x in 0..n {
473                    let dx = (x as i32) - (center as i32);
474                    let dy = (y as i32) - (center as i32);
475                    let dz = (z as i32) - (center as i32);
476                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
477                        mask[x + y * n + z * n * n] = 1;
478                    }
479                }
480            }
481        }
482
483        let radii = vec![6.0, 4.0, 3.0, 2.0];
484        let (local, final_mask) = vsharp(
485            &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05
486        );
487
488        assert_eq!(local.len(), n * n * n);
489        for &val in &local {
490            assert!(val.is_finite(), "V-SHARP larger volume values should be finite");
491        }
492
493        let mask_count: usize = final_mask.iter().map(|&m| m as usize).sum();
494        assert!(mask_count > 0, "V-SHARP larger volume should have voxels in final mask");
495
496        // Voxels outside the final mask should be zero
497        for i in 0..n * n * n {
498            if final_mask[i] == 0 {
499                assert_eq!(local[i], 0.0, "Outside final mask should be zero");
500            }
501        }
502    }
503
504    #[test]
505    fn test_vsharp_with_progress() {
506        let n = 8;
507        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
508        let mask = vec![1u8; n * n * n];
509
510        let radii = vec![4.0, 3.0, 2.0];
511        let mut progress_calls = Vec::new();
512        let (local, _) = vsharp_with_progress(
513            &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05,
514            |idx, total| { progress_calls.push((idx, total)); }
515        );
516
517        assert_eq!(local.len(), n * n * n);
518        assert!(!progress_calls.is_empty(), "Progress should be called");
519        for &val in &local {
520            assert!(val.is_finite());
521        }
522    }
523
524    #[test]
525    fn test_vsharp_with_progress_single_radius() {
526        let n = 8;
527        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
528        let mask = vec![1u8; n * n * n];
529
530        let radii = vec![3.0];
531        let mut progress_calls = Vec::new();
532        let (local, _) = vsharp_with_progress(
533            &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05,
534            |idx, total| { progress_calls.push((idx, total)); }
535        );
536
537        assert_eq!(local.len(), n * n * n);
538        assert!(!progress_calls.is_empty(), "Progress should be called for single radius");
539        for &val in &local {
540            assert!(val.is_finite());
541        }
542    }
543
544    #[test]
545    fn test_vsharp_with_progress_empty_radii() {
546        let n = 8;
547        let field = vec![0.0; n * n * n];
548        let mask = vec![1u8; n * n * n];
549
550        let mut progress_calls = Vec::new();
551        let (local, returned_mask) = vsharp_with_progress(
552            &field, &mask, n, n, n, 1.0, 1.0, 1.0, &[], 0.05,
553            |idx, total| { progress_calls.push((idx, total)); }
554        );
555
556        for &val in &local {
557            assert_eq!(val, 0.0);
558        }
559        assert_eq!(returned_mask, mask);
560    }
561
562    #[test]
563    fn test_vsharp_default_wrapper() {
564        let n = 8;
565        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
566        let mask = vec![1u8; n * n * n];
567
568        let (local, final_mask) = vsharp_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
569
570        assert_eq!(local.len(), n * n * n);
571        for &val in &local {
572            assert!(val.is_finite());
573        }
574        // Final mask should have some voxels
575        let count: usize = final_mask.iter().map(|&m| m as usize).sum();
576        // May be 0 if volume is too small for default radii, but should not crash
577        assert!(count <= n * n * n);
578    }
579
580    #[test]
581    fn test_vsharp_unsorted_radii() {
582        // Radii given in arbitrary order - should be sorted internally
583        let n = 8;
584        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
585        let mask = vec![1u8; n * n * n];
586
587        let radii_sorted = vec![4.0, 3.0, 2.0];
588        let radii_unsorted = vec![2.0, 4.0, 3.0];
589
590        let (local_sorted, mask_sorted) = vsharp(
591            &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii_sorted, 0.05
592        );
593        let (local_unsorted, mask_unsorted) = vsharp(
594            &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii_unsorted, 0.05
595        );
596
597        // Results should be the same regardless of input order
598        assert_eq!(mask_sorted, mask_unsorted, "Sorted and unsorted radii should give same mask");
599        for i in 0..n * n * n {
600            assert!(
601                (local_sorted[i] - local_unsorted[i]).abs() < 1e-10,
602                "Results should match at index {}",
603                i
604            );
605        }
606    }
607}