Skip to main content

qsm_core/bgremove/
ismv.rs

1//! Iterative Spherical Mean Value (iSMV) background field removal
2//!
3//! Iterative approach that avoids mask erosion by iteratively
4//! correcting boundary values.
5//!
6//! Reference:
7//! Wen, Y., Zhou, D., Liu, T., Spincemaille, P., Wang, Y. (2014).
8//! "An iterative spherical mean value method for background field removal in MRI."
9//! Magnetic Resonance in Medicine, 72(4):1065-1071. https://doi.org/10.1002/mrm.24998
10//!
11//! Reference implementation: https://github.com/kamesy/QSM.jl
12
13use num_complex::Complex64;
14use crate::fft::{fft3d, ifft3d};
15use crate::kernels::smv::smv_kernel;
16
17/// iSMV background field removal
18///
19/// # Arguments
20/// * `field` - Total field (nx * ny * nz)
21/// * `mask` - Binary mask (nx * ny * nz), 1 = brain, 0 = background
22/// * `nx`, `ny`, `nz` - Array dimensions
23/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
24/// * `radius` - SMV kernel radius in mm
25/// * `tol` - Convergence tolerance
26/// * `max_iter` - Maximum iterations
27///
28/// iSMV algorithm parameters
29#[derive(Clone, Debug)]
30pub struct IsmvParams {
31    /// Convergence tolerance
32    pub tol: f64,
33    /// Maximum iterations
34    pub max_iter: usize,
35    /// Kernel radius factor (multiplied by max voxel size; default: 2.0)
36    pub radius_factor: f64,
37}
38
39impl Default for IsmvParams {
40    fn default() -> Self {
41        Self { tol: 1e-3, max_iter: 500, radius_factor: 2.0 }
42    }
43}
44
45/// # Returns
46/// Tuple of (local field, eroded mask)
47pub fn ismv(
48    field: &[f64],
49    mask: &[u8],
50    nx: usize, ny: usize, nz: usize,
51    vsx: f64, vsy: f64, vsz: f64,
52    radius: f64,
53    tol: f64,
54    max_iter: usize,
55) -> (Vec<f64>, Vec<u8>) {
56    let n_total = nx * ny * nz;
57
58    // Generate SMV kernel
59    let smv = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
60
61    // FFT of SMV kernel
62    let mut smv_complex: Vec<Complex64> = smv.iter()
63        .map(|&x| Complex64::new(x, 0.0))
64        .collect();
65    fft3d(&mut smv_complex, nx, ny, nz);
66    let smv_fft = smv_complex;
67
68    // Convert mask to f64
69    let m0: Vec<f64> = mask.iter()
70        .map(|&m| if m != 0 { 1.0 } else { 0.0 })
71        .collect();
72
73    // Erode mask using SMV
74    let eroded_mask = erode_mask(&m0, &smv_fft, nx, ny, nz);
75
76    // Boundary mask: original mask minus eroded mask
77    let boundary: Vec<f64> = m0.iter()
78        .zip(eroded_mask.iter())
79        .map(|(&m, &e)| m - e)
80        .collect();
81
82    // Initialize: f = field
83    let mut f: Vec<f64> = field.to_vec();
84
85    // f0 = eroded_mask * field (for residual calculation)
86    let mut f0: Vec<f64> = field.iter()
87        .zip(eroded_mask.iter())
88        .map(|(&fi, &m)| fi * m)
89        .collect();
90
91    // Boundary correction: bc = boundary * field
92    let bc: Vec<f64> = field.iter()
93        .zip(boundary.iter())
94        .map(|(&fi, &b)| fi * b)
95        .collect();
96
97    // Initial residual norm
98    let mut nr = vec_norm(&f0);
99    let eps = tol * nr;
100
101    // iSMV iterations
102    for _iter in 0..max_iter {
103        if nr <= eps {
104            break;
105        }
106
107        // f = SMV(f)
108        let mut f_complex: Vec<Complex64> = f.iter()
109            .map(|&x| Complex64::new(x, 0.0))
110            .collect();
111
112        fft3d(&mut f_complex, nx, ny, nz);
113
114        for i in 0..n_total {
115            f_complex[i] *= smv_fft[i];
116        }
117
118        ifft3d(&mut f_complex, nx, ny, nz);
119
120        // f = eroded_mask * f + bc
121        for i in 0..n_total {
122            f[i] = eroded_mask[i] * f_complex[i].re + bc[i];
123        }
124
125        // Compute residual: ||f0 - f||
126        let mut residual_sq = 0.0;
127        for i in 0..n_total {
128            let diff = f0[i] - f[i];
129            residual_sq += diff * diff;
130            f0[i] = f[i];
131        }
132        nr = residual_sq.sqrt();
133    }
134
135    // Compute local field: m * (field - f)
136    let mut local_field = vec![0.0; n_total];
137    for i in 0..n_total {
138        if mask[i] != 0 {
139            local_field[i] = field[i] - f[i];
140        }
141    }
142
143    // Convert eroded mask to u8
144    let eroded_mask_u8: Vec<u8> = eroded_mask.iter()
145        .map(|&m| if m > 0.5 { 1 } else { 0 })
146        .collect();
147
148    (local_field, eroded_mask_u8)
149}
150
151/// Erode mask using SMV convolution
152fn erode_mask(mask: &[f64], smv_fft: &[Complex64], nx: usize, ny: usize, nz: usize) -> Vec<f64> {
153    let n_total = nx * ny * nz;
154    let delta = 1.0 - 1e-10;
155
156    let mut m_complex: Vec<Complex64> = mask.iter()
157        .map(|&x| Complex64::new(x, 0.0))
158        .collect();
159
160    fft3d(&mut m_complex, nx, ny, nz);
161
162    for i in 0..n_total {
163        m_complex[i] *= smv_fft[i];
164    }
165
166    ifft3d(&mut m_complex, nx, ny, nz);
167
168    // Threshold: eroded where SMV(mask) > delta
169    m_complex.iter()
170        .map(|c| if c.re > delta { 1.0 } else { 0.0 })
171        .collect()
172}
173
174/// Vector 2-norm
175fn vec_norm(v: &[f64]) -> f64 {
176    v.iter().map(|&x| x * x).sum::<f64>().sqrt()
177}
178
179/// iSMV with progress callback
180///
181/// Same as `ismv` but calls `progress_callback(iteration, max_iter)` each iteration.
182pub fn ismv_with_progress<F>(
183    field: &[f64],
184    mask: &[u8],
185    nx: usize, ny: usize, nz: usize,
186    vsx: f64, vsy: f64, vsz: f64,
187    radius: f64,
188    tol: f64,
189    max_iter: usize,
190    mut progress_callback: F,
191) -> (Vec<f64>, Vec<u8>)
192where
193    F: FnMut(usize, usize),
194{
195    let n_total = nx * ny * nz;
196
197    // Generate SMV kernel
198    let smv = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
199
200    // FFT of SMV kernel
201    let mut smv_complex: Vec<Complex64> = smv.iter()
202        .map(|&x| Complex64::new(x, 0.0))
203        .collect();
204    fft3d(&mut smv_complex, nx, ny, nz);
205    let smv_fft = smv_complex;
206
207    // Convert mask to f64
208    let m0: Vec<f64> = mask.iter()
209        .map(|&m| if m != 0 { 1.0 } else { 0.0 })
210        .collect();
211
212    // Erode mask using SMV
213    let eroded_mask = erode_mask(&m0, &smv_fft, nx, ny, nz);
214
215    // Boundary mask: original mask minus eroded mask
216    let boundary: Vec<f64> = m0.iter()
217        .zip(eroded_mask.iter())
218        .map(|(&m, &e)| m - e)
219        .collect();
220
221    // Initialize: f = field
222    let mut f: Vec<f64> = field.to_vec();
223
224    // f0 = eroded_mask * field (for residual calculation)
225    let mut f0: Vec<f64> = field.iter()
226        .zip(eroded_mask.iter())
227        .map(|(&fi, &m)| fi * m)
228        .collect();
229
230    // Boundary correction: bc = boundary * field
231    let bc: Vec<f64> = field.iter()
232        .zip(boundary.iter())
233        .map(|(&fi, &b)| fi * b)
234        .collect();
235
236    // Initial residual norm
237    let mut nr = vec_norm(&f0);
238    let eps = tol * nr;
239
240    // iSMV iterations
241    for iter in 0..max_iter {
242        // Report progress
243        progress_callback(iter + 1, max_iter);
244
245        if nr <= eps {
246            progress_callback(iter + 1, iter + 1);
247            break;
248        }
249
250        // f = SMV(f)
251        let mut f_complex: Vec<Complex64> = f.iter()
252            .map(|&x| Complex64::new(x, 0.0))
253            .collect();
254
255        fft3d(&mut f_complex, nx, ny, nz);
256
257        for i in 0..n_total {
258            f_complex[i] *= smv_fft[i];
259        }
260
261        ifft3d(&mut f_complex, nx, ny, nz);
262
263        // f = eroded_mask * f + bc
264        for i in 0..n_total {
265            f[i] = eroded_mask[i] * f_complex[i].re + bc[i];
266        }
267
268        // Compute residual: ||f0 - f||
269        let mut residual_sq = 0.0;
270        for i in 0..n_total {
271            let diff = f0[i] - f[i];
272            residual_sq += diff * diff;
273            f0[i] = f[i];
274        }
275        nr = residual_sq.sqrt();
276    }
277
278    // Compute local field: m * (field - f)
279    let mut local_field = vec![0.0; n_total];
280    for i in 0..n_total {
281        if mask[i] != 0 {
282            local_field[i] = field[i] - f[i];
283        }
284    }
285
286    // Convert eroded mask to u8
287    let eroded_mask_u8: Vec<u8> = eroded_mask.iter()
288        .map(|&m| if m > 0.5 { 1 } else { 0 })
289        .collect();
290
291    (local_field, eroded_mask_u8)
292}
293
294/// iSMV with default parameters
295pub fn ismv_default(
296    field: &[f64],
297    mask: &[u8],
298    nx: usize, ny: usize, nz: usize,
299    vsx: f64, vsy: f64, vsz: f64,
300) -> (Vec<f64>, Vec<u8>) {
301    let radius = 2.0 * vsx.max(vsy).max(vsz);
302    ismv(
303        field, mask, nx, ny, nz, vsx, vsy, vsz,
304        radius,
305        1e-3,  // tol
306        500    // max_iter
307    )
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_ismv_zero_field() {
316        let n = 8;
317        let field = vec![0.0; n * n * n];
318        let mask = vec![1u8; n * n * n];
319
320        let (local, eroded) = ismv(
321            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
322            2.0, 1e-3, 10
323        );
324
325        for &val in local.iter() {
326            assert!(val.abs() < 1e-10, "Zero field should give zero local field");
327        }
328
329        // Some voxels should be in eroded mask
330        let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
331        assert!(eroded_count > 0, "Eroded mask should have some voxels");
332    }
333
334    #[test]
335    fn test_ismv_finite() {
336        let n = 8;
337        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
338        let mask = vec![1u8; n * n * n];
339
340        let (local, _eroded) = ismv(
341            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
342            2.0, 1e-3, 20
343        );
344
345        for (i, &val) in local.iter().enumerate() {
346            assert!(val.is_finite(), "Local field should be finite at index {}", i);
347        }
348    }
349
350    #[test]
351    fn test_ismv_preserves_interior() {
352        // iSMV should preserve some of the mask interior
353        let n = 16;
354        let field = vec![0.1; n * n * n];
355
356        // Create a spherical mask
357        let mut mask = vec![0u8; n * n * n];
358        let center = n / 2;
359        let radius = n / 3;
360
361        for i in 0..n {
362            for j in 0..n {
363                for k in 0..n {
364                    let di = (i as i32) - (center as i32);
365                    let dj = (j as i32) - (center as i32);
366                    let dk = (k as i32) - (center as i32);
367                    if di*di + dj*dj + dk*dk <= (radius * radius) as i32 {
368                        mask[i * n * n + j * n + k] = 1;
369                    }
370                }
371            }
372        }
373
374        let mask_count: usize = mask.iter().map(|&m| m as usize).sum();
375
376        // Use small radius for less erosion
377        let (_, eroded) = ismv(
378            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
379            1.5, 1e-3, 50
380        );
381
382        let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
383
384        // Eroded mask should have fewer voxels than original
385        assert!(eroded_count <= mask_count, "Eroded mask should be smaller than original");
386        // Should preserve at least some interior voxels
387        assert!(eroded_count > 0, "Eroded mask should have some voxels");
388    }
389
390    #[test]
391    fn test_ismv_convergence() {
392        let n = 8;
393        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
394        let mask = vec![1u8; n * n * n];
395
396        // Run with more iterations and tight tolerance
397        let (local_many, _) = ismv(
398            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
399            2.0, 1e-6, 100
400        );
401
402        // Run with fewer iterations
403        let (local_few, _) = ismv(
404            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
405            2.0, 1e-6, 5
406        );
407
408        // Both should be finite
409        for (i, &val) in local_many.iter().enumerate() {
410            assert!(val.is_finite(), "iSMV many iters: finite at index {}", i);
411        }
412        for (i, &val) in local_few.iter().enumerate() {
413            assert!(val.is_finite(), "iSMV few iters: finite at index {}", i);
414        }
415
416        // More iterations should give different (hopefully more converged) result
417        // or the same if already converged
418        let diff_norm: f64 = local_many.iter()
419            .zip(local_few.iter())
420            .map(|(&a, &b)| (a - b).powi(2))
421            .sum::<f64>()
422            .sqrt();
423
424        // The difference should be finite (no NaN/Inf)
425        assert!(diff_norm.is_finite(), "Difference between runs should be finite");
426    }
427
428    #[test]
429    fn test_ismv_different_radius() {
430        let n = 8;
431        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
432        let mask = vec![1u8; n * n * n];
433
434        // Small radius
435        let (local_small, eroded_small) = ismv(
436            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
437            1.5, 1e-3, 20
438        );
439
440        // Larger radius
441        let (local_large, eroded_large) = ismv(
442            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
443            3.0, 1e-3, 20
444        );
445
446        // Both should produce finite results
447        for (i, &val) in local_small.iter().enumerate() {
448            assert!(val.is_finite(), "iSMV small radius: finite at index {}", i);
449        }
450        for (i, &val) in local_large.iter().enumerate() {
451            assert!(val.is_finite(), "iSMV large radius: finite at index {}", i);
452        }
453
454        // Larger radius should erode more
455        let small_count: usize = eroded_small.iter().map(|&m| m as usize).sum();
456        let large_count: usize = eroded_large.iter().map(|&m| m as usize).sum();
457        assert!(
458            large_count <= small_count,
459            "Larger radius should erode more: large={}, small={}",
460            large_count, small_count
461        );
462    }
463
464    #[test]
465    fn test_ismv_larger_volume() {
466        // Test with 16x16x16 volume and spherical mask
467        let n = 16;
468
469        // Create a field with a linear background
470        let mut field = vec![0.0; n * n * n];
471        for z in 0..n {
472            for y in 0..n {
473                for x in 0..n {
474                    field[x + y * n + z * n * n] = (z as f64) * 0.1;
475                }
476            }
477        }
478
479        // Spherical mask
480        let mut mask = vec![0u8; n * n * n];
481        let center = n / 2;
482        let radius = n / 3;
483        for z in 0..n {
484            for y in 0..n {
485                for x in 0..n {
486                    let dx = (x as i32) - (center as i32);
487                    let dy = (y as i32) - (center as i32);
488                    let dz = (z as i32) - (center as i32);
489                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
490                        mask[x + y * n + z * n * n] = 1;
491                    }
492                }
493            }
494        }
495
496        let (local, eroded) = ismv(
497            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
498            2.0, 1e-3, 50
499        );
500
501        assert_eq!(local.len(), n * n * n);
502        for &val in &local {
503            assert!(val.is_finite());
504        }
505
506        // Eroded mask should be non-empty
507        let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
508        assert!(eroded_count > 0, "Eroded mask should have some voxels");
509
510        // Outside original mask should be zero
511        for i in 0..n * n * n {
512            if mask[i] == 0 {
513                assert_eq!(local[i], 0.0, "Outside mask should be zero");
514            }
515        }
516    }
517
518    #[test]
519    fn test_ismv_with_progress() {
520        let n = 8;
521        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
522        let mask = vec![1u8; n * n * n];
523
524        let mut progress_calls = Vec::new();
525        let (local, _) = ismv_with_progress(
526            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
527            2.0, 1e-3, 20,
528            |iter, max| { progress_calls.push((iter, max)); }
529        );
530
531        assert_eq!(local.len(), n * n * n);
532        assert!(!progress_calls.is_empty(), "Progress callback should be called");
533        for &val in &local {
534            assert!(val.is_finite());
535        }
536    }
537
538    #[test]
539    fn test_ismv_default_wrapper() {
540        let n = 8;
541        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
542        let mask = vec![1u8; n * n * n];
543
544        let (local, eroded) = ismv_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
545
546        assert_eq!(local.len(), n * n * n);
547        for &val in &local {
548            assert!(val.is_finite());
549        }
550        let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
551        assert!(eroded_count > 0, "Default iSMV should produce non-empty eroded mask");
552    }
553
554    #[test]
555    fn test_ismv_anisotropic_voxels() {
556        let n = 8;
557        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
558        let mask = vec![1u8; n * n * n];
559
560        // Anisotropic voxel sizes
561        let (local, eroded) = ismv(
562            &field, &mask, n, n, n, 0.5, 1.0, 2.0,
563            3.0, 1e-3, 20
564        );
565
566        for &val in &local {
567            assert!(val.is_finite());
568        }
569        // Should still produce some eroded mask
570        let count: usize = eroded.iter().map(|&m| m as usize).sum();
571        assert!(count <= n * n * n);
572    }
573
574    #[test]
575    fn test_ismv_tight_convergence() {
576        // Test tight convergence to ensure the convergence check branch is hit
577        let n = 8;
578        let field = vec![0.5; n * n * n]; // Constant field
579        let mask = vec![1u8; n * n * n];
580
581        let (local, _) = ismv(
582            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
583            2.0, 1e-12, 200 // Very tight tolerance, many iterations allowed
584        );
585
586        for &val in &local {
587            assert!(val.is_finite());
588        }
589    }
590
591    #[test]
592    fn test_ismv_with_background_mask() {
593        // Test where some voxels are outside the mask
594        let n = 8;
595        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
596
597        // Mask with zero border
598        let mut mask = vec![0u8; n * n * n];
599        for z in 1..(n - 1) {
600            for y in 1..(n - 1) {
601                for x in 1..(n - 1) {
602                    mask[x + y * n + z * n * n] = 1;
603                }
604            }
605        }
606
607        let (local, eroded) = ismv(
608            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
609            1.5, 1e-3, 30
610        );
611
612        for &val in &local {
613            assert!(val.is_finite());
614        }
615
616        // Outside mask should be zero
617        for i in 0..n * n * n {
618            if mask[i] == 0 {
619                assert_eq!(local[i], 0.0, "Outside mask should be zero at index {}", i);
620            }
621        }
622
623        // Eroded mask should be subset of original mask
624        for i in 0..n * n * n {
625            if eroded[i] != 0 {
626                assert_eq!(mask[i], 1, "Eroded voxel must be inside original mask");
627            }
628        }
629    }
630}