1use num_complex::Complex64;
14use crate::fft::{fft3d, ifft3d};
15use crate::kernels::smv::smv_kernel;
16
17#[derive(Clone, Debug)]
30pub struct IsmvParams {
31 pub tol: f64,
33 pub max_iter: usize,
35 pub radius_factor: f64,
37}
38
39impl Default for IsmvParams {
40 fn default() -> Self {
41 Self { tol: 1e-3, max_iter: 500, radius_factor: 2.0 }
42 }
43}
44
45pub fn ismv(
48 field: &[f64],
49 mask: &[u8],
50 nx: usize, ny: usize, nz: usize,
51 vsx: f64, vsy: f64, vsz: f64,
52 radius: f64,
53 tol: f64,
54 max_iter: usize,
55) -> (Vec<f64>, Vec<u8>) {
56 let n_total = nx * ny * nz;
57
58 let smv = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
60
61 let mut smv_complex: Vec<Complex64> = smv.iter()
63 .map(|&x| Complex64::new(x, 0.0))
64 .collect();
65 fft3d(&mut smv_complex, nx, ny, nz);
66 let smv_fft = smv_complex;
67
68 let m0: Vec<f64> = mask.iter()
70 .map(|&m| if m != 0 { 1.0 } else { 0.0 })
71 .collect();
72
73 let eroded_mask = erode_mask(&m0, &smv_fft, nx, ny, nz);
75
76 let boundary: Vec<f64> = m0.iter()
78 .zip(eroded_mask.iter())
79 .map(|(&m, &e)| m - e)
80 .collect();
81
82 let mut f: Vec<f64> = field.to_vec();
84
85 let mut f0: Vec<f64> = field.iter()
87 .zip(eroded_mask.iter())
88 .map(|(&fi, &m)| fi * m)
89 .collect();
90
91 let bc: Vec<f64> = field.iter()
93 .zip(boundary.iter())
94 .map(|(&fi, &b)| fi * b)
95 .collect();
96
97 let mut nr = vec_norm(&f0);
99 let eps = tol * nr;
100
101 for _iter in 0..max_iter {
103 if nr <= eps {
104 break;
105 }
106
107 let mut f_complex: Vec<Complex64> = f.iter()
109 .map(|&x| Complex64::new(x, 0.0))
110 .collect();
111
112 fft3d(&mut f_complex, nx, ny, nz);
113
114 for i in 0..n_total {
115 f_complex[i] *= smv_fft[i];
116 }
117
118 ifft3d(&mut f_complex, nx, ny, nz);
119
120 for i in 0..n_total {
122 f[i] = eroded_mask[i] * f_complex[i].re + bc[i];
123 }
124
125 let mut residual_sq = 0.0;
127 for i in 0..n_total {
128 let diff = f0[i] - f[i];
129 residual_sq += diff * diff;
130 f0[i] = f[i];
131 }
132 nr = residual_sq.sqrt();
133 }
134
135 let mut local_field = vec![0.0; n_total];
137 for i in 0..n_total {
138 if mask[i] != 0 {
139 local_field[i] = field[i] - f[i];
140 }
141 }
142
143 let eroded_mask_u8: Vec<u8> = eroded_mask.iter()
145 .map(|&m| if m > 0.5 { 1 } else { 0 })
146 .collect();
147
148 (local_field, eroded_mask_u8)
149}
150
151fn erode_mask(mask: &[f64], smv_fft: &[Complex64], nx: usize, ny: usize, nz: usize) -> Vec<f64> {
153 let n_total = nx * ny * nz;
154 let delta = 1.0 - 1e-10;
155
156 let mut m_complex: Vec<Complex64> = mask.iter()
157 .map(|&x| Complex64::new(x, 0.0))
158 .collect();
159
160 fft3d(&mut m_complex, nx, ny, nz);
161
162 for i in 0..n_total {
163 m_complex[i] *= smv_fft[i];
164 }
165
166 ifft3d(&mut m_complex, nx, ny, nz);
167
168 m_complex.iter()
170 .map(|c| if c.re > delta { 1.0 } else { 0.0 })
171 .collect()
172}
173
174fn vec_norm(v: &[f64]) -> f64 {
176 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
177}
178
179pub fn ismv_with_progress<F>(
183 field: &[f64],
184 mask: &[u8],
185 nx: usize, ny: usize, nz: usize,
186 vsx: f64, vsy: f64, vsz: f64,
187 radius: f64,
188 tol: f64,
189 max_iter: usize,
190 mut progress_callback: F,
191) -> (Vec<f64>, Vec<u8>)
192where
193 F: FnMut(usize, usize),
194{
195 let n_total = nx * ny * nz;
196
197 let smv = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
199
200 let mut smv_complex: Vec<Complex64> = smv.iter()
202 .map(|&x| Complex64::new(x, 0.0))
203 .collect();
204 fft3d(&mut smv_complex, nx, ny, nz);
205 let smv_fft = smv_complex;
206
207 let m0: Vec<f64> = mask.iter()
209 .map(|&m| if m != 0 { 1.0 } else { 0.0 })
210 .collect();
211
212 let eroded_mask = erode_mask(&m0, &smv_fft, nx, ny, nz);
214
215 let boundary: Vec<f64> = m0.iter()
217 .zip(eroded_mask.iter())
218 .map(|(&m, &e)| m - e)
219 .collect();
220
221 let mut f: Vec<f64> = field.to_vec();
223
224 let mut f0: Vec<f64> = field.iter()
226 .zip(eroded_mask.iter())
227 .map(|(&fi, &m)| fi * m)
228 .collect();
229
230 let bc: Vec<f64> = field.iter()
232 .zip(boundary.iter())
233 .map(|(&fi, &b)| fi * b)
234 .collect();
235
236 let mut nr = vec_norm(&f0);
238 let eps = tol * nr;
239
240 for iter in 0..max_iter {
242 progress_callback(iter + 1, max_iter);
244
245 if nr <= eps {
246 progress_callback(iter + 1, iter + 1);
247 break;
248 }
249
250 let mut f_complex: Vec<Complex64> = f.iter()
252 .map(|&x| Complex64::new(x, 0.0))
253 .collect();
254
255 fft3d(&mut f_complex, nx, ny, nz);
256
257 for i in 0..n_total {
258 f_complex[i] *= smv_fft[i];
259 }
260
261 ifft3d(&mut f_complex, nx, ny, nz);
262
263 for i in 0..n_total {
265 f[i] = eroded_mask[i] * f_complex[i].re + bc[i];
266 }
267
268 let mut residual_sq = 0.0;
270 for i in 0..n_total {
271 let diff = f0[i] - f[i];
272 residual_sq += diff * diff;
273 f0[i] = f[i];
274 }
275 nr = residual_sq.sqrt();
276 }
277
278 let mut local_field = vec![0.0; n_total];
280 for i in 0..n_total {
281 if mask[i] != 0 {
282 local_field[i] = field[i] - f[i];
283 }
284 }
285
286 let eroded_mask_u8: Vec<u8> = eroded_mask.iter()
288 .map(|&m| if m > 0.5 { 1 } else { 0 })
289 .collect();
290
291 (local_field, eroded_mask_u8)
292}
293
294pub fn ismv_default(
296 field: &[f64],
297 mask: &[u8],
298 nx: usize, ny: usize, nz: usize,
299 vsx: f64, vsy: f64, vsz: f64,
300) -> (Vec<f64>, Vec<u8>) {
301 let radius = 2.0 * vsx.max(vsy).max(vsz);
302 ismv(
303 field, mask, nx, ny, nz, vsx, vsy, vsz,
304 radius,
305 1e-3, 500 )
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_ismv_zero_field() {
316 let n = 8;
317 let field = vec![0.0; n * n * n];
318 let mask = vec![1u8; n * n * n];
319
320 let (local, eroded) = ismv(
321 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
322 2.0, 1e-3, 10
323 );
324
325 for &val in local.iter() {
326 assert!(val.abs() < 1e-10, "Zero field should give zero local field");
327 }
328
329 let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
331 assert!(eroded_count > 0, "Eroded mask should have some voxels");
332 }
333
334 #[test]
335 fn test_ismv_finite() {
336 let n = 8;
337 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
338 let mask = vec![1u8; n * n * n];
339
340 let (local, _eroded) = ismv(
341 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
342 2.0, 1e-3, 20
343 );
344
345 for (i, &val) in local.iter().enumerate() {
346 assert!(val.is_finite(), "Local field should be finite at index {}", i);
347 }
348 }
349
350 #[test]
351 fn test_ismv_preserves_interior() {
352 let n = 16;
354 let field = vec![0.1; n * n * n];
355
356 let mut mask = vec![0u8; n * n * n];
358 let center = n / 2;
359 let radius = n / 3;
360
361 for i in 0..n {
362 for j in 0..n {
363 for k in 0..n {
364 let di = (i as i32) - (center as i32);
365 let dj = (j as i32) - (center as i32);
366 let dk = (k as i32) - (center as i32);
367 if di*di + dj*dj + dk*dk <= (radius * radius) as i32 {
368 mask[i * n * n + j * n + k] = 1;
369 }
370 }
371 }
372 }
373
374 let mask_count: usize = mask.iter().map(|&m| m as usize).sum();
375
376 let (_, eroded) = ismv(
378 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
379 1.5, 1e-3, 50
380 );
381
382 let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
383
384 assert!(eroded_count <= mask_count, "Eroded mask should be smaller than original");
386 assert!(eroded_count > 0, "Eroded mask should have some voxels");
388 }
389
390 #[test]
391 fn test_ismv_convergence() {
392 let n = 8;
393 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
394 let mask = vec![1u8; n * n * n];
395
396 let (local_many, _) = ismv(
398 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
399 2.0, 1e-6, 100
400 );
401
402 let (local_few, _) = ismv(
404 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
405 2.0, 1e-6, 5
406 );
407
408 for (i, &val) in local_many.iter().enumerate() {
410 assert!(val.is_finite(), "iSMV many iters: finite at index {}", i);
411 }
412 for (i, &val) in local_few.iter().enumerate() {
413 assert!(val.is_finite(), "iSMV few iters: finite at index {}", i);
414 }
415
416 let diff_norm: f64 = local_many.iter()
419 .zip(local_few.iter())
420 .map(|(&a, &b)| (a - b).powi(2))
421 .sum::<f64>()
422 .sqrt();
423
424 assert!(diff_norm.is_finite(), "Difference between runs should be finite");
426 }
427
428 #[test]
429 fn test_ismv_different_radius() {
430 let n = 8;
431 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
432 let mask = vec![1u8; n * n * n];
433
434 let (local_small, eroded_small) = ismv(
436 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
437 1.5, 1e-3, 20
438 );
439
440 let (local_large, eroded_large) = ismv(
442 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
443 3.0, 1e-3, 20
444 );
445
446 for (i, &val) in local_small.iter().enumerate() {
448 assert!(val.is_finite(), "iSMV small radius: finite at index {}", i);
449 }
450 for (i, &val) in local_large.iter().enumerate() {
451 assert!(val.is_finite(), "iSMV large radius: finite at index {}", i);
452 }
453
454 let small_count: usize = eroded_small.iter().map(|&m| m as usize).sum();
456 let large_count: usize = eroded_large.iter().map(|&m| m as usize).sum();
457 assert!(
458 large_count <= small_count,
459 "Larger radius should erode more: large={}, small={}",
460 large_count, small_count
461 );
462 }
463
464 #[test]
465 fn test_ismv_larger_volume() {
466 let n = 16;
468
469 let mut field = vec![0.0; n * n * n];
471 for z in 0..n {
472 for y in 0..n {
473 for x in 0..n {
474 field[x + y * n + z * n * n] = (z as f64) * 0.1;
475 }
476 }
477 }
478
479 let mut mask = vec![0u8; n * n * n];
481 let center = n / 2;
482 let radius = n / 3;
483 for z in 0..n {
484 for y in 0..n {
485 for x in 0..n {
486 let dx = (x as i32) - (center as i32);
487 let dy = (y as i32) - (center as i32);
488 let dz = (z as i32) - (center as i32);
489 if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
490 mask[x + y * n + z * n * n] = 1;
491 }
492 }
493 }
494 }
495
496 let (local, eroded) = ismv(
497 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
498 2.0, 1e-3, 50
499 );
500
501 assert_eq!(local.len(), n * n * n);
502 for &val in &local {
503 assert!(val.is_finite());
504 }
505
506 let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
508 assert!(eroded_count > 0, "Eroded mask should have some voxels");
509
510 for i in 0..n * n * n {
512 if mask[i] == 0 {
513 assert_eq!(local[i], 0.0, "Outside mask should be zero");
514 }
515 }
516 }
517
518 #[test]
519 fn test_ismv_with_progress() {
520 let n = 8;
521 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
522 let mask = vec![1u8; n * n * n];
523
524 let mut progress_calls = Vec::new();
525 let (local, _) = ismv_with_progress(
526 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
527 2.0, 1e-3, 20,
528 |iter, max| { progress_calls.push((iter, max)); }
529 );
530
531 assert_eq!(local.len(), n * n * n);
532 assert!(!progress_calls.is_empty(), "Progress callback should be called");
533 for &val in &local {
534 assert!(val.is_finite());
535 }
536 }
537
538 #[test]
539 fn test_ismv_default_wrapper() {
540 let n = 8;
541 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
542 let mask = vec![1u8; n * n * n];
543
544 let (local, eroded) = ismv_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
545
546 assert_eq!(local.len(), n * n * n);
547 for &val in &local {
548 assert!(val.is_finite());
549 }
550 let eroded_count: usize = eroded.iter().map(|&m| m as usize).sum();
551 assert!(eroded_count > 0, "Default iSMV should produce non-empty eroded mask");
552 }
553
554 #[test]
555 fn test_ismv_anisotropic_voxels() {
556 let n = 8;
557 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
558 let mask = vec![1u8; n * n * n];
559
560 let (local, eroded) = ismv(
562 &field, &mask, n, n, n, 0.5, 1.0, 2.0,
563 3.0, 1e-3, 20
564 );
565
566 for &val in &local {
567 assert!(val.is_finite());
568 }
569 let count: usize = eroded.iter().map(|&m| m as usize).sum();
571 assert!(count <= n * n * n);
572 }
573
574 #[test]
575 fn test_ismv_tight_convergence() {
576 let n = 8;
578 let field = vec![0.5; n * n * n]; let mask = vec![1u8; n * n * n];
580
581 let (local, _) = ismv(
582 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
583 2.0, 1e-12, 200 );
585
586 for &val in &local {
587 assert!(val.is_finite());
588 }
589 }
590
591 #[test]
592 fn test_ismv_with_background_mask() {
593 let n = 8;
595 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
596
597 let mut mask = vec![0u8; n * n * n];
599 for z in 1..(n - 1) {
600 for y in 1..(n - 1) {
601 for x in 1..(n - 1) {
602 mask[x + y * n + z * n * n] = 1;
603 }
604 }
605 }
606
607 let (local, eroded) = ismv(
608 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
609 1.5, 1e-3, 30
610 );
611
612 for &val in &local {
613 assert!(val.is_finite());
614 }
615
616 for i in 0..n * n * n {
618 if mask[i] == 0 {
619 assert_eq!(local[i], 0.0, "Outside mask should be zero at index {}", i);
620 }
621 }
622
623 for i in 0..n * n * n {
625 if eroded[i] != 0 {
626 assert_eq!(mask[i], 1, "Eroded voxel must be inside original mask");
627 }
628 }
629 }
630}