1use std::collections::HashMap;
20use std::f64::consts::PI;
21use delaunator::{triangulate, Point};
22
23pub struct CurvatureResult {
25 pub gaussian_curvature: Vec<f64>,
27 pub mean_curvature: Vec<f64>,
29 pub surface_indices: Vec<usize>,
31}
32
33#[derive(Clone, Copy, Debug)]
35struct Point3D {
36 x: f64,
37 y: f64,
38 z: f64,
39}
40
41impl Point3D {
42 fn new(x: f64, y: f64, z: f64) -> Self {
43 Self { x, y, z }
44 }
45
46 fn sub(&self, other: &Point3D) -> Point3D {
47 Point3D::new(self.x - other.x, self.y - other.y, self.z - other.z)
48 }
49
50 fn dot(&self, other: &Point3D) -> f64 {
51 self.x * other.x + self.y * other.y + self.z * other.z
52 }
53
54 fn cross(&self, other: &Point3D) -> Point3D {
55 Point3D::new(
56 self.y * other.z - self.z * other.y,
57 self.z * other.x - self.x * other.z,
58 self.x * other.y - self.y * other.x,
59 )
60 }
61
62 fn norm(&self) -> f64 {
63 (self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
64 }
65
66 fn normalize(&self) -> Point3D {
67 let n = self.norm();
68 if n > 1e-10 {
69 Point3D::new(self.x / n, self.y / n, self.z / n)
70 } else {
71 Point3D::new(0.0, 0.0, 0.0)
72 }
73 }
74
75 fn scale(&self, s: f64) -> Point3D {
76 Point3D::new(self.x * s, self.y * s, self.z * s)
77 }
78
79 fn add(&self, other: &Point3D) -> Point3D {
80 Point3D::new(self.x + other.x, self.y + other.y, self.z + other.z)
81 }
82}
83
84#[derive(Clone, Copy, Debug)]
86struct Triangle {
87 v0: usize,
88 v1: usize,
89 v2: usize,
90}
91
92fn extract_surface_voxels(
97 mask: &[u8],
98 nx: usize, ny: usize, nz: usize,
99) -> Vec<usize> {
100 let eroded = erode_mask(mask, nx, ny, nz, 1);
101
102 let mut surface = Vec::new();
103 for i in 0..mask.len() {
104 if mask[i] != 0 && eroded[i] == 0 {
105 surface.push(i);
106 }
107 }
108
109 surface
110}
111
112fn triangulate_surface(
122 points: &[Point3D],
123) -> (Vec<Triangle>, Vec<bool>) {
124 if points.len() < 3 {
125 return (Vec::new(), vec![false; points.len()]);
126 }
127
128 let coords: Vec<Point> = points.iter()
130 .map(|p| Point { x: p.x, y: p.y })
131 .collect();
132
133 let result = triangulate(&coords);
135
136 let mut boundary = vec![false; points.len()];
138 for &idx in &result.hull {
139 boundary[idx] = true;
140 }
141
142 let mut triangles = Vec::with_capacity(result.triangles.len() / 3);
144 for i in (0..result.triangles.len()).step_by(3) {
145 triangles.push(Triangle {
146 v0: result.triangles[i],
147 v1: result.triangles[i + 1],
148 v2: result.triangles[i + 2],
149 });
150 }
151
152 (triangles, boundary)
153}
154
155fn compute_curvatures_from_mesh(
162 points: &[Point3D],
163 triangles: &[Triangle],
164 boundary: &[bool],
165) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
166 let n_points = points.len();
167 let mut gaussian_curvature = vec![0.0f64; n_points];
168 let mut mean_curvature = vec![0.0f64; n_points];
169 let mut angle_sum = vec![0.0f64; n_points];
170 let mut area_mixed = vec![0.0f64; n_points];
171 let mut mean_curv_vec = vec![Point3D::new(0.0, 0.0, 0.0); n_points];
172 let mut normal_vec = vec![Point3D::new(0.0, 0.0, 0.0); n_points];
173
174 for tri in triangles {
176 let p0 = &points[tri.v0];
177 let p1 = &points[tri.v1];
178 let p2 = &points[tri.v2];
179
180 let e01 = p1.sub(p0); let e12 = p2.sub(p1); let e20 = p0.sub(p2); let l01 = e01.norm();
186 let l12 = e12.norm();
187 let l20 = e20.norm();
188
189 if l01 < 1e-10 || l12 < 1e-10 || l20 < 1e-10 {
190 continue;
191 }
192
193 let cross = e01.cross(&e12.scale(-1.0));
195 let area = 0.5 * cross.norm();
196 if area < 1e-10 {
197 continue;
198 }
199
200 let face_normal = cross.normalize();
202
203 let cos_a0 = e01.normalize().dot(&e20.scale(-1.0).normalize());
205 let cos_a1 = e01.scale(-1.0).normalize().dot(&e12.normalize());
206 let cos_a2 = e12.scale(-1.0).normalize().dot(&e20.normalize());
207
208 let a0 = cos_a0.clamp(-1.0, 1.0).acos();
209 let a1 = cos_a1.clamp(-1.0, 1.0).acos();
210 let a2 = cos_a2.clamp(-1.0, 1.0).acos();
211
212 angle_sum[tri.v0] += a0;
214 angle_sum[tri.v1] += a1;
215 angle_sum[tri.v2] += a2;
216
217 let cot_a0 = cos_a0 / (1.0 - cos_a0 * cos_a0).sqrt().max(1e-10);
219 let cot_a1 = cos_a1 / (1.0 - cos_a1 * cos_a1).sqrt().max(1e-10);
220 let cot_a2 = cos_a2 / (1.0 - cos_a2 * cos_a2).sqrt().max(1e-10);
221
222 let obtuse_0 = a0 > PI / 2.0;
225 let obtuse_1 = a1 > PI / 2.0;
226 let obtuse_2 = a2 > PI / 2.0;
227
228 if obtuse_0 {
230 area_mixed[tri.v0] += area / 2.0;
231 } else if obtuse_1 || obtuse_2 {
232 area_mixed[tri.v0] += area / 4.0;
233 } else {
234 area_mixed[tri.v0] += (l20 * l20 * cot_a1 + l01 * l01 * cot_a2) / 8.0;
235 }
236
237 if obtuse_1 {
238 area_mixed[tri.v1] += area / 2.0;
239 } else if obtuse_0 || obtuse_2 {
240 area_mixed[tri.v1] += area / 4.0;
241 } else {
242 area_mixed[tri.v1] += (l01 * l01 * cot_a2 + l12 * l12 * cot_a0) / 8.0;
243 }
244
245 if obtuse_2 {
246 area_mixed[tri.v2] += area / 2.0;
247 } else if obtuse_0 || obtuse_1 {
248 area_mixed[tri.v2] += area / 4.0;
249 } else {
250 area_mixed[tri.v2] += (l12 * l12 * cot_a0 + l20 * l20 * cot_a1) / 8.0;
251 }
252
253 mean_curv_vec[tri.v0] = mean_curv_vec[tri.v0].add(&e01.scale(cot_a2).add(&e20.scale(-cot_a1)));
255 mean_curv_vec[tri.v1] = mean_curv_vec[tri.v1].add(&e12.scale(cot_a0).add(&e01.scale(-cot_a2)));
256 mean_curv_vec[tri.v2] = mean_curv_vec[tri.v2].add(&e20.scale(cot_a1).add(&e12.scale(-cot_a0)));
257
258 let perim = l12 + l20 + l01;
262 if perim > 1e-10 {
263 let incenter = p0.scale(l12).add(&p1.scale(l20)).add(&p2.scale(l01)).scale(1.0 / perim);
264
265 let w0 = 1.0 / p0.sub(&incenter).norm().max(1e-10);
266 let w1 = 1.0 / p1.sub(&incenter).norm().max(1e-10);
267 let w2 = 1.0 / p2.sub(&incenter).norm().max(1e-10);
268
269 normal_vec[tri.v0] = normal_vec[tri.v0].add(&face_normal.scale(w0));
270 normal_vec[tri.v1] = normal_vec[tri.v1].add(&face_normal.scale(w1));
271 normal_vec[tri.v2] = normal_vec[tri.v2].add(&face_normal.scale(w2));
272 }
273 }
274
275 for i in 0..n_points {
278 if boundary[i] {
279 continue;
281 }
282
283 if area_mixed[i] > 1e-10 {
284 gaussian_curvature[i] = (2.0 * PI - angle_sum[i]) / area_mixed[i];
286
287 let mc_vec = mean_curv_vec[i].scale(0.25 / area_mixed[i]);
289 let mc_mag = mc_vec.norm();
290
291 let n_vec = normal_vec[i].normalize();
293 let sign = if mc_vec.dot(&n_vec) < 0.0 { -1.0 } else { 1.0 };
294
295 mean_curvature[i] = sign * mc_mag;
296 }
297 }
298
299 (gaussian_curvature, mean_curvature, area_mixed)
300}
301
302pub fn calculate_curvature_proximity(
317 mask: &[u8],
318 prox1: &[f64],
319 lower_lim: f64,
320 curv_constant: f64,
321 sigma: f64,
322 nx: usize, ny: usize, nz: usize,
323) -> (Vec<f64>, Vec<f64>) {
324 let n_total = nx * ny * nz;
325
326 let surface_indices = extract_surface_voxels(mask, nx, ny, nz);
328
329 if surface_indices.is_empty() {
330 return (prox1.to_vec(), vec![1.0; n_total]);
331 }
332
333 let all_points: Vec<Point3D> = surface_indices
335 .iter()
336 .map(|&idx| {
337 let i = idx % nx;
338 let j = (idx / nx) % ny;
339 let k = idx / (nx * ny);
340 Point3D::new(i as f64, j as f64, k as f64)
341 })
342 .collect();
343
344 let mut xy_to_rep: HashMap<(usize, usize), usize> = HashMap::new();
350 let mut is_representative = vec![false; all_points.len()];
351 for (idx, p) in all_points.iter().enumerate() {
352 let key = (p.x as usize, p.y as usize);
353 xy_to_rep.entry(key).or_insert_with(|| {
354 is_representative[idx] = true;
355 idx
356 });
357 }
358
359 let rep_indices: Vec<usize> = (0..all_points.len())
361 .filter(|&i| is_representative[i])
362 .collect();
363 let mut orig_to_rep = vec![0usize; all_points.len()];
364 for (new_idx, &old_idx) in rep_indices.iter().enumerate() {
365 orig_to_rep[old_idx] = new_idx;
366 }
367 let rep_points: Vec<Point3D> = rep_indices.iter().map(|&i| all_points[i].clone()).collect();
368
369 let (triangles, boundary) = triangulate_surface(&rep_points);
371
372 let (gc, _mc, _amixed) = compute_curvatures_from_mesh(&rep_points, &triangles, &boundary);
374
375 let mut curv_i = vec![1.0f64; n_total];
377
378 let max_neg_gc = gc.iter()
380 .filter(|&&v| v < 0.0)
381 .map(|&v| v.abs())
382 .fold(1.0f64, |a, b| a.max(b));
383
384 for (orig_idx, &vol_idx) in surface_indices.iter().enumerate() {
388 if !is_representative[orig_idx] {
389 continue; }
391 let rep_idx = orig_to_rep[orig_idx];
392 let g = gc[rep_idx];
393 let scaled = if g < 0.0 {
394 g / max_neg_gc * curv_constant
395 } else if g > 0.0 {
396 1.0
397 } else {
398 0.0
400 };
401 curv_i[vol_idx] = scaled;
402 }
403
404 let sigmas = [sigma, 2.0 * sigma, 2.0 * sigma];
406 let prox3 = gaussian_smooth_3d_masked(&curv_i, mask, nx, ny, nz, &sigmas);
407
408 let prox3_clamped: Vec<f64> = prox3.iter().enumerate()
410 .map(|(i, &v)| {
411 if mask[i] == 0 {
412 0.0
413 } else if v < 0.5 && v != 0.0 {
414 0.5
415 } else {
416 v
417 }
418 })
419 .collect();
420
421 let mut prox: Vec<f64> = prox1.iter()
423 .zip(prox3_clamped.iter())
424 .map(|(&p1, &p3)| p1 * p3)
425 .collect();
426
427 let surface_mask = create_surface_mask(mask, nx, ny, nz);
433 let dilated_mask = dilate_mask(mask, nx, ny, nz, 5);
434
435 let mut prox4 = vec![0.0f64; n_total];
437 for i in 0..n_total {
438 if surface_mask[i] != 0 {
439 prox4[i] = prox[i];
440 }
441 }
442 for i in 0..n_total {
444 if prox4[i] == 0.0 {
445 prox4[i] = 1.0;
446 }
447 }
448 for i in 0..n_total {
450 if dilated_mask[i] != 0 && mask[i] == 0 {
451 prox4[i] = 0.0;
452 }
453 }
454
455 let prox4_smooth = gaussian_smooth_3d_masked(&prox4, &vec![1u8; n_total], nx, ny, nz, &[5.0, 10.0, 10.0]);
457
458 for i in 0..n_total {
460 if mask[i] == 0 {
461 prox[i] = 0.0;
462 } else if prox[i] < lower_lim && prox[i] != 0.0 {
463 prox[i] = lower_lim;
464 }
465 }
466
467 for i in 0..n_total {
469 prox[i] *= prox4_smooth[i];
470 }
471
472 (prox, curv_i)
473}
474
475fn create_surface_mask(mask: &[u8], nx: usize, ny: usize, nz: usize) -> Vec<u8> {
477 let eroded = erode_mask(mask, nx, ny, nz, 1);
478 let mut surface = vec![0u8; mask.len()];
479
480 for i in 0..mask.len() {
481 if mask[i] != 0 && eroded[i] == 0 {
482 surface[i] = 1;
483 }
484 }
485
486 surface
487}
488
489fn erode_mask(mask: &[u8], nx: usize, ny: usize, nz: usize, radius: i32) -> Vec<u8> {
491 let n_total = nx * ny * nz;
492 let mut eroded = vec![0u8; n_total];
493
494 let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
495
496 for k in 0..nz {
497 for j in 0..ny {
498 for i in 0..nx {
499 if mask[idx(i, j, k)] == 0 {
500 continue;
501 }
502
503 let mut all_inside = true;
504
505 'outer: for dz in -radius..=radius {
506 for dy in -radius..=radius {
507 for dx in -radius..=radius {
508 let dist2 = dx * dx + dy * dy + dz * dz;
509 if dist2 > radius * radius {
510 continue;
511 }
512
513 let ni = i as i32 + dx;
514 let nj = j as i32 + dy;
515 let nk = k as i32 + dz;
516
517 if ni < 0 || ni >= nx as i32 ||
518 nj < 0 || nj >= ny as i32 ||
519 nk < 0 || nk >= nz as i32 {
520 all_inside = false;
521 break 'outer;
522 }
523
524 if mask[idx(ni as usize, nj as usize, nk as usize)] == 0 {
525 all_inside = false;
526 break 'outer;
527 }
528 }
529 }
530 }
531
532 if all_inside {
533 eroded[idx(i, j, k)] = 1;
534 }
535 }
536 }
537 }
538
539 eroded
540}
541
542fn dilate_mask(mask: &[u8], nx: usize, ny: usize, nz: usize, radius: i32) -> Vec<u8> {
544 let n_total = nx * ny * nz;
545 let mut dilated = vec![0u8; n_total];
546
547 let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
548
549 for k in 0..nz {
550 for j in 0..ny {
551 for i in 0..nx {
552 if mask[idx(i, j, k)] != 0 {
553 for dz in -radius..=radius {
555 for dy in -radius..=radius {
556 for dx in -radius..=radius {
557 let dist2 = dx * dx + dy * dy + dz * dz;
558 if dist2 > radius * radius {
559 continue;
560 }
561
562 let ni = i as i32 + dx;
563 let nj = j as i32 + dy;
564 let nk = k as i32 + dz;
565
566 if ni >= 0 && ni < nx as i32 &&
567 nj >= 0 && nj < ny as i32 &&
568 nk >= 0 && nk < nz as i32 {
569 dilated[idx(ni as usize, nj as usize, nk as usize)] = 1;
570 }
571 }
572 }
573 }
574 }
575 }
576 }
577 }
578
579 dilated
580}
581
582pub fn morphological_close(mask: &[u8], nx: usize, ny: usize, nz: usize, radius: i32) -> Vec<u8> {
584 let dilated = dilate_mask(mask, nx, ny, nz, radius);
585 erode_mask(&dilated, nx, ny, nz, radius)
586}
587
588fn gaussian_smooth_3d_masked(
590 data: &[f64],
591 mask: &[u8],
592 nx: usize, ny: usize, nz: usize,
593 sigmas: &[f64; 3],
594) -> Vec<f64> {
595 let smoothed_x = convolve_1d_direction_masked(data, mask, nx, ny, nz, sigmas[0], 'x');
597 let smoothed_xy = convolve_1d_direction_masked(&smoothed_x, mask, nx, ny, nz, sigmas[1], 'y');
598 let smoothed_xyz = convolve_1d_direction_masked(&smoothed_xy, mask, nx, ny, nz, sigmas[2], 'z');
599
600 smoothed_xyz.iter()
602 .enumerate()
603 .map(|(i, &v)| if mask[i] != 0 { v } else { 0.0 })
604 .collect()
605}
606
607fn convolve_1d_direction_masked(
610 data: &[f64],
611 _mask: &[u8],
612 nx: usize, ny: usize, nz: usize,
613 sigma: f64,
614 direction: char,
615) -> Vec<f64> {
616 if sigma <= 0.0 {
617 return data.to_vec();
618 }
619
620 let n_total = nx * ny * nz;
621 let mut result = vec![0.0f64; n_total];
622
623 let kernel_radius = (2.0 * sigma).ceil() as i32;
626 let kernel_size = 2 * kernel_radius + 1;
627 let mut kernel = vec![0.0f64; kernel_size as usize];
628
629 let mut sum = 0.0;
630 for i in 0..kernel_size {
631 let x = (i - kernel_radius) as f64;
632 kernel[i as usize] = (-x * x / (2.0 * sigma * sigma)).exp();
633 sum += kernel[i as usize];
634 }
635
636 for k in kernel.iter_mut() {
638 *k /= sum;
639 }
640
641 let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
642
643 let clamp_x = |x: i32| -> usize { x.max(0).min(nx as i32 - 1) as usize };
645 let clamp_y = |y: i32| -> usize { y.max(0).min(ny as i32 - 1) as usize };
646 let clamp_z = |z: i32| -> usize { z.max(0).min(nz as i32 - 1) as usize };
647
648 match direction {
649 'x' => {
650 for k in 0..nz {
651 for j in 0..ny {
652 for i in 0..nx {
653 let mut conv_sum = 0.0;
654
655 for ki in 0..kernel_size {
656 let offset = ki - kernel_radius;
657 let ni = clamp_x(i as i32 + offset);
658 conv_sum += data[idx(ni, j, k)] * kernel[ki as usize];
659 }
660
661 result[idx(i, j, k)] = conv_sum;
662 }
663 }
664 }
665 }
666 'y' => {
667 for k in 0..nz {
668 for j in 0..ny {
669 for i in 0..nx {
670 let mut conv_sum = 0.0;
671
672 for ki in 0..kernel_size {
673 let offset = ki - kernel_radius;
674 let nj = clamp_y(j as i32 + offset);
675 conv_sum += data[idx(i, nj, k)] * kernel[ki as usize];
676 }
677
678 result[idx(i, j, k)] = conv_sum;
679 }
680 }
681 }
682 }
683 'z' => {
684 for k in 0..nz {
685 for j in 0..ny {
686 for i in 0..nx {
687 let mut conv_sum = 0.0;
688
689 for ki in 0..kernel_size {
690 let offset = ki - kernel_radius;
691 let nk = clamp_z(k as i32 + offset);
692 conv_sum += data[idx(i, j, nk)] * kernel[ki as usize];
693 }
694
695 result[idx(i, j, k)] = conv_sum;
696 }
697 }
698 }
699 }
700 _ => panic!("Invalid convolution direction"),
701 }
702
703 result
704}
705
706pub fn calculate_gaussian_curvature(
709 mask: &[u8],
710 nx: usize, ny: usize, nz: usize,
711) -> CurvatureResult {
712 let n_total = nx * ny * nz;
713
714 let surface_indices = extract_surface_voxels(mask, nx, ny, nz);
716
717 if surface_indices.is_empty() {
718 return CurvatureResult {
719 gaussian_curvature: vec![0.0; n_total],
720 mean_curvature: vec![0.0; n_total],
721 surface_indices: Vec::new(),
722 };
723 }
724
725 let all_points: Vec<Point3D> = surface_indices
727 .iter()
728 .map(|&idx| {
729 let i = idx % nx;
730 let j = (idx / nx) % ny;
731 let k = idx / (nx * ny);
732 Point3D::new(i as f64, j as f64, k as f64)
733 })
734 .collect();
735
736 let mut xy_to_rep: HashMap<(usize, usize), usize> = HashMap::new();
738 let mut is_representative = vec![false; all_points.len()];
739 for (idx, p) in all_points.iter().enumerate() {
740 let key = (p.x as usize, p.y as usize);
741 xy_to_rep.entry(key).or_insert_with(|| {
742 is_representative[idx] = true;
743 idx
744 });
745 }
746 let rep_indices: Vec<usize> = (0..all_points.len())
747 .filter(|&i| is_representative[i])
748 .collect();
749 let mut orig_to_rep = vec![0usize; all_points.len()];
750 for (new_idx, &old_idx) in rep_indices.iter().enumerate() {
751 orig_to_rep[old_idx] = new_idx;
752 }
753 let rep_points: Vec<Point3D> = rep_indices.iter().map(|&i| all_points[i].clone()).collect();
754
755 let (triangles, boundary) = triangulate_surface(&rep_points);
757
758 let (gc_points, mc_points, _amixed) = compute_curvatures_from_mesh(&rep_points, &triangles, &boundary);
760
761 let mut gaussian_curvature = vec![0.0f64; n_total];
763 let mut mean_curvature = vec![0.0f64; n_total];
764
765 for (orig_idx, &vol_idx) in surface_indices.iter().enumerate() {
766 if is_representative[orig_idx] {
767 let rep_idx = orig_to_rep[orig_idx];
768 gaussian_curvature[vol_idx] = gc_points[rep_idx];
769 mean_curvature[vol_idx] = mc_points[rep_idx];
770 }
771 }
772
773 CurvatureResult {
774 gaussian_curvature,
775 mean_curvature,
776 surface_indices,
777 }
778}
779
780#[cfg(test)]
781mod tests {
782 use super::*;
783
784 #[test]
785 fn test_extract_surface_basic() {
786 let mut mask = vec![0u8; 27];
788 mask[13] = 1; let surface = extract_surface_voxels(&mask, 3, 3, 3);
791 assert_eq!(surface.len(), 1);
792 assert_eq!(surface[0], 13);
793 }
794
795 #[test]
796 fn test_erode_mask() {
797 let mask = vec![1u8; 125];
799 let eroded = erode_mask(&mask, 5, 5, 5, 1);
800
801 let count: usize = eroded.iter().map(|&v| v as usize).sum();
803 assert!(count > 0);
804 assert!(count < 125);
805 }
806
807 #[test]
808 fn test_dilate_mask() {
809 let mut mask = vec![0u8; 125];
811 mask[62] = 1; let dilated = dilate_mask(&mask, 5, 5, 5, 1);
814
815 let count: usize = dilated.iter().map(|&v| v as usize).sum();
817 assert!(count >= 7); }
819
820 fn make_sphere_mask(n: usize, radius: f64) -> Vec<u8> {
826 let center = n as f64 / 2.0;
827 let n_total = n * n * n;
828 let mut mask = vec![0u8; n_total];
829
830 for k in 0..n {
831 for j in 0..n {
832 for i in 0..n {
833 let dx = i as f64 - center;
834 let dy = j as f64 - center;
835 let dz = k as f64 - center;
836 let dist = (dx * dx + dy * dy + dz * dz).sqrt();
837 if dist < radius {
838 mask[i + j * n + k * n * n] = 1;
839 }
840 }
841 }
842 }
843
844 mask
845 }
846
847 #[test]
852 fn test_point3d_sub() {
853 let a = Point3D::new(3.0, 4.0, 5.0);
854 let b = Point3D::new(1.0, 1.0, 1.0);
855 let c = a.sub(&b);
856 assert!((c.x - 2.0).abs() < 1e-10);
857 assert!((c.y - 3.0).abs() < 1e-10);
858 assert!((c.z - 4.0).abs() < 1e-10);
859 }
860
861 #[test]
862 fn test_point3d_dot() {
863 let a = Point3D::new(1.0, 2.0, 3.0);
864 let b = Point3D::new(4.0, 5.0, 6.0);
865 let d = a.dot(&b);
866 assert!((d - 32.0).abs() < 1e-10); }
868
869 #[test]
870 fn test_point3d_cross() {
871 let a = Point3D::new(1.0, 0.0, 0.0);
872 let b = Point3D::new(0.0, 1.0, 0.0);
873 let c = a.cross(&b);
874 assert!((c.x - 0.0).abs() < 1e-10);
875 assert!((c.y - 0.0).abs() < 1e-10);
876 assert!((c.z - 1.0).abs() < 1e-10);
877 }
878
879 #[test]
880 fn test_point3d_norm() {
881 let p = Point3D::new(3.0, 4.0, 0.0);
882 assert!((p.norm() - 5.0).abs() < 1e-10);
883 }
884
885 #[test]
886 fn test_point3d_normalize() {
887 let p = Point3D::new(0.0, 0.0, 5.0);
888 let n = p.normalize();
889 assert!((n.x - 0.0).abs() < 1e-10);
890 assert!((n.y - 0.0).abs() < 1e-10);
891 assert!((n.z - 1.0).abs() < 1e-10);
892 }
893
894 #[test]
895 fn test_point3d_normalize_zero() {
896 let p = Point3D::new(0.0, 0.0, 0.0);
897 let n = p.normalize();
898 assert!((n.x).abs() < 1e-10);
899 assert!((n.y).abs() < 1e-10);
900 assert!((n.z).abs() < 1e-10);
901 }
902
903 #[test]
904 fn test_point3d_scale_and_add() {
905 let a = Point3D::new(1.0, 2.0, 3.0);
906 let b = a.scale(2.0);
907 assert!((b.x - 2.0).abs() < 1e-10);
908 assert!((b.y - 4.0).abs() < 1e-10);
909 assert!((b.z - 6.0).abs() < 1e-10);
910
911 let c = Point3D::new(0.5, 0.5, 0.5);
912 let d = b.add(&c);
913 assert!((d.x - 2.5).abs() < 1e-10);
914 assert!((d.y - 4.5).abs() < 1e-10);
915 assert!((d.z - 6.5).abs() < 1e-10);
916 }
917
918 #[test]
923 fn test_extract_surface_sphere() {
924 let n = 10;
925 let mask = make_sphere_mask(n, 3.5);
926 let surface = extract_surface_voxels(&mask, n, n, n);
927
928 assert!(!surface.is_empty(), "Sphere should have surface voxels");
930
931 for &idx in &surface {
933 assert_eq!(mask[idx], 1, "Surface voxel should be in mask");
934 }
935
936 let mask_count: usize = mask.iter().map(|&v| v as usize).sum();
938 assert!(
939 surface.len() < mask_count,
940 "Surface ({}) should be smaller than total mask ({})",
941 surface.len(),
942 mask_count
943 );
944 }
945
946 #[test]
947 fn test_extract_surface_empty_mask() {
948 let mask = vec![0u8; 27];
949 let surface = extract_surface_voxels(&mask, 3, 3, 3);
950 assert!(surface.is_empty(), "Empty mask should have no surface voxels");
951 }
952
953 #[test]
958 fn test_erode_mask_sphere() {
959 let n = 10;
960 let mask = make_sphere_mask(n, 4.0);
961 let eroded = erode_mask(&mask, n, n, n, 1);
962
963 let orig_count: usize = mask.iter().map(|&v| v as usize).sum();
964 let eroded_count: usize = eroded.iter().map(|&v| v as usize).sum();
965 assert!(
966 eroded_count < orig_count,
967 "Eroded sphere should be smaller: {} < {}",
968 eroded_count,
969 orig_count
970 );
971 assert!(eroded_count > 0, "Eroded sphere should not be empty");
972
973 let center = n / 2 + (n / 2) * n + (n / 2) * n * n;
975 assert_eq!(eroded[center], 1, "Center should survive erosion");
976 }
977
978 #[test]
979 fn test_erode_mask_single_voxel() {
980 let mut mask = vec![0u8; 125];
982 mask[62] = 1; let eroded = erode_mask(&mask, 5, 5, 5, 1);
984 let count: usize = eroded.iter().map(|&v| v as usize).sum();
985 assert_eq!(count, 0, "Single voxel should be fully eroded");
986 }
987
988 #[test]
993 fn test_dilate_mask_sphere() {
994 let n = 10;
995 let mask = make_sphere_mask(n, 3.0);
996 let dilated = dilate_mask(&mask, n, n, n, 1);
997
998 let orig_count: usize = mask.iter().map(|&v| v as usize).sum();
999 let dilated_count: usize = dilated.iter().map(|&v| v as usize).sum();
1000 assert!(
1001 dilated_count > orig_count,
1002 "Dilated sphere should be larger: {} > {}",
1003 dilated_count,
1004 orig_count
1005 );
1006 }
1007
1008 #[test]
1009 fn test_dilate_mask_radius_2() {
1010 let mut mask = vec![0u8; 125];
1011 mask[62] = 1; let dilated = dilate_mask(&mask, 5, 5, 5, 2);
1013 let count: usize = dilated.iter().map(|&v| v as usize).sum();
1014 assert!(count > 7, "Radius-2 dilation should produce more than 7 voxels, got {}", count);
1016 }
1017
1018 #[test]
1023 fn test_morphological_close_fills_small_gaps() {
1024 let n = 10;
1025 let mut mask = make_sphere_mask(n, 4.0);
1026 let surface = extract_surface_voxels(&mask, n, n, n);
1028 if !surface.is_empty() {
1029 mask[surface[0]] = 0;
1030 }
1031
1032 let closed = morphological_close(&mask, n, n, n, 1);
1033 let orig_count: usize = mask.iter().map(|&v| v as usize).sum();
1034 let closed_count: usize = closed.iter().map(|&v| v as usize).sum();
1035 assert!(
1037 closed_count >= orig_count,
1038 "Closing should not reduce mask size: {} vs {}",
1039 closed_count,
1040 orig_count
1041 );
1042 }
1043
1044 #[test]
1045 fn test_morphological_close_empty() {
1046 let mask = vec![0u8; 27];
1047 let closed = morphological_close(&mask, 3, 3, 3, 1);
1048 let count: usize = closed.iter().map(|&v| v as usize).sum();
1049 assert_eq!(count, 0, "Closing empty mask should stay empty");
1050 }
1051
1052 #[test]
1057 fn test_create_surface_mask_sphere() {
1058 let n = 10;
1059 let mask = make_sphere_mask(n, 4.0);
1060 let surface = create_surface_mask(&mask, n, n, n);
1061 let surface_count: usize = surface.iter().map(|&v| v as usize).sum();
1062 let mask_count: usize = mask.iter().map(|&v| v as usize).sum();
1063
1064 assert!(surface_count > 0, "Surface mask should be non-empty");
1065 assert!(
1066 surface_count < mask_count,
1067 "Surface ({}) should be smaller than mask ({})",
1068 surface_count,
1069 mask_count
1070 );
1071
1072 for i in 0..surface.len() {
1074 if surface[i] > 0 {
1075 assert_eq!(mask[i], 1, "Surface voxel should be in original mask");
1076 }
1077 }
1078 }
1079
1080 #[test]
1085 fn test_triangulate_surface_few_points() {
1086 let points = vec![Point3D::new(0.0, 0.0, 0.0), Point3D::new(1.0, 1.0, 1.0)];
1088 let (triangles, boundary) = triangulate_surface(&points);
1089 assert!(triangles.is_empty(), "Less than 3 points should give no triangles");
1090 assert_eq!(boundary.len(), 2);
1091 }
1092
1093 #[test]
1094 fn test_triangulate_surface_square_points() {
1095 let points = vec![
1097 Point3D::new(0.0, 0.0, 0.0),
1098 Point3D::new(1.0, 0.0, 0.0),
1099 Point3D::new(0.0, 1.0, 0.0),
1100 Point3D::new(1.0, 1.0, 0.0),
1101 ];
1102 let (triangles, boundary) = triangulate_surface(&points);
1103 assert_eq!(triangles.len(), 2, "4 points should produce 2 triangles");
1105 for &b in &boundary {
1107 assert!(b, "All 4 points should be on boundary");
1108 }
1109 }
1110
1111 #[test]
1116 fn test_compute_curvatures_from_mesh_flat_surface() {
1117 let points = vec![
1119 Point3D::new(0.0, 0.0, 0.0),
1120 Point3D::new(1.0, 0.0, 0.0),
1121 Point3D::new(2.0, 0.0, 0.0),
1122 Point3D::new(0.0, 1.0, 0.0),
1123 Point3D::new(1.0, 1.0, 0.0),
1124 Point3D::new(2.0, 1.0, 0.0),
1125 Point3D::new(0.0, 2.0, 0.0),
1126 Point3D::new(1.0, 2.0, 0.0),
1127 Point3D::new(2.0, 2.0, 0.0),
1128 ];
1129
1130 let triangles = vec![
1132 Triangle { v0: 0, v1: 1, v2: 4 },
1133 Triangle { v0: 0, v1: 4, v2: 3 },
1134 Triangle { v0: 1, v1: 2, v2: 5 },
1135 Triangle { v0: 1, v1: 5, v2: 4 },
1136 Triangle { v0: 3, v1: 4, v2: 7 },
1137 Triangle { v0: 3, v1: 7, v2: 6 },
1138 Triangle { v0: 4, v1: 5, v2: 8 },
1139 Triangle { v0: 4, v1: 8, v2: 7 },
1140 ];
1141
1142 let boundary = vec![true, true, true, true, false, true, true, true, true];
1144
1145 let (gc, mc, _amixed) = compute_curvatures_from_mesh(&points, &triangles, &boundary);
1146
1147 assert!(
1149 gc[4].abs() < 1e-6,
1150 "Flat surface should have ~0 Gaussian curvature, got {}",
1151 gc[4]
1152 );
1153 assert!(
1154 mc[4].abs() < 1e-6,
1155 "Flat surface should have ~0 mean curvature, got {}",
1156 mc[4]
1157 );
1158 }
1159
1160 #[test]
1161 fn test_compute_curvatures_from_mesh_degenerate_triangle() {
1162 let points = vec![
1164 Point3D::new(0.0, 0.0, 0.0),
1165 Point3D::new(1.0, 0.0, 0.0),
1166 Point3D::new(2.0, 0.0, 0.0), ];
1168 let triangles = vec![Triangle { v0: 0, v1: 1, v2: 2 }];
1169 let boundary = vec![false, false, false];
1170 let (gc, mc, _amixed) = compute_curvatures_from_mesh(&points, &triangles, &boundary);
1171 assert_eq!(gc.len(), 3);
1173 assert_eq!(mc.len(), 3);
1174 }
1175
1176 #[test]
1177 fn test_compute_curvatures_from_mesh_boundary_zero() {
1178 let points = vec![
1180 Point3D::new(0.0, 0.0, 0.0),
1181 Point3D::new(1.0, 0.0, 0.0),
1182 Point3D::new(0.5, 1.0, 1.0),
1183 ];
1184 let triangles = vec![Triangle { v0: 0, v1: 1, v2: 2 }];
1185 let boundary = vec![true, true, true]; let (gc, mc, _amixed) = compute_curvatures_from_mesh(&points, &triangles, &boundary);
1187 for i in 0..3 {
1188 assert!((gc[i]).abs() < 1e-10, "Boundary vertex GC should be 0");
1189 assert!((mc[i]).abs() < 1e-10, "Boundary vertex MC should be 0");
1190 }
1191 }
1192
1193 #[test]
1198 fn test_convolve_1d_direction_uniform() {
1199 let n = 8;
1200 let data = vec![5.0; n * n * n];
1201 let mask = vec![1u8; n * n * n];
1202
1203 let result_x = convolve_1d_direction_masked(&data, &mask, n, n, n, 1.0, 'x');
1204 let result_y = convolve_1d_direction_masked(&data, &mask, n, n, n, 1.0, 'y');
1205 let result_z = convolve_1d_direction_masked(&data, &mask, n, n, n, 1.0, 'z');
1206
1207 for &v in &result_x {
1209 assert!((v - 5.0).abs() < 0.1, "X convolution should preserve uniform data, got {}", v);
1210 }
1211 for &v in &result_y {
1212 assert!((v - 5.0).abs() < 0.1, "Y convolution should preserve uniform data, got {}", v);
1213 }
1214 for &v in &result_z {
1215 assert!((v - 5.0).abs() < 0.1, "Z convolution should preserve uniform data, got {}", v);
1216 }
1217 }
1218
1219 #[test]
1220 fn test_convolve_1d_direction_zero_sigma() {
1221 let n = 5;
1222 let data = vec![3.0; n * n * n];
1223 let mask = vec![1u8; n * n * n];
1224
1225 let result = convolve_1d_direction_masked(&data, &mask, n, n, n, 0.0, 'x');
1226 assert_eq!(result, data, "Zero sigma should return copy of input");
1227 }
1228
1229 #[test]
1234 fn test_gaussian_smooth_3d_masked_uniform() {
1235 let n = 8;
1236 let data = vec![10.0; n * n * n];
1237 let mask = vec![1u8; n * n * n];
1238 let sigmas = [1.0, 1.0, 1.0];
1239 let result = gaussian_smooth_3d_masked(&data, &mask, n, n, n, &sigmas);
1240 assert_eq!(result.len(), n * n * n);
1241 for &v in &result {
1242 assert!(v.is_finite(), "Result should be finite");
1243 assert!((v - 10.0).abs() < 1.0, "Uniform data should stay near 10.0, got {}", v);
1244 }
1245 }
1246
1247 #[test]
1248 fn test_gaussian_smooth_3d_masked_applies_mask() {
1249 let n = 8;
1250 let data = vec![10.0; n * n * n];
1251 let mut mask = vec![1u8; n * n * n];
1252 for i in 0..(n * n * n / 2) {
1254 mask[i] = 0;
1255 }
1256 let sigmas = [1.0, 1.0, 1.0];
1257 let result = gaussian_smooth_3d_masked(&data, &mask, n, n, n, &sigmas);
1258 for i in 0..result.len() {
1260 if mask[i] == 0 {
1261 assert!((result[i]).abs() < 1e-10, "Masked-out voxel should be 0, got {}", result[i]);
1262 }
1263 }
1264 }
1265
1266 #[test]
1271 fn test_calculate_gaussian_curvature_sphere() {
1272 let n = 12;
1273 let mask = make_sphere_mask(n, 4.5);
1274 let result = calculate_gaussian_curvature(&mask, n, n, n);
1275
1276 assert_eq!(result.gaussian_curvature.len(), n * n * n);
1277 assert_eq!(result.mean_curvature.len(), n * n * n);
1278 assert!(!result.surface_indices.is_empty(), "Should have surface indices");
1279
1280 for &idx in &result.surface_indices {
1282 assert!(
1283 result.gaussian_curvature[idx].is_finite(),
1284 "GC at surface index {} should be finite",
1285 idx
1286 );
1287 assert!(
1288 result.mean_curvature[idx].is_finite(),
1289 "MC at surface index {} should be finite",
1290 idx
1291 );
1292 }
1293
1294 let surface_set: std::collections::HashSet<usize> =
1296 result.surface_indices.iter().cloned().collect();
1297 for i in 0..(n * n * n) {
1298 if !surface_set.contains(&i) {
1299 assert!(
1300 (result.gaussian_curvature[i]).abs() < 1e-10,
1301 "Non-surface GC should be 0"
1302 );
1303 assert!(
1304 (result.mean_curvature[i]).abs() < 1e-10,
1305 "Non-surface MC should be 0"
1306 );
1307 }
1308 }
1309 }
1310
1311 #[test]
1312 fn test_calculate_gaussian_curvature_empty_mask() {
1313 let n = 5;
1314 let mask = vec![0u8; n * n * n];
1315 let result = calculate_gaussian_curvature(&mask, n, n, n);
1316 assert!(result.surface_indices.is_empty());
1317 assert!(result.gaussian_curvature.iter().all(|&v| v == 0.0));
1318 assert!(result.mean_curvature.iter().all(|&v| v == 0.0));
1319 }
1320
1321 #[test]
1322 fn test_calculate_gaussian_curvature_single_voxel() {
1323 let mut mask = vec![0u8; 125];
1324 mask[62] = 1; let result = calculate_gaussian_curvature(&mask, 5, 5, 5);
1326 assert_eq!(result.gaussian_curvature.len(), 125);
1329 assert_eq!(result.mean_curvature.len(), 125);
1330 }
1331
1332 #[test]
1337 fn test_calculate_curvature_proximity_sphere() {
1338 let n = 12;
1339 let mask = make_sphere_mask(n, 4.5);
1340 let n_total = n * n * n;
1341
1342 let prox1: Vec<f64> = mask.iter().map(|&v| v as f64).collect();
1344
1345 let (prox, curv_i) = calculate_curvature_proximity(
1346 &mask, &prox1, 0.6, 500.0, 1.0, n, n, n,
1347 );
1348
1349 assert_eq!(prox.len(), n_total);
1350 assert_eq!(curv_i.len(), n_total);
1351
1352 for (i, &v) in prox.iter().enumerate() {
1354 assert!(v.is_finite(), "Prox at {} should be finite, got {}", i, v);
1355 }
1356
1357 for (i, &v) in curv_i.iter().enumerate() {
1359 assert!(v.is_finite(), "Curv_i at {} should be finite, got {}", i, v);
1360 }
1361 }
1362
1363 #[test]
1364 fn test_calculate_curvature_proximity_empty_surface() {
1365 let n = 5;
1366 let mask = vec![0u8; n * n * n];
1367 let n_total = n * n * n;
1368 let prox1 = vec![1.0; n_total];
1369
1370 let (prox, curv_i) = calculate_curvature_proximity(
1371 &mask, &prox1, 0.6, 500.0, 1.0, n, n, n,
1372 );
1373
1374 assert_eq!(prox.len(), n_total);
1376 assert_eq!(curv_i.len(), n_total);
1377 for &v in &curv_i {
1378 assert!((v - 1.0).abs() < 1e-10, "Empty surface should give curv_i=1.0");
1379 }
1380 }
1381
1382 #[test]
1383 fn test_calculate_curvature_proximity_respects_mask() {
1384 let n = 12;
1385 let mask = make_sphere_mask(n, 4.5);
1386 let n_total = n * n * n;
1387 let prox1: Vec<f64> = mask.iter().map(|&v| v as f64).collect();
1388
1389 let (prox, _curv_i) = calculate_curvature_proximity(
1390 &mask, &prox1, 0.6, 500.0, 1.0, n, n, n,
1391 );
1392
1393 for i in 0..n_total {
1396 assert!(prox[i].is_finite(), "Prox should be finite everywhere");
1397 }
1398 }
1399
1400 #[test]
1401 fn test_calculate_curvature_proximity_varying_params() {
1402 let n = 12;
1403 let mask = make_sphere_mask(n, 4.5);
1404 let prox1: Vec<f64> = mask.iter().map(|&v| v as f64).collect();
1405
1406 let (prox_a, _) = calculate_curvature_proximity(
1408 &mask, &prox1, 0.3, 100.0, 0.5, n, n, n,
1409 );
1410 let (prox_b, _) = calculate_curvature_proximity(
1411 &mask, &prox1, 0.9, 1000.0, 2.0, n, n, n,
1412 );
1413
1414 for &v in &prox_a {
1416 assert!(v.is_finite());
1417 }
1418 for &v in &prox_b {
1419 assert!(v.is_finite());
1420 }
1421 }
1422}