Skip to main content

qsm_core/bet/
mesh.rs

1//! Mesh utilities for BET
2
3/// Build neighbor matrix for vectorized neighbor lookups
4///
5/// Returns (neighbor_matrix, neighbor_counts) where:
6/// - neighbor_matrix[i] contains the indices of vertex i's neighbors (padded with usize::MAX)
7/// - neighbor_counts[i] is the number of valid neighbors for vertex i
8pub fn build_neighbor_matrix(n_vertices: usize, faces: &[[usize; 3]], max_neighbors: usize) -> (Vec<Vec<usize>>, Vec<usize>) {
9    let mut neighbor_lists: Vec<Vec<usize>> = vec![Vec::new(); n_vertices];
10
11    for &[v0, v1, v2] in faces {
12        // Add edges v0-v1, v1-v2, v2-v0
13        for &(a, b) in &[(v0, v1), (v1, v2), (v2, v0)] {
14            if !neighbor_lists[a].contains(&b) {
15                neighbor_lists[a].push(b);
16            }
17            if !neighbor_lists[b].contains(&a) {
18                neighbor_lists[b].push(a);
19            }
20        }
21    }
22
23    // Find actual max neighbors
24    let actual_max = neighbor_lists.iter().map(|n| n.len()).max().unwrap_or(0);
25    let padded_max = actual_max.max(max_neighbors);
26
27    // Build padded matrix
28    let mut neighbor_matrix: Vec<Vec<usize>> = Vec::with_capacity(n_vertices);
29    let mut neighbor_counts: Vec<usize> = Vec::with_capacity(n_vertices);
30
31    for neighs in neighbor_lists {
32        neighbor_counts.push(neighs.len());
33        let mut row = neighs;
34        row.resize(padded_max, usize::MAX);
35        neighbor_matrix.push(row);
36    }
37
38    (neighbor_matrix, neighbor_counts)
39}
40
41/// Compute outward-pointing normals at each vertex
42pub fn compute_vertex_normals(vertices: &[[f64; 3]], faces: &[[usize; 3]]) -> Vec<[f64; 3]> {
43    let n_vertices = vertices.len();
44    let mut normals: Vec<[f64; 3]> = vec![[0.0, 0.0, 0.0]; n_vertices];
45
46    for &[i0, i1, i2] in faces {
47        let v0 = vertices[i0];
48        let v1 = vertices[i1];
49        let v2 = vertices[i2];
50
51        // Edge vectors
52        let e1 = [v1[0] - v0[0], v1[1] - v0[1], v1[2] - v0[2]];
53        let e2 = [v2[0] - v0[0], v2[1] - v0[1], v2[2] - v0[2]];
54
55        // Face normal (cross product)
56        let face_normal = [
57            e1[1] * e2[2] - e1[2] * e2[1],
58            e1[2] * e2[0] - e1[0] * e2[2],
59            e1[0] * e2[1] - e1[1] * e2[0],
60        ];
61
62        // Normalize
63        let norm = (face_normal[0].powi(2) + face_normal[1].powi(2) + face_normal[2].powi(2)).sqrt();
64        let face_normal = if norm > 1e-10 {
65            [face_normal[0] / norm, face_normal[1] / norm, face_normal[2] / norm]
66        } else {
67            [0.0, 0.0, 0.0]
68        };
69
70        // Accumulate at vertices
71        for &idx in &[i0, i1, i2] {
72            normals[idx][0] += face_normal[0];
73            normals[idx][1] += face_normal[1];
74            normals[idx][2] += face_normal[2];
75        }
76    }
77
78    // Normalize all vertex normals
79    for n in normals.iter_mut() {
80        let norm = (n[0].powi(2) + n[1].powi(2) + n[2].powi(2)).sqrt();
81        if norm > 1e-10 {
82            n[0] /= norm;
83            n[1] /= norm;
84            n[2] /= norm;
85        }
86    }
87
88    normals
89}
90
91/// Compute mean edge length for vertices in voxel coordinates (converts to mm)
92pub fn compute_mean_edge_length(vertices: &[[f64; 3]], faces: &[[usize; 3]], voxel_size: &[f64; 3]) -> f64 {
93    let mut total_length = 0.0;
94    let mut count = 0;
95
96    for &[i0, i1, i2] in faces {
97        // Edge v0-v1
98        let dx = (vertices[i1][0] - vertices[i0][0]) * voxel_size[0];
99        let dy = (vertices[i1][1] - vertices[i0][1]) * voxel_size[1];
100        let dz = (vertices[i1][2] - vertices[i0][2]) * voxel_size[2];
101        total_length += (dx*dx + dy*dy + dz*dz).sqrt();
102
103        // Edge v1-v2
104        let dx = (vertices[i2][0] - vertices[i1][0]) * voxel_size[0];
105        let dy = (vertices[i2][1] - vertices[i1][1]) * voxel_size[1];
106        let dz = (vertices[i2][2] - vertices[i1][2]) * voxel_size[2];
107        total_length += (dx*dx + dy*dy + dz*dz).sqrt();
108
109        // Edge v2-v0
110        let dx = (vertices[i0][0] - vertices[i2][0]) * voxel_size[0];
111        let dy = (vertices[i0][1] - vertices[i2][1]) * voxel_size[1];
112        let dz = (vertices[i0][2] - vertices[i2][2]) * voxel_size[2];
113        total_length += (dx*dx + dy*dy + dz*dz).sqrt();
114
115        count += 3;
116    }
117
118    if count > 0 {
119        total_length / count as f64
120    } else {
121        1.0
122    }
123}
124
125/// Compute mean edge length for vertices already in mm coordinates
126pub fn compute_mean_edge_length_mm(vertices_mm: &[[f64; 3]], faces: &[[usize; 3]]) -> f64 {
127    let mut total_length = 0.0;
128    let mut count = 0;
129
130    for &[i0, i1, i2] in faces {
131        // Edge v0-v1
132        let dx = vertices_mm[i1][0] - vertices_mm[i0][0];
133        let dy = vertices_mm[i1][1] - vertices_mm[i0][1];
134        let dz = vertices_mm[i1][2] - vertices_mm[i0][2];
135        total_length += (dx*dx + dy*dy + dz*dz).sqrt();
136
137        // Edge v1-v2
138        let dx = vertices_mm[i2][0] - vertices_mm[i1][0];
139        let dy = vertices_mm[i2][1] - vertices_mm[i1][1];
140        let dz = vertices_mm[i2][2] - vertices_mm[i1][2];
141        total_length += (dx*dx + dy*dy + dz*dz).sqrt();
142
143        // Edge v2-v0
144        let dx = vertices_mm[i0][0] - vertices_mm[i2][0];
145        let dy = vertices_mm[i0][1] - vertices_mm[i2][1];
146        let dz = vertices_mm[i0][2] - vertices_mm[i2][2];
147        total_length += (dx*dx + dy*dy + dz*dz).sqrt();
148
149        count += 3;
150    }
151
152    if count > 0 {
153        total_length / count as f64
154    } else {
155        1.0
156    }
157}
158
159/// Compute distance between two vertices (already in mm)
160fn vertex_distance(v1: &[f64; 3], v2: &[f64; 3]) -> f64 {
161    let dx = v2[0] - v1[0];
162    let dy = v2[1] - v1[1];
163    let dz = v2[2] - v1[2];
164    (dx * dx + dy * dy + dz * dz).sqrt()
165}
166
167/// Compute self-intersection heuristic by comparing vertex distances
168/// between current and original mesh.
169///
170/// This is based on FSL-BET2's self_intersection method which:
171/// 1. Computes mean edge length for both meshes (ml, mlo)
172/// 2. For vertex pairs that are currently close (< ml apart), checks if they've
173///    gotten significantly closer than they were in the original mesh
174/// 3. Accumulates squared differences in normalized distances
175///
176/// Returns a scalar value where > 4000 indicates likely self-intersection.
177/// The threshold of 4000 matches FSL-BET2's self_intersection_threshold.
178///
179/// Note: Vertices are expected to be in mm coordinates.
180pub fn self_intersection_heuristic(
181    current_vertices: &[[f64; 3]],
182    original_vertices: &[[f64; 3]],
183    faces: &[[usize; 3]],
184    _voxel_size: &[f64; 3], // kept for API compatibility, not used (vertices are in mm)
185) -> f64 {
186    if current_vertices.len() != original_vertices.len() {
187        return f64::MAX;
188    }
189
190    let n = current_vertices.len();
191
192    // Compute mean edge length for normalization (like FSL's ml and mlo)
193    // Vertices are in mm, so use the mm version
194    let ml = compute_mean_edge_length_mm(current_vertices, faces);
195    let mlo = compute_mean_edge_length_mm(original_vertices, faces);
196
197    if ml < 1e-10 || mlo < 1e-10 {
198        return f64::MAX;
199    }
200
201    let ml_sq = ml * ml;
202    let mut intersection = 0.0;
203
204    // FSL compares all vertex pairs, but only counts pairs where current distance < ml
205    // This detects when non-adjacent vertices have gotten too close (mesh folding)
206    // For efficiency, we sample a subset of pairs for large meshes
207    let step = if n > 500 { (n / 500).max(1) } else { 1 };
208
209    for i in (0..n).step_by(step) {
210        for j in (i + 1..n).step_by(step) {
211            // Current distance squared (vertices already in mm)
212            let dx = current_vertices[j][0] - current_vertices[i][0];
213            let dy = current_vertices[j][1] - current_vertices[i][1];
214            let dz = current_vertices[j][2] - current_vertices[i][2];
215            let curr_dist_sq = dx * dx + dy * dy + dz * dz;
216
217            // Only consider pairs that are currently close (< ml apart)
218            // This is the key insight from FSL - we're looking for folding
219            if curr_dist_sq < ml_sq {
220                let curr_dist = curr_dist_sq.sqrt();
221                let orig_dist = vertex_distance(&original_vertices[i], &original_vertices[j]);
222
223                // Normalize distances
224                let dist = curr_dist / ml;
225                let disto = orig_dist / mlo;
226
227                // Accumulate squared difference
228                let diff = dist - disto;
229                intersection += diff * diff;
230            }
231        }
232    }
233
234    intersection
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::bet::icosphere::create_icosphere;
241
242    #[test]
243    fn test_neighbor_matrix() {
244        let (vertices, faces) = create_icosphere(1);
245        let (neighbor_matrix, neighbor_counts) = build_neighbor_matrix(vertices.len(), &faces, 6);
246
247        assert_eq!(neighbor_matrix.len(), vertices.len());
248        assert_eq!(neighbor_counts.len(), vertices.len());
249
250        // Each vertex should have at least 1 neighbor
251        for &count in &neighbor_counts {
252            assert!(count >= 1);
253        }
254    }
255
256    #[test]
257    fn test_vertex_normals() {
258        let (vertices, faces) = create_icosphere(1);
259        let normals = compute_vertex_normals(&vertices, &faces);
260
261        assert_eq!(normals.len(), vertices.len());
262
263        // Normals should be unit length and point outward (same direction as vertex)
264        for (v, n) in vertices.iter().zip(normals.iter()) {
265            let norm = (n[0].powi(2) + n[1].powi(2) + n[2].powi(2)).sqrt();
266            assert!((norm - 1.0).abs() < 1e-6, "Normal not unit length");
267
268            // Dot product with vertex should be positive (outward pointing)
269            let dot = v[0] * n[0] + v[1] * n[1] + v[2] * n[2];
270            assert!(dot > 0.9, "Normal not pointing outward");
271        }
272    }
273
274    #[test]
275    fn test_neighbor_matrix_symmetry() {
276        // Neighbor relationships must be symmetric: if A neighbors B then B neighbors A
277        let (vertices, faces) = create_icosphere(2);
278        let n = vertices.len();
279        let (neighbor_matrix, neighbor_counts) = build_neighbor_matrix(n, &faces, 6);
280
281        for i in 0..n {
282            for j in 0..neighbor_counts[i] {
283                let ni = neighbor_matrix[i][j];
284                // ni should have i in its neighbor list
285                let found = (0..neighbor_counts[ni]).any(|k| neighbor_matrix[ni][k] == i);
286                assert!(found, "Asymmetric neighbors: {} -> {} but not reverse", i, ni);
287            }
288        }
289    }
290
291    #[test]
292    fn test_neighbor_matrix_single_triangle() {
293        // Minimal mesh: one triangle with 3 vertices
294        let faces = vec![[0, 1, 2]];
295        let (neighbor_matrix, neighbor_counts) = build_neighbor_matrix(3, &faces, 6);
296
297        assert_eq!(neighbor_counts.len(), 3);
298        // Each vertex in a single triangle should have exactly 2 neighbors
299        for &count in &neighbor_counts {
300            assert_eq!(count, 2);
301        }
302
303        // Padding should be usize::MAX
304        for row in &neighbor_matrix {
305            for &val in row.iter().skip(2) {
306                assert_eq!(val, usize::MAX);
307            }
308        }
309    }
310
311    #[test]
312    fn test_neighbor_matrix_max_neighbors_padding() {
313        // When max_neighbors > actual neighbors, rows should be padded
314        let faces = vec![[0, 1, 2]];
315        let (neighbor_matrix, _neighbor_counts) = build_neighbor_matrix(3, &faces, 10);
316
317        // Each row should have at least 10 entries (max_neighbors)
318        for row in &neighbor_matrix {
319            assert!(row.len() >= 10);
320        }
321    }
322
323    #[test]
324    fn test_compute_mean_edge_length_voxel() {
325        // A unit cube triangle (vertices in voxel coords, voxel_size = [2, 2, 2])
326        let vertices: Vec<[f64; 3]> = vec![
327            [0.0, 0.0, 0.0],
328            [1.0, 0.0, 0.0],
329            [0.0, 1.0, 0.0],
330        ];
331        let faces = vec![[0, 1, 2]];
332        let voxel_size = [2.0, 2.0, 2.0];
333
334        let mel = compute_mean_edge_length(&vertices, &faces, &voxel_size);
335        // Edge lengths in mm: 2.0, 2.0, 2*sqrt(2) = 2.828...
336        // Mean = (2 + 2 + 2.828) / 3 = 2.276
337        assert!(mel > 2.0 && mel < 3.0, "mean edge length = {}", mel);
338        assert!(mel.is_finite());
339    }
340
341    #[test]
342    fn test_compute_mean_edge_length_mm_unit_triangle() {
343        let vertices_mm: Vec<[f64; 3]> = vec![
344            [0.0, 0.0, 0.0],
345            [1.0, 0.0, 0.0],
346            [0.0, 1.0, 0.0],
347        ];
348        let faces = vec![[0, 1, 2]];
349
350        let mel = compute_mean_edge_length_mm(&vertices_mm, &faces);
351        // Edge lengths: 1.0, 1.0, sqrt(2) = 1.414
352        // Mean = (1 + 1 + 1.414) / 3 = 1.138
353        let expected = (1.0 + 1.0 + 2.0_f64.sqrt()) / 3.0;
354        assert!((mel - expected).abs() < 1e-10, "mel={}, expected={}", mel, expected);
355    }
356
357    #[test]
358    fn test_compute_mean_edge_length_mm_empty_faces() {
359        let vertices_mm: Vec<[f64; 3]> = vec![[0.0, 0.0, 0.0]];
360        let faces: Vec<[usize; 3]> = vec![];
361
362        let mel = compute_mean_edge_length_mm(&vertices_mm, &faces);
363        // With no faces, should return fallback of 1.0
364        assert!((mel - 1.0).abs() < 1e-10);
365    }
366
367    #[test]
368    fn test_compute_mean_edge_length_empty_faces() {
369        let vertices: Vec<[f64; 3]> = vec![[0.0, 0.0, 0.0]];
370        let faces: Vec<[usize; 3]> = vec![];
371        let voxel_size = [1.0, 1.0, 1.0];
372
373        let mel = compute_mean_edge_length(&vertices, &faces, &voxel_size);
374        assert!((mel - 1.0).abs() < 1e-10);
375    }
376
377    #[test]
378    fn test_compute_mean_edge_length_mm_icosphere() {
379        // Icosphere with subdivision 2, scaled by radius 5.0
380        let (unit_verts, faces) = create_icosphere(2);
381        let radius = 5.0;
382        let vertices_mm: Vec<[f64; 3]> = unit_verts
383            .iter()
384            .map(|v| [v[0] * radius, v[1] * radius, v[2] * radius])
385            .collect();
386
387        let mel = compute_mean_edge_length_mm(&vertices_mm, &faces);
388        assert!(mel > 0.0 && mel.is_finite(), "mel should be positive and finite, got {}", mel);
389        // For a sphere of radius 5, edges should be a fraction of the radius
390        assert!(mel < radius, "Edge length should be smaller than radius");
391    }
392
393    #[test]
394    fn test_vertex_normals_degenerate_face() {
395        // Include a degenerate triangle (zero area) alongside a valid one
396        let vertices: Vec<[f64; 3]> = vec![
397            [0.0, 0.0, 0.0],
398            [1.0, 0.0, 0.0],
399            [0.0, 1.0, 0.0],
400            [0.5, 0.0, 0.0], // collinear with v0 and v1
401        ];
402        let faces = vec![
403            [0, 1, 2],  // valid triangle
404            [0, 1, 3],  // degenerate (collinear points)
405        ];
406
407        let normals = compute_vertex_normals(&vertices, &faces);
408        assert_eq!(normals.len(), 4);
409        // The valid triangle's normal contribution should still produce finite normals
410        for n in &normals {
411            assert!(n[0].is_finite() && n[1].is_finite() && n[2].is_finite());
412        }
413    }
414
415    #[test]
416    fn test_vertex_normals_icosphere_subdivision_2() {
417        // Higher subdivision level
418        let (vertices, faces) = create_icosphere(2);
419        let normals = compute_vertex_normals(&vertices, &faces);
420
421        assert_eq!(normals.len(), vertices.len());
422        for (v, n) in vertices.iter().zip(normals.iter()) {
423            let norm = (n[0].powi(2) + n[1].powi(2) + n[2].powi(2)).sqrt();
424            assert!((norm - 1.0).abs() < 1e-6, "Normal not unit length: {}", norm);
425
426            let dot = v[0] * n[0] + v[1] * n[1] + v[2] * n[2];
427            assert!(dot > 0.9, "Normal not pointing outward, dot={}", dot);
428        }
429    }
430
431    #[test]
432    fn test_vertex_distance() {
433        let v1 = [0.0, 0.0, 0.0];
434        let v2 = [3.0, 4.0, 0.0];
435        let d = vertex_distance(&v1, &v2);
436        assert!((d - 5.0).abs() < 1e-10, "Expected 5.0, got {}", d);
437
438        // Same point
439        let d0 = vertex_distance(&v1, &v1);
440        assert!((d0 - 0.0).abs() < 1e-10);
441
442        // 3D case
443        let v3 = [1.0, 2.0, 3.0];
444        let v4 = [4.0, 6.0, 3.0];
445        let d34 = vertex_distance(&v3, &v4);
446        let expected = (9.0 + 16.0 + 0.0_f64).sqrt();
447        assert!((d34 - expected).abs() < 1e-10);
448    }
449
450    #[test]
451    fn test_self_intersection_heuristic_no_intersection() {
452        // An icosphere that hasn't changed should have very low score
453        let (vertices, faces) = create_icosphere(1);
454        let radius = 5.0;
455        let scaled: Vec<[f64; 3]> = vertices
456            .iter()
457            .map(|v| [v[0] * radius, v[1] * radius, v[2] * radius])
458            .collect();
459
460        let score = self_intersection_heuristic(&scaled, &scaled, &faces, &[1.0, 1.0, 1.0]);
461        // Identical meshes should yield 0 (no differences)
462        assert!((score - 0.0).abs() < 1e-6, "Expected ~0, got {}", score);
463    }
464
465    #[test]
466    fn test_self_intersection_heuristic_expanded_mesh() {
467        // Slightly expand the mesh uniformly -- should not indicate self-intersection
468        let (vertices, faces) = create_icosphere(1);
469        let radius = 5.0;
470        let original: Vec<[f64; 3]> = vertices
471            .iter()
472            .map(|v| [v[0] * radius, v[1] * radius, v[2] * radius])
473            .collect();
474
475        let expanded: Vec<[f64; 3]> = vertices
476            .iter()
477            .map(|v| [v[0] * radius * 1.1, v[1] * radius * 1.1, v[2] * radius * 1.1])
478            .collect();
479
480        let score = self_intersection_heuristic(&expanded, &original, &faces, &[1.0, 1.0, 1.0]);
481        assert!(score.is_finite(), "Score should be finite, got {}", score);
482        // Uniform expansion should have a very low score (no folding)
483        assert!(score < 4000.0, "Uniform expansion should not trigger self-intersection, score={}", score);
484    }
485
486    #[test]
487    fn test_self_intersection_heuristic_mismatched_vertices() {
488        let faces = vec![[0, 1, 2]];
489        let v1: Vec<[f64; 3]> = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
490        let v2: Vec<[f64; 3]> = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]; // wrong size
491
492        let score = self_intersection_heuristic(&v1, &v2, &faces, &[1.0, 1.0, 1.0]);
493        assert_eq!(score, f64::MAX);
494    }
495
496    #[test]
497    fn test_self_intersection_heuristic_collapsed_mesh() {
498        // Collapse all vertices to the center -- extreme folding
499        let (vertices, faces) = create_icosphere(1);
500        let radius = 5.0;
501        let original: Vec<[f64; 3]> = vertices
502            .iter()
503            .map(|v| [v[0] * radius, v[1] * radius, v[2] * radius])
504            .collect();
505
506        // Collapse to a single point
507        let collapsed: Vec<[f64; 3]> = vec![[0.0, 0.0, 0.0]; vertices.len()];
508
509        let score = self_intersection_heuristic(&collapsed, &original, &faces, &[1.0, 1.0, 1.0]);
510        // All vertices at same point means ml ~ 0, which should give MAX
511        assert_eq!(score, f64::MAX, "Collapsed mesh should have MAX score");
512    }
513
514    #[test]
515    fn test_compute_mean_edge_length_anisotropic_voxel() {
516        // Test with anisotropic voxel sizes
517        let vertices: Vec<[f64; 3]> = vec![
518            [0.0, 0.0, 0.0],
519            [1.0, 0.0, 0.0],
520            [0.0, 0.0, 1.0],
521        ];
522        let faces = vec![[0, 1, 2]];
523        let voxel_size = [1.0, 1.0, 3.0]; // z is 3x larger
524
525        let mel = compute_mean_edge_length(&vertices, &faces, &voxel_size);
526        // Edge 0->1: (1*1, 0, 0) -> length 1.0
527        // Edge 1->2: (-1*1, 0, 1*3) -> length sqrt(1+9) = sqrt(10) = 3.162
528        // Edge 2->0: (0, 0, -1*3) -> length 3.0
529        // Mean = (1.0 + 3.162 + 3.0) / 3 = 2.387
530        let expected = (1.0 + 10.0_f64.sqrt() + 3.0) / 3.0;
531        assert!((mel - expected).abs() < 1e-6, "mel={}, expected={}", mel, expected);
532    }
533}