1use num_complex::Complex64;
16use crate::fft::{fft3d, ifft3d};
17use crate::kernels::smv::smv_kernel;
18
19pub fn smv(
33 field: &[f64],
34 mask: &[u8],
35 nx: usize, ny: usize, nz: usize,
36 vsx: f64, vsy: f64, vsz: f64,
37 radius: f64,
38) -> (Vec<f64>, Vec<u8>) {
39 let n_total = nx * ny * nz;
40
41 let s_kernel = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
43
44 let mut s_complex: Vec<Complex64> = s_kernel.iter()
46 .map(|&x| Complex64::new(x, 0.0))
47 .collect();
48 fft3d(&mut s_complex, nx, ny, nz);
49 let s_fft = s_complex;
50
51 let mask_f64: Vec<f64> = mask.iter().map(|&m| m as f64).collect();
53 let mut mask_complex: Vec<Complex64> = mask_f64.iter()
54 .map(|&x| Complex64::new(x, 0.0))
55 .collect();
56
57 fft3d(&mut mask_complex, nx, ny, nz);
58
59 for i in 0..n_total {
61 mask_complex[i] *= s_fft[i].re;
62 }
63
64 ifft3d(&mut mask_complex, nx, ny, nz);
65
66 let delta = 1.0 - 1e-10;
68 let eroded_mask: Vec<u8> = mask_complex.iter()
69 .map(|c| if c.re > delta { 1 } else { 0 })
70 .collect();
71
72 let mut field_complex: Vec<Complex64> = field.iter()
74 .map(|&x| Complex64::new(x, 0.0))
75 .collect();
76 fft3d(&mut field_complex, nx, ny, nz);
77
78 for i in 0..n_total {
80 field_complex[i] *= s_fft[i].re;
81 }
82
83 ifft3d(&mut field_complex, nx, ny, nz);
84
85 let local_field: Vec<f64> = field.iter()
87 .zip(field_complex.iter())
88 .enumerate()
89 .map(|(i, (&f, smv_f))| {
90 if eroded_mask[i] == 1 {
91 f - smv_f.re
92 } else {
93 0.0
94 }
95 })
96 .collect();
97
98 (local_field, eroded_mask)
99}
100
101pub fn smv_default(
103 field: &[f64],
104 mask: &[u8],
105 nx: usize, ny: usize, nz: usize,
106 vsx: f64, vsy: f64, vsz: f64,
107) -> (Vec<f64>, Vec<u8>) {
108 let radius = 5.0;
110 smv(field, mask, nx, ny, nz, vsx, vsy, vsz, radius)
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn test_smv_zero_field() {
119 let n = 16;
120 let field = vec![0.0; n * n * n];
121 let mask = vec![1u8; n * n * n];
122
123 let (local, _) = smv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 2.0);
124
125 for &val in local.iter() {
126 assert!(val.abs() < 1e-8, "Zero field should give zero local field, got {}", val);
127 }
128 }
129
130 #[test]
131 fn test_smv_finite() {
132 let n = 16;
133 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.01).collect();
134 let mask = vec![1u8; n * n * n];
135
136 let (local, eroded) = smv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 2.0);
137
138 for (i, &val) in local.iter().enumerate() {
139 assert!(val.is_finite(), "Local field should be finite at index {}", i);
140 }
141
142 let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
143 assert!(eroded_count > 0, "Eroded mask should have some voxels");
144 }
145}