1use num_complex::Complex64;
15use crate::fft::{fft3d, ifft3d};
16use crate::kernels::smv::smv_kernel;
17
18#[derive(Clone, Debug)]
33pub struct SharpParams {
34 pub threshold: f64,
36 pub radius_factor: f64,
38}
39
40impl Default for SharpParams {
41 fn default() -> Self {
42 Self {
43 threshold: 0.05,
44 radius_factor: 18.0,
45 }
46 }
47}
48
49pub fn sharp(
52 field: &[f64],
53 mask: &[u8],
54 nx: usize, ny: usize, nz: usize,
55 vsx: f64, vsy: f64, vsz: f64,
56 radius: f64,
57 threshold: f64,
58) -> (Vec<f64>, Vec<u8>) {
59 let n_total = nx * ny * nz;
60
61 let s_kernel = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
63
64 let mut s_complex: Vec<Complex64> = s_kernel.iter()
66 .map(|&x| Complex64::new(x, 0.0))
67 .collect();
68 fft3d(&mut s_complex, nx, ny, nz);
69
70 let s_fft: Vec<f64> = s_complex.iter().map(|c| c.re).collect();
73
74 let mask_f64: Vec<f64> = mask.iter().map(|&m| m as f64).collect();
77 let mut mask_complex: Vec<Complex64> = mask_f64.iter()
78 .map(|&x| Complex64::new(x, 0.0))
79 .collect();
80
81 fft3d(&mut mask_complex, nx, ny, nz);
82
83 for i in 0..n_total {
85 mask_complex[i] *= s_complex[i].re; }
87
88 ifft3d(&mut mask_complex, nx, ny, nz);
89
90 let delta = 1.0 - 1e-7_f64.sqrt(); let eroded_mask: Vec<u8> = mask_complex.iter()
93 .map(|c| if c.re > delta { 1 } else { 0 })
94 .collect();
95
96 let mut field_complex: Vec<Complex64> = field.iter()
104 .map(|&x| Complex64::new(x, 0.0))
105 .collect();
106 fft3d(&mut field_complex, nx, ny, nz);
107
108 for i in 0..n_total {
110 field_complex[i] *= 1.0 - s_fft[i];
111 }
112
113 ifft3d(&mut field_complex, nx, ny, nz);
115
116 for i in 0..n_total {
118 if eroded_mask[i] == 0 {
119 field_complex[i] = Complex64::new(0.0, 0.0);
120 }
121 }
122
123 fft3d(&mut field_complex, nx, ny, nz);
125
126 for i in 0..n_total {
128 let one_minus_s = 1.0 - s_fft[i];
129 if one_minus_s.abs() < threshold {
130 field_complex[i] = Complex64::new(0.0, 0.0);
131 } else {
132 field_complex[i] /= one_minus_s;
133 }
134 }
135
136 ifft3d(&mut field_complex, nx, ny, nz);
138
139 let local_field: Vec<f64> = field_complex.iter()
141 .enumerate()
142 .map(|(i, c)| if eroded_mask[i] == 1 { c.re } else { 0.0 })
143 .collect();
144
145 (local_field, eroded_mask)
146}
147
148pub fn sharp_default(
150 field: &[f64],
151 mask: &[u8],
152 nx: usize, ny: usize, nz: usize,
153 vsx: f64, vsy: f64, vsz: f64,
154) -> (Vec<f64>, Vec<u8>) {
155 let radius = 18.0 * vsx.min(vsy).min(vsz);
157 sharp(field, mask, nx, ny, nz, vsx, vsy, vsz, radius, 0.05)
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_sharp_zero_field() {
166 let n = 16;
168 let field = vec![0.0; n * n * n];
169 let mask = vec![1u8; n * n * n];
170
171 let (local, _) = sharp(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 2.0, 0.05);
173
174 for &val in local.iter() {
175 assert!(val.abs() < 1e-8, "Zero field should give zero local field, got {}", val);
176 }
177 }
178
179 #[test]
180 fn test_sharp_finite() {
181 let n = 16;
183 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.01).collect();
184 let mask = vec![1u8; n * n * n];
185
186 let (local, eroded) = sharp(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 2.0, 0.05);
187
188 for (i, &val) in local.iter().enumerate() {
189 assert!(val.is_finite(), "Local field should be finite at index {}", i);
190 }
191
192 let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
194 assert!(eroded_count > 0, "Eroded mask should have some voxels");
195 }
196}