Skip to main content

qsm_core/utils/
curvature.rs

1//! Surface Curvature Calculation for QSMART
2//!
3//! This module computes Gaussian and mean curvatures at the surface of a 3D binary mask,
4//! based on the discrete differential geometry approach.
5//!
6//! The curvatures are used in QSMART to weight the spatially-dependent filtering
7//! near brain boundaries to reduce artifacts.
8//!
9//! Uses 2D Delaunay triangulation (via delaunator crate) matching MATLAB's approach:
10//! `tri = delaunay(x, y)` - triangulates on x,y coordinates with z as height.
11//!
12//! Reference:
13//! Meyer, M., Desbrun, M., Schröder, P., Barr, A.H. (2003).
14//! "Discrete Differential-Geometry Operators for Triangulated 2-Manifolds."
15//! Visualization and Mathematics III, 35-57. https://doi.org/10.1007/978-3-662-05105-4_2
16//!
17//! Reference implementation: https://www.mathworks.com/matlabcentral/fileexchange/61136-curvatures
18
19use std::collections::HashMap;
20use std::f64::consts::PI;
21use delaunator::{triangulate, Point};
22
23/// Result of curvature calculation
24pub struct CurvatureResult {
25    /// Gaussian curvature at surface voxels (full volume, 0 for non-surface)
26    pub gaussian_curvature: Vec<f64>,
27    /// Mean curvature at surface voxels (full volume, 0 for non-surface)
28    pub mean_curvature: Vec<f64>,
29    /// Indices of surface voxels
30    pub surface_indices: Vec<usize>,
31}
32
33/// Simple 3D point structure
34#[derive(Clone, Copy, Debug)]
35struct Point3D {
36    x: f64,
37    y: f64,
38    z: f64,
39}
40
41impl Point3D {
42    fn new(x: f64, y: f64, z: f64) -> Self {
43        Self { x, y, z }
44    }
45
46    fn sub(&self, other: &Point3D) -> Point3D {
47        Point3D::new(self.x - other.x, self.y - other.y, self.z - other.z)
48    }
49
50    fn dot(&self, other: &Point3D) -> f64 {
51        self.x * other.x + self.y * other.y + self.z * other.z
52    }
53
54    fn cross(&self, other: &Point3D) -> Point3D {
55        Point3D::new(
56            self.y * other.z - self.z * other.y,
57            self.z * other.x - self.x * other.z,
58            self.x * other.y - self.y * other.x,
59        )
60    }
61
62    fn norm(&self) -> f64 {
63        (self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
64    }
65
66    fn normalize(&self) -> Point3D {
67        let n = self.norm();
68        if n > 1e-10 {
69            Point3D::new(self.x / n, self.y / n, self.z / n)
70        } else {
71            Point3D::new(0.0, 0.0, 0.0)
72        }
73    }
74
75    fn scale(&self, s: f64) -> Point3D {
76        Point3D::new(self.x * s, self.y * s, self.z * s)
77    }
78
79    fn add(&self, other: &Point3D) -> Point3D {
80        Point3D::new(self.x + other.x, self.y + other.y, self.z + other.z)
81    }
82}
83
84/// Triangle structure
85#[derive(Clone, Copy, Debug)]
86struct Triangle {
87    v0: usize,
88    v1: usize,
89    v2: usize,
90}
91
92/// Extract surface voxels from a binary mask
93///
94/// Matches MATLAB's approach: curvMask = mask - imerode(mask, strel('sphere',1))
95/// Surface voxels are those in the mask but not in the eroded mask.
96fn extract_surface_voxels(
97    mask: &[u8],
98    nx: usize, ny: usize, nz: usize,
99) -> Vec<usize> {
100    let eroded = erode_mask(mask, nx, ny, nz, 1);
101
102    let mut surface = Vec::new();
103    for i in 0..mask.len() {
104        if mask[i] != 0 && eroded[i] == 0 {
105            surface.push(i);
106        }
107    }
108
109    surface
110}
111
112/// 2D Delaunay triangulation of surface points
113///
114/// This matches MATLAB's approach: `tri = delaunay(x, y)`
115/// Triangulates on x,y coordinates, treating z as a height field.
116///
117/// Points must have unique (x,y) coordinates (caller should deduplicate first).
118///
119/// Returns (triangles, boundary_flags) where boundary_flags[i] is true
120/// if vertex i is on the convex hull boundary (matching MATLAB's freeBoundary).
121fn triangulate_surface(
122    points: &[Point3D],
123) -> (Vec<Triangle>, Vec<bool>) {
124    if points.len() < 3 {
125        return (Vec::new(), vec![false; points.len()]);
126    }
127
128    // Convert to delaunator's Point format (2D: x, y only)
129    let coords: Vec<Point> = points.iter()
130        .map(|p| Point { x: p.x, y: p.y })
131        .collect();
132
133    // Run 2D Delaunay triangulation
134    let result = triangulate(&coords);
135
136    // Identify boundary vertices (convex hull of the 2D triangulation)
137    let mut boundary = vec![false; points.len()];
138    for &idx in &result.hull {
139        boundary[idx] = true;
140    }
141
142    // Convert triangles
143    let mut triangles = Vec::with_capacity(result.triangles.len() / 3);
144    for i in (0..result.triangles.len()).step_by(3) {
145        triangles.push(Triangle {
146            v0: result.triangles[i],
147            v1: result.triangles[i + 1],
148            v2: result.triangles[i + 2],
149        });
150    }
151
152    (triangles, boundary)
153}
154
155/// Compute Gaussian and mean curvatures using discrete differential geometry
156///
157/// Based on Meyer et al., "Discrete differential-geometry operators for triangulated 2-manifolds"
158///
159/// Boundary vertices (on the triangulation free boundary) get GC=0, MC=0
160/// to match MATLAB's curvatures.m behavior.
161fn compute_curvatures_from_mesh(
162    points: &[Point3D],
163    triangles: &[Triangle],
164    boundary: &[bool],
165) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
166    let n_points = points.len();
167    let mut gaussian_curvature = vec![0.0f64; n_points];
168    let mut mean_curvature = vec![0.0f64; n_points];
169    let mut angle_sum = vec![0.0f64; n_points];
170    let mut area_mixed = vec![0.0f64; n_points];
171    let mut mean_curv_vec = vec![Point3D::new(0.0, 0.0, 0.0); n_points];
172    let mut normal_vec = vec![Point3D::new(0.0, 0.0, 0.0); n_points];
173
174    // Process each triangle
175    for tri in triangles {
176        let p0 = &points[tri.v0];
177        let p1 = &points[tri.v1];
178        let p2 = &points[tri.v2];
179
180        // Edge vectors
181        let e01 = p1.sub(p0); // v0 -> v1
182        let e12 = p2.sub(p1); // v1 -> v2
183        let e20 = p0.sub(p2); // v2 -> v0
184
185        let l01 = e01.norm();
186        let l12 = e12.norm();
187        let l20 = e20.norm();
188
189        if l01 < 1e-10 || l12 < 1e-10 || l20 < 1e-10 {
190            continue;
191        }
192
193        // Triangle area
194        let cross = e01.cross(&e12.scale(-1.0));
195        let area = 0.5 * cross.norm();
196        if area < 1e-10 {
197            continue;
198        }
199
200        // Triangle normal
201        let face_normal = cross.normalize();
202
203        // Angles at each vertex
204        let cos_a0 = e01.normalize().dot(&e20.scale(-1.0).normalize());
205        let cos_a1 = e01.scale(-1.0).normalize().dot(&e12.normalize());
206        let cos_a2 = e12.scale(-1.0).normalize().dot(&e20.normalize());
207
208        let a0 = cos_a0.clamp(-1.0, 1.0).acos();
209        let a1 = cos_a1.clamp(-1.0, 1.0).acos();
210        let a2 = cos_a2.clamp(-1.0, 1.0).acos();
211
212        // Accumulate angle sums for Gaussian curvature
213        angle_sum[tri.v0] += a0;
214        angle_sum[tri.v1] += a1;
215        angle_sum[tri.v2] += a2;
216
217        // Compute cotangent weights for mean curvature
218        let cot_a0 = cos_a0 / (1.0 - cos_a0 * cos_a0).sqrt().max(1e-10);
219        let cot_a1 = cos_a1 / (1.0 - cos_a1 * cos_a1).sqrt().max(1e-10);
220        let cot_a2 = cos_a2 / (1.0 - cos_a2 * cos_a2).sqrt().max(1e-10);
221
222        // Compute A_mixed for each vertex
223        // Check if any angle is obtuse
224        let obtuse_0 = a0 > PI / 2.0;
225        let obtuse_1 = a1 > PI / 2.0;
226        let obtuse_2 = a2 > PI / 2.0;
227
228        // Add contribution to A_mixed for each vertex
229        if obtuse_0 {
230            area_mixed[tri.v0] += area / 2.0;
231        } else if obtuse_1 || obtuse_2 {
232            area_mixed[tri.v0] += area / 4.0;
233        } else {
234            area_mixed[tri.v0] += (l20 * l20 * cot_a1 + l01 * l01 * cot_a2) / 8.0;
235        }
236
237        if obtuse_1 {
238            area_mixed[tri.v1] += area / 2.0;
239        } else if obtuse_0 || obtuse_2 {
240            area_mixed[tri.v1] += area / 4.0;
241        } else {
242            area_mixed[tri.v1] += (l01 * l01 * cot_a2 + l12 * l12 * cot_a0) / 8.0;
243        }
244
245        if obtuse_2 {
246            area_mixed[tri.v2] += area / 2.0;
247        } else if obtuse_0 || obtuse_1 {
248            area_mixed[tri.v2] += area / 4.0;
249        } else {
250            area_mixed[tri.v2] += (l12 * l12 * cot_a0 + l20 * l20 * cot_a1) / 8.0;
251        }
252
253        // Mean curvature vector contribution
254        mean_curv_vec[tri.v0] = mean_curv_vec[tri.v0].add(&e01.scale(cot_a2).add(&e20.scale(-cot_a1)));
255        mean_curv_vec[tri.v1] = mean_curv_vec[tri.v1].add(&e12.scale(cot_a0).add(&e01.scale(-cot_a2)));
256        mean_curv_vec[tri.v2] = mean_curv_vec[tri.v2].add(&e20.scale(cot_a1).add(&e12.scale(-cot_a0)));
257
258        // Accumulate face normal for vertex normal using incenter-based distance weighting
259        // Matches MATLAB: wi = 1/norm(incenter - vertex); n_vec += wi * faceNormal
260        // Incenter = (a*P0 + b*P1 + c*P2) / (a+b+c) where a=|P1P2|, b=|P0P2|, c=|P0P1|
261        let perim = l12 + l20 + l01;
262        if perim > 1e-10 {
263            let incenter = p0.scale(l12).add(&p1.scale(l20)).add(&p2.scale(l01)).scale(1.0 / perim);
264
265            let w0 = 1.0 / p0.sub(&incenter).norm().max(1e-10);
266            let w1 = 1.0 / p1.sub(&incenter).norm().max(1e-10);
267            let w2 = 1.0 / p2.sub(&incenter).norm().max(1e-10);
268
269            normal_vec[tri.v0] = normal_vec[tri.v0].add(&face_normal.scale(w0));
270            normal_vec[tri.v1] = normal_vec[tri.v1].add(&face_normal.scale(w1));
271            normal_vec[tri.v2] = normal_vec[tri.v2].add(&face_normal.scale(w2));
272        }
273    }
274
275    // Compute final curvature values
276    // Skip boundary vertices (GC=0, MC=0) matching MATLAB's freeBoundary check
277    for i in 0..n_points {
278        if boundary[i] {
279            // Boundary vertices get zero curvature (unreliable)
280            continue;
281        }
282
283        if area_mixed[i] > 1e-10 {
284            // Gaussian curvature: K = (2π - Σθ) / A_mixed
285            gaussian_curvature[i] = (2.0 * PI - angle_sum[i]) / area_mixed[i];
286
287            // Mean curvature: H = |mean_curv_vec| / (4 * A_mixed)
288            let mc_vec = mean_curv_vec[i].scale(0.25 / area_mixed[i]);
289            let mc_mag = mc_vec.norm();
290
291            // Determine sign from dot product with normal
292            let n_vec = normal_vec[i].normalize();
293            let sign = if mc_vec.dot(&n_vec) < 0.0 { -1.0 } else { 1.0 };
294
295            mean_curvature[i] = sign * mc_mag;
296        }
297    }
298
299    (gaussian_curvature, mean_curvature, area_mixed)
300}
301
302/// Calculate proximity maps using curvature at the brain surface
303///
304/// This is the main entry point matching QSMART's calculate_curvature function.
305///
306/// # Arguments
307/// * `mask` - Binary brain mask
308/// * `prox1` - Initial proximity map from Gaussian smoothing
309/// * `lower_lim` - Clamping value for proximity (default 0.6)
310/// * `curv_constant` - Scaling constant for curvature (default 500)
311/// * `sigma` - Kernel size for smoothing curvature
312/// * `nx`, `ny`, `nz` - Volume dimensions
313///
314/// # Returns
315/// Modified proximity map incorporating curvature-based edge weighting
316pub fn calculate_curvature_proximity(
317    mask: &[u8],
318    prox1: &[f64],
319    lower_lim: f64,
320    curv_constant: f64,
321    sigma: f64,
322    nx: usize, ny: usize, nz: usize,
323) -> (Vec<f64>, Vec<f64>) {
324    let n_total = nx * ny * nz;
325
326    // Extract surface voxels
327    let surface_indices = extract_surface_voxels(mask, nx, ny, nz);
328
329    if surface_indices.is_empty() {
330        return (prox1.to_vec(), vec![1.0; n_total]);
331    }
332
333    // Convert surface indices to 3D points
334    let all_points: Vec<Point3D> = surface_indices
335        .iter()
336        .map(|&idx| {
337            let i = idx % nx;
338            let j = (idx / nx) % ny;
339            let k = idx / (nx * ny);
340            Point3D::new(i as f64, j as f64, k as f64)
341        })
342        .collect();
343
344    // Deduplicate (x,y) coordinates before triangulation.
345    // MATLAB's delaunay (via Qhull) suppresses duplicate (x,y) points, keeping
346    // the first occurrence (smallest z). Duplicate vertices get GC=Inf → curvI=1.0.
347    // We explicitly dedup to avoid degenerate zero-area triangles that would
348    // corrupt curvature values.
349    let mut xy_to_rep: HashMap<(usize, usize), usize> = HashMap::new();
350    let mut is_representative = vec![false; all_points.len()];
351    for (idx, p) in all_points.iter().enumerate() {
352        let key = (p.x as usize, p.y as usize);
353        xy_to_rep.entry(key).or_insert_with(|| {
354            is_representative[idx] = true;
355            idx
356        });
357    }
358
359    // Build representative point array and index mapping
360    let rep_indices: Vec<usize> = (0..all_points.len())
361        .filter(|&i| is_representative[i])
362        .collect();
363    let mut orig_to_rep = vec![0usize; all_points.len()];
364    for (new_idx, &old_idx) in rep_indices.iter().enumerate() {
365        orig_to_rep[old_idx] = new_idx;
366    }
367    let rep_points: Vec<Point3D> = rep_indices.iter().map(|&i| all_points[i].clone()).collect();
368
369    // Triangulate unique representatives via Qhull (same library MATLAB uses)
370    let (triangles, boundary) = triangulate_surface(&rep_points);
371
372    // Compute curvatures on representative points
373    let (gc, _mc, _amixed) = compute_curvatures_from_mesh(&rep_points, &triangles, &boundary);
374
375    // Create full curvature volume
376    let mut curv_i = vec![1.0f64; n_total];
377
378    // Find max negative curvature for scaling
379    let max_neg_gc = gc.iter()
380        .filter(|&&v| v < 0.0)
381        .map(|&v| v.abs())
382        .fold(1.0f64, |a, b| a.max(b));
383
384    // Scale and assign curvature values for representative vertices only.
385    // Non-representative (duplicate x,y) vertices keep curvI=1.0, matching MATLAB
386    // where suppressed duplicates get GC=Inf → scaledGC=1.0.
387    for (orig_idx, &vol_idx) in surface_indices.iter().enumerate() {
388        if !is_representative[orig_idx] {
389            continue; // duplicate (x,y) → curvI=1.0
390        }
391        let rep_idx = orig_to_rep[orig_idx];
392        let g = gc[rep_idx];
393        let scaled = if g < 0.0 {
394            g / max_neg_gc * curv_constant
395        } else if g > 0.0 {
396            1.0
397        } else {
398            // GC == 0: boundary vertices, flat regions
399            0.0
400        };
401        curv_i[vol_idx] = scaled;
402    }
403
404    // Smooth the curvature map
405    let sigmas = [sigma, 2.0 * sigma, 2.0 * sigma];
406    let prox3 = gaussian_smooth_3d_masked(&curv_i, mask, nx, ny, nz, &sigmas);
407
408    // Clamp prox3 values
409    let prox3_clamped: Vec<f64> = prox3.iter().enumerate()
410        .map(|(i, &v)| {
411            if mask[i] == 0 {
412                0.0
413            } else if v < 0.5 && v != 0.0 {
414                0.5
415            } else {
416                v
417            }
418        })
419        .collect();
420
421    // Multiply with initial proximity
422    let mut prox: Vec<f64> = prox1.iter()
423        .zip(prox3_clamped.iter())
424        .map(|(&p1, &p3)| p1 * p3)
425        .collect();
426
427    // Edge proximity calculation (prox4)
428    // Matches MATLAB order of operations:
429    //   prox4 = prox .* (mask - imerode(mask, strel('sphere',1)));
430    //   prox4(prox4==0) = 1;
431    //   prox4((imdilate(mask, strel('sphere',5)) - mask)==1) = 0;
432    let surface_mask = create_surface_mask(mask, nx, ny, nz);
433    let dilated_mask = dilate_mask(mask, nx, ny, nz, 5);
434
435    // Step 1: prox4 = prox * surface_mask (surface voxels get prox, rest get 0)
436    let mut prox4 = vec![0.0f64; n_total];
437    for i in 0..n_total {
438        if surface_mask[i] != 0 {
439            prox4[i] = prox[i];
440        }
441    }
442    // Step 2: set ALL zero-valued voxels to 1 (interior + outside + surface with prox==0)
443    for i in 0..n_total {
444        if prox4[i] == 0.0 {
445            prox4[i] = 1.0;
446        }
447    }
448    // Step 3: set dilated shell outside mask to 0
449    for i in 0..n_total {
450        if dilated_mask[i] != 0 && mask[i] == 0 {
451            prox4[i] = 0.0;
452        }
453    }
454
455    // Smooth prox4
456    let prox4_smooth = gaussian_smooth_3d_masked(&prox4, &vec![1u8; n_total], nx, ny, nz, &[5.0, 10.0, 10.0]);
457
458    // Clamp proximity values
459    for i in 0..n_total {
460        if mask[i] == 0 {
461            prox[i] = 0.0;
462        } else if prox[i] < lower_lim && prox[i] != 0.0 {
463            prox[i] = lower_lim;
464        }
465    }
466
467    // Edge refinement
468    for i in 0..n_total {
469        prox[i] *= prox4_smooth[i];
470    }
471
472    (prox, curv_i)
473}
474
475/// Create a surface mask (boundary voxels)
476fn create_surface_mask(mask: &[u8], nx: usize, ny: usize, nz: usize) -> Vec<u8> {
477    let eroded = erode_mask(mask, nx, ny, nz, 1);
478    let mut surface = vec![0u8; mask.len()];
479
480    for i in 0..mask.len() {
481        if mask[i] != 0 && eroded[i] == 0 {
482            surface[i] = 1;
483        }
484    }
485
486    surface
487}
488
489/// Erode a binary mask using spherical structuring element
490fn erode_mask(mask: &[u8], nx: usize, ny: usize, nz: usize, radius: i32) -> Vec<u8> {
491    let n_total = nx * ny * nz;
492    let mut eroded = vec![0u8; n_total];
493
494    let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
495
496    for k in 0..nz {
497        for j in 0..ny {
498            for i in 0..nx {
499                if mask[idx(i, j, k)] == 0 {
500                    continue;
501                }
502
503                let mut all_inside = true;
504
505                'outer: for dz in -radius..=radius {
506                    for dy in -radius..=radius {
507                        for dx in -radius..=radius {
508                            let dist2 = dx * dx + dy * dy + dz * dz;
509                            if dist2 > radius * radius {
510                                continue;
511                            }
512
513                            let ni = i as i32 + dx;
514                            let nj = j as i32 + dy;
515                            let nk = k as i32 + dz;
516
517                            if ni < 0 || ni >= nx as i32 ||
518                               nj < 0 || nj >= ny as i32 ||
519                               nk < 0 || nk >= nz as i32 {
520                                all_inside = false;
521                                break 'outer;
522                            }
523
524                            if mask[idx(ni as usize, nj as usize, nk as usize)] == 0 {
525                                all_inside = false;
526                                break 'outer;
527                            }
528                        }
529                    }
530                }
531
532                if all_inside {
533                    eroded[idx(i, j, k)] = 1;
534                }
535            }
536        }
537    }
538
539    eroded
540}
541
542/// Dilate a binary mask using spherical structuring element
543fn dilate_mask(mask: &[u8], nx: usize, ny: usize, nz: usize, radius: i32) -> Vec<u8> {
544    let n_total = nx * ny * nz;
545    let mut dilated = vec![0u8; n_total];
546
547    let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
548
549    for k in 0..nz {
550        for j in 0..ny {
551            for i in 0..nx {
552                if mask[idx(i, j, k)] != 0 {
553                    // Set all neighbors within radius
554                    for dz in -radius..=radius {
555                        for dy in -radius..=radius {
556                            for dx in -radius..=radius {
557                                let dist2 = dx * dx + dy * dy + dz * dz;
558                                if dist2 > radius * radius {
559                                    continue;
560                                }
561
562                                let ni = i as i32 + dx;
563                                let nj = j as i32 + dy;
564                                let nk = k as i32 + dz;
565
566                                if ni >= 0 && ni < nx as i32 &&
567                                   nj >= 0 && nj < ny as i32 &&
568                                   nk >= 0 && nk < nz as i32 {
569                                    dilated[idx(ni as usize, nj as usize, nk as usize)] = 1;
570                                }
571                            }
572                        }
573                    }
574                }
575            }
576        }
577    }
578
579    dilated
580}
581
582/// Morphological closing (dilation followed by erosion)
583pub fn morphological_close(mask: &[u8], nx: usize, ny: usize, nz: usize, radius: i32) -> Vec<u8> {
584    let dilated = dilate_mask(mask, nx, ny, nz, radius);
585    erode_mask(&dilated, nx, ny, nz, radius)
586}
587
588/// 3D Gaussian smoothing with anisotropic sigma
589fn gaussian_smooth_3d_masked(
590    data: &[f64],
591    mask: &[u8],
592    nx: usize, ny: usize, nz: usize,
593    sigmas: &[f64; 3],
594) -> Vec<f64> {
595    // Apply separable 1D convolutions
596    let smoothed_x = convolve_1d_direction_masked(data, mask, nx, ny, nz, sigmas[0], 'x');
597    let smoothed_xy = convolve_1d_direction_masked(&smoothed_x, mask, nx, ny, nz, sigmas[1], 'y');
598    let smoothed_xyz = convolve_1d_direction_masked(&smoothed_xy, mask, nx, ny, nz, sigmas[2], 'z');
599
600    // Apply mask
601    smoothed_xyz.iter()
602        .enumerate()
603        .map(|(i, &v)| if mask[i] != 0 { v } else { 0.0 })
604        .collect()
605}
606
607/// 1D convolution with Gaussian kernel along specified axis
608/// Uses replicate padding to match MATLAB's imgaussfilt3 behavior
609fn convolve_1d_direction_masked(
610    data: &[f64],
611    _mask: &[u8],
612    nx: usize, ny: usize, nz: usize,
613    sigma: f64,
614    direction: char,
615) -> Vec<f64> {
616    if sigma <= 0.0 {
617        return data.to_vec();
618    }
619
620    let n_total = nx * ny * nz;
621    let mut result = vec![0.0f64; n_total];
622
623    // Create 1D Gaussian kernel
624    // Match MATLAB's imgaussfilt3 default: filterSize = 2*ceil(2*sigma)+1
625    let kernel_radius = (2.0 * sigma).ceil() as i32;
626    let kernel_size = 2 * kernel_radius + 1;
627    let mut kernel = vec![0.0f64; kernel_size as usize];
628
629    let mut sum = 0.0;
630    for i in 0..kernel_size {
631        let x = (i - kernel_radius) as f64;
632        kernel[i as usize] = (-x * x / (2.0 * sigma * sigma)).exp();
633        sum += kernel[i as usize];
634    }
635
636    // Normalize
637    for k in kernel.iter_mut() {
638        *k /= sum;
639    }
640
641    let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
642
643    // Helper functions for replicate padding (clamp to valid range)
644    let clamp_x = |x: i32| -> usize { x.max(0).min(nx as i32 - 1) as usize };
645    let clamp_y = |y: i32| -> usize { y.max(0).min(ny as i32 - 1) as usize };
646    let clamp_z = |z: i32| -> usize { z.max(0).min(nz as i32 - 1) as usize };
647
648    match direction {
649        'x' => {
650            for k in 0..nz {
651                for j in 0..ny {
652                    for i in 0..nx {
653                        let mut conv_sum = 0.0;
654
655                        for ki in 0..kernel_size {
656                            let offset = ki - kernel_radius;
657                            let ni = clamp_x(i as i32 + offset);
658                            conv_sum += data[idx(ni, j, k)] * kernel[ki as usize];
659                        }
660
661                        result[idx(i, j, k)] = conv_sum;
662                    }
663                }
664            }
665        }
666        'y' => {
667            for k in 0..nz {
668                for j in 0..ny {
669                    for i in 0..nx {
670                        let mut conv_sum = 0.0;
671
672                        for ki in 0..kernel_size {
673                            let offset = ki - kernel_radius;
674                            let nj = clamp_y(j as i32 + offset);
675                            conv_sum += data[idx(i, nj, k)] * kernel[ki as usize];
676                        }
677
678                        result[idx(i, j, k)] = conv_sum;
679                    }
680                }
681            }
682        }
683        'z' => {
684            for k in 0..nz {
685                for j in 0..ny {
686                    for i in 0..nx {
687                        let mut conv_sum = 0.0;
688
689                        for ki in 0..kernel_size {
690                            let offset = ki - kernel_radius;
691                            let nk = clamp_z(k as i32 + offset);
692                            conv_sum += data[idx(i, j, nk)] * kernel[ki as usize];
693                        }
694
695                        result[idx(i, j, k)] = conv_sum;
696                    }
697                }
698            }
699        }
700        _ => panic!("Invalid convolution direction"),
701    }
702
703    result
704}
705
706/// Simple Gaussian curvature calculation for mask boundary
707/// Returns full volume with curvature values at surface voxels
708pub fn calculate_gaussian_curvature(
709    mask: &[u8],
710    nx: usize, ny: usize, nz: usize,
711) -> CurvatureResult {
712    let n_total = nx * ny * nz;
713
714    // Extract surface voxels
715    let surface_indices = extract_surface_voxels(mask, nx, ny, nz);
716
717    if surface_indices.is_empty() {
718        return CurvatureResult {
719            gaussian_curvature: vec![0.0; n_total],
720            mean_curvature: vec![0.0; n_total],
721            surface_indices: Vec::new(),
722        };
723    }
724
725    // Convert surface indices to 3D points
726    let all_points: Vec<Point3D> = surface_indices
727        .iter()
728        .map(|&idx| {
729            let i = idx % nx;
730            let j = (idx / nx) % ny;
731            let k = idx / (nx * ny);
732            Point3D::new(i as f64, j as f64, k as f64)
733        })
734        .collect();
735
736    // Deduplicate (x,y) — same logic as calculate_curvature_proximity
737    let mut xy_to_rep: HashMap<(usize, usize), usize> = HashMap::new();
738    let mut is_representative = vec![false; all_points.len()];
739    for (idx, p) in all_points.iter().enumerate() {
740        let key = (p.x as usize, p.y as usize);
741        xy_to_rep.entry(key).or_insert_with(|| {
742            is_representative[idx] = true;
743            idx
744        });
745    }
746    let rep_indices: Vec<usize> = (0..all_points.len())
747        .filter(|&i| is_representative[i])
748        .collect();
749    let mut orig_to_rep = vec![0usize; all_points.len()];
750    for (new_idx, &old_idx) in rep_indices.iter().enumerate() {
751        orig_to_rep[old_idx] = new_idx;
752    }
753    let rep_points: Vec<Point3D> = rep_indices.iter().map(|&i| all_points[i].clone()).collect();
754
755    // Triangulate unique representatives via Qhull
756    let (triangles, boundary) = triangulate_surface(&rep_points);
757
758    // Compute curvatures on representatives
759    let (gc_points, mc_points, _amixed) = compute_curvatures_from_mesh(&rep_points, &triangles, &boundary);
760
761    // Create full volumes — only representative vertices get curvature values
762    let mut gaussian_curvature = vec![0.0f64; n_total];
763    let mut mean_curvature = vec![0.0f64; n_total];
764
765    for (orig_idx, &vol_idx) in surface_indices.iter().enumerate() {
766        if is_representative[orig_idx] {
767            let rep_idx = orig_to_rep[orig_idx];
768            gaussian_curvature[vol_idx] = gc_points[rep_idx];
769            mean_curvature[vol_idx] = mc_points[rep_idx];
770        }
771    }
772
773    CurvatureResult {
774        gaussian_curvature,
775        mean_curvature,
776        surface_indices,
777    }
778}
779
780#[cfg(test)]
781mod tests {
782    use super::*;
783
784    #[test]
785    fn test_extract_surface_basic() {
786        // 3x3x3 cube with center filled
787        let mut mask = vec![0u8; 27];
788        mask[13] = 1; // Center voxel
789
790        let surface = extract_surface_voxels(&mask, 3, 3, 3);
791        assert_eq!(surface.len(), 1);
792        assert_eq!(surface[0], 13);
793    }
794
795    #[test]
796    fn test_erode_mask() {
797        // 5x5x5 solid cube
798        let mask = vec![1u8; 125];
799        let eroded = erode_mask(&mask, 5, 5, 5, 1);
800
801        // Center 3x3x3 should remain
802        let count: usize = eroded.iter().map(|&v| v as usize).sum();
803        assert!(count > 0);
804        assert!(count < 125);
805    }
806
807    #[test]
808    fn test_dilate_mask() {
809        // Single center voxel in 5x5x5
810        let mut mask = vec![0u8; 125];
811        mask[62] = 1; // Center
812
813        let dilated = dilate_mask(&mask, 5, 5, 5, 1);
814
815        // Should expand to 6-connectivity
816        let count: usize = dilated.iter().map(|&v| v as usize).sum();
817        assert!(count >= 7); // At least 7 voxels (center + 6 neighbors)
818    }
819
820    // =====================================================================
821    // Helper: create a 3D sphere mask
822    // =====================================================================
823
824    /// Create a solid sphere mask centered in an n x n x n volume.
825    fn make_sphere_mask(n: usize, radius: f64) -> Vec<u8> {
826        let center = n as f64 / 2.0;
827        let n_total = n * n * n;
828        let mut mask = vec![0u8; n_total];
829
830        for k in 0..n {
831            for j in 0..n {
832                for i in 0..n {
833                    let dx = i as f64 - center;
834                    let dy = j as f64 - center;
835                    let dz = k as f64 - center;
836                    let dist = (dx * dx + dy * dy + dz * dz).sqrt();
837                    if dist < radius {
838                        mask[i + j * n + k * n * n] = 1;
839                    }
840                }
841            }
842        }
843
844        mask
845    }
846
847    // =====================================================================
848    // Tests for Point3D operations
849    // =====================================================================
850
851    #[test]
852    fn test_point3d_sub() {
853        let a = Point3D::new(3.0, 4.0, 5.0);
854        let b = Point3D::new(1.0, 1.0, 1.0);
855        let c = a.sub(&b);
856        assert!((c.x - 2.0).abs() < 1e-10);
857        assert!((c.y - 3.0).abs() < 1e-10);
858        assert!((c.z - 4.0).abs() < 1e-10);
859    }
860
861    #[test]
862    fn test_point3d_dot() {
863        let a = Point3D::new(1.0, 2.0, 3.0);
864        let b = Point3D::new(4.0, 5.0, 6.0);
865        let d = a.dot(&b);
866        assert!((d - 32.0).abs() < 1e-10); // 1*4 + 2*5 + 3*6 = 32
867    }
868
869    #[test]
870    fn test_point3d_cross() {
871        let a = Point3D::new(1.0, 0.0, 0.0);
872        let b = Point3D::new(0.0, 1.0, 0.0);
873        let c = a.cross(&b);
874        assert!((c.x - 0.0).abs() < 1e-10);
875        assert!((c.y - 0.0).abs() < 1e-10);
876        assert!((c.z - 1.0).abs() < 1e-10);
877    }
878
879    #[test]
880    fn test_point3d_norm() {
881        let p = Point3D::new(3.0, 4.0, 0.0);
882        assert!((p.norm() - 5.0).abs() < 1e-10);
883    }
884
885    #[test]
886    fn test_point3d_normalize() {
887        let p = Point3D::new(0.0, 0.0, 5.0);
888        let n = p.normalize();
889        assert!((n.x - 0.0).abs() < 1e-10);
890        assert!((n.y - 0.0).abs() < 1e-10);
891        assert!((n.z - 1.0).abs() < 1e-10);
892    }
893
894    #[test]
895    fn test_point3d_normalize_zero() {
896        let p = Point3D::new(0.0, 0.0, 0.0);
897        let n = p.normalize();
898        assert!((n.x).abs() < 1e-10);
899        assert!((n.y).abs() < 1e-10);
900        assert!((n.z).abs() < 1e-10);
901    }
902
903    #[test]
904    fn test_point3d_scale_and_add() {
905        let a = Point3D::new(1.0, 2.0, 3.0);
906        let b = a.scale(2.0);
907        assert!((b.x - 2.0).abs() < 1e-10);
908        assert!((b.y - 4.0).abs() < 1e-10);
909        assert!((b.z - 6.0).abs() < 1e-10);
910
911        let c = Point3D::new(0.5, 0.5, 0.5);
912        let d = b.add(&c);
913        assert!((d.x - 2.5).abs() < 1e-10);
914        assert!((d.y - 4.5).abs() < 1e-10);
915        assert!((d.z - 6.5).abs() < 1e-10);
916    }
917
918    // =====================================================================
919    // Tests for extract_surface_voxels
920    // =====================================================================
921
922    #[test]
923    fn test_extract_surface_sphere() {
924        let n = 10;
925        let mask = make_sphere_mask(n, 3.5);
926        let surface = extract_surface_voxels(&mask, n, n, n);
927
928        // Surface should be non-empty
929        assert!(!surface.is_empty(), "Sphere should have surface voxels");
930
931        // All surface indices should be within the mask
932        for &idx in &surface {
933            assert_eq!(mask[idx], 1, "Surface voxel should be in mask");
934        }
935
936        // Surface count should be less than total mask count
937        let mask_count: usize = mask.iter().map(|&v| v as usize).sum();
938        assert!(
939            surface.len() < mask_count,
940            "Surface ({}) should be smaller than total mask ({})",
941            surface.len(),
942            mask_count
943        );
944    }
945
946    #[test]
947    fn test_extract_surface_empty_mask() {
948        let mask = vec![0u8; 27];
949        let surface = extract_surface_voxels(&mask, 3, 3, 3);
950        assert!(surface.is_empty(), "Empty mask should have no surface voxels");
951    }
952
953    // =====================================================================
954    // Tests for erode_mask (more thorough)
955    // =====================================================================
956
957    #[test]
958    fn test_erode_mask_sphere() {
959        let n = 10;
960        let mask = make_sphere_mask(n, 4.0);
961        let eroded = erode_mask(&mask, n, n, n, 1);
962
963        let orig_count: usize = mask.iter().map(|&v| v as usize).sum();
964        let eroded_count: usize = eroded.iter().map(|&v| v as usize).sum();
965        assert!(
966            eroded_count < orig_count,
967            "Eroded sphere should be smaller: {} < {}",
968            eroded_count,
969            orig_count
970        );
971        assert!(eroded_count > 0, "Eroded sphere should not be empty");
972
973        // Center should still be in eroded mask
974        let center = n / 2 + (n / 2) * n + (n / 2) * n * n;
975        assert_eq!(eroded[center], 1, "Center should survive erosion");
976    }
977
978    #[test]
979    fn test_erode_mask_single_voxel() {
980        // A single voxel should be eroded away
981        let mut mask = vec![0u8; 125];
982        mask[62] = 1; // center of 5x5x5
983        let eroded = erode_mask(&mask, 5, 5, 5, 1);
984        let count: usize = eroded.iter().map(|&v| v as usize).sum();
985        assert_eq!(count, 0, "Single voxel should be fully eroded");
986    }
987
988    // =====================================================================
989    // Tests for dilate_mask (more thorough)
990    // =====================================================================
991
992    #[test]
993    fn test_dilate_mask_sphere() {
994        let n = 10;
995        let mask = make_sphere_mask(n, 3.0);
996        let dilated = dilate_mask(&mask, n, n, n, 1);
997
998        let orig_count: usize = mask.iter().map(|&v| v as usize).sum();
999        let dilated_count: usize = dilated.iter().map(|&v| v as usize).sum();
1000        assert!(
1001            dilated_count > orig_count,
1002            "Dilated sphere should be larger: {} > {}",
1003            dilated_count,
1004            orig_count
1005        );
1006    }
1007
1008    #[test]
1009    fn test_dilate_mask_radius_2() {
1010        let mut mask = vec![0u8; 125];
1011        mask[62] = 1; // center of 5x5x5
1012        let dilated = dilate_mask(&mask, 5, 5, 5, 2);
1013        let count: usize = dilated.iter().map(|&v| v as usize).sum();
1014        // Should be more than radius=1 dilation
1015        assert!(count > 7, "Radius-2 dilation should produce more than 7 voxels, got {}", count);
1016    }
1017
1018    // =====================================================================
1019    // Tests for morphological_close
1020    // =====================================================================
1021
1022    #[test]
1023    fn test_morphological_close_fills_small_gaps() {
1024        let n = 10;
1025        let mut mask = make_sphere_mask(n, 4.0);
1026        // Remove a surface voxel to create a small gap
1027        let surface = extract_surface_voxels(&mask, n, n, n);
1028        if !surface.is_empty() {
1029            mask[surface[0]] = 0;
1030        }
1031
1032        let closed = morphological_close(&mask, n, n, n, 1);
1033        let orig_count: usize = mask.iter().map(|&v| v as usize).sum();
1034        let closed_count: usize = closed.iter().map(|&v| v as usize).sum();
1035        // Closing should recover the gap or at least not shrink significantly
1036        assert!(
1037            closed_count >= orig_count,
1038            "Closing should not reduce mask size: {} vs {}",
1039            closed_count,
1040            orig_count
1041        );
1042    }
1043
1044    #[test]
1045    fn test_morphological_close_empty() {
1046        let mask = vec![0u8; 27];
1047        let closed = morphological_close(&mask, 3, 3, 3, 1);
1048        let count: usize = closed.iter().map(|&v| v as usize).sum();
1049        assert_eq!(count, 0, "Closing empty mask should stay empty");
1050    }
1051
1052    // =====================================================================
1053    // Tests for create_surface_mask
1054    // =====================================================================
1055
1056    #[test]
1057    fn test_create_surface_mask_sphere() {
1058        let n = 10;
1059        let mask = make_sphere_mask(n, 4.0);
1060        let surface = create_surface_mask(&mask, n, n, n);
1061        let surface_count: usize = surface.iter().map(|&v| v as usize).sum();
1062        let mask_count: usize = mask.iter().map(|&v| v as usize).sum();
1063
1064        assert!(surface_count > 0, "Surface mask should be non-empty");
1065        assert!(
1066            surface_count < mask_count,
1067            "Surface ({}) should be smaller than mask ({})",
1068            surface_count,
1069            mask_count
1070        );
1071
1072        // Every surface voxel should be in the original mask
1073        for i in 0..surface.len() {
1074            if surface[i] > 0 {
1075                assert_eq!(mask[i], 1, "Surface voxel should be in original mask");
1076            }
1077        }
1078    }
1079
1080    // =====================================================================
1081    // Tests for triangulate_surface
1082    // =====================================================================
1083
1084    #[test]
1085    fn test_triangulate_surface_few_points() {
1086        // Less than 3 points should return empty triangulation
1087        let points = vec![Point3D::new(0.0, 0.0, 0.0), Point3D::new(1.0, 1.0, 1.0)];
1088        let (triangles, boundary) = triangulate_surface(&points);
1089        assert!(triangles.is_empty(), "Less than 3 points should give no triangles");
1090        assert_eq!(boundary.len(), 2);
1091    }
1092
1093    #[test]
1094    fn test_triangulate_surface_square_points() {
1095        // Four points forming a square in XY
1096        let points = vec![
1097            Point3D::new(0.0, 0.0, 0.0),
1098            Point3D::new(1.0, 0.0, 0.0),
1099            Point3D::new(0.0, 1.0, 0.0),
1100            Point3D::new(1.0, 1.0, 0.0),
1101        ];
1102        let (triangles, boundary) = triangulate_surface(&points);
1103        // Should produce 2 triangles from 4 points
1104        assert_eq!(triangles.len(), 2, "4 points should produce 2 triangles");
1105        // All 4 points are on the convex hull
1106        for &b in &boundary {
1107            assert!(b, "All 4 points should be on boundary");
1108        }
1109    }
1110
1111    // =====================================================================
1112    // Tests for compute_curvatures_from_mesh
1113    // =====================================================================
1114
1115    #[test]
1116    fn test_compute_curvatures_from_mesh_flat_surface() {
1117        // A flat grid of points (z=0) should have zero curvature
1118        let points = vec![
1119            Point3D::new(0.0, 0.0, 0.0),
1120            Point3D::new(1.0, 0.0, 0.0),
1121            Point3D::new(2.0, 0.0, 0.0),
1122            Point3D::new(0.0, 1.0, 0.0),
1123            Point3D::new(1.0, 1.0, 0.0),
1124            Point3D::new(2.0, 1.0, 0.0),
1125            Point3D::new(0.0, 2.0, 0.0),
1126            Point3D::new(1.0, 2.0, 0.0),
1127            Point3D::new(2.0, 2.0, 0.0),
1128        ];
1129
1130        // Create triangulation for the 3x3 grid
1131        let triangles = vec![
1132            Triangle { v0: 0, v1: 1, v2: 4 },
1133            Triangle { v0: 0, v1: 4, v2: 3 },
1134            Triangle { v0: 1, v1: 2, v2: 5 },
1135            Triangle { v0: 1, v1: 5, v2: 4 },
1136            Triangle { v0: 3, v1: 4, v2: 7 },
1137            Triangle { v0: 3, v1: 7, v2: 6 },
1138            Triangle { v0: 4, v1: 5, v2: 8 },
1139            Triangle { v0: 4, v1: 8, v2: 7 },
1140        ];
1141
1142        // All boundary except center vertex (index 4)
1143        let boundary = vec![true, true, true, true, false, true, true, true, true];
1144
1145        let (gc, mc, _amixed) = compute_curvatures_from_mesh(&points, &triangles, &boundary);
1146
1147        // Center vertex (not boundary) on flat surface should have ~zero curvature
1148        assert!(
1149            gc[4].abs() < 1e-6,
1150            "Flat surface should have ~0 Gaussian curvature, got {}",
1151            gc[4]
1152        );
1153        assert!(
1154            mc[4].abs() < 1e-6,
1155            "Flat surface should have ~0 mean curvature, got {}",
1156            mc[4]
1157        );
1158    }
1159
1160    #[test]
1161    fn test_compute_curvatures_from_mesh_degenerate_triangle() {
1162        // Degenerate triangle (collinear points) should not crash
1163        let points = vec![
1164            Point3D::new(0.0, 0.0, 0.0),
1165            Point3D::new(1.0, 0.0, 0.0),
1166            Point3D::new(2.0, 0.0, 0.0), // collinear
1167        ];
1168        let triangles = vec![Triangle { v0: 0, v1: 1, v2: 2 }];
1169        let boundary = vec![false, false, false];
1170        let (gc, mc, _amixed) = compute_curvatures_from_mesh(&points, &triangles, &boundary);
1171        // Should not crash; values may be zero because area is zero
1172        assert_eq!(gc.len(), 3);
1173        assert_eq!(mc.len(), 3);
1174    }
1175
1176    #[test]
1177    fn test_compute_curvatures_from_mesh_boundary_zero() {
1178        // Boundary vertices should have zero curvature
1179        let points = vec![
1180            Point3D::new(0.0, 0.0, 0.0),
1181            Point3D::new(1.0, 0.0, 0.0),
1182            Point3D::new(0.5, 1.0, 1.0),
1183        ];
1184        let triangles = vec![Triangle { v0: 0, v1: 1, v2: 2 }];
1185        let boundary = vec![true, true, true]; // all boundary
1186        let (gc, mc, _amixed) = compute_curvatures_from_mesh(&points, &triangles, &boundary);
1187        for i in 0..3 {
1188            assert!((gc[i]).abs() < 1e-10, "Boundary vertex GC should be 0");
1189            assert!((mc[i]).abs() < 1e-10, "Boundary vertex MC should be 0");
1190        }
1191    }
1192
1193    // =====================================================================
1194    // Tests for convolve_1d_direction_masked
1195    // =====================================================================
1196
1197    #[test]
1198    fn test_convolve_1d_direction_uniform() {
1199        let n = 8;
1200        let data = vec![5.0; n * n * n];
1201        let mask = vec![1u8; n * n * n];
1202
1203        let result_x = convolve_1d_direction_masked(&data, &mask, n, n, n, 1.0, 'x');
1204        let result_y = convolve_1d_direction_masked(&data, &mask, n, n, n, 1.0, 'y');
1205        let result_z = convolve_1d_direction_masked(&data, &mask, n, n, n, 1.0, 'z');
1206
1207        // Uniform data should stay uniform after convolution
1208        for &v in &result_x {
1209            assert!((v - 5.0).abs() < 0.1, "X convolution should preserve uniform data, got {}", v);
1210        }
1211        for &v in &result_y {
1212            assert!((v - 5.0).abs() < 0.1, "Y convolution should preserve uniform data, got {}", v);
1213        }
1214        for &v in &result_z {
1215            assert!((v - 5.0).abs() < 0.1, "Z convolution should preserve uniform data, got {}", v);
1216        }
1217    }
1218
1219    #[test]
1220    fn test_convolve_1d_direction_zero_sigma() {
1221        let n = 5;
1222        let data = vec![3.0; n * n * n];
1223        let mask = vec![1u8; n * n * n];
1224
1225        let result = convolve_1d_direction_masked(&data, &mask, n, n, n, 0.0, 'x');
1226        assert_eq!(result, data, "Zero sigma should return copy of input");
1227    }
1228
1229    // =====================================================================
1230    // Tests for gaussian_smooth_3d_masked
1231    // =====================================================================
1232
1233    #[test]
1234    fn test_gaussian_smooth_3d_masked_uniform() {
1235        let n = 8;
1236        let data = vec![10.0; n * n * n];
1237        let mask = vec![1u8; n * n * n];
1238        let sigmas = [1.0, 1.0, 1.0];
1239        let result = gaussian_smooth_3d_masked(&data, &mask, n, n, n, &sigmas);
1240        assert_eq!(result.len(), n * n * n);
1241        for &v in &result {
1242            assert!(v.is_finite(), "Result should be finite");
1243            assert!((v - 10.0).abs() < 1.0, "Uniform data should stay near 10.0, got {}", v);
1244        }
1245    }
1246
1247    #[test]
1248    fn test_gaussian_smooth_3d_masked_applies_mask() {
1249        let n = 8;
1250        let data = vec![10.0; n * n * n];
1251        let mut mask = vec![1u8; n * n * n];
1252        // Zero out half the mask
1253        for i in 0..(n * n * n / 2) {
1254            mask[i] = 0;
1255        }
1256        let sigmas = [1.0, 1.0, 1.0];
1257        let result = gaussian_smooth_3d_masked(&data, &mask, n, n, n, &sigmas);
1258        // Masked-out voxels should be 0
1259        for i in 0..result.len() {
1260            if mask[i] == 0 {
1261                assert!((result[i]).abs() < 1e-10, "Masked-out voxel should be 0, got {}", result[i]);
1262            }
1263        }
1264    }
1265
1266    // =====================================================================
1267    // Tests for calculate_gaussian_curvature (main public function)
1268    // =====================================================================
1269
1270    #[test]
1271    fn test_calculate_gaussian_curvature_sphere() {
1272        let n = 12;
1273        let mask = make_sphere_mask(n, 4.5);
1274        let result = calculate_gaussian_curvature(&mask, n, n, n);
1275
1276        assert_eq!(result.gaussian_curvature.len(), n * n * n);
1277        assert_eq!(result.mean_curvature.len(), n * n * n);
1278        assert!(!result.surface_indices.is_empty(), "Should have surface indices");
1279
1280        // Surface curvature values should be finite
1281        for &idx in &result.surface_indices {
1282            assert!(
1283                result.gaussian_curvature[idx].is_finite(),
1284                "GC at surface index {} should be finite",
1285                idx
1286            );
1287            assert!(
1288                result.mean_curvature[idx].is_finite(),
1289                "MC at surface index {} should be finite",
1290                idx
1291            );
1292        }
1293
1294        // Non-surface voxels should have zero curvature
1295        let surface_set: std::collections::HashSet<usize> =
1296            result.surface_indices.iter().cloned().collect();
1297        for i in 0..(n * n * n) {
1298            if !surface_set.contains(&i) {
1299                assert!(
1300                    (result.gaussian_curvature[i]).abs() < 1e-10,
1301                    "Non-surface GC should be 0"
1302                );
1303                assert!(
1304                    (result.mean_curvature[i]).abs() < 1e-10,
1305                    "Non-surface MC should be 0"
1306                );
1307            }
1308        }
1309    }
1310
1311    #[test]
1312    fn test_calculate_gaussian_curvature_empty_mask() {
1313        let n = 5;
1314        let mask = vec![0u8; n * n * n];
1315        let result = calculate_gaussian_curvature(&mask, n, n, n);
1316        assert!(result.surface_indices.is_empty());
1317        assert!(result.gaussian_curvature.iter().all(|&v| v == 0.0));
1318        assert!(result.mean_curvature.iter().all(|&v| v == 0.0));
1319    }
1320
1321    #[test]
1322    fn test_calculate_gaussian_curvature_single_voxel() {
1323        let mut mask = vec![0u8; 125];
1324        mask[62] = 1; // single voxel in center of 5x5x5
1325        let result = calculate_gaussian_curvature(&mask, 5, 5, 5);
1326        // Single voxel is its own surface after erosion removes it
1327        // Result depends on whether erosion removes it entirely
1328        assert_eq!(result.gaussian_curvature.len(), 125);
1329        assert_eq!(result.mean_curvature.len(), 125);
1330    }
1331
1332    // =====================================================================
1333    // Tests for calculate_curvature_proximity (main entry point)
1334    // =====================================================================
1335
1336    #[test]
1337    fn test_calculate_curvature_proximity_sphere() {
1338        let n = 12;
1339        let mask = make_sphere_mask(n, 4.5);
1340        let n_total = n * n * n;
1341
1342        // Create an initial proximity map (all 1.0 inside mask)
1343        let prox1: Vec<f64> = mask.iter().map(|&v| v as f64).collect();
1344
1345        let (prox, curv_i) = calculate_curvature_proximity(
1346            &mask, &prox1, 0.6, 500.0, 1.0, n, n, n,
1347        );
1348
1349        assert_eq!(prox.len(), n_total);
1350        assert_eq!(curv_i.len(), n_total);
1351
1352        // All prox values should be finite
1353        for (i, &v) in prox.iter().enumerate() {
1354            assert!(v.is_finite(), "Prox at {} should be finite, got {}", i, v);
1355        }
1356
1357        // All curv_i values should be finite
1358        for (i, &v) in curv_i.iter().enumerate() {
1359            assert!(v.is_finite(), "Curv_i at {} should be finite, got {}", i, v);
1360        }
1361    }
1362
1363    #[test]
1364    fn test_calculate_curvature_proximity_empty_surface() {
1365        let n = 5;
1366        let mask = vec![0u8; n * n * n];
1367        let n_total = n * n * n;
1368        let prox1 = vec![1.0; n_total];
1369
1370        let (prox, curv_i) = calculate_curvature_proximity(
1371            &mask, &prox1, 0.6, 500.0, 1.0, n, n, n,
1372        );
1373
1374        // With empty mask, should return prox1 and all-ones curv_i
1375        assert_eq!(prox.len(), n_total);
1376        assert_eq!(curv_i.len(), n_total);
1377        for &v in &curv_i {
1378            assert!((v - 1.0).abs() < 1e-10, "Empty surface should give curv_i=1.0");
1379        }
1380    }
1381
1382    #[test]
1383    fn test_calculate_curvature_proximity_respects_mask() {
1384        let n = 12;
1385        let mask = make_sphere_mask(n, 4.5);
1386        let n_total = n * n * n;
1387        let prox1: Vec<f64> = mask.iter().map(|&v| v as f64).collect();
1388
1389        let (prox, _curv_i) = calculate_curvature_proximity(
1390            &mask, &prox1, 0.6, 500.0, 1.0, n, n, n,
1391        );
1392
1393        // Outside mask, proximity should be 0 (due to prox1 being 0 there)
1394        // or small from smoothing bleed
1395        for i in 0..n_total {
1396            assert!(prox[i].is_finite(), "Prox should be finite everywhere");
1397        }
1398    }
1399
1400    #[test]
1401    fn test_calculate_curvature_proximity_varying_params() {
1402        let n = 12;
1403        let mask = make_sphere_mask(n, 4.5);
1404        let prox1: Vec<f64> = mask.iter().map(|&v| v as f64).collect();
1405
1406        // Different lower_lim and curv_constant
1407        let (prox_a, _) = calculate_curvature_proximity(
1408            &mask, &prox1, 0.3, 100.0, 0.5, n, n, n,
1409        );
1410        let (prox_b, _) = calculate_curvature_proximity(
1411            &mask, &prox1, 0.9, 1000.0, 2.0, n, n, n,
1412        );
1413
1414        // Both should produce finite results
1415        for &v in &prox_a {
1416            assert!(v.is_finite());
1417        }
1418        for &v in &prox_b {
1419            assert!(v.is_finite());
1420        }
1421    }
1422}