Skip to main content

qsm_core/kernels/
smv.rs

1//! Spherical Mean Value (SMV) kernel
2//!
3//! Creates a binary sphere kernel used for background field removal methods
4//! like SHARP and V-SHARP.
5
6/// Generate SMV kernel in image space
7///
8/// Creates a binary sphere of given radius, normalized so sum = 1.
9/// Kernel is centered at index (0, 0, 0) for FFT compatibility.
10///
11/// # Arguments
12/// * `nx`, `ny`, `nz` - Array dimensions
13/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
14/// * `radius` - Sphere radius in mm
15///
16/// # Returns
17/// Flattened SMV kernel array of size nx*ny*nz in C order, normalized
18pub fn smv_kernel(
19    nx: usize, ny: usize, nz: usize,
20    vsx: f64, vsy: f64, vsz: f64,
21    radius: f64,
22) -> Vec<f64> {
23    let n_total = nx * ny * nz;
24    let mut s = vec![0.0; n_total];
25    let r_squared = radius * radius;
26
27    // Create binary sphere centered at (0,0,0) with wraparound
28    let mut count = 0.0;
29
30    // Fortran order: index = i + j*nx + k*nx*ny
31    for k in 0..nz {
32        let zk = if k <= nz / 2 { k as f64 } else { (k as i64 - nz as i64) as f64 };
33        let dz = zk * vsz;
34
35        for j in 0..ny {
36            let yj = if j <= ny / 2 { j as f64 } else { (j as i64 - ny as i64) as f64 };
37            let dy = yj * vsy;
38
39            for i in 0..nx {
40                // Distance from center in x (with wraparound)
41                let xi = if i <= nx / 2 { i as f64 } else { (i as i64 - nx as i64) as f64 };
42                let dx = xi * vsx;
43
44                let dist_sq = dx * dx + dy * dy + dz * dz;
45
46                let idx = i + j * nx + k * nx * ny;
47
48                if dist_sq <= r_squared {
49                    s[idx] = 1.0;
50                    count += 1.0;
51                }
52            }
53        }
54    }
55
56    // Normalize so sum = 1
57    if count > 0.0 {
58        let norm = 1.0 / count;
59        for val in s.iter_mut() {
60            *val *= norm;
61        }
62    }
63
64    s
65}
66
67// ============================================================================
68// F32 (Single Precision) SMV Kernel Functions
69// ============================================================================
70
71/// Generate SMV kernel in image space (f32 version for WASM performance)
72///
73/// Creates a binary sphere of given radius, normalized so sum = 1.
74/// Kernel is centered at index (0, 0, 0) for FFT compatibility.
75pub fn smv_kernel_f32(
76    nx: usize, ny: usize, nz: usize,
77    vsx: f32, vsy: f32, vsz: f32,
78    radius: f32,
79) -> Vec<f32> {
80    let n_total = nx * ny * nz;
81    let mut s = vec![0.0f32; n_total];
82    let r_squared = radius * radius;
83
84    // Create binary sphere centered at (0,0,0) with wraparound
85    let mut count = 0.0f32;
86
87    // Fortran order: index = i + j*nx + k*nx*ny
88    for k in 0..nz {
89        let zk = if k <= nz / 2 { k as f32 } else { (k as i64 - nz as i64) as f32 };
90        let dz = zk * vsz;
91
92        for j in 0..ny {
93            let yj = if j <= ny / 2 { j as f32 } else { (j as i64 - ny as i64) as f32 };
94            let dy = yj * vsy;
95
96            for i in 0..nx {
97                // Distance from center in x (with wraparound)
98                let xi = if i <= nx / 2 { i as f32 } else { (i as i64 - nx as i64) as f32 };
99                let dx = xi * vsx;
100
101                let dist_sq = dx * dx + dy * dy + dz * dz;
102
103                let idx = i + j * nx + k * nx * ny;
104
105                if dist_sq <= r_squared {
106                    s[idx] = 1.0;
107                    count += 1.0;
108                }
109            }
110        }
111    }
112
113    // Normalize so sum = 1
114    if count > 0.0 {
115        let norm = 1.0 / count;
116        for val in s.iter_mut() {
117            *val *= norm;
118        }
119    }
120
121    s
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_smv_kernel_normalization() {
130        let s = smv_kernel(16, 16, 16, 1.0, 1.0, 1.0, 3.0);
131        let sum: f64 = s.iter().sum();
132        assert!((sum - 1.0).abs() < 1e-10, "SMV kernel should sum to 1, got {}", sum);
133    }
134
135    #[test]
136    fn test_smv_kernel_center() {
137        let s = smv_kernel(16, 16, 16, 1.0, 1.0, 1.0, 2.0);
138        // Center voxel (0,0,0) should be non-zero
139        assert!(s[0] > 0.0, "Center voxel should be in sphere");
140    }
141
142    #[test]
143    fn test_smv_kernel_radius() {
144        // With radius 1 and voxel size 1, only center should be included
145        let s = smv_kernel(8, 8, 8, 1.0, 1.0, 1.0, 0.5);
146        let count: usize = s.iter().filter(|&&v| v > 0.0).count();
147        assert_eq!(count, 1, "Radius 0.5 should only include center");
148    }
149}