Skip to main content

qsm_core/bet/
evolution.rs

1//! BET surface evolution algorithm
2//!
3//! Based on: Smith, S.M. (2002) "Fast robust automated brain extraction"
4//! Human Brain Mapping, 17(3):143-155
5//!
6//! Aligned with FSL-BET2 implementation.
7
8use super::icosphere::create_icosphere;
9use super::mesh::{build_neighbor_matrix, compute_vertex_normals, compute_mean_edge_length_mm, self_intersection_heuristic};
10use std::collections::VecDeque;
11
12/// Brain parameters struct (like FSL's bet_parameters)
13struct BetParameters {
14    t2: f64,       // 2nd percentile (robust min)
15    t98: f64,      // 98th percentile (robust max)
16    t: f64,        // threshold = t2 + 0.1*(t98-t2)
17    tm: f64,       // median within-brain intensity (critical for proper surface evolution)
18    cog: [f64; 3], // center of gravity in voxel coordinates
19    cog_mm: [f64; 3], // center of gravity in mm (for z-gradient)
20    radius: f64,   // estimated brain radius in mm
21}
22
23/// Estimate brain parameters from the image (matches FSL-BET2's adjust_initial_mesh)
24fn estimate_brain_parameters(
25    data: &[f64],
26    nx: usize, ny: usize, nz: usize,
27    voxel_size: &[f64; 3],
28) -> BetParameters {
29    // Collect non-zero values
30    let nonzero: Vec<f64> = data.iter().copied().filter(|&v| v > 0.0).collect();
31
32    if nonzero.is_empty() {
33        let cog = [(nx as f64) / 2.0, (ny as f64) / 2.0, (nz as f64) / 2.0];
34        let cog_mm = [cog[0] * voxel_size[0], cog[1] * voxel_size[1], cog[2] * voxel_size[2]];
35        return BetParameters { t2: 0.0, t98: 1.0, t: 0.1, tm: 0.5, cog, cog_mm, radius: 50.0 };
36    }
37
38    // Sort for percentiles
39    let mut sorted = nonzero.clone();
40    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
41
42    let t2 = percentile(&sorted, 2.0);
43    let t98 = percentile(&sorted, 98.0);
44    let t = t2 + 0.1 * (t98 - t2);
45
46    // Find center of gravity (weighted by intensity, like FSL)
47    let mut sum_x = 0.0;
48    let mut sum_y = 0.0;
49    let mut sum_z = 0.0;
50    let mut sum_weight = 0.0;
51    let mut n_voxels = 0usize;
52
53    // Use Fortran order: index = x + y*nx + z*nx*ny
54    for k in 0..nz {
55        for j in 0..ny {
56            for i in 0..nx {
57                let idx = i + j * nx + k * nx * ny;
58                let val = data[idx];
59                if val > t {
60                    // FSL: weight = min(c, t98 - t2) where c = val - t2
61                    let c = (val - t2).min(t98 - t2);
62                    sum_x += (i as f64) * c;
63                    sum_y += (j as f64) * c;
64                    sum_z += (k as f64) * c;
65                    sum_weight += c;
66                    n_voxels += 1;
67                }
68            }
69        }
70    }
71
72    let cog = if sum_weight > 0.0 {
73        [sum_x / sum_weight, sum_y / sum_weight, sum_z / sum_weight]
74    } else {
75        [(nx as f64) / 2.0, (ny as f64) / 2.0, (nz as f64) / 2.0]
76    };
77
78    // Estimate brain radius
79    let voxel_volume = voxel_size[0] * voxel_size[1] * voxel_size[2];
80    let brain_volume = (n_voxels as f64) * voxel_volume;
81    let radius = (3.0 * brain_volume / (4.0 * std::f64::consts::PI)).powf(1.0 / 3.0);
82
83    // Compute tm: median intensity within a sphere centered at COG with radius = brain radius
84    // This is critical for proper intensity-based surface evolution (FSL bet2.cpp lines 385-403)
85    let cog_mm = [cog[0] * voxel_size[0], cog[1] * voxel_size[1], cog[2] * voxel_size[2]];
86    let radius_sq = radius * radius;
87
88    let mut within_brain_values: Vec<f64> = Vec::new();
89    for k in 0..nz {
90        for j in 0..ny {
91            for i in 0..nx {
92                let idx = i + j * nx + k * nx * ny;
93                let val = data[idx];
94                // Only consider voxels with intensity between t2 and t98
95                if val > t2 && val < t98 {
96                    // Check if within sphere of radius centered at COG
97                    let px = (i as f64) * voxel_size[0];
98                    let py = (j as f64) * voxel_size[1];
99                    let pz = (k as f64) * voxel_size[2];
100                    let dx = px - cog_mm[0];
101                    let dy = py - cog_mm[1];
102                    let dz = pz - cog_mm[2];
103                    let dist_sq = dx * dx + dy * dy + dz * dz;
104                    if dist_sq < radius_sq {
105                        within_brain_values.push(val);
106                    }
107                }
108            }
109        }
110    }
111
112    // Compute median (tm)
113    let tm = if within_brain_values.is_empty() {
114        (t2 + t98) / 2.0 // fallback
115    } else {
116        within_brain_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
117        let mid = within_brain_values.len() / 2;
118        within_brain_values[mid]
119    };
120
121    BetParameters { t2, t98, t, tm, cog, cog_mm, radius }
122}
123
124/// Compute percentile of sorted array
125fn percentile(sorted: &[f64], p: f64) -> f64 {
126    if sorted.is_empty() {
127        return 0.0;
128    }
129    let idx = (p / 100.0 * (sorted.len() - 1) as f64).round() as usize;
130    sorted[idx.min(sorted.len() - 1)]
131}
132
133/// Trilinear interpolation
134fn sample_intensity(data: &[f64], nx: usize, ny: usize, nz: usize, x: f64, y: f64, z: f64) -> f64 {
135    let x = x.max(0.0).min((nx - 1) as f64);
136    let y = y.max(0.0).min((ny - 1) as f64);
137    let z = z.max(0.0).min((nz - 1) as f64);
138
139    let x0 = x.floor() as usize;
140    let y0 = y.floor() as usize;
141    let z0 = z.floor() as usize;
142    let x1 = (x0 + 1).min(nx - 1);
143    let y1 = (y0 + 1).min(ny - 1);
144    let z1 = (z0 + 1).min(nz - 1);
145
146    let xd = x - x0 as f64;
147    let yd = y - y0 as f64;
148    let zd = z - z0 as f64;
149
150    // Fortran order: x + y*nx + z*nx*ny
151    let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
152
153    let c000 = data[idx(x0, y0, z0)];
154    let c001 = data[idx(x0, y0, z1)];
155    let c010 = data[idx(x0, y1, z0)];
156    let c011 = data[idx(x0, y1, z1)];
157    let c100 = data[idx(x1, y0, z0)];
158    let c101 = data[idx(x1, y0, z1)];
159    let c110 = data[idx(x1, y1, z0)];
160    let c111 = data[idx(x1, y1, z1)];
161
162    let c00 = c000 * (1.0 - xd) + c100 * xd;
163    let c01 = c001 * (1.0 - xd) + c101 * xd;
164    let c10 = c010 * (1.0 - xd) + c110 * xd;
165    let c11 = c011 * (1.0 - xd) + c111 * xd;
166
167    let c0 = c00 * (1.0 - yd) + c10 * yd;
168    let c1 = c01 * (1.0 - yd) + c11 * yd;
169
170    c0 * (1.0 - zd) + c1 * zd
171}
172
173/// Sample min/max intensities along inward normal (matches FSL-BET2's step_of_computation)
174///
175/// FSL samples from 1mm to d1 (7mm) for Imin, and up to d2 (3mm) for Imax.
176/// Initial values: Imin = tm, Imax = t
177/// Final clamps: Imin >= t2, Imax <= tm
178///
179/// point_mm: vertex position in mm coordinates
180/// normal: unit normal vector (outward pointing)
181///
182/// Returns (i_min, i_max, success) where success=false means sampling failed
183/// and the caller should use f3=0 (like FSL does).
184fn sample_intensities_fsl(
185    data: &[f64],
186    nx: usize, ny: usize, nz: usize,
187    point_mm: &[f64; 3],
188    normal: &[f64; 3],
189    voxel_size: &[f64; 3],
190    t2: f64,
191    t: f64,
192    tm: f64,
193) -> (f64, f64, bool) {
194    let d1 = 7.0; // max search distance for Imin (mm)
195    let d2 = 3.0; // max search distance for Imax (mm)
196    let dscale = voxel_size[0].min(voxel_size[1]).min(voxel_size[2]).min(1.0);
197
198    // Initialize like FSL does
199    let mut i_min = tm;
200    let mut i_max = t;
201
202    // Convert mm to voxel helper
203    let mm_to_voxel = |p_mm: &[f64; 3]| -> [f64; 3] {
204        [p_mm[0] / voxel_size[0], p_mm[1] / voxel_size[1], p_mm[2] / voxel_size[2]]
205    };
206
207    // Check if voxel position is in bounds
208    let in_bounds = |v: &[f64; 3]| -> bool {
209        v[0] >= 0.0 && v[0] < (nx - 1) as f64 &&
210        v[1] >= 0.0 && v[1] < (ny - 1) as f64 &&
211        v[2] >= 0.0 && v[2] < (nz - 1) as f64
212    };
213
214    // Starting position in mm (1mm inward along normal)
215    let mut p_mm = [
216        point_mm[0] - normal[0],
217        point_mm[1] - normal[1],
218        point_mm[2] - normal[2],
219    ];
220    let mut p_vox = mm_to_voxel(&p_mm);
221
222    // Check if starting point is in bounds (FSL: first bounds check)
223    if !in_bounds(&p_vox) {
224        // FSL: if first bounds check fails, f3 = 0 (no intensity force)
225        return (i_min, i_max, false);
226    }
227
228    let im = sample_intensity(data, nx, ny, nz, p_vox[0], p_vox[1], p_vox[2]);
229    i_min = i_min.min(im);
230    i_max = i_max.max(im);
231
232    // Check far point at d1-1 (FSL: second bounds check)
233    let p_far_mm = [
234        point_mm[0] - (d1 - 1.0) * normal[0],
235        point_mm[1] - (d1 - 1.0) * normal[1],
236        point_mm[2] - (d1 - 1.0) * normal[2],
237    ];
238    let p_far_vox = mm_to_voxel(&p_far_mm);
239
240    if !in_bounds(&p_far_vox) {
241        // FSL: if second bounds check fails, f3 = 0 (no intensity force)
242        return (i_min, i_max, false);
243    }
244
245    let im = sample_intensity(data, nx, ny, nz, p_far_vox[0], p_far_vox[1], p_far_vox[2]);
246    i_min = i_min.min(im);
247
248    // Sample from 2mm to d1 (stepping by dscale mm)
249    let mut gi = 2.0;
250    while gi < d1 {
251        p_mm[0] -= normal[0] * dscale;
252        p_mm[1] -= normal[1] * dscale;
253        p_mm[2] -= normal[2] * dscale;
254        p_vox = mm_to_voxel(&p_mm);
255
256        if in_bounds(&p_vox) {
257            let im = sample_intensity(data, nx, ny, nz, p_vox[0], p_vox[1], p_vox[2]);
258            i_min = i_min.min(im);
259
260            // Only update Imax for samples within d2
261            if gi < d2 {
262                i_max = i_max.max(im);
263            }
264        }
265
266        gi += dscale;
267    }
268
269    // Clamp like FSL does (this is critical for sinus exclusion)
270    i_min = i_min.max(t2);    // Imin can't go below noise floor
271    i_max = i_max.min(tm);    // Imax can't go above median brain intensity
272
273    (i_min, i_max, true)
274}
275
276/// Convert surface mesh to binary mask using flood fill
277///
278/// vertices_mm: vertex positions in mm coordinates
279/// voxel_size: voxel dimensions in mm
280fn surface_to_mask(
281    vertices_mm: &[[f64; 3]],
282    faces: &[[usize; 3]],
283    nx: usize, ny: usize, nz: usize,
284    voxel_size: &[f64; 3],
285) -> Vec<u8> {
286    // Convert mm vertices to voxel coordinates
287    let vertices: Vec<[f64; 3]> = vertices_mm
288        .iter()
289        .map(|v| [
290            v[0] / voxel_size[0],
291            v[1] / voxel_size[1],
292            v[2] / voxel_size[2],
293        ])
294        .collect();
295
296    let mininc = 0.5;
297
298    // Start with all 1s (outside)
299    let mut grid: Vec<u8> = vec![1; nx * ny * nz];
300    // Fortran order: x + y*nx + z*nx*ny
301    let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
302
303    // Draw mesh surface as 0s
304    for &[i0, i1, i2] in faces {
305        let v0 = vertices[i0];
306        let v1 = vertices[i1];
307        let v2 = vertices[i2];
308
309        // Edge from v1 to v0
310        let edge = [v0[0] - v1[0], v0[1] - v1[1], v0[2] - v1[2]];
311        let edge_len = (edge[0].powi(2) + edge[1].powi(2) + edge[2].powi(2)).sqrt();
312
313        if edge_len < 0.001 {
314            continue;
315        }
316
317        let edge_dir = [edge[0] / edge_len, edge[1] / edge_len, edge[2] / edge_len];
318        let n_edge_steps = (edge_len / mininc).ceil() as usize + 1;
319
320        for j in 0..n_edge_steps {
321            let d = (j as f64) * mininc;
322            let p_edge = if d > edge_len {
323                v0
324            } else {
325                [v1[0] + d * edge_dir[0], v1[1] + d * edge_dir[1], v1[2] + d * edge_dir[2]]
326            };
327
328            // Draw segment from p_edge to v2
329            let seg = [v2[0] - p_edge[0], v2[1] - p_edge[1], v2[2] - p_edge[2]];
330            let seg_len = (seg[0].powi(2) + seg[1].powi(2) + seg[2].powi(2)).sqrt();
331
332            if seg_len < 0.001 {
333                let ix = p_edge[0].round() as isize;
334                let iy = p_edge[1].round() as isize;
335                let iz = p_edge[2].round() as isize;
336                if ix >= 0 && ix < nx as isize && iy >= 0 && iy < ny as isize && iz >= 0 && iz < nz as isize {
337                    grid[idx(ix as usize, iy as usize, iz as usize)] = 0;
338                }
339                continue;
340            }
341
342            let seg_dir = [seg[0] / seg_len, seg[1] / seg_len, seg[2] / seg_len];
343            let n_seg_steps = (seg_len / mininc).ceil() as usize + 1;
344
345            for k in 0..n_seg_steps {
346                let sd = (k as f64) * mininc;
347                let p = if sd > seg_len {
348                    v2
349                } else {
350                    [p_edge[0] + sd * seg_dir[0], p_edge[1] + sd * seg_dir[1], p_edge[2] + sd * seg_dir[2]]
351                };
352
353                let ix = p[0].round() as isize;
354                let iy = p[1].round() as isize;
355                let iz = p[2].round() as isize;
356
357                if ix >= 0 && ix < nx as isize && iy >= 0 && iy < ny as isize && iz >= 0 && iz < nz as isize {
358                    grid[idx(ix as usize, iy as usize, iz as usize)] = 0;
359                }
360            }
361        }
362    }
363
364    // Flood fill from center of mesh
365    let mut center = [0.0, 0.0, 0.0];
366    for v in &vertices {
367        center[0] += v[0];
368        center[1] += v[1];
369        center[2] += v[2];
370    }
371    center[0] /= vertices.len() as f64;
372    center[1] /= vertices.len() as f64;
373    center[2] /= vertices.len() as f64;
374
375    let mut cx = center[0].round() as isize;
376    let mut cy = center[1].round() as isize;
377    let mut cz = center[2].round() as isize;
378
379    cx = cx.max(0).min(nx as isize - 1);
380    cy = cy.max(0).min(ny as isize - 1);
381    cz = cz.max(0).min(nz as isize - 1);
382
383    // If center is on surface, find nearby interior point
384    if grid[idx(cx as usize, cy as usize, cz as usize)] == 0 {
385        'search: for dx in -5..=5 {
386            for dy in -5..=5 {
387                for dz in -5..=5 {
388                    let nx_ = cx + dx;
389                    let ny_ = cy + dy;
390                    let nz_ = cz + dz;
391                    if nx_ >= 0 && nx_ < nx as isize && ny_ >= 0 && ny_ < ny as isize && nz_ >= 0 && nz_ < nz as isize {
392                        if grid[idx(nx_ as usize, ny_ as usize, nz_ as usize)] == 1 {
393                            cx = nx_;
394                            cy = ny_;
395                            cz = nz_;
396                            break 'search;
397                        }
398                    }
399                }
400            }
401        }
402    }
403
404    // BFS flood fill
405    let mut queue: VecDeque<(usize, usize, usize)> = VecDeque::new();
406    let cx = cx as usize;
407    let cy = cy as usize;
408    let cz = cz as usize;
409    grid[idx(cx, cy, cz)] = 0;
410    queue.push_back((cx, cy, cz));
411
412    let neighbors: [(isize, isize, isize); 6] = [
413        (-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), (0, 0, 1)
414    ];
415
416    while let Some((x, y, z)) = queue.pop_front() {
417        for &(dx, dy, dz) in &neighbors {
418            let nx_ = x as isize + dx;
419            let ny_ = y as isize + dy;
420            let nz_ = z as isize + dz;
421
422            if nx_ >= 0 && nx_ < nx as isize && ny_ >= 0 && ny_ < ny as isize && nz_ >= 0 && nz_ < nz as isize {
423                let ni = idx(nx_ as usize, ny_ as usize, nz_ as usize);
424                if grid[ni] == 1 {
425                    grid[ni] = 0;
426                    queue.push_back((nx_ as usize, ny_ as usize, nz_ as usize));
427                }
428            }
429        }
430    }
431
432    // Invert: 0 = brain (inside + surface), we want 1 = brain
433    for v in grid.iter_mut() {
434        *v = if *v == 0 { 1 } else { 0 };
435    }
436
437    // Fill holes using simple morphological closing
438    fill_holes(&mut grid, nx, ny, nz);
439
440    grid
441}
442
443/// Simple hole filling
444fn fill_holes(mask: &mut [u8], nx: usize, ny: usize, nz: usize) {
445    // Fortran order: x + y*nx + z*nx*ny
446    let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
447
448    // Flood fill from corners to find exterior
449    let mut exterior: Vec<bool> = vec![false; nx * ny * nz];
450    let mut queue: VecDeque<(usize, usize, usize)> = VecDeque::new();
451
452    // Start from all boundary voxels that are 0
453    for i in 0..nx {
454        for j in 0..ny {
455            for k in 0..nz {
456                if i == 0 || i == nx - 1 || j == 0 || j == ny - 1 || k == 0 || k == nz - 1 {
457                    if mask[idx(i, j, k)] == 0 {
458                        exterior[idx(i, j, k)] = true;
459                        queue.push_back((i, j, k));
460                    }
461                }
462            }
463        }
464    }
465
466    let neighbors: [(isize, isize, isize); 6] = [
467        (-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), (0, 0, 1)
468    ];
469
470    while let Some((x, y, z)) = queue.pop_front() {
471        for &(dx, dy, dz) in &neighbors {
472            let nx_ = x as isize + dx;
473            let ny_ = y as isize + dy;
474            let nz_ = z as isize + dz;
475
476            if nx_ >= 0 && nx_ < nx as isize && ny_ >= 0 && ny_ < ny as isize && nz_ >= 0 && nz_ < nz as isize {
477                let ni = idx(nx_ as usize, ny_ as usize, nz_ as usize);
478                if mask[ni] == 0 && !exterior[ni] {
479                    exterior[ni] = true;
480                    queue.push_back((nx_ as usize, ny_ as usize, nz_ as usize));
481                }
482            }
483        }
484    }
485
486    // Any voxel that is 0 but not exterior is interior -> fill it
487    for i in 0..mask.len() {
488        if mask[i] == 0 && !exterior[i] {
489            mask[i] = 1;
490        }
491    }
492}
493
494/// Self-intersection threshold (matches FSL-BET2)
495const SELF_INTERSECTION_THRESHOLD: f64 = 4000.0;
496
497/// Maximum number of recovery passes (matches FSL-BET2)
498const MAX_PASSES: usize = 10;
499
500/// Run a single pass of surface evolution
501///
502/// Returns the final vertices after evolution
503fn evolution_pass(
504    data: &[f64],
505    nx: usize, ny: usize, nz: usize,
506    voxel_size: &[f64; 3],
507    bp: &BetParameters,
508    vertices: &mut Vec<[f64; 3]>,
509    faces: &[[usize; 3]],
510    neighbor_matrix: &[Vec<usize>],
511    neighbor_counts: &[usize],
512    bt: f64,
513    smoothness_factor: f64,
514    gradient_threshold: f64,
515    iterations: usize,
516    pass: usize,
517    progress_callback: &mut Option<&mut dyn FnMut(usize, usize)>,
518) {
519    let n_vertices = vertices.len();
520
521    // BET parameters (from FSL) - adjusted by smoothness_factor
522    let rmin = 3.33 * smoothness_factor;
523    let rmax = 10.0 * smoothness_factor;
524    let e = (1.0 / rmin + 1.0 / rmax) / 2.0;
525    let f = 6.0 / (1.0 / rmin - 1.0 / rmax);
526    let normal_max_update_fraction = 0.5;
527    let lambda_fit = 0.1;
528
529    // Initial mean edge length
530    // Vertices are in mm, so use the mm version
531    let mut l = compute_mean_edge_length_mm(vertices, faces);
532
533    // Smoothing increase factor for recovery passes (FSL: 10^(pass+1))
534    let base_increase = if pass > 0 { 10.0_f64.powi((pass + 1) as i32) } else { 1.0 };
535
536    // Report progress at start if first pass
537    let progress_interval = (iterations / 20).max(1);
538
539    // Debug counter for sampling failures
540    let mut sample_fail_count: usize = 0;
541
542    for iteration in 0..iterations {
543        // Report progress periodically
544        if let Some(ref mut cb) = progress_callback {
545            if iteration % progress_interval == 0 {
546                cb(iteration, iterations);
547            }
548        }
549
550        // Compute increase factor with tapering in later iterations (FSL: after 75%)
551        let incfactor = if pass > 0 && iteration > (0.75 * iterations as f64) as usize {
552            let t = iteration as f64 / iterations as f64;
553            4.0 * (1.0 - t) * (base_increase - 1.0) + 1.0
554        } else {
555            base_increase
556        };
557
558        // Compute vertex normals
559        let normals = compute_vertex_normals(vertices, faces);
560
561        // Compute updates for each vertex
562        let mut updates: Vec<[f64; 3]> = vec![[0.0, 0.0, 0.0]; n_vertices];
563
564        for i in 0..n_vertices {
565            let v = vertices[i];
566            let n = normals[i];
567
568            // Compute mean neighbor position
569            let mut mean_neighbor = [0.0, 0.0, 0.0];
570            let count = neighbor_counts[i];
571            for j in 0..count {
572                let ni = neighbor_matrix[i][j];
573                mean_neighbor[0] += vertices[ni][0];
574                mean_neighbor[1] += vertices[ni][1];
575                mean_neighbor[2] += vertices[ni][2];
576            }
577            if count > 0 {
578                mean_neighbor[0] /= count as f64;
579                mean_neighbor[1] /= count as f64;
580                mean_neighbor[2] /= count as f64;
581            }
582
583            // Vector from vertex to mean neighbor
584            let dv = [mean_neighbor[0] - v[0], mean_neighbor[1] - v[1], mean_neighbor[2] - v[2]];
585
586            // Dot product with normal
587            let dv_dot_n = dv[0] * n[0] + dv[1] * n[1] + dv[2] * n[2];
588
589            // Normal component
590            let sn = [dv_dot_n * n[0], dv_dot_n * n[1], dv_dot_n * n[2]];
591
592            // Tangential component
593            let st = [dv[0] - sn[0], dv[1] - sn[1], dv[2] - sn[2]];
594
595            // Force 1: Tangential (vertex spacing)
596            let u1 = [st[0] * 0.5, st[1] * 0.5, st[2] * 0.5];
597
598            // Force 2: Normal (smoothness)
599            let sn_mag = dv_dot_n.abs();
600            let rinv = (2.0 * sn_mag) / (l * l);
601            let mut f2 = (1.0 + (f * (rinv - e)).tanh()) * 0.5;
602
603            // In recovery passes, increase smoothing for outward-pointing updates (FSL behavior)
604            if pass > 0 && dv_dot_n > 0.0 {
605                f2 *= incfactor;
606                f2 = f2.min(1.0);
607            }
608
609            let u2 = [f2 * sn[0], f2 * sn[1], f2 * sn[2]];
610
611            // Force 3: Intensity-based (using FSL-style sampling with tm)
612            let (i_min, i_max, sample_ok) = sample_intensities_fsl(
613                data, nx, ny, nz, &v, &n, voxel_size, bp.t2, bp.t, bp.tm
614            );
615
616            // FSL: if sampling fails (out of bounds), f3 = 0 (no intensity force)
617            let u3 = if sample_ok {
618                // Apply z-gradient to local threshold (FSL's -g option)
619                let local_bt = if gradient_threshold.abs() > 1e-10 {
620                    // Vertex is already in mm coordinates
621                    let z_offset = (v[2] - bp.cog_mm[2]) / bp.radius;
622                    (bt + gradient_threshold * z_offset).clamp(0.0, 1.0)
623                } else {
624                    bt
625                };
626
627                // Compute local threshold and force (matches FSL exactly)
628                let t_l = (i_max - bp.t2) * local_bt + bp.t2;
629                let f3 = if i_max - bp.t2 > 0.0 {
630                    2.0 * (i_min - t_l) / (i_max - bp.t2)
631                } else {
632                    2.0 * (i_min - t_l)
633                };
634                let f3 = f3 * normal_max_update_fraction * lambda_fit * l;
635
636                [f3 * n[0], f3 * n[1], f3 * n[2]]
637            } else {
638                // Sampling failed - use f3 = 0 like FSL does
639                sample_fail_count += 1;
640                [0.0, 0.0, 0.0]
641            };
642
643            // Combined update
644            updates[i] = [u1[0] + u2[0] + u3[0], u1[1] + u2[1] + u3[1], u1[2] + u2[2] + u3[2]];
645        }
646
647        // Apply updates
648        for i in 0..n_vertices {
649            vertices[i][0] += updates[i][0];
650            vertices[i][1] += updates[i][1];
651            vertices[i][2] += updates[i][2];
652        }
653
654        // Update edge length periodically
655        if iteration % 100 == 0 {
656            l = compute_mean_edge_length_mm(vertices, faces);
657        }
658    }
659
660    // Debug: report sampling failures
661    if sample_fail_count > 0 {
662        let total_samples = iterations * n_vertices;
663        let fail_pct = 100.0 * sample_fail_count as f64 / total_samples as f64;
664        eprintln!("[BET] Sampling fallback used: {} / {} ({:.1}%)",
665                  sample_fail_count, total_samples, fail_pct);
666    }
667}
668
669/// Run BET brain extraction
670///
671/// # Arguments
672/// * `data` - 3D magnitude image data (nx * ny * nz, Fortran order)
673/// * `nx`, `ny`, `nz` - Image dimensions
674/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
675/// * `fractional_intensity` - Intensity threshold (0.0-1.0, smaller = larger brain)
676/// * `smoothness_factor` - Smoothness constraint (default 1.0, larger = smoother)
677/// * `gradient_threshold` - Z-gradient for threshold (-1 to 1, positive = larger at bottom)
678/// * `iterations` - Number of surface evolution iterations
679/// * `subdivisions` - Icosphere subdivision level
680///
681/// BET algorithm parameters
682#[derive(Clone, Debug)]
683pub struct BetParams {
684    /// Fractional intensity threshold (0.0-1.0, smaller = larger brain)
685    pub fractional_intensity: f64,
686    /// Surface smoothness factor
687    pub smoothness: f64,
688    /// Gradient threshold (-1 to 1)
689    pub gradient_threshold: f64,
690    /// Number of iterations
691    pub iterations: usize,
692    /// Icosphere subdivision level
693    pub subdivisions: usize,
694}
695
696impl Default for BetParams {
697    fn default() -> Self {
698        Self {
699            fractional_intensity: 0.5,
700            smoothness: 1.0,
701            gradient_threshold: 0.0,
702            iterations: 1000,
703            subdivisions: 4,
704        }
705    }
706}
707
708/// # Returns
709/// Binary mask (1 = brain, 0 = background)
710pub fn run_bet(
711    data: &[f64],
712    nx: usize, ny: usize, nz: usize,
713    vsx: f64, vsy: f64, vsz: f64,
714    fractional_intensity: f64,
715    smoothness_factor: f64,
716    gradient_threshold: f64,
717    iterations: usize,
718    subdivisions: usize,
719) -> Vec<u8> {
720    let voxel_size = [vsx, vsy, vsz];
721
722    // Step 1: Estimate brain parameters
723    let bp = estimate_brain_parameters(data, nx, ny, nz, &voxel_size);
724
725    // Step 2: Create icosphere
726    let (unit_vertices, faces) = create_icosphere(subdivisions);
727    let n_vertices = unit_vertices.len();
728
729    // Scale and position sphere in mm coordinates (start at 50% of estimated radius)
730    // Like FSL, we work entirely in mm - voxel conversion only happens at sampling/masking
731    let initial_radius_mm = bp.radius * 0.5;
732
733    let initial_vertices: Vec<[f64; 3]> = unit_vertices
734        .iter()
735        .map(|v| [
736            v[0] * initial_radius_mm + bp.cog_mm[0],
737            v[1] * initial_radius_mm + bp.cog_mm[1],
738            v[2] * initial_radius_mm + bp.cog_mm[2],
739        ])
740        .collect();
741
742    // Build neighbor structure
743    let (neighbor_matrix, neighbor_counts) = build_neighbor_matrix(n_vertices, &faces, 6);
744
745    // FSL power transform: bt = pow(f, 0.275)
746    // This raises 0.5 -> 0.826, which makes the surface more aggressive in expanding
747    let bt = fractional_intensity.powf(0.275);
748
749    // Multi-pass evolution with self-intersection recovery (like FSL-BET2)
750    let mut vertices = initial_vertices.clone();
751    let mut pass = 0;
752
753    loop {
754        // Run evolution pass
755        evolution_pass(
756            data, nx, ny, nz, &voxel_size, &bp,
757            &mut vertices, &faces,
758            &neighbor_matrix, &neighbor_counts,
759            bt, smoothness_factor, gradient_threshold,
760            iterations, pass,
761            &mut None,
762        );
763
764        // Check for self-intersection
765        let si_score = self_intersection_heuristic(&vertices, &initial_vertices, &faces, &voxel_size);
766        let has_self_intersection = si_score > SELF_INTERSECTION_THRESHOLD;
767
768        if has_self_intersection {
769            eprintln!("[BET] Self-intersection detected (score={:.0}, threshold={}), pass {}",
770                      si_score, SELF_INTERSECTION_THRESHOLD, pass + 1);
771        }
772
773        // Exit if no self-intersection or max passes reached
774        if !has_self_intersection || pass >= MAX_PASSES {
775            if pass > 0 {
776                eprintln!("[BET] Completed after {} recovery pass(es)", pass);
777            }
778            break;
779        }
780
781        // Reset to original mesh and try again with higher smoothing
782        vertices = initial_vertices.clone();
783        pass += 1;
784    }
785
786    // Step 4: Convert surface to binary mask
787    surface_to_mask(&vertices, &faces, nx, ny, nz, &voxel_size)
788}
789
790/// Run BET brain extraction with progress callback
791///
792/// Same as run_bet but calls progress_callback(iteration, total_iterations) periodically
793pub fn run_bet_with_progress<F>(
794    data: &[f64],
795    nx: usize, ny: usize, nz: usize,
796    vsx: f64, vsy: f64, vsz: f64,
797    fractional_intensity: f64,
798    smoothness_factor: f64,
799    gradient_threshold: f64,
800    iterations: usize,
801    subdivisions: usize,
802    mut progress_callback: F,
803) -> Vec<u8>
804where
805    F: FnMut(usize, usize),
806{
807    let voxel_size = [vsx, vsy, vsz];
808
809    // Step 1: Estimate brain parameters
810    progress_callback(0, iterations);
811    let bp = estimate_brain_parameters(data, nx, ny, nz, &voxel_size);
812
813    // Step 2: Create icosphere
814    let (unit_vertices, faces) = create_icosphere(subdivisions);
815    let n_vertices = unit_vertices.len();
816
817    // Scale and position sphere in mm coordinates (start at 50% of estimated radius)
818    let initial_radius_mm = bp.radius * 0.5;
819
820    let initial_vertices: Vec<[f64; 3]> = unit_vertices
821        .iter()
822        .map(|v| [
823            v[0] * initial_radius_mm + bp.cog_mm[0],
824            v[1] * initial_radius_mm + bp.cog_mm[1],
825            v[2] * initial_radius_mm + bp.cog_mm[2],
826        ])
827        .collect();
828
829    let (neighbor_matrix, neighbor_counts) = build_neighbor_matrix(n_vertices, &faces, 6);
830
831    // FSL power transform: bt = pow(f, 0.275)
832    let bt = fractional_intensity.powf(0.275);
833
834    // Multi-pass evolution with self-intersection recovery (like FSL-BET2)
835    let mut vertices = initial_vertices.clone();
836    let mut pass = 0;
837
838    loop {
839        // Run evolution pass
840        let mut cb: Option<&mut dyn FnMut(usize, usize)> = Some(&mut progress_callback);
841        evolution_pass(
842            data, nx, ny, nz, &voxel_size, &bp,
843            &mut vertices, &faces,
844            &neighbor_matrix, &neighbor_counts,
845            bt, smoothness_factor, gradient_threshold,
846            iterations, pass,
847            &mut cb,
848        );
849
850        // Check for self-intersection
851        let si_score = self_intersection_heuristic(&vertices, &initial_vertices, &faces, &voxel_size);
852        let has_self_intersection = si_score > SELF_INTERSECTION_THRESHOLD;
853
854        if has_self_intersection {
855            eprintln!("[BET] Self-intersection detected (score={:.0}, threshold={}), pass {}",
856                      si_score, SELF_INTERSECTION_THRESHOLD, pass + 1);
857        }
858
859        // Exit if no self-intersection or max passes reached
860        if !has_self_intersection || pass >= MAX_PASSES {
861            if pass > 0 {
862                eprintln!("[BET] Completed after {} recovery pass(es)", pass);
863            }
864            break;
865        }
866
867        // Reset to original mesh and try again with higher smoothing
868        vertices = initial_vertices.clone();
869        pass += 1;
870    }
871
872    // Final progress update
873    progress_callback(iterations, iterations);
874
875    // Step 4: Convert surface to binary mask
876    surface_to_mask(&vertices, &faces, nx, ny, nz, &voxel_size)
877}
878
879#[cfg(test)]
880mod tests {
881    use super::*;
882
883    /// Helper: create a 3D sphere volume (Fortran order) centered in the grid.
884    /// Returns (data, nx, ny, nz) where the sphere has intensity `intensity`
885    /// and background is 0. Sphere radius in voxels is `radius`.
886    fn make_sphere_volume(n: usize, radius: f64, intensity: f64) -> (Vec<f64>, usize, usize, usize) {
887        let center = (n as f64) / 2.0;
888        let mut data = vec![0.0; n * n * n];
889        for k in 0..n {
890            for j in 0..n {
891                for i in 0..n {
892                    let di = (i as f64) - center;
893                    let dj = (j as f64) - center;
894                    let dk = (k as f64) - center;
895                    let dist = (di * di + dj * dj + dk * dk).sqrt();
896                    if dist <= radius {
897                        // Smoothly varying intensity: bright at center, dimmer at edge
898                        data[i + j * n + k * n * n] = intensity * (1.0 - 0.5 * dist / radius);
899                    }
900                }
901            }
902        }
903        (data, n, n, n)
904    }
905
906    #[test]
907    fn test_estimate_brain_parameters() {
908        let nx = 10;
909        let ny = 10;
910        let nz = 10;
911        let mut data = vec![0.0; nx * ny * nz];
912
913        // Create a sphere with varying intensity (like a real brain)
914        // Fortran order: index = i + j*nx + k*nx*ny
915        for k in 0..nz {
916            for j in 0..ny {
917                for i in 0..nx {
918                    let di = (i as f64) - 5.0;
919                    let dj = (j as f64) - 5.0;
920                    let dk = (k as f64) - 5.0;
921                    let dist = (di*di + dj*dj + dk*dk).sqrt();
922                    if dist <= 4.0 {
923                        // Intensity varies from 50 to 150 based on distance from center
924                        data[i + j * nx + k * nx * ny] = 150.0 - dist * 25.0;
925                    }
926                }
927            }
928        }
929
930        let bp = estimate_brain_parameters(&data, nx, ny, nz, &[1.0, 1.0, 1.0]);
931
932        assert!(bp.t2 >= 0.0);
933        assert!(bp.t98 >= bp.t2); // Allow equal for edge cases
934        assert!((bp.cog[0] - 5.0).abs() < 1.0);
935        assert!((bp.cog[1] - 5.0).abs() < 1.0);
936        assert!((bp.cog[2] - 5.0).abs() < 1.0);
937        assert!(bp.radius > 0.0);
938        assert!(bp.tm > bp.t2 && bp.tm < bp.t98); // tm should be between t2 and t98
939        // Check cog_mm is correctly computed
940        assert!((bp.cog_mm[0] - bp.cog[0]).abs() < 1e-10);
941    }
942
943    #[test]
944    fn test_estimate_brain_parameters_empty_data() {
945        // All zeros should trigger the nonzero.is_empty() branch
946        let nx = 4;
947        let ny = 4;
948        let nz = 4;
949        let data = vec![0.0; nx * ny * nz];
950        let bp = estimate_brain_parameters(&data, nx, ny, nz, &[1.0, 1.0, 1.0]);
951
952        assert!((bp.t2 - 0.0).abs() < 1e-10);
953        assert!((bp.t98 - 1.0).abs() < 1e-10);
954        assert!((bp.t - 0.1).abs() < 1e-10);
955        assert!((bp.tm - 0.5).abs() < 1e-10);
956        assert!((bp.cog[0] - 2.0).abs() < 1e-10);
957        assert!((bp.cog[1] - 2.0).abs() < 1e-10);
958        assert!((bp.cog[2] - 2.0).abs() < 1e-10);
959        assert!((bp.radius - 50.0).abs() < 1e-10);
960    }
961
962    #[test]
963    fn test_estimate_brain_parameters_anisotropic_voxels() {
964        let (data, nx, ny, nz) = make_sphere_volume(12, 4.0, 200.0);
965        let voxel_size = [2.0, 2.0, 2.0];
966        let bp = estimate_brain_parameters(&data, nx, ny, nz, &voxel_size);
967
968        // cog_mm should be cog * voxel_size
969        assert!((bp.cog_mm[0] - bp.cog[0] * voxel_size[0]).abs() < 1e-10);
970        assert!((bp.cog_mm[1] - bp.cog[1] * voxel_size[1]).abs() < 1e-10);
971        assert!((bp.cog_mm[2] - bp.cog[2] * voxel_size[2]).abs() < 1e-10);
972        assert!(bp.radius > 0.0);
973        assert!(bp.t98 > bp.t2);
974    }
975
976    #[test]
977    fn test_sample_intensity() {
978        let data = vec![
979            0.0, 1.0, 2.0, 3.0,
980            4.0, 5.0, 6.0, 7.0,
981        ];
982        let val = sample_intensity(&data, 2, 2, 2, 0.5, 0.5, 0.5);
983        // Trilinear interpolation of cube corners
984        assert!((val - 3.5).abs() < 0.01);
985    }
986
987    #[test]
988    fn test_sample_intensity_at_corners() {
989        // 2x2x2 cube: values 0..7
990        let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
991        // Sampling at integer corners should return exact values
992        assert!((sample_intensity(&data, 2, 2, 2, 0.0, 0.0, 0.0) - 0.0).abs() < 1e-10);
993        assert!((sample_intensity(&data, 2, 2, 2, 1.0, 0.0, 0.0) - 1.0).abs() < 1e-10);
994        assert!((sample_intensity(&data, 2, 2, 2, 0.0, 1.0, 0.0) - 2.0).abs() < 1e-10);
995        assert!((sample_intensity(&data, 2, 2, 2, 1.0, 1.0, 0.0) - 3.0).abs() < 1e-10);
996        assert!((sample_intensity(&data, 2, 2, 2, 0.0, 0.0, 1.0) - 4.0).abs() < 1e-10);
997    }
998
999    #[test]
1000    fn test_sample_intensity_clamping() {
1001        // Out-of-bounds coordinates should be clamped
1002        let data = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0];
1003        // Negative coordinates should clamp to 0
1004        let val = sample_intensity(&data, 2, 2, 2, -1.0, -1.0, -1.0);
1005        assert!((val - 10.0).abs() < 1e-10, "Expected 10.0, got {}", val);
1006        // Beyond max should clamp to max
1007        let val = sample_intensity(&data, 2, 2, 2, 5.0, 5.0, 5.0);
1008        assert!((val - 80.0).abs() < 1e-10, "Expected 80.0, got {}", val);
1009    }
1010
1011    #[test]
1012    fn test_sample_intensity_larger_volume() {
1013        // 4x4x4 volume with known pattern
1014        let n = 4;
1015        let mut data = vec![0.0; n * n * n];
1016        for k in 0..n {
1017            for j in 0..n {
1018                for i in 0..n {
1019                    data[i + j * n + k * n * n] = (i + j + k) as f64;
1020                }
1021            }
1022        }
1023        // At center (1.5, 1.5, 1.5), each corner = i+j+k
1024        // Average of corners: (1+1+1=3, 2+1+1=4, 1+2+1=4, 2+2+1=5, 1+1+2=4, 2+1+2=5, 1+2+2=5, 2+2+2=6) / 8 = 4.5
1025        let val = sample_intensity(&data, n, n, n, 1.5, 1.5, 1.5);
1026        assert!((val - 4.5).abs() < 1e-10, "Expected 4.5, got {}", val);
1027    }
1028
1029    #[test]
1030    fn test_percentile() {
1031        let sorted = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1032        let p0 = percentile(&sorted, 0.0);
1033        assert!((p0 - 1.0).abs() < 1e-10);
1034
1035        let p50 = percentile(&sorted, 50.0);
1036        // Index = round(0.5 * 9) = round(4.5) = 5, value = 6.0
1037        assert!((p50 - 5.0).abs() < 1.5, "p50={}", p50);
1038
1039        let p100 = percentile(&sorted, 100.0);
1040        assert!((p100 - 10.0).abs() < 1e-10);
1041
1042        // Empty array
1043        let empty: Vec<f64> = vec![];
1044        assert!((percentile(&empty, 50.0) - 0.0).abs() < 1e-10);
1045    }
1046
1047    #[test]
1048    fn test_power_transform() {
1049        // Verify power transform matches FSL
1050        let f = 0.5_f64;
1051        let bt = f.powf(0.275);
1052        // 0.5^0.275 ~ 0.826
1053        assert!((bt - 0.826).abs() < 0.01);
1054    }
1055
1056    #[test]
1057    fn test_sample_intensities_fsl_basic() {
1058        // Create a 16x16x16 uniform volume
1059        let n = 16;
1060        let data = vec![100.0; n * n * n];
1061        let voxel_size = [1.0, 1.0, 1.0];
1062        // Point at the center, normal pointing inward
1063        let point_mm = [8.0, 8.0, 8.0];
1064        let normal = [0.0, 0.0, 1.0]; // outward along z
1065
1066        let t2 = 10.0;
1067        let t = 20.0;
1068        let tm = 100.0;
1069
1070        let (i_min, i_max, ok) = sample_intensities_fsl(
1071            &data, n, n, n, &point_mm, &normal, &voxel_size, t2, t, tm,
1072        );
1073
1074        assert!(ok, "Sampling should succeed for center point");
1075        assert!(i_min.is_finite());
1076        assert!(i_max.is_finite());
1077        // Uniform volume: i_min should be clamped to >= t2
1078        assert!(i_min >= t2, "i_min={} should be >= t2={}", i_min, t2);
1079        // i_max should be clamped to <= tm
1080        assert!(i_max <= tm, "i_max={} should be <= tm={}", i_max, tm);
1081    }
1082
1083    #[test]
1084    fn test_sample_intensities_fsl_out_of_bounds() {
1085        // Point near edge, normal pointing outward -- should fail bounds check
1086        let n = 8;
1087        let data = vec![100.0; n * n * n];
1088        let voxel_size = [1.0, 1.0, 1.0];
1089        // Point at the edge
1090        let point_mm = [0.5, 0.5, 0.5];
1091        let normal = [-1.0, 0.0, 0.0]; // pointing outward (further out of bounds)
1092
1093        let (_, _, _ok) = sample_intensities_fsl(
1094            &data, n, n, n, &point_mm, &normal, &voxel_size, 10.0, 20.0, 100.0,
1095        );
1096
1097        // Construct a case that goes out of bounds on the first inward step:
1098        // point - normal = (0.5 - 1.0, 0.5, 0.5) = (-0.5, 0.5, 0.5) -> voxel (-0.5) is OOB
1099        let point_mm2 = [0.5, 0.5, 0.5];
1100        let normal2 = [1.0, 0.0, 0.0]; // point - normal = (-0.5, 0.5, 0.5) -> voxel (-0.5) out of bounds
1101        let (_, _, ok2) = sample_intensities_fsl(
1102            &data, n, n, n, &point_mm2, &normal2, &voxel_size, 10.0, 20.0, 100.0,
1103        );
1104        assert!(!ok2, "Sampling should fail when initial step goes out of bounds");
1105    }
1106
1107    #[test]
1108    fn test_sample_intensities_fsl_varying_intensity() {
1109        // Create a volume with intensity gradient along z
1110        let n = 16;
1111        let mut data = vec![0.0; n * n * n];
1112        for k in 0..n {
1113            for j in 0..n {
1114                for i in 0..n {
1115                    data[i + j * n + k * n * n] = 50.0 + (k as f64) * 10.0;
1116                }
1117            }
1118        }
1119        let voxel_size = [1.0, 1.0, 1.0];
1120        let point_mm = [8.0, 8.0, 8.0];
1121        // Normal pointing in -z direction (inward sampling goes into higher z)
1122        let normal = [0.0, 0.0, -1.0];
1123
1124        let t2 = 50.0;
1125        let t = 55.0;
1126        let tm = 120.0;
1127
1128        let (i_min, i_max, ok) = sample_intensities_fsl(
1129            &data, n, n, n, &point_mm, &normal, &voxel_size, t2, t, tm,
1130        );
1131
1132        assert!(ok, "Sampling should succeed");
1133        assert!(i_min.is_finite() && i_max.is_finite());
1134        assert!(i_min >= t2);
1135        assert!(i_max <= tm);
1136    }
1137
1138    #[test]
1139    fn test_fill_holes_no_holes() {
1140        // A solid 3x3x3 cube with no holes should remain unchanged
1141        let nx = 5;
1142        let ny = 5;
1143        let nz = 5;
1144        let mut mask = vec![0u8; nx * ny * nz];
1145        let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
1146
1147        // Place a solid 3x3x3 block in the center
1148        for k in 1..4 {
1149            for j in 1..4 {
1150                for i in 1..4 {
1151                    mask[idx(i, j, k)] = 1;
1152                }
1153            }
1154        }
1155
1156        let original = mask.clone();
1157        fill_holes(&mut mask, nx, ny, nz);
1158
1159        assert_eq!(mask, original, "Solid mask should not change");
1160    }
1161
1162    #[test]
1163    fn test_fill_holes_with_interior_hole() {
1164        // A 7x7x7 volume with a shell of 1s and a hole (0) inside
1165        let n = 7;
1166        let idx = |i: usize, j: usize, k: usize| i + j * n + k * n * n;
1167        let mut mask = vec![0u8; n * n * n];
1168
1169        // Create a hollow cube: 1 on shell from (1,1,1) to (5,5,5), 0 inside
1170        for k in 1..6 {
1171            for j in 1..6 {
1172                for i in 1..6 {
1173                    mask[idx(i, j, k)] = 1;
1174                }
1175            }
1176        }
1177        // Carve out the interior
1178        for k in 2..5 {
1179            for j in 2..5 {
1180                for i in 2..5 {
1181                    mask[idx(i, j, k)] = 0;
1182                }
1183            }
1184        }
1185
1186        // The interior (2..5)^3 should be 0 before fill_holes
1187        assert_eq!(mask[idx(3, 3, 3)], 0, "Center should be 0 before fill");
1188
1189        fill_holes(&mut mask, n, n, n);
1190
1191        // After fill_holes, interior holes should be filled
1192        assert_eq!(mask[idx(3, 3, 3)], 1, "Center hole should be filled");
1193        assert_eq!(mask[idx(2, 2, 2)], 1, "Interior corner should be filled");
1194
1195        // Exterior should still be 0
1196        assert_eq!(mask[idx(0, 0, 0)], 0, "Exterior should remain 0");
1197        assert_eq!(mask[idx(6, 6, 6)], 0, "Exterior should remain 0");
1198    }
1199
1200    #[test]
1201    fn test_fill_holes_exterior_connected() {
1202        // All zeros (no brain): nothing to fill
1203        let n = 5;
1204        let mut mask = vec![0u8; n * n * n];
1205        fill_holes(&mut mask, n, n, n);
1206        assert!(mask.iter().all(|&v| v == 0), "All-zero mask should stay all-zero");
1207    }
1208
1209    #[test]
1210    fn test_surface_to_mask_small_sphere() {
1211        // Create a small icosphere mesh centered in a 16x16x16 grid
1212        let (unit_verts, faces) = create_icosphere(2);
1213        let n = 16;
1214        let center = (n as f64) / 2.0;
1215        let radius = 5.0; // 5 voxel radius
1216        let voxel_size = [1.0, 1.0, 1.0];
1217
1218        // Place vertices in mm (= voxel coords since voxel_size is 1)
1219        let vertices_mm: Vec<[f64; 3]> = unit_verts
1220            .iter()
1221            .map(|v| [
1222                v[0] * radius + center,
1223                v[1] * radius + center,
1224                v[2] * radius + center,
1225            ])
1226            .collect();
1227
1228        let mask = surface_to_mask(&vertices_mm, &faces, n, n, n, &voxel_size);
1229
1230        assert_eq!(mask.len(), n * n * n);
1231
1232        // Center should be inside the mask
1233        let idx = |i: usize, j: usize, k: usize| i + j * n + k * n * n;
1234        assert_eq!(mask[idx(8, 8, 8)], 1, "Center should be brain");
1235
1236        // Corners should be outside the mask
1237        assert_eq!(mask[idx(0, 0, 0)], 0, "Corner should be background");
1238        assert_eq!(mask[idx(15, 15, 15)], 0, "Corner should be background");
1239
1240        // Count brain voxels -- should be a reasonable fraction of total
1241        let brain_count: usize = mask.iter().map(|&v| v as usize).sum();
1242        assert!(brain_count > 0, "Should have some brain voxels");
1243        // Sphere volume ~ 4/3 * pi * 5^3 = 524 voxels; total = 4096
1244        // Allow generous range
1245        assert!(brain_count > 100 && brain_count < 2000,
1246                "Brain voxels ({}) should be roughly sphere-like", brain_count);
1247    }
1248
1249    #[test]
1250    fn test_surface_to_mask_anisotropic_voxels() {
1251        let (unit_verts, faces) = create_icosphere(1);
1252        let n = 16;
1253        let voxel_size = [2.0, 2.0, 2.0];
1254        let center_mm = [(n as f64) * voxel_size[0] / 2.0,
1255                         (n as f64) * voxel_size[1] / 2.0,
1256                         (n as f64) * voxel_size[2] / 2.0];
1257        let radius_mm = 8.0;
1258
1259        let vertices_mm: Vec<[f64; 3]> = unit_verts
1260            .iter()
1261            .map(|v| [
1262                v[0] * radius_mm + center_mm[0],
1263                v[1] * radius_mm + center_mm[1],
1264                v[2] * radius_mm + center_mm[2],
1265            ])
1266            .collect();
1267
1268        let mask = surface_to_mask(&vertices_mm, &faces, n, n, n, &voxel_size);
1269        assert_eq!(mask.len(), n * n * n);
1270
1271        let brain_count: usize = mask.iter().map(|&v| v as usize).sum();
1272        assert!(brain_count > 0, "Should have brain voxels with anisotropic voxels");
1273    }
1274
1275    #[test]
1276    fn test_run_bet_small_synthetic_volume() {
1277        // Create a 16x16x16 volume with a bright sphere
1278        let (data, nx, ny, nz) = make_sphere_volume(16, 6.0, 200.0);
1279
1280        let mask = run_bet(
1281            &data,
1282            nx, ny, nz,
1283            1.0, 1.0, 1.0, // voxel size
1284            0.5,            // fractional_intensity
1285            1.0,            // smoothness_factor
1286            0.0,            // gradient_threshold
1287            50,             // iterations (few for speed)
1288            1,              // subdivisions (low for speed)
1289        );
1290
1291        assert_eq!(mask.len(), nx * ny * nz);
1292
1293        // Should produce a valid binary mask
1294        assert!(mask.iter().all(|&v| v == 0 || v == 1), "Mask should be binary");
1295
1296        // Should have some brain voxels
1297        let brain_count: usize = mask.iter().map(|&v| v as usize).sum();
1298        assert!(brain_count > 0, "BET should extract some brain voxels");
1299
1300        // Center should be brain
1301        let idx = |i: usize, j: usize, k: usize| i + j * nx + k * nx * ny;
1302        assert_eq!(mask[idx(8, 8, 8)], 1, "Center of sphere should be brain");
1303    }
1304
1305    #[test]
1306    fn test_run_bet_with_gradient_threshold() {
1307        // Exercise the gradient threshold code path
1308        let (data, nx, ny, nz) = make_sphere_volume(16, 6.0, 200.0);
1309
1310        let mask = run_bet(
1311            &data,
1312            nx, ny, nz,
1313            1.0, 1.0, 1.0,
1314            0.5,
1315            1.0,
1316            0.3,  // non-zero gradient threshold
1317            50,
1318            1,
1319        );
1320
1321        assert_eq!(mask.len(), nx * ny * nz);
1322        assert!(mask.iter().all(|&v| v == 0 || v == 1));
1323        let brain_count: usize = mask.iter().map(|&v| v as usize).sum();
1324        assert!(brain_count > 0, "BET with gradient should extract brain voxels");
1325    }
1326
1327    #[test]
1328    fn test_run_bet_with_progress() {
1329        // Test the progress callback variant
1330        let (data, nx, ny, nz) = make_sphere_volume(16, 6.0, 200.0);
1331
1332        let mut progress_calls = Vec::new();
1333        let mask = run_bet_with_progress(
1334            &data,
1335            nx, ny, nz,
1336            1.0, 1.0, 1.0,
1337            0.5,
1338            1.0,
1339            0.0,
1340            50,
1341            1,
1342            |current, total| {
1343                progress_calls.push((current, total));
1344            },
1345        );
1346
1347        assert_eq!(mask.len(), nx * ny * nz);
1348        assert!(mask.iter().all(|&v| v == 0 || v == 1));
1349
1350        // Progress callback should have been called at least twice (start + end)
1351        assert!(progress_calls.len() >= 2,
1352                "Progress callback should be called at least twice, got {} calls",
1353                progress_calls.len());
1354
1355        // First call should be (0, iterations)
1356        assert_eq!(progress_calls[0].0, 0);
1357
1358        // Last call should be (iterations, iterations)
1359        let last = progress_calls.last().unwrap();
1360        assert_eq!(last.0, last.1, "Final progress should be complete");
1361    }
1362
1363    #[test]
1364    fn test_run_bet_different_fractional_intensities() {
1365        let (data, nx, ny, nz) = make_sphere_volume(16, 6.0, 200.0);
1366
1367        // Higher fractional_intensity = smaller brain
1368        let mask_small = run_bet(&data, nx, ny, nz, 1.0, 1.0, 1.0, 0.7, 1.0, 0.0, 50, 1);
1369        let mask_large = run_bet(&data, nx, ny, nz, 1.0, 1.0, 1.0, 0.3, 1.0, 0.0, 50, 1);
1370
1371        let count_small: usize = mask_small.iter().map(|&v| v as usize).sum();
1372        let count_large: usize = mask_large.iter().map(|&v| v as usize).sum();
1373
1374        // Both should produce valid masks
1375        assert!(count_small > 0);
1376        assert!(count_large > 0);
1377    }
1378
1379    #[test]
1380    fn test_run_bet_anisotropic_voxels() {
1381        let (data, nx, ny, nz) = make_sphere_volume(16, 6.0, 200.0);
1382
1383        let mask = run_bet(
1384            &data,
1385            nx, ny, nz,
1386            2.0, 2.0, 2.0, // 2mm voxels
1387            0.5,
1388            1.0,
1389            0.0,
1390            50,
1391            1,
1392        );
1393
1394        assert_eq!(mask.len(), nx * ny * nz);
1395        assert!(mask.iter().all(|&v| v == 0 || v == 1));
1396    }
1397
1398    #[test]
1399    fn test_evolution_pass_basic() {
1400        // Directly test evolution_pass with a small icosphere
1401        let n = 16;
1402        let (data, nx, ny, nz) = make_sphere_volume(n, 6.0, 200.0);
1403        let voxel_size = [1.0, 1.0, 1.0];
1404        let bp = estimate_brain_parameters(&data, nx, ny, nz, &voxel_size);
1405
1406        let (unit_verts, faces) = create_icosphere(1);
1407        let n_vertices = unit_verts.len();
1408        let initial_radius_mm = bp.radius * 0.5;
1409
1410        let mut vertices: Vec<[f64; 3]> = unit_verts
1411            .iter()
1412            .map(|v| [
1413                v[0] * initial_radius_mm + bp.cog_mm[0],
1414                v[1] * initial_radius_mm + bp.cog_mm[1],
1415                v[2] * initial_radius_mm + bp.cog_mm[2],
1416            ])
1417            .collect();
1418
1419        let (neighbor_matrix, neighbor_counts) = build_neighbor_matrix(n_vertices, &faces, 6);
1420        let bt = 0.5_f64.powf(0.275);
1421
1422        let vertices_before = vertices.clone();
1423
1424        evolution_pass(
1425            &data, nx, ny, nz, &voxel_size, &bp,
1426            &mut vertices, &faces,
1427            &neighbor_matrix, &neighbor_counts,
1428            bt, 1.0, 0.0,
1429            10, // few iterations
1430            0,  // first pass
1431            &mut None,
1432        );
1433
1434        // Vertices should have moved
1435        let mut any_moved = false;
1436        for (before, after) in vertices_before.iter().zip(vertices.iter()) {
1437            let dist = ((after[0] - before[0]).powi(2) +
1438                       (after[1] - before[1]).powi(2) +
1439                       (after[2] - before[2]).powi(2)).sqrt();
1440            if dist > 1e-10 {
1441                any_moved = true;
1442            }
1443            // All coordinates should be finite
1444            assert!(after[0].is_finite() && after[1].is_finite() && after[2].is_finite());
1445        }
1446        assert!(any_moved, "Evolution should move at least some vertices");
1447    }
1448
1449    #[test]
1450    fn test_evolution_pass_recovery_pass() {
1451        // Test with pass > 0 to exercise recovery smoothing code paths
1452        let n = 16;
1453        let (data, nx, ny, nz) = make_sphere_volume(n, 6.0, 200.0);
1454        let voxel_size = [1.0, 1.0, 1.0];
1455        let bp = estimate_brain_parameters(&data, nx, ny, nz, &voxel_size);
1456
1457        let (unit_verts, faces) = create_icosphere(1);
1458        let n_vertices = unit_verts.len();
1459        let initial_radius_mm = bp.radius * 0.5;
1460
1461        let mut vertices: Vec<[f64; 3]> = unit_verts
1462            .iter()
1463            .map(|v| [
1464                v[0] * initial_radius_mm + bp.cog_mm[0],
1465                v[1] * initial_radius_mm + bp.cog_mm[1],
1466                v[2] * initial_radius_mm + bp.cog_mm[2],
1467            ])
1468            .collect();
1469
1470        let (neighbor_matrix, neighbor_counts) = build_neighbor_matrix(n_vertices, &faces, 6);
1471        let bt = 0.5_f64.powf(0.275);
1472
1473        // Run with pass=1 and enough iterations to hit the 75% tapering code
1474        evolution_pass(
1475            &data, nx, ny, nz, &voxel_size, &bp,
1476            &mut vertices, &faces,
1477            &neighbor_matrix, &neighbor_counts,
1478            bt, 1.0, 0.0,
1479            20,  // enough iterations to hit 75% mark
1480            1,   // recovery pass
1481            &mut None,
1482        );
1483
1484        // All coordinates should remain finite
1485        for v in &vertices {
1486            assert!(v[0].is_finite() && v[1].is_finite() && v[2].is_finite());
1487        }
1488    }
1489
1490    #[test]
1491    fn test_evolution_pass_with_gradient() {
1492        // Exercise the z-gradient threshold branch
1493        let n = 16;
1494        let (data, nx, ny, nz) = make_sphere_volume(n, 6.0, 200.0);
1495        let voxel_size = [1.0, 1.0, 1.0];
1496        let bp = estimate_brain_parameters(&data, nx, ny, nz, &voxel_size);
1497
1498        let (unit_verts, faces) = create_icosphere(1);
1499        let n_vertices = unit_verts.len();
1500        let initial_radius_mm = bp.radius * 0.5;
1501
1502        let mut vertices: Vec<[f64; 3]> = unit_verts
1503            .iter()
1504            .map(|v| [
1505                v[0] * initial_radius_mm + bp.cog_mm[0],
1506                v[1] * initial_radius_mm + bp.cog_mm[1],
1507                v[2] * initial_radius_mm + bp.cog_mm[2],
1508            ])
1509            .collect();
1510
1511        let (neighbor_matrix, neighbor_counts) = build_neighbor_matrix(n_vertices, &faces, 6);
1512        let bt = 0.5_f64.powf(0.275);
1513
1514        evolution_pass(
1515            &data, nx, ny, nz, &voxel_size, &bp,
1516            &mut vertices, &faces,
1517            &neighbor_matrix, &neighbor_counts,
1518            bt, 1.0, 0.5, // gradient_threshold = 0.5
1519            10, 0,
1520            &mut None,
1521        );
1522
1523        for v in &vertices {
1524            assert!(v[0].is_finite() && v[1].is_finite() && v[2].is_finite());
1525        }
1526    }
1527
1528    #[test]
1529    fn test_evolution_pass_with_progress_callback() {
1530        let n = 16;
1531        let (data, nx, ny, nz) = make_sphere_volume(n, 6.0, 200.0);
1532        let voxel_size = [1.0, 1.0, 1.0];
1533        let bp = estimate_brain_parameters(&data, nx, ny, nz, &voxel_size);
1534
1535        let (unit_verts, faces) = create_icosphere(1);
1536        let n_vertices = unit_verts.len();
1537        let initial_radius_mm = bp.radius * 0.5;
1538
1539        let mut vertices: Vec<[f64; 3]> = unit_verts
1540            .iter()
1541            .map(|v| [
1542                v[0] * initial_radius_mm + bp.cog_mm[0],
1543                v[1] * initial_radius_mm + bp.cog_mm[1],
1544                v[2] * initial_radius_mm + bp.cog_mm[2],
1545            ])
1546            .collect();
1547
1548        let (neighbor_matrix, neighbor_counts) = build_neighbor_matrix(n_vertices, &faces, 6);
1549        let bt = 0.5_f64.powf(0.275);
1550
1551        let mut calls = 0usize;
1552        let mut callback = |_iter: usize, _total: usize| {
1553            calls += 1;
1554        };
1555        let mut cb: Option<&mut dyn FnMut(usize, usize)> = Some(&mut callback);
1556
1557        evolution_pass(
1558            &data, nx, ny, nz, &voxel_size, &bp,
1559            &mut vertices, &faces,
1560            &neighbor_matrix, &neighbor_counts,
1561            bt, 1.0, 0.0,
1562            20, 0,
1563            &mut cb,
1564        );
1565
1566        assert!(calls > 0, "Progress callback should have been called");
1567    }
1568}