1use num_complex::Complex64;
15use crate::fft::{fft3d, ifft3d};
16use crate::kernels::smv::smv_kernel;
17
18#[derive(Clone, Debug)]
33pub struct VsharpParams {
34 pub threshold: f64,
36 pub max_radius_factor: f64,
38 pub min_radius_factor: f64,
40}
41
42impl Default for VsharpParams {
43 fn default() -> Self {
44 Self {
45 threshold: 0.001,
46 max_radius_factor: 18.0,
47 min_radius_factor: 2.0,
48 }
49 }
50}
51
52pub fn vsharp(
55 field: &[f64],
56 mask: &[u8],
57 nx: usize, ny: usize, nz: usize,
58 vsx: f64, vsy: f64, vsz: f64,
59 radii: &[f64],
60 threshold: f64,
61) -> (Vec<f64>, Vec<u8>) {
62 if radii.is_empty() {
63 return (vec![0.0; nx * ny * nz], mask.to_vec());
64 }
65
66 if radii.len() == 1 {
68 return crate::bgremove::sharp::sharp(
69 field, mask, nx, ny, nz, vsx, vsy, vsz, radii[0], threshold
70 );
71 }
72
73 let n_total = nx * ny * nz;
74
75 let mut sorted_radii = radii.to_vec();
77 sorted_radii.sort_by(|a, b| b.partial_cmp(a).unwrap());
78
79 let mut field_complex: Vec<Complex64> = field.iter()
81 .map(|&x| Complex64::new(x, 0.0))
82 .collect();
83 fft3d(&mut field_complex, nx, ny, nz);
84 let field_fft = field_complex.clone();
85
86 let mut processed = vec![false; n_total];
88 let mut local_field = vec![0.0; n_total];
89 let mut final_mask = vec![0u8; n_total];
90
91 let delta = 1.0 - 1e-7_f64.sqrt();
93
94 let mut inverse_kernel: Option<Vec<f64>> = None;
96
97 for &radius in &sorted_radii {
98 let s_kernel = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
100
101 let mut s_complex: Vec<Complex64> = s_kernel.iter()
103 .map(|&x| Complex64::new(x, 0.0))
104 .collect();
105 fft3d(&mut s_complex, nx, ny, nz);
106 let s_fft: Vec<f64> = s_complex.iter().map(|c| c.re).collect();
107
108 if inverse_kernel.is_none() {
110 inverse_kernel = Some(s_fft.iter().map(|&s| {
111 let one_minus_s = 1.0 - s;
112 if one_minus_s.abs() < threshold {
113 0.0
114 } else {
115 1.0 / one_minus_s
116 }
117 }).collect());
118 }
119
120 let mask_f64: Vec<f64> = mask.iter().map(|&m| m as f64).collect();
122 let mut mask_complex: Vec<Complex64> = mask_f64.iter()
123 .map(|&x| Complex64::new(x, 0.0))
124 .collect();
125
126 fft3d(&mut mask_complex, nx, ny, nz);
127
128 for i in 0..n_total {
130 mask_complex[i] *= s_fft[i];
131 }
132
133 ifft3d(&mut mask_complex, nx, ny, nz);
134
135 let current_mask: Vec<bool> = mask_complex.iter()
137 .map(|c| c.re > delta)
138 .collect();
139
140 let mut filtered = field_fft.clone();
142 for i in 0..n_total {
143 filtered[i] *= 1.0 - s_fft[i];
144 }
145
146 ifft3d(&mut filtered, nx, ny, nz);
147
148 for i in 0..n_total {
151 if current_mask[i] && !processed[i] {
152 local_field[i] = filtered[i].re;
153 processed[i] = true;
154 final_mask[i] = 1;
155 }
156 }
157 }
158
159 if let Some(inv_kernel) = inverse_kernel {
161 let mut local_complex: Vec<Complex64> = local_field.iter()
162 .map(|&x| Complex64::new(x, 0.0))
163 .collect();
164
165 fft3d(&mut local_complex, nx, ny, nz);
166
167 for i in 0..n_total {
168 local_complex[i] *= inv_kernel[i];
169 }
170
171 ifft3d(&mut local_complex, nx, ny, nz);
172
173 for i in 0..n_total {
175 local_field[i] = if final_mask[i] == 1 { local_complex[i].re } else { 0.0 };
176 }
177 }
178
179 (local_field, final_mask)
180}
181
182pub fn vsharp_with_progress<F>(
186 field: &[f64],
187 mask: &[u8],
188 nx: usize, ny: usize, nz: usize,
189 vsx: f64, vsy: f64, vsz: f64,
190 radii: &[f64],
191 threshold: f64,
192 mut progress_callback: F,
193) -> (Vec<f64>, Vec<u8>)
194where
195 F: FnMut(usize, usize),
196{
197 if radii.is_empty() {
198 return (vec![0.0; nx * ny * nz], mask.to_vec());
199 }
200
201 if radii.len() == 1 {
203 progress_callback(1, 1);
204 return crate::bgremove::sharp::sharp(
205 field, mask, nx, ny, nz, vsx, vsy, vsz, radii[0], threshold
206 );
207 }
208
209 let n_total = nx * ny * nz;
210 let n_radii = radii.len();
211
212 let mut sorted_radii = radii.to_vec();
214 sorted_radii.sort_by(|a, b| b.partial_cmp(a).unwrap());
215
216 let mut field_complex: Vec<Complex64> = field.iter()
218 .map(|&x| Complex64::new(x, 0.0))
219 .collect();
220 fft3d(&mut field_complex, nx, ny, nz);
221 let field_fft = field_complex.clone();
222
223 let mut processed = vec![false; n_total];
225 let mut local_field = vec![0.0; n_total];
226 let mut final_mask = vec![0u8; n_total];
227
228 let delta = 1.0 - 1e-7_f64.sqrt();
229 let mut inverse_kernel: Option<Vec<f64>> = None;
230
231 for (idx, &radius) in sorted_radii.iter().enumerate() {
232 progress_callback(idx + 1, n_radii);
234
235 let s_kernel = smv_kernel(nx, ny, nz, vsx, vsy, vsz, radius);
237
238 let mut s_complex: Vec<Complex64> = s_kernel.iter()
240 .map(|&x| Complex64::new(x, 0.0))
241 .collect();
242 fft3d(&mut s_complex, nx, ny, nz);
243 let s_fft: Vec<f64> = s_complex.iter().map(|c| c.re).collect();
244
245 if inverse_kernel.is_none() {
247 inverse_kernel = Some(s_fft.iter().map(|&s| {
248 let one_minus_s = 1.0 - s;
249 if one_minus_s.abs() < threshold {
250 0.0
251 } else {
252 1.0 / one_minus_s
253 }
254 }).collect());
255 }
256
257 let mask_f64: Vec<f64> = mask.iter().map(|&m| m as f64).collect();
259 let mut mask_complex: Vec<Complex64> = mask_f64.iter()
260 .map(|&x| Complex64::new(x, 0.0))
261 .collect();
262
263 fft3d(&mut mask_complex, nx, ny, nz);
264
265 for i in 0..n_total {
266 mask_complex[i] *= s_fft[i];
267 }
268
269 ifft3d(&mut mask_complex, nx, ny, nz);
270
271 let current_mask: Vec<bool> = mask_complex.iter()
272 .map(|c| c.re > delta)
273 .collect();
274
275 let mut filtered = field_fft.clone();
277 for i in 0..n_total {
278 filtered[i] *= 1.0 - s_fft[i];
279 }
280
281 ifft3d(&mut filtered, nx, ny, nz);
282
283 for i in 0..n_total {
284 if current_mask[i] && !processed[i] {
285 local_field[i] = filtered[i].re;
286 processed[i] = true;
287 final_mask[i] = 1;
288 }
289 }
290 }
291
292 if let Some(inv_kernel) = inverse_kernel {
294 let mut local_complex: Vec<Complex64> = local_field.iter()
295 .map(|&x| Complex64::new(x, 0.0))
296 .collect();
297
298 fft3d(&mut local_complex, nx, ny, nz);
299
300 for i in 0..n_total {
301 local_complex[i] *= inv_kernel[i];
302 }
303
304 ifft3d(&mut local_complex, nx, ny, nz);
305
306 for i in 0..n_total {
307 local_field[i] = if final_mask[i] == 1 { local_complex[i].re } else { 0.0 };
308 }
309 }
310
311 (local_field, final_mask)
312}
313
314pub fn vsharp_default(
316 field: &[f64],
317 mask: &[u8],
318 nx: usize, ny: usize, nz: usize,
319 vsx: f64, vsy: f64, vsz: f64,
320) -> (Vec<f64>, Vec<u8>) {
321 let min_vox = vsx.min(vsy).min(vsz);
322 let max_vox = vsx.max(vsy).max(vsz);
323
324 let mut radii = Vec::new();
326 let mut r = 18.0 * min_vox;
327 while r >= 2.0 * max_vox {
328 radii.push(r);
329 r -= 2.0 * max_vox;
330 }
331
332 if radii.is_empty() {
333 radii.push(18.0 * min_vox);
334 }
335
336 vsharp(field, mask, nx, ny, nz, vsx, vsy, vsz, &radii, 0.05)
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_vsharp_zero_field() {
345 let n = 8;
346 let field = vec![0.0; n * n * n];
347 let mask = vec![1u8; n * n * n];
348
349 let radii = vec![4.0, 3.0, 2.0];
350 let (local, _) = vsharp(&field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05);
351
352 for &val in local.iter() {
353 assert!(val.abs() < 1e-10);
354 }
355 }
356
357 #[test]
358 fn test_vsharp_preserves_more_than_sharp() {
359 let n = 16;
360 let field = vec![0.0; n * n * n];
361 let mask = vec![1u8; n * n * n];
362
363 let radii = vec![5.0, 4.0, 3.0, 2.0];
365 let (_, vsharp_mask) = vsharp(&field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05);
366
367 let (_, sharp_mask) = crate::bgremove::sharp::sharp(
369 &field, &mask, n, n, n, 1.0, 1.0, 1.0, 5.0, 0.05
370 );
371
372 let vsharp_count: usize = vsharp_mask.iter().map(|&m| m as usize).sum();
373 let sharp_count: usize = sharp_mask.iter().map(|&m| m as usize).sum();
374
375 assert!(vsharp_count >= sharp_count,
377 "V-SHARP {} should preserve at least as many as SHARP {}",
378 vsharp_count, sharp_count);
379 }
380
381 #[test]
382 fn test_vsharp_nonuniform_voxels() {
383 let n = 8;
384 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
385 let mask = vec![1u8; n * n * n];
386
387 let radii = vec![4.0, 3.0, 2.0];
389 let (local, final_mask) = vsharp(
390 &field, &mask, n, n, n, 0.5, 1.0, 2.0, &radii, 0.05
391 );
392
393 for (i, &val) in local.iter().enumerate() {
395 assert!(val.is_finite(), "V-SHARP nonuniform voxels: finite at index {}", i);
396 }
397
398 let mask_count: usize = final_mask.iter().map(|&m| m as usize).sum();
400 assert!(mask_count > 0, "V-SHARP nonuniform: final mask should have some voxels");
401 }
402
403 #[test]
404 fn test_vsharp_single_radius() {
405 let n = 8;
406 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
407 let mask = vec![1u8; n * n * n];
408
409 let radii = vec![3.0];
411 let (local, final_mask) = vsharp(
412 &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05
413 );
414
415 for (i, &val) in local.iter().enumerate() {
417 assert!(val.is_finite(), "V-SHARP single radius: finite at index {}", i);
418 }
419
420 let (sharp_local, sharp_mask) = crate::bgremove::sharp::sharp(
422 &field, &mask, n, n, n, 1.0, 1.0, 1.0, 3.0, 0.05
423 );
424
425 for i in 0..n*n*n {
426 assert!(
427 (local[i] - sharp_local[i]).abs() < 1e-10,
428 "Single-radius V-SHARP should match SHARP at index {}", i
429 );
430 }
431
432 assert_eq!(final_mask, sharp_mask, "Single-radius V-SHARP mask should match SHARP mask");
433 }
434
435 #[test]
436 fn test_vsharp_empty_radii() {
437 let n = 8;
439 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
440 let mask = vec![1u8; n * n * n];
441
442 let (local, returned_mask) = vsharp(
443 &field, &mask, n, n, n, 1.0, 1.0, 1.0, &[], 0.05
444 );
445
446 for &val in &local {
447 assert_eq!(val, 0.0, "Empty radii should return zero local field");
448 }
449 assert_eq!(returned_mask, mask, "Empty radii should return original mask");
450 }
451
452 #[test]
453 fn test_vsharp_larger_volume() {
454 let n = 16;
456 let mut field = vec![0.0; n * n * n];
457 for z in 0..n {
459 for y in 0..n {
460 for x in 0..n {
461 field[x + y * n + z * n * n] = (z as f64) * 0.1;
462 }
463 }
464 }
465
466 let mut mask = vec![0u8; n * n * n];
468 let center = n / 2;
469 let radius = n / 3;
470 for z in 0..n {
471 for y in 0..n {
472 for x in 0..n {
473 let dx = (x as i32) - (center as i32);
474 let dy = (y as i32) - (center as i32);
475 let dz = (z as i32) - (center as i32);
476 if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
477 mask[x + y * n + z * n * n] = 1;
478 }
479 }
480 }
481 }
482
483 let radii = vec![6.0, 4.0, 3.0, 2.0];
484 let (local, final_mask) = vsharp(
485 &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05
486 );
487
488 assert_eq!(local.len(), n * n * n);
489 for &val in &local {
490 assert!(val.is_finite(), "V-SHARP larger volume values should be finite");
491 }
492
493 let mask_count: usize = final_mask.iter().map(|&m| m as usize).sum();
494 assert!(mask_count > 0, "V-SHARP larger volume should have voxels in final mask");
495
496 for i in 0..n * n * n {
498 if final_mask[i] == 0 {
499 assert_eq!(local[i], 0.0, "Outside final mask should be zero");
500 }
501 }
502 }
503
504 #[test]
505 fn test_vsharp_with_progress() {
506 let n = 8;
507 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
508 let mask = vec![1u8; n * n * n];
509
510 let radii = vec![4.0, 3.0, 2.0];
511 let mut progress_calls = Vec::new();
512 let (local, _) = vsharp_with_progress(
513 &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05,
514 |idx, total| { progress_calls.push((idx, total)); }
515 );
516
517 assert_eq!(local.len(), n * n * n);
518 assert!(!progress_calls.is_empty(), "Progress should be called");
519 for &val in &local {
520 assert!(val.is_finite());
521 }
522 }
523
524 #[test]
525 fn test_vsharp_with_progress_single_radius() {
526 let n = 8;
527 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
528 let mask = vec![1u8; n * n * n];
529
530 let radii = vec![3.0];
531 let mut progress_calls = Vec::new();
532 let (local, _) = vsharp_with_progress(
533 &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii, 0.05,
534 |idx, total| { progress_calls.push((idx, total)); }
535 );
536
537 assert_eq!(local.len(), n * n * n);
538 assert!(!progress_calls.is_empty(), "Progress should be called for single radius");
539 for &val in &local {
540 assert!(val.is_finite());
541 }
542 }
543
544 #[test]
545 fn test_vsharp_with_progress_empty_radii() {
546 let n = 8;
547 let field = vec![0.0; n * n * n];
548 let mask = vec![1u8; n * n * n];
549
550 let mut progress_calls = Vec::new();
551 let (local, returned_mask) = vsharp_with_progress(
552 &field, &mask, n, n, n, 1.0, 1.0, 1.0, &[], 0.05,
553 |idx, total| { progress_calls.push((idx, total)); }
554 );
555
556 for &val in &local {
557 assert_eq!(val, 0.0);
558 }
559 assert_eq!(returned_mask, mask);
560 }
561
562 #[test]
563 fn test_vsharp_default_wrapper() {
564 let n = 8;
565 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
566 let mask = vec![1u8; n * n * n];
567
568 let (local, final_mask) = vsharp_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
569
570 assert_eq!(local.len(), n * n * n);
571 for &val in &local {
572 assert!(val.is_finite());
573 }
574 let count: usize = final_mask.iter().map(|&m| m as usize).sum();
576 assert!(count <= n * n * n);
578 }
579
580 #[test]
581 fn test_vsharp_unsorted_radii() {
582 let n = 8;
584 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
585 let mask = vec![1u8; n * n * n];
586
587 let radii_sorted = vec![4.0, 3.0, 2.0];
588 let radii_unsorted = vec![2.0, 4.0, 3.0];
589
590 let (local_sorted, mask_sorted) = vsharp(
591 &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii_sorted, 0.05
592 );
593 let (local_unsorted, mask_unsorted) = vsharp(
594 &field, &mask, n, n, n, 1.0, 1.0, 1.0, &radii_unsorted, 0.05
595 );
596
597 assert_eq!(mask_sorted, mask_unsorted, "Sorted and unsorted radii should give same mask");
599 for i in 0..n * n * n {
600 assert!(
601 (local_sorted[i] - local_unsorted[i]).abs() < 1e-10,
602 "Results should match at index {}",
603 i
604 );
605 }
606 }
607}