Skip to main content

qsm_core/utils/
bias_correction.rs

1//! Bias field correction (homogeneity correction)
2//!
3//! Implements the makehomogeneous algorithm for correcting RF receive field inhomogeneities.
4//! This uses the "boxsegment" approach with box filter Gaussian approximation.
5//!
6//! Reference:
7//! Eckstein, K., Trattnig, S., Robinson, S.D. (2019).
8//! "A Simple Homogeneity Correction for Neuroimaging at 7T."
9//! Proc. ISMRM 27th Annual Meeting.
10//!
11//! Reference implementation: https://github.com/korbinian90/MriResearchTools.jl
12
13/// Parameters for inhomogeneity correction (bias field removal).
14#[derive(Clone, Debug)]
15pub struct HomogeneityParams {
16    /// Gaussian smoothing sigma in mm (default: 7.0)
17    pub sigma_mm: f64,
18    /// Number of box filter passes for Gaussian approximation (default: 3)
19    pub nbox: usize,
20}
21
22impl Default for HomogeneityParams {
23    fn default() -> Self {
24        Self {
25            sigma_mm: 7.0,
26            nbox: 3,
27        }
28    }
29}
30
31use std::collections::VecDeque;
32
33/// Index into 3D array (Fortran/column-major order)
34#[inline(always)]
35fn idx3d(i: usize, j: usize, k: usize, nx: usize, ny: usize) -> usize {
36    i + j * nx + k * nx * ny
37}
38
39//=============================================================================
40// Box Filter Gaussian Approximation (matching MriResearchTools.jl)
41//=============================================================================
42
43/// Calculate box sizes to approximate Gaussian with given sigma using n box filters
44///
45/// This implements the algorithm from MriResearchTools.jl:
46/// Multiple box filter passes approximate a Gaussian convolution.
47fn get_box_sizes(sigma: f64, n: usize) -> Vec<usize> {
48    if sigma <= 0.0 || n == 0 {
49        return vec![0; n];
50    }
51
52    // wideal = sqrt((12*sigma^2 / n) + 1)
53    let wideal = ((12.0 * sigma * sigma / n as f64) + 1.0).sqrt();
54
55    // wl = next lower odd integer
56    let wl_float = wideal - (wideal + 1.0) % 2.0;
57    let wl = wl_float.round() as usize;
58    let wl = if wl % 2 == 0 { wl + 1 } else { wl }; // ensure odd
59    let wu = wl + 2;
60
61    // mideal = (12*sigma^2 - n*wl^2 - 4*n*wl - 3*n) / (-4*wl - 4)
62    let wl_f = wl as f64;
63    let n_f = n as f64;
64    let mideal = (12.0 * sigma * sigma - n_f * wl_f * wl_f - 4.0 * n_f * wl_f - 3.0 * n_f)
65                 / (-4.0 * wl_f - 4.0);
66    let m = mideal.round() as usize;
67
68    (0..n).map(|i| if i < m { wl } else { wu }).collect()
69}
70
71/// Check and adjust box sizes to fit image dimensions
72fn check_box_sizes(boxsizes: &mut [Vec<usize>], dims: &[usize]) {
73    for (dim, bs) in boxsizes.iter_mut().enumerate() {
74        if dim >= dims.len() {
75            continue;
76        }
77        for b in bs.iter_mut() {
78            // Ensure odd
79            if *b % 2 == 0 {
80                *b += 1;
81            }
82            // Limit to half image size
83            let max_size = dims[dim] / 2;
84            if *b > max_size {
85                *b = if max_size % 2 == 0 { max_size + 1 } else { max_size };
86            }
87        }
88    }
89}
90
91/// 1D box filter on a line (in-place), matching Julia's boxfilterline!
92///
93/// Uses running sum with edge handling.
94fn box_filter_line(line: &mut [f64], boxsize: usize) {
95    if boxsize < 3 || line.len() < boxsize {
96        return;
97    }
98
99    let n = line.len();
100    let r = boxsize / 2;
101
102    // Use a circular buffer approach
103    let mut queue: VecDeque<f64> = VecDeque::with_capacity(boxsize);
104
105    // Initialize with first r values
106    let mut lsum: f64 = line[..r].iter().sum();
107    for i in 0..r {
108        queue.push_back(line[i]);
109    }
110
111    // Start with edge effect (growing window)
112    for i in 0..=r {
113        lsum += line[i + r];
114        queue.push_back(line[i + r]);
115        line[i] = lsum / (r + i + 1) as f64;
116    }
117
118    // Middle part (full window)
119    for i in (r + 1)..(n - r) {
120        let old = queue.pop_front().unwrap();
121        lsum += line[i + r] - old;
122        queue.push_back(line[i + r]);
123        line[i] = lsum / boxsize as f64;
124    }
125
126    // End with edge effect (shrinking window)
127    for i in (n - r)..n {
128        let old = queue.pop_front().unwrap();
129        lsum -= old;
130        line[i] = lsum / (r + n - i) as f64;
131    }
132}
133
134/// 1D weighted box filter on a line (in-place), matching Julia's weighted boxfilterline!
135fn box_filter_line_weighted(line: &mut [f64], weight: &mut [f64], boxsize: usize) {
136    if boxsize < 3 || line.len() < boxsize {
137        return;
138    }
139
140    let n = line.len();
141    let r = boxsize / 2;
142
143    let mut lq: VecDeque<f64> = VecDeque::with_capacity(boxsize);
144    let mut wq: VecDeque<f64> = VecDeque::with_capacity(boxsize);
145
146    // Initialize with first boxsize values
147    let mut sum = f64::EPSILON; // slightly bigger than 0 to avoid division by 0
148    let mut wsum = f64::EPSILON;
149    let mut wsmooth = f64::EPSILON;
150
151    for i in 0..boxsize {
152        sum += line[i] * weight[i];
153        wsum += weight[i];
154        wsmooth += weight[i] * weight[i];
155        lq.push_back(line[i]);
156        wq.push_back(weight[i]);
157    }
158
159    // Middle part
160    for i in (r + 1)..(n - r) {
161        let w = weight[i + r];
162        let l = line[i + r];
163        let wold = wq.pop_front().unwrap();
164        let lold = lq.pop_front().unwrap();
165        wq.push_back(w);
166        lq.push_back(l);
167
168        sum += l * w - lold * wold;
169        wsum += w - wold;
170        line[i] = sum / wsum;
171        wsmooth += w * w - wold * wold;
172        weight[i] = wsmooth / wsum;
173    }
174}
175
176/// 1D box filter with NaN handling (for masked smoothing)
177/// Matches Julia's nanboxfilterline!
178fn nan_box_filter_line(line: &mut [f64], boxsize: usize) {
179    if boxsize < 3 || line.len() < boxsize {
180        return;
181    }
182
183    let n = line.len();
184    let r = boxsize / 2;
185    let maxfills = r;
186
187    // Create padded buffer with NaN padding
188    let mut orig = vec![f64::NAN; n + boxsize - 1];
189    orig[r..r + n].copy_from_slice(line);
190
191    // Initial sum of first window (excluding NaN)
192    let mut lsum = 0.0;
193    for i in (r + 1)..=(2 * r) {
194        if !orig[i].is_nan() {
195            lsum += orig[i];
196        }
197    }
198
199    let mut nfills = 0usize;
200    let mut nvalids = 0usize;
201
202    #[derive(PartialEq, Clone, Copy)]
203    enum Mode { Nan, Normal, Fill }
204    let mut mode = Mode::Nan;
205
206    for i in 0..n {
207        // Check for mode change
208        match mode {
209            Mode::Normal => {
210                if orig[i + 2 * r].is_nan() {
211                    mode = Mode::Fill;
212                }
213            }
214            Mode::Nan => {
215                if orig[i + 2 * r].is_nan() {
216                    nvalids = 0;
217                } else {
218                    nvalids += 1;
219                }
220                if nvalids == boxsize {
221                    mode = Mode::Normal;
222                    lsum = 0.0;
223                    for j in i..=(i + 2 * r) {
224                        lsum += orig[j];
225                    }
226                    line[i] = lsum / boxsize as f64;
227                    continue;
228                }
229            }
230            Mode::Fill => {
231                if orig[i + 2 * r].is_nan() {
232                    nfills += 1;
233                    if nfills > maxfills {
234                        mode = Mode::Nan;
235                        nfills = 0;
236                        lsum = 0.0;
237                        nvalids = 0;
238                    }
239                } else {
240                    mode = Mode::Normal;
241                    nfills = 0;
242                }
243            }
244        }
245
246        // Perform operation
247        match mode {
248            Mode::Normal => {
249                if i > 0 {
250                    lsum += orig[i + 2 * r] - orig[i - 1];
251                }
252                line[i] = lsum / boxsize as f64;
253            }
254            Mode::Fill => {
255                if i > 0 {
256                    lsum -= orig[i - 1];
257                }
258                line[i] = (lsum - orig[i]) / (boxsize - 2) as f64;
259
260                // Extrapolate the NaN value
261                let extrapolated = if i >= r {
262                    2.0 * line[i] - line[i - r]
263                } else {
264                    line[i]
265                };
266                orig[i + 2 * r] = extrapolated;
267                if i + r < n {
268                    line[i + r] = extrapolated;
269                }
270                lsum += orig[i + 2 * r];
271            }
272            Mode::Nan => {
273                // Keep as NaN or 0
274            }
275        }
276    }
277}
278
279/// 3D Gaussian smoothing using box filter approximation
280///
281/// This matches MriResearchTools.jl's gaussiansmooth3d function.
282///
283/// Parameters:
284/// - data: input 3D data (will be copied)
285/// - sigma: sigma values for each dimension [sx, sy, sz]
286/// - mask: optional mask (None = no masking)
287/// - weight: optional weights (None = no weighting)
288/// - nbox: number of box filter passes (default 3, or 4 with mask)
289/// - nx, ny, nz: dimensions
290pub fn gaussian_smooth_3d(
291    data: &[f64],
292    sigma: [f64; 3],
293    mask: Option<&[u8]>,
294    mut weight: Option<&mut [f64]>,
295    nbox: usize,
296    nx: usize, ny: usize, nz: usize,
297) -> Vec<f64> {
298    let n_total = nx * ny * nz;
299    let mut result: Vec<f64> = data.iter().map(|&v| v as f64).collect();
300
301    // Calculate box sizes for each dimension
302    let mut boxsizes: Vec<Vec<usize>> = sigma.iter()
303        .map(|&s| get_box_sizes(s, nbox))
304        .collect();
305
306    check_box_sizes(&mut boxsizes, &[nx, ny, nz]);
307
308    // Apply mask: set masked-out voxels to NaN
309    if let Some(m) = mask {
310        for i in 0..n_total {
311            if m[i] == 0 {
312                result[i] = f64::NAN;
313            }
314        }
315    }
316
317    // Apply box filters for each pass and dimension
318    for ibox in 0..nbox {
319        // X direction
320        let bsize_x = boxsizes[0][ibox];
321        if nx > 1 && bsize_x >= 3 {
322            // Alternate direction for masked smoothing on even passes
323            let reverse = mask.is_some() && ibox % 2 == 1;
324
325            for k in 0..nz {
326                for j in 0..ny {
327                    let mut line: Vec<f64> = (0..nx).map(|i| {
328                        let idx = if reverse { nx - 1 - i } else { i };
329                        result[idx3d(idx, j, k, nx, ny)]
330                    }).collect();
331
332                    if mask.is_some() {
333                        nan_box_filter_line(&mut line, bsize_x);
334                    } else if let Some(ref mut w) = weight.as_deref_mut() {
335                        let mut wline: Vec<f64> = (0..nx).map(|i| {
336                            let idx = if reverse { nx - 1 - i } else { i };
337                            w[idx3d(idx, j, k, nx, ny)]
338                        }).collect();
339                        box_filter_line_weighted(&mut line, &mut wline, bsize_x);
340                        for i in 0..nx {
341                            let idx = if reverse { nx - 1 - i } else { i };
342                            w[idx3d(idx, j, k, nx, ny)] = wline[i];
343                        }
344                    } else {
345                        box_filter_line(&mut line, bsize_x);
346                    }
347
348                    for i in 0..nx {
349                        let idx = if reverse { nx - 1 - i } else { i };
350                        result[idx3d(idx, j, k, nx, ny)] = line[i];
351                    }
352                }
353            }
354        }
355
356        // Y direction
357        let bsize_y = boxsizes[1][ibox];
358        if ny > 1 && bsize_y >= 3 {
359            let reverse = mask.is_some() && ibox % 2 == 1;
360
361            for k in 0..nz {
362                for i in 0..nx {
363                    let mut line: Vec<f64> = (0..ny).map(|j| {
364                        let idx = if reverse { ny - 1 - j } else { j };
365                        result[idx3d(i, idx, k, nx, ny)]
366                    }).collect();
367
368                    if mask.is_some() {
369                        nan_box_filter_line(&mut line, bsize_y);
370                    } else if let Some(ref mut w) = weight.as_deref_mut() {
371                        let mut wline: Vec<f64> = (0..ny).map(|j| {
372                            let idx = if reverse { ny - 1 - j } else { j };
373                            w[idx3d(i, idx, k, nx, ny)]
374                        }).collect();
375                        box_filter_line_weighted(&mut line, &mut wline, bsize_y);
376                        for j in 0..ny {
377                            let idx = if reverse { ny - 1 - j } else { j };
378                            w[idx3d(i, idx, k, nx, ny)] = wline[j];
379                        }
380                    } else {
381                        box_filter_line(&mut line, bsize_y);
382                    }
383
384                    for j in 0..ny {
385                        let idx = if reverse { ny - 1 - j } else { j };
386                        result[idx3d(i, idx, k, nx, ny)] = line[j];
387                    }
388                }
389            }
390        }
391
392        // Z direction
393        let bsize_z = boxsizes[2][ibox];
394        if nz > 1 && bsize_z >= 3 {
395            let reverse = mask.is_some() && ibox % 2 == 1;
396
397            for j in 0..ny {
398                for i in 0..nx {
399                    let mut line: Vec<f64> = (0..nz).map(|k| {
400                        let idx = if reverse { nz - 1 - k } else { k };
401                        result[idx3d(i, j, idx, nx, ny)]
402                    }).collect();
403
404                    if mask.is_some() {
405                        nan_box_filter_line(&mut line, bsize_z);
406                    } else if let Some(ref mut w) = weight.as_deref_mut() {
407                        let mut wline: Vec<f64> = (0..nz).map(|k| {
408                            let idx = if reverse { nz - 1 - k } else { k };
409                            w[idx3d(i, j, idx, nx, ny)]
410                        }).collect();
411                        box_filter_line_weighted(&mut line, &mut wline, bsize_z);
412                        for k in 0..nz {
413                            let idx = if reverse { nz - 1 - k } else { k };
414                            w[idx3d(i, j, idx, nx, ny)] = wline[k];
415                        }
416                    } else {
417                        box_filter_line(&mut line, bsize_z);
418                    }
419
420                    for k in 0..nz {
421                        let idx = if reverse { nz - 1 - k } else { k };
422                        result[idx3d(i, j, idx, nx, ny)] = line[k];
423                    }
424                }
425            }
426        }
427    }
428
429    result
430}
431
432/// Simplified smoothing with explicit box sizes (for robustmask post-processing)
433pub fn gaussian_smooth_3d_boxsizes(
434    data: &[f64],
435    boxsizes: &[Vec<usize>],
436    nbox: usize,
437    nx: usize, ny: usize, nz: usize,
438) -> Vec<f64> {
439    let mut result = data.to_vec();
440
441    // Apply box filters for each pass and dimension
442    for ibox in 0..nbox {
443        // X direction
444        if nx > 1 && ibox < boxsizes[0].len() {
445            let bsize = boxsizes[0][ibox];
446            if bsize >= 3 {
447                for k in 0..nz {
448                    for j in 0..ny {
449                        let mut line: Vec<f64> = (0..nx).map(|i| result[idx3d(i, j, k, nx, ny)]).collect();
450                        box_filter_line(&mut line, bsize);
451                        for i in 0..nx {
452                            result[idx3d(i, j, k, nx, ny)] = line[i];
453                        }
454                    }
455                }
456            }
457        }
458
459        // Y direction
460        if ny > 1 && ibox < boxsizes[1].len() {
461            let bsize = boxsizes[1][ibox];
462            if bsize >= 3 {
463                for k in 0..nz {
464                    for i in 0..nx {
465                        let mut line: Vec<f64> = (0..ny).map(|j| result[idx3d(i, j, k, nx, ny)]).collect();
466                        box_filter_line(&mut line, bsize);
467                        for j in 0..ny {
468                            result[idx3d(i, j, k, nx, ny)] = line[j];
469                        }
470                    }
471                }
472            }
473        }
474
475        // Z direction
476        if nz > 1 && ibox < boxsizes[2].len() {
477            let bsize = boxsizes[2][ibox];
478            if bsize >= 3 {
479                for j in 0..ny {
480                    for i in 0..nx {
481                        let mut line: Vec<f64> = (0..nz).map(|k| result[idx3d(i, j, k, nx, ny)]).collect();
482                        box_filter_line(&mut line, bsize);
483                        for k in 0..nz {
484                            result[idx3d(i, j, k, nx, ny)] = line[k];
485                        }
486                    }
487                }
488            }
489        }
490    }
491
492    result
493}
494
495//=============================================================================
496// Connected Components and Hole Filling
497//=============================================================================
498
499/// Find connected component using flood fill (6-connectivity in 3D)
500fn flood_fill_component(
501    mask: &[u8],
502    visited: &mut [bool],
503    start: usize,
504    nx: usize, ny: usize, nz: usize,
505) -> Vec<usize> {
506    let mut component = Vec::new();
507    let mut stack = vec![start];
508
509    while let Some(idx) = stack.pop() {
510        if visited[idx] || mask[idx] != 0 {
511            continue;
512        }
513
514        visited[idx] = true;
515        component.push(idx);
516
517        // Get 3D coordinates
518        let k = idx / (nx * ny);
519        let rem = idx % (nx * ny);
520        let j = rem / nx;
521        let i = rem % nx;
522
523        // 6-connectivity neighbors
524        if i > 0 {
525            let n = idx3d(i - 1, j, k, nx, ny);
526            if !visited[n] && mask[n] == 0 { stack.push(n); }
527        }
528        if i + 1 < nx {
529            let n = idx3d(i + 1, j, k, nx, ny);
530            if !visited[n] && mask[n] == 0 { stack.push(n); }
531        }
532        if j > 0 {
533            let n = idx3d(i, j - 1, k, nx, ny);
534            if !visited[n] && mask[n] == 0 { stack.push(n); }
535        }
536        if j + 1 < ny {
537            let n = idx3d(i, j + 1, k, nx, ny);
538            if !visited[n] && mask[n] == 0 { stack.push(n); }
539        }
540        if k > 0 {
541            let n = idx3d(i, j, k - 1, nx, ny);
542            if !visited[n] && mask[n] == 0 { stack.push(n); }
543        }
544        if k + 1 < nz {
545            let n = idx3d(i, j, k + 1, nx, ny);
546            if !visited[n] && mask[n] == 0 { stack.push(n); }
547        }
548    }
549
550    component
551}
552
553/// Fill holes in a binary mask
554///
555/// Matches MriResearchTools.jl's fill_holes function.
556/// Fills connected components of zeros (holes) up to max_hole_size.
557/// Uses 6-connectivity for 3D.
558pub fn fill_holes(mask: &[u8], nx: usize, ny: usize, nz: usize, max_hole_size: usize) -> Vec<u8> {
559    let n_total = nx * ny * nz;
560    let mut result = mask.to_vec();
561    let mut visited = vec![false; n_total];
562
563    // Find all connected components of zeros (potential holes)
564    for idx in 0..n_total {
565        if mask[idx] == 0 && !visited[idx] {
566            let component = flood_fill_component(mask, &mut visited, idx, nx, ny, nz);
567
568            // Check if this component touches the boundary
569            let mut touches_boundary = false;
570            for &cidx in &component {
571                let k = cidx / (nx * ny);
572                let rem = cidx % (nx * ny);
573                let j = rem / nx;
574                let i = rem % nx;
575
576                if i == 0 || i == nx - 1 || j == 0 || j == ny - 1 || k == 0 || k == nz - 1 {
577                    touches_boundary = true;
578                    break;
579                }
580            }
581
582            // Fill if it's a hole (doesn't touch boundary) and small enough
583            if !touches_boundary && component.len() <= max_hole_size {
584                for cidx in component {
585                    result[cidx] = 1;
586                }
587            }
588        }
589    }
590
591    result
592}
593
594//=============================================================================
595// Robust Mask (matching MriResearchTools.jl)
596//=============================================================================
597
598/// Create robust mask from magnitude using quantile-based thresholding
599///
600/// This matches MriResearchTools.jl's robustmask function, including
601/// post-processing with smoothing and hole filling.
602pub fn robust_mask(mag: &[f64], nx: usize, ny: usize, nz: usize) -> Vec<u8> {
603    let n_total = nx * ny * nz;
604
605    // Collect valid (positive, finite) samples and sort
606    let mut samples: Vec<f64> = mag.iter()
607        .filter(|&&v| v.is_finite() && v > 0.0)
608        .copied()
609        .collect();
610
611    if samples.is_empty() {
612        return vec![0u8; n_total];
613    }
614
615    samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
616
617    let len = samples.len();
618
619    // Calculate quantiles
620    let q05_idx = ((0.05 * len as f64) as usize).min(len - 1);
621    let q15_idx = ((0.15 * len as f64) as usize).min(len - 1);
622    let q80_idx = ((0.80 * len as f64) as usize).min(len - 1);
623    let q99_idx = ((0.99 * len as f64) as usize).min(len - 1);
624
625    let q05 = samples[q05_idx];
626    let q15 = samples[q15_idx];
627    let q80 = samples[q80_idx];
628    let q99 = samples[q99_idx];
629
630    // Calculate high intensity mean (between 80th and 99th percentile)
631    let high_samples: Vec<f64> = samples.iter()
632        .filter(|&&v| v >= q80 && v <= q99)
633        .copied()
634        .collect();
635
636    let high_intensity = if high_samples.is_empty() {
637        q99
638    } else {
639        high_samples.iter().sum::<f64>() / high_samples.len() as f64
640    };
641
642    // Estimate noise level from low-intensity voxels
643    let low_samples: Vec<f64> = samples.iter()
644        .filter(|&&v| v <= q15)
645        .copied()
646        .collect();
647
648    let mut noise = if low_samples.is_empty() {
649        0.0
650    } else {
651        low_samples.iter().sum::<f64>() / low_samples.len() as f64
652    };
653
654    // If noise estimate is too high, try using lower percentile
655    if noise > high_intensity / 10.0 {
656        let very_low_samples: Vec<f64> = samples.iter()
657            .filter(|&&v| v <= q05)
658            .copied()
659            .collect();
660
661        noise = if very_low_samples.is_empty() {
662            0.0
663        } else {
664            very_low_samples.iter().sum::<f64>() / very_low_samples.len() as f64
665        };
666
667        if noise > high_intensity / 10.0 {
668            noise = 0.0;
669        }
670    }
671
672    // Calculate threshold: max(5*noise, high_intensity/5)
673    let threshold = (5.0 * noise).max(high_intensity / 5.0);
674
675    // Create initial binary mask
676    let mut mask_f64: Vec<f64> = mag.iter()
677        .map(|&v| if v > threshold { 1.0 } else { 0.0 })
678        .collect();
679
680    // Post-processing Step 1: Smooth with nbox=1, boxsize=5, threshold at 0.4
681    let boxsizes1 = vec![vec![5], vec![5], vec![5]];
682    mask_f64 = gaussian_smooth_3d_boxsizes(&mask_f64, &boxsizes1, 1, nx, ny, nz);
683    let mut mask: Vec<u8> = mask_f64.iter()
684        .map(|&v| if v > 0.4 { 1 } else { 0 })
685        .collect();
686
687    // Post-processing Step 2: Fill holes
688    let max_hole_size = n_total / 20;
689    mask = fill_holes(&mask, nx, ny, nz, max_hole_size);
690
691    // Post-processing Step 3: Smooth with nbox=2, boxsizes=[3,3], threshold at 0.6
692    mask_f64 = mask.iter().map(|&v| v as f64).collect();
693    let boxsizes2 = vec![vec![3, 3], vec![3, 3], vec![3, 3]];
694    mask_f64 = gaussian_smooth_3d_boxsizes(&mask_f64, &boxsizes2, 2, nx, ny, nz);
695    mask = mask_f64.iter()
696        .map(|&v| if v > 0.6 { 1 } else { 0 })
697        .collect();
698
699    mask
700}
701
702//=============================================================================
703// Box Segmentation
704//=============================================================================
705
706/// Box segmentation for finding tissue regions
707///
708/// Divides the image into nbox^3 boxes and identifies voxels that
709/// consistently appear in the high-intensity range across multiple boxes.
710fn box_segment(
711    image: &[f64],
712    mask: &[u8],
713    nbox: usize,
714    nx: usize, ny: usize, nz: usize,
715) -> Vec<u8> {
716    let n_total = nx * ny * nz;
717    let mut vote_count = vec![0u8; n_total];
718
719    // Calculate box shift (stride between box centers)
720    let box_shift_x = (nx + nbox - 1) / nbox;
721    let box_shift_y = (ny + nbox - 1) / nbox;
722    let box_shift_z = (nz + nbox - 1) / nbox;
723
724    // For each box center
725    let mut cz = 0;
726    while cz < nz {
727        let mut cy = 0;
728        while cy < ny {
729            let mut cx = 0;
730            while cx < nx {
731                // Calculate box bounds (2x box_shift around center)
732                let x_start = cx.saturating_sub(box_shift_x);
733                let x_end = (cx + box_shift_x).min(nx);
734                let y_start = cy.saturating_sub(box_shift_y);
735                let y_end = (cy + box_shift_y).min(ny);
736                let z_start = cz.saturating_sub(box_shift_z);
737                let z_end = (cz + box_shift_z).min(nz);
738
739                // Collect values in this box
740                let mut box_vals: Vec<f64> = Vec::new();
741                for z in z_start..z_end {
742                    for y in y_start..y_end {
743                        for x in x_start..x_end {
744                            let idx = idx3d(x, y, z, nx, ny);
745                            if mask[idx] > 0 && image[idx].is_finite() {
746                                box_vals.push(image[idx]);
747                            }
748                        }
749                    }
750                }
751
752                if box_vals.is_empty() {
753                    cx += box_shift_x;
754                    continue;
755                }
756
757                // Sort and find 90th percentile
758                box_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
759                let q90_idx = ((0.9 * box_vals.len() as f64) as usize).min(box_vals.len() - 1);
760                let q90 = box_vals[q90_idx];
761
762                // Define tissue range around 90th percentile
763                let width = 0.1;
764                let low = (1.0 - width) * q90;
765                let high = (1.0 + width) * q90;
766
767                // Vote for voxels in tissue range
768                for z in z_start..z_end {
769                    for y in y_start..y_end {
770                        for x in x_start..x_end {
771                            let idx = idx3d(x, y, z, nx, ny);
772                            if mask[idx] > 0 {
773                                let v = image[idx];
774                                if v > low && v < high {
775                                    vote_count[idx] = vote_count[idx].saturating_add(1);
776                                }
777                            }
778                        }
779                    }
780                }
781
782                cx += box_shift_x;
783            }
784            cy += box_shift_y;
785        }
786        cz += box_shift_z;
787    }
788
789    // Threshold: must be identified as tissue in >= 2 boxes
790    let mut segmented = vec![0u8; n_total];
791    for i in 0..n_total {
792        if vote_count[i] >= 2 && mask[i] > 0 {
793            segmented[i] = 1;
794        }
795    }
796
797    segmented
798}
799
800//=============================================================================
801// Fill and Smooth (with weighted smoothing)
802//=============================================================================
803
804/// Fill holes and smooth the lowpass field with weighted smoothing
805///
806/// Matches MriResearchTools.jl's fillandsmooth! function.
807/// Uses weighted smoothing where filled holes get weight 0.2.
808fn fill_and_smooth(
809    lowpass: &mut [f64],
810    stable_mean: f64,
811    sigma2: [f64; 3],
812    nx: usize, ny: usize, nz: usize,
813) {
814    let n_total = nx * ny * nz;
815
816    // Identify holes/outliers and create weight mask
817    // lowpassweight = 1.2 - lowpassmask (so holes get 0.2, normal get 1.2)
818    let mut weight = vec![1.2f64; n_total];
819
820    for i in 0..n_total {
821        if lowpass[i] < stable_mean / 4.0 ||
822           lowpass[i].is_nan() ||
823           lowpass[i] > 10.0 * stable_mean {
824            lowpass[i] = 3.0 * stable_mean;
825            weight[i] = 0.2; // Filled holes get less weight
826        }
827    }
828
829    // Apply weighted smoothing
830    let nbox = 3; // default for non-masked smoothing
831    let smoothed = gaussian_smooth_3d(lowpass, sigma2, None, Some(&mut weight), nbox, nx, ny, nz);
832    lowpass.copy_from_slice(&smoothed);
833}
834
835//=============================================================================
836// Main API
837//=============================================================================
838
839/// Get sensitivity (bias field) from magnitude
840///
841/// This estimates the RF receive field inhomogeneity (sensitivity map)
842/// that can be divided out to correct the image.
843pub fn get_sensitivity(
844    mag: &[f64],
845    nx: usize, ny: usize, nz: usize,
846    vx: f64, vy: f64, vz: f64,
847    sigma_mm: f64,
848    nbox: usize,
849) -> Vec<f64> {
850    let n_total = nx * ny * nz;
851
852    // Convert mm to voxels
853    let sigma = [sigma_mm / vx, sigma_mm / vy, sigma_mm / vz];
854
855    // Create initial mask (with full post-processing)
856    let mask = robust_mask(mag, nx, ny, nz);
857
858    // Box segmentation to find tissue
859    let segmentation = box_segment(mag, &mask, nbox, nx, ny, nz);
860
861    // Split sigma into two parts (matching MriResearchTools.jl)
862    let factor: f64 = 0.7;
863    let sigma1 = [
864        (1.0_f64 - factor * factor).sqrt() * sigma[0],
865        (1.0_f64 - factor * factor).sqrt() * sigma[1],
866        (1.0_f64 - factor * factor).sqrt() * sigma[2],
867    ];
868    let sigma2 = [
869        factor * sigma[0],
870        factor * sigma[1],
871        factor * sigma[2],
872    ];
873
874    // First smoothing with tissue mask (nbox=8 for masked smoothing)
875    let mut lowpass = gaussian_smooth_3d(mag, sigma1, Some(&segmentation), None, 8, nx, ny, nz);
876
877    // Calculate stable mean for filling
878    let mut sum = 0.0;
879    let mut count = 0usize;
880    for i in 0..n_total {
881        if mask[i] > 0 && mag[i].is_finite() {
882            sum += mag[i];
883            count += 1;
884        }
885    }
886    let stable_mean = if count > 0 { sum / count as f64 } else { 1.0 };
887
888    // Fill holes and apply weighted second smoothing
889    fill_and_smooth(&mut lowpass, stable_mean, sigma2, nx, ny, nz);
890
891    lowpass
892}
893
894/// Make magnitude homogeneous by dividing by bias field
895///
896/// This is the main entry point for bias field correction.
897///
898/// # Arguments
899/// * `mag` - Input magnitude data (nx * ny * nz)
900/// * `nx`, `ny`, `nz` - Dimensions
901/// * `vx`, `vy`, `vz` - Voxel sizes in mm
902/// * `sigma_mm` - Smoothing sigma in mm (default 7, will be clamped to 10% FOV)
903/// * `nbox` - Number of boxes per dimension for segmentation (default 15)
904///
905/// # Returns
906/// Bias-corrected magnitude
907pub fn makehomogeneous(
908    mag: &[f64],
909    nx: usize, ny: usize, nz: usize,
910    vx: f64, vy: f64, vz: f64,
911    sigma_mm: f64,
912    nbox: usize,
913) -> Vec<f64> {
914    let sensitivity = get_sensitivity(mag, nx, ny, nz, vx, vy, vz, sigma_mm, nbox);
915    let n_total = nx * ny * nz;
916
917    let mut result = vec![0.0; n_total];
918    for i in 0..n_total {
919        if sensitivity[i] > 1e-10 && !sensitivity[i].is_nan() {
920            result[i] = mag[i] / sensitivity[i];
921        } else {
922            result[i] = mag[i];
923        }
924    }
925
926    result
927}
928
929/// RSS (Root Sum of Squares) magnitude combination
930///
931/// Combines multi-echo magnitude images using RSS.
932///
933/// # Arguments
934/// * `mags_flat` - Flattened magnitudes [echo0, echo1, ...]
935/// * `n_echoes` - Number of echoes
936/// * `n_total` - Voxels per echo (nx * ny * nz)
937///
938/// # Returns
939/// RSS-combined magnitude
940pub fn rss_combine(
941    mags_flat: &[f64],
942    n_echoes: usize,
943    n_total: usize,
944) -> Vec<f64> {
945    let mut result = vec![0.0; n_total];
946
947    for e in 0..n_echoes {
948        let offset = e * n_total;
949        for i in 0..n_total {
950            let v = mags_flat[offset + i];
951            result[i] += v * v;
952        }
953    }
954
955    for i in 0..n_total {
956        result[i] = result[i].sqrt();
957    }
958
959    result
960}
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965
966    #[test]
967    fn test_get_box_sizes() {
968        // Test box size calculation matches Julia
969        let sizes = get_box_sizes(5.0, 3);
970        assert_eq!(sizes.len(), 3);
971        // For sigma=5, n=3: wideal ≈ 5.77
972        // All sizes should be odd and reasonable
973        for &s in &sizes {
974            assert!(s % 2 == 1, "Box size should be odd, got {}", s);
975            assert!(s >= 3 && s <= 11, "Box size should be in reasonable range, got {}", s);
976        }
977    }
978
979    #[test]
980    fn test_box_filter_line() {
981        // Simple test: uniform values should stay uniform
982        let mut line = vec![1.0; 10];
983        box_filter_line(&mut line, 3);
984        for &v in &line {
985            assert!((v - 1.0).abs() < 1e-10, "Uniform line should stay uniform");
986        }
987    }
988
989    #[test]
990    fn test_fill_holes_basic() {
991        // 3x3x3 cube with a hole in the center
992        let mut mask = vec![1u8; 27];
993        mask[13] = 0; // center voxel
994
995        let filled = fill_holes(&mask, 3, 3, 3, 5);
996        assert_eq!(filled[13], 1, "Center hole should be filled");
997    }
998
999    #[test]
1000    fn test_robust_mask_basic() {
1001        // Simple test with uniform high values
1002        let mag = vec![100.0; 27];
1003        let mask = robust_mask(&mag, 3, 3, 3);
1004        // All values are the same, so all should be masked
1005        let masked_count: usize = mask.iter().map(|&v| v as usize).sum();
1006        assert!(masked_count > 0, "Should have some masked voxels");
1007    }
1008
1009    #[test]
1010    fn test_rss_combine() {
1011        // Two echoes, 4 voxels each
1012        let mags = vec![
1013            3.0, 0.0, 0.0, 5.0,  // echo 0
1014            4.0, 0.0, 0.0, 12.0, // echo 1
1015        ];
1016        let result = rss_combine(&mags, 2, 4);
1017
1018        // sqrt(3^2 + 4^2) = 5
1019        assert!((result[0] - 5.0).abs() < 1e-10);
1020        // sqrt(0 + 0) = 0
1021        assert!((result[1] - 0.0).abs() < 1e-10);
1022        // sqrt(5^2 + 12^2) = 13
1023        assert!((result[3] - 13.0).abs() < 1e-10);
1024    }
1025
1026    // =====================================================================
1027    // Helper: create a 3D sphere magnitude phantom with optional bias field
1028    // =====================================================================
1029
1030    /// Create a 3D sphere phantom with a smooth bias field applied.
1031    /// Returns (magnitude_data, mask) where mask marks inside-sphere voxels.
1032    fn make_sphere_phantom(n: usize, bias: bool) -> (Vec<f64>, Vec<u8>) {
1033        let center = n as f64 / 2.0;
1034        let radius = n as f64 / 2.0 - 2.0;
1035        let n_total = n * n * n;
1036        let mut mag = vec![0.0f64; n_total];
1037        let mut mask = vec![0u8; n_total];
1038
1039        for k in 0..n {
1040            for j in 0..n {
1041                for i in 0..n {
1042                    let dx = i as f64 - center;
1043                    let dy = j as f64 - center;
1044                    let dz = k as f64 - center;
1045                    let dist = (dx * dx + dy * dy + dz * dz).sqrt();
1046                    let idx = i + j * n + k * n * n;
1047
1048                    if dist < radius {
1049                        // Base tissue intensity
1050                        let base = 100.0;
1051                        // Apply smooth bias field if requested
1052                        let bias_val = if bias {
1053                            1.0 + 0.5 * (i as f64 / n as f64)
1054                        } else {
1055                            1.0
1056                        };
1057                        mag[idx] = base * bias_val;
1058                        mask[idx] = 1;
1059                    } else {
1060                        // Background noise
1061                        mag[idx] = 1.0 + 0.5 * ((i + j + k) % 3) as f64;
1062                    }
1063                }
1064            }
1065        }
1066
1067        (mag, mask)
1068    }
1069
1070    // =====================================================================
1071    // Tests for get_box_sizes edge cases
1072    // =====================================================================
1073
1074    #[test]
1075    fn test_get_box_sizes_zero_sigma() {
1076        let sizes = get_box_sizes(0.0, 3);
1077        assert_eq!(sizes, vec![0, 0, 0]);
1078    }
1079
1080    #[test]
1081    fn test_get_box_sizes_zero_n() {
1082        let sizes = get_box_sizes(5.0, 0);
1083        assert!(sizes.is_empty());
1084    }
1085
1086    #[test]
1087    fn test_get_box_sizes_negative_sigma() {
1088        let sizes = get_box_sizes(-1.0, 3);
1089        assert_eq!(sizes, vec![0, 0, 0]);
1090    }
1091
1092    #[test]
1093    fn test_get_box_sizes_large_sigma() {
1094        let sizes = get_box_sizes(20.0, 4);
1095        assert_eq!(sizes.len(), 4);
1096        for &s in &sizes {
1097            assert!(s % 2 == 1, "Box size should be odd, got {}", s);
1098            assert!(s >= 3, "Box size should be at least 3, got {}", s);
1099        }
1100    }
1101
1102    // =====================================================================
1103    // Tests for check_box_sizes
1104    // =====================================================================
1105
1106    #[test]
1107    fn test_check_box_sizes_clamps_to_half_image() {
1108        // Box size larger than half the image dimension should be clamped
1109        let mut boxsizes = vec![vec![99], vec![99], vec![99]];
1110        let dims = [10, 10, 10];
1111        check_box_sizes(&mut boxsizes, &dims);
1112        for bs in &boxsizes {
1113            for &b in bs {
1114                assert!(b <= dims[0], "Box size should be clamped, got {}", b);
1115                assert!(b % 2 == 1, "Box size should be odd, got {}", b);
1116            }
1117        }
1118    }
1119
1120    #[test]
1121    fn test_check_box_sizes_makes_even_odd() {
1122        let mut boxsizes = vec![vec![4], vec![6], vec![8]];
1123        let dims = [100, 100, 100];
1124        check_box_sizes(&mut boxsizes, &dims);
1125        for bs in &boxsizes {
1126            for &b in bs {
1127                assert!(b % 2 == 1, "Box size should be odd, got {}", b);
1128            }
1129        }
1130    }
1131
1132    // =====================================================================
1133    // Tests for box_filter_line edge cases
1134    // =====================================================================
1135
1136    #[test]
1137    fn test_box_filter_line_too_small_boxsize() {
1138        // boxsize < 3 should be a no-op
1139        let mut line = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1140        let original = line.clone();
1141        box_filter_line(&mut line, 1);
1142        assert_eq!(line, original);
1143    }
1144
1145    #[test]
1146    fn test_box_filter_line_larger_than_data() {
1147        // boxsize > line length should be a no-op
1148        let mut line = vec![1.0, 2.0];
1149        let original = line.clone();
1150        box_filter_line(&mut line, 5);
1151        assert_eq!(line, original);
1152    }
1153
1154    #[test]
1155    fn test_box_filter_line_smoothing_effect() {
1156        // A spike should be smoothed out
1157        let mut line = vec![0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
1158        box_filter_line(&mut line, 5);
1159        // The spike at index 3 should be reduced
1160        assert!(line[3] < 10.0, "Spike should be smoothed");
1161        // Neighbors should have received some intensity
1162        assert!(line[2] > 0.0 || line[4] > 0.0, "Neighbors should gain intensity");
1163        // All values should be finite
1164        for &v in &line {
1165            assert!(v.is_finite(), "All values should be finite");
1166        }
1167    }
1168
1169    // =====================================================================
1170    // Tests for box_filter_line_weighted
1171    // =====================================================================
1172
1173    #[test]
1174    fn test_box_filter_line_weighted_too_small() {
1175        let mut line = vec![1.0, 2.0];
1176        let mut weight = vec![1.0, 1.0];
1177        let orig_l = line.clone();
1178        let orig_w = weight.clone();
1179        box_filter_line_weighted(&mut line, &mut weight, 1);
1180        assert_eq!(line, orig_l);
1181        assert_eq!(weight, orig_w);
1182    }
1183
1184    #[test]
1185    fn test_box_filter_line_weighted_uniform() {
1186        let mut line = vec![5.0; 10];
1187        let mut weight = vec![1.0; 10];
1188        box_filter_line_weighted(&mut line, &mut weight, 3);
1189        // With uniform values and uniform weights, the middle portion
1190        // should remain close to 5.0
1191        for i in 2..8 {
1192            assert!(
1193                (line[i] - 5.0).abs() < 0.5,
1194                "Uniform weighted line should stay near 5.0, got {} at index {}",
1195                line[i], i
1196            );
1197        }
1198    }
1199
1200    #[test]
1201    fn test_box_filter_line_weighted_varying_weights() {
1202        let mut line = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1203        let mut weight = vec![1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0];
1204        box_filter_line_weighted(&mut line, &mut weight, 3);
1205        // All values should remain finite
1206        for &v in &line {
1207            assert!(v.is_finite(), "All values should be finite after weighted filter");
1208        }
1209        for &w in &weight {
1210            assert!(w.is_finite(), "All weights should be finite after weighted filter");
1211        }
1212    }
1213
1214    // =====================================================================
1215    // Tests for nan_box_filter_line
1216    // =====================================================================
1217
1218    #[test]
1219    fn test_nan_box_filter_line_too_small() {
1220        let mut line = vec![1.0, 2.0];
1221        let original = line.clone();
1222        nan_box_filter_line(&mut line, 5);
1223        assert_eq!(line, original);
1224    }
1225
1226    #[test]
1227    fn test_nan_box_filter_line_no_nans() {
1228        // All valid data => the nan filter processes it and produces finite results
1229        let mut line = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1230        nan_box_filter_line(&mut line, 3);
1231        // All values should remain finite
1232        for &v in &line {
1233            assert!(v.is_finite(), "All values should be finite after nan box filter");
1234        }
1235        // The nan box filter in Normal mode should act as a smoothing filter
1236        // Check that some middle values have changed (been smoothed)
1237        // The function's Normal mode applies running sum averaging
1238    }
1239
1240    #[test]
1241    fn test_nan_box_filter_line_with_nans() {
1242        // Data with NaN gaps
1243        let mut line = vec![5.0, 5.0, 5.0, f64::NAN, f64::NAN, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0];
1244        nan_box_filter_line(&mut line, 3);
1245        // Non-NaN regions should still be mostly finite
1246        let finite_count = line.iter().filter(|v| v.is_finite()).count();
1247        assert!(finite_count >= 6, "Most values should be finite, got {}", finite_count);
1248    }
1249
1250    #[test]
1251    fn test_nan_box_filter_line_all_nan_except_edges() {
1252        // Mostly NaN
1253        let mut line = vec![1.0; 12];
1254        for i in 2..10 {
1255            line[i] = f64::NAN;
1256        }
1257        nan_box_filter_line(&mut line, 3);
1258        // The function should not panic and edge values should remain finite
1259        assert!(line[0].is_finite() || line[0].is_nan());
1260    }
1261
1262    // =====================================================================
1263    // Tests for gaussian_smooth_3d (main 3D smoothing)
1264    // =====================================================================
1265
1266    #[test]
1267    fn test_gaussian_smooth_3d_uniform_no_mask() {
1268        let n = 8;
1269        let data = vec![5.0; n * n * n];
1270        let result = gaussian_smooth_3d(&data, [2.0, 2.0, 2.0], None, None, 3, n, n, n);
1271        assert_eq!(result.len(), n * n * n);
1272        // Uniform data should stay nearly uniform after smoothing
1273        for &v in &result {
1274            assert!(
1275                (v - 5.0).abs() < 0.5,
1276                "Uniform data should stay near 5.0, got {}",
1277                v
1278            );
1279        }
1280    }
1281
1282    #[test]
1283    fn test_gaussian_smooth_3d_with_mask() {
1284        let n = 10;
1285        let (mag, mask) = make_sphere_phantom(n, false);
1286        // Smooth with mask
1287        let result = gaussian_smooth_3d(&mag, [1.5, 1.5, 1.5], Some(&mask), None, 4, n, n, n);
1288        assert_eq!(result.len(), n * n * n);
1289        // Inside mask, values should still be finite
1290        for i in 0..result.len() {
1291            if mask[i] > 0 {
1292                assert!(result[i].is_finite(), "Masked voxel at {} should be finite", i);
1293            }
1294        }
1295    }
1296
1297    #[test]
1298    fn test_gaussian_smooth_3d_with_weights() {
1299        let n = 8;
1300        let data = vec![10.0; n * n * n];
1301        let mut weight = vec![1.0; n * n * n];
1302        let result = gaussian_smooth_3d(
1303            &data, [1.5, 1.5, 1.5], None, Some(&mut weight), 3, n, n, n,
1304        );
1305        assert_eq!(result.len(), n * n * n);
1306        for &v in &result {
1307            assert!(v.is_finite(), "Result should be finite");
1308        }
1309    }
1310
1311    // =====================================================================
1312    // Tests for gaussian_smooth_3d_boxsizes
1313    // =====================================================================
1314
1315    #[test]
1316    fn test_gaussian_smooth_3d_boxsizes_uniform() {
1317        let n = 8;
1318        let data = vec![3.0; n * n * n];
1319        let boxsizes = vec![vec![3, 3], vec![3, 3], vec![3, 3]];
1320        let result = gaussian_smooth_3d_boxsizes(&data, &boxsizes, 2, n, n, n);
1321        assert_eq!(result.len(), n * n * n);
1322        for &v in &result {
1323            assert!(
1324                (v - 3.0).abs() < 0.5,
1325                "Uniform data should stay near 3.0, got {}",
1326                v
1327            );
1328        }
1329    }
1330
1331    #[test]
1332    fn test_gaussian_smooth_3d_boxsizes_spike() {
1333        let n = 10;
1334        let n_total = n * n * n;
1335        let mut data = vec![0.0; n_total];
1336        // Place a spike at the center
1337        let center = n / 2 + (n / 2) * n + (n / 2) * n * n;
1338        data[center] = 100.0;
1339        let boxsizes = vec![vec![5, 5], vec![5, 5], vec![5, 5]];
1340        let result = gaussian_smooth_3d_boxsizes(&data, &boxsizes, 2, n, n, n);
1341        assert_eq!(result.len(), n_total);
1342        // Spike should be reduced
1343        assert!(
1344            result[center] < 100.0,
1345            "Spike should be smoothed, got {}",
1346            result[center]
1347        );
1348        // Sum should be approximately conserved (box filter is mean-preserving)
1349        for &v in &result {
1350            assert!(v.is_finite(), "All values should be finite");
1351        }
1352    }
1353
1354    // =====================================================================
1355    // Tests for fill_holes (more complex cases)
1356    // =====================================================================
1357
1358    #[test]
1359    fn test_fill_holes_boundary_hole_not_filled() {
1360        // A hole touching the boundary should NOT be filled
1361        let mut mask = vec![1u8; 125]; // 5x5x5
1362        mask[0] = 0; // corner voxel - touches boundary
1363        let filled = fill_holes(&mask, 5, 5, 5, 100);
1364        assert_eq!(filled[0], 0, "Boundary hole should not be filled");
1365    }
1366
1367    #[test]
1368    fn test_fill_holes_large_hole_not_filled() {
1369        // A hole larger than max_hole_size should NOT be filled
1370        let n = 7;
1371        let n_total = n * n * n;
1372        let mut mask = vec![1u8; n_total];
1373        // Create a large internal hole (3x3x3 = 27 voxels)
1374        for k in 2..5 {
1375            for j in 2..5 {
1376                for i in 2..5 {
1377                    mask[i + j * n + k * n * n] = 0;
1378                }
1379            }
1380        }
1381        let filled = fill_holes(&mask, n, n, n, 5); // max_hole_size=5, hole is 27
1382        // Hole should NOT be filled because it's too large
1383        let center = 3 + 3 * n + 3 * n * n;
1384        assert_eq!(filled[center], 0, "Large hole should not be filled");
1385    }
1386
1387    // =====================================================================
1388    // Tests for robust_mask (more comprehensive)
1389    // =====================================================================
1390
1391    #[test]
1392    fn test_robust_mask_sphere() {
1393        let n = 12;
1394        let (mag, _) = make_sphere_phantom(n, false);
1395        let mask = robust_mask(&mag, n, n, n);
1396        assert_eq!(mask.len(), n * n * n);
1397        let masked_count: usize = mask.iter().map(|&v| v as usize).sum();
1398        // The sphere should produce some masked voxels
1399        assert!(masked_count > 0, "Should have masked voxels for sphere phantom");
1400        // Center should be masked (high intensity)
1401        let center = n / 2 + (n / 2) * n + (n / 2) * n * n;
1402        assert_eq!(mask[center], 1, "Center of sphere should be masked");
1403    }
1404
1405    #[test]
1406    fn test_robust_mask_empty() {
1407        // All zero magnitude -> empty mask
1408        let mag = vec![0.0; 27];
1409        let mask = robust_mask(&mag, 3, 3, 3);
1410        let count: usize = mask.iter().map(|&v| v as usize).sum();
1411        assert_eq!(count, 0, "Zero magnitude should produce empty mask");
1412    }
1413
1414    #[test]
1415    fn test_robust_mask_nan_values() {
1416        // Magnitude with NaN values
1417        let mut mag = vec![100.0; 125];
1418        mag[0] = f64::NAN;
1419        mag[10] = f64::NAN;
1420        mag[50] = f64::INFINITY;
1421        let mask = robust_mask(&mag, 5, 5, 5);
1422        assert_eq!(mask.len(), 125);
1423        // Should still produce a valid mask
1424        for &v in &mask {
1425            assert!(v == 0 || v == 1, "Mask values should be 0 or 1");
1426        }
1427    }
1428
1429    // =====================================================================
1430    // Tests for box_segment
1431    // =====================================================================
1432
1433    #[test]
1434    fn test_box_segment_uniform() {
1435        let n = 10;
1436        let (mag, mask) = make_sphere_phantom(n, false);
1437        let seg = box_segment(&mag, &mask, 3, n, n, n);
1438        assert_eq!(seg.len(), n * n * n);
1439        // In a uniform sphere, most interior voxels should be segmented as tissue
1440        let seg_count: usize = seg.iter().map(|&v| v as usize).sum();
1441        assert!(seg_count > 0, "Box segment should find some tissue voxels");
1442    }
1443
1444    #[test]
1445    fn test_box_segment_empty_mask() {
1446        let n = 8;
1447        let mag = vec![100.0; n * n * n];
1448        let mask = vec![0u8; n * n * n]; // empty mask
1449        let seg = box_segment(&mag, &mask, 3, n, n, n);
1450        let count: usize = seg.iter().map(|&v| v as usize).sum();
1451        assert_eq!(count, 0, "Empty mask should produce no segmentation");
1452    }
1453
1454    // =====================================================================
1455    // Tests for fill_and_smooth
1456    // =====================================================================
1457
1458    #[test]
1459    fn test_fill_and_smooth_basic() {
1460        let n = 10;
1461        let n_total = n * n * n;
1462        let stable_mean = 100.0;
1463        let mut lowpass = vec![stable_mean; n_total];
1464        // Add some holes (very low values)
1465        lowpass[0] = 0.0;
1466        lowpass[100] = f64::NAN;
1467        lowpass[200] = 2000.0; // outlier > 10*stable_mean
1468
1469        let sigma2 = [2.0, 2.0, 2.0];
1470        fill_and_smooth(&mut lowpass, stable_mean, sigma2, n, n, n);
1471
1472        // All values should be finite after fill and smooth
1473        for (i, &v) in lowpass.iter().enumerate() {
1474            assert!(v.is_finite(), "Value at {} should be finite, got {}", i, v);
1475        }
1476    }
1477
1478    #[test]
1479    fn test_fill_and_smooth_preserves_approximate_mean() {
1480        let n = 8;
1481        let n_total = n * n * n;
1482        let stable_mean = 50.0;
1483        let mut lowpass = vec![stable_mean; n_total];
1484        let sigma2 = [1.5, 1.5, 1.5];
1485        fill_and_smooth(&mut lowpass, stable_mean, sigma2, n, n, n);
1486
1487        // Mean should be approximately preserved
1488        let mean: f64 = lowpass.iter().sum::<f64>() / n_total as f64;
1489        assert!(
1490            (mean - stable_mean).abs() < stable_mean * 0.5,
1491            "Mean should be approximately preserved, got {} vs {}",
1492            mean,
1493            stable_mean
1494        );
1495    }
1496
1497    // =====================================================================
1498    // Tests for get_sensitivity
1499    // =====================================================================
1500
1501    #[test]
1502    fn test_get_sensitivity_sphere() {
1503        let n = 12;
1504        let (mag, _) = make_sphere_phantom(n, true);
1505        let sensitivity = get_sensitivity(&mag, n, n, n, 1.0, 1.0, 1.0, 4.0, 5);
1506        assert_eq!(sensitivity.len(), n * n * n);
1507        // Sensitivity should be finite and mostly positive in the sphere
1508        let center = n / 2 + (n / 2) * n + (n / 2) * n * n;
1509        assert!(
1510            sensitivity[center].is_finite(),
1511            "Sensitivity at center should be finite"
1512        );
1513        assert!(
1514            sensitivity[center] > 0.0,
1515            "Sensitivity at center should be positive, got {}",
1516            sensitivity[center]
1517        );
1518    }
1519
1520    #[test]
1521    fn test_get_sensitivity_output_size() {
1522        let n = 10;
1523        let (mag, _) = make_sphere_phantom(n, false);
1524        let sensitivity = get_sensitivity(&mag, n, n, n, 1.0, 1.0, 1.0, 3.0, 3);
1525        assert_eq!(sensitivity.len(), n * n * n);
1526    }
1527
1528    // =====================================================================
1529    // Tests for makehomogeneous (main entry point)
1530    // =====================================================================
1531
1532    #[test]
1533    fn test_makehomogeneous_output_size_and_finite() {
1534        let n = 12;
1535        let (mag, _) = make_sphere_phantom(n, true);
1536        let result = makehomogeneous(&mag, n, n, n, 1.0, 1.0, 1.0, 4.0, 5);
1537        assert_eq!(result.len(), n * n * n);
1538        // All output values should be finite
1539        for (i, &v) in result.iter().enumerate() {
1540            assert!(v.is_finite(), "Output at {} should be finite, got {}", i, v);
1541        }
1542    }
1543
1544    #[test]
1545    fn test_makehomogeneous_reduces_bias() {
1546        let n = 12;
1547        let (mag_biased, mask) = make_sphere_phantom(n, true);
1548        let result = makehomogeneous(&mag_biased, n, n, n, 1.0, 1.0, 1.0, 4.0, 5);
1549
1550        // Collect values inside the sphere
1551        let mut original_vals = Vec::new();
1552        let mut corrected_vals = Vec::new();
1553        for i in 0..(n * n * n) {
1554            if mask[i] > 0 {
1555                original_vals.push(mag_biased[i]);
1556                corrected_vals.push(result[i]);
1557            }
1558        }
1559
1560        // The coefficient of variation should be reduced (or at least not much worse)
1561        let orig_mean = original_vals.iter().sum::<f64>() / original_vals.len() as f64;
1562        let orig_std = (original_vals
1563            .iter()
1564            .map(|v| (v - orig_mean).powi(2))
1565            .sum::<f64>()
1566            / original_vals.len() as f64)
1567            .sqrt();
1568        let orig_cv = orig_std / orig_mean;
1569
1570        let corr_mean = corrected_vals.iter().sum::<f64>() / corrected_vals.len() as f64;
1571        let corr_std = (corrected_vals
1572            .iter()
1573            .map(|v| (v - corr_mean).powi(2))
1574            .sum::<f64>()
1575            / corrected_vals.len() as f64)
1576            .sqrt();
1577        let corr_cv = corr_std / corr_mean;
1578
1579        assert!(
1580            corr_cv <= orig_cv * 2.0,
1581            "Corrected CV ({}) should not be much worse than original CV ({})",
1582            corr_cv,
1583            orig_cv
1584        );
1585    }
1586
1587    #[test]
1588    fn test_makehomogeneous_no_bias_preserves() {
1589        let n = 12;
1590        let (mag, mask) = make_sphere_phantom(n, false);
1591        let result = makehomogeneous(&mag, n, n, n, 1.0, 1.0, 1.0, 4.0, 5);
1592
1593        // With no bias, corrected values should still be positive inside sphere
1594        for i in 0..(n * n * n) {
1595            if mask[i] > 0 {
1596                assert!(
1597                    result[i] > 0.0,
1598                    "Inside sphere voxel {} should be positive, got {}",
1599                    i,
1600                    result[i]
1601                );
1602            }
1603        }
1604    }
1605
1606    #[test]
1607    fn test_makehomogeneous_anisotropic_voxels() {
1608        let n = 12;
1609        let (mag, _) = make_sphere_phantom(n, true);
1610        // Use anisotropic voxel sizes
1611        let result = makehomogeneous(&mag, n, n, n, 0.5, 0.5, 2.0, 4.0, 5);
1612        assert_eq!(result.len(), n * n * n);
1613        for &v in &result {
1614            assert!(v.is_finite(), "All output values should be finite");
1615        }
1616    }
1617
1618    // =====================================================================
1619    // Tests for flood_fill_component
1620    // =====================================================================
1621
1622    #[test]
1623    fn test_flood_fill_component_single() {
1624        let mask = vec![1u8; 27]; // 3x3x3 all filled => no zeros
1625        let mut visited = vec![false; 27];
1626        // Start at a filled voxel -> should find nothing (mask[idx] != 0)
1627        let component = flood_fill_component(&mask, &mut visited, 0, 3, 3, 3);
1628        assert!(component.is_empty(), "All filled mask should have no zero component");
1629    }
1630
1631    #[test]
1632    fn test_flood_fill_component_connected_zeros() {
1633        let n = 5;
1634        let n_total = n * n * n;
1635        let mut mask = vec![1u8; n_total];
1636        // Create a connected line of zeros along x at j=2, k=2
1637        for i in 1..4 {
1638            mask[i + 2 * n + 2 * n * n] = 0;
1639        }
1640        let mut visited = vec![false; n_total];
1641        let start = 1 + 2 * n + 2 * n * n;
1642        let component = flood_fill_component(&mask, &mut visited, start, n, n, n);
1643        assert_eq!(component.len(), 3, "Should find 3 connected zero voxels");
1644    }
1645
1646    // =====================================================================
1647    // Tests for gaussian_smooth_3d with reversed passes (mask + even pass)
1648    // =====================================================================
1649
1650    #[test]
1651    fn test_gaussian_smooth_3d_mask_reverse_passes() {
1652        // Exercise the reverse pass code paths (mask.is_some() && ibox % 2 == 1)
1653        let n = 10;
1654        let (mag, mask) = make_sphere_phantom(n, false);
1655        // nbox=4 ensures we have even passes (ibox=1,3)
1656        let result = gaussian_smooth_3d(&mag, [1.5, 1.5, 1.5], Some(&mask), None, 4, n, n, n);
1657        assert_eq!(result.len(), n * n * n);
1658        // Inside mask, values should still be finite
1659        for i in 0..result.len() {
1660            if mask[i] > 0 {
1661                // NaN smoothing can produce NaN near edges, but center should be finite
1662            }
1663        }
1664        // At least some values should be finite
1665        let finite_count = result.iter().filter(|v| v.is_finite()).count();
1666        assert!(finite_count > 0, "Should have some finite values");
1667    }
1668}