Skip to main content

qsm_core/utils/
multi_echo.rs

1//! Multi-echo phase combination utilities
2//!
3//! Implements MCPC-3D-S (Multi-Channel Phase Combination - 3D - Smoothed) algorithm
4//! and weighted B0 calculation.
5//!
6//! Reference:
7//! Eckstein, K., Dymerska, B., Bachrata, B., Bogner, W., Poljanc, K., Trattnig, S.,
8//! Robinson, S.D. (2018). "Computationally Efficient Combination of Multi-channel Phase
9//! Data From Multi-echo Acquisitions (ASPIRE)."
10//! Magnetic Resonance in Medicine, 79:2996-3006. https://doi.org/10.1002/mrm.26963
11//!
12//! Reference implementation: https://github.com/korbinian90/MriResearchTools.jl
13
14/// Parameters for MCPC-3D-S phase combination.
15#[derive(Clone, Debug)]
16pub struct Mcpc3dsParams {
17    /// Gaussian smoothing sigma in voxels [x, y, z] for phase offset estimation
18    pub sigma: [f64; 3],
19}
20
21impl Default for Mcpc3dsParams {
22    fn default() -> Self {
23        Self {
24            sigma: [4.0, 4.0, 4.0],
25        }
26    }
27}
28
29/// Parameters for multi-echo linear fit.
30#[derive(Clone, Debug)]
31pub struct LinearFitParams {
32    /// Estimate and remove constant phase offset
33    pub estimate_offset: bool,
34    /// Percentile threshold for reliability-based voxel exclusion (degrees)
35    pub reliability_threshold_percentile: f64,
36}
37
38impl Default for LinearFitParams {
39    fn default() -> Self {
40        Self {
41            estimate_offset: true,
42            reliability_threshold_percentile: 90.0,
43        }
44    }
45}
46
47use std::f64::consts::PI;
48use crate::unwrap::romeo::calculate_weights_romeo;
49use crate::region_grow::grow_region_unwrap;
50
51const TWO_PI: f64 = 2.0 * PI;
52
53/// Wrap angle to [-π, π]
54#[inline]
55fn wrap_to_pi(angle: f64) -> f64 {
56    let mut a = angle % TWO_PI;
57    if a > PI {
58        a -= TWO_PI;
59    } else if a < -PI {
60        a += TWO_PI;
61    }
62    a
63}
64
65/// Index into 3D array (Fortran/column-major order)
66#[inline(always)]
67fn idx3d(i: usize, j: usize, k: usize, nx: usize, ny: usize) -> usize {
68    i + j * nx + k * nx * ny
69}
70
71/// B0 weighting types matching MriResearchTools.jl
72#[derive(Clone, Copy, Debug, PartialEq)]
73pub enum B0WeightType {
74    /// mag * TE - optimal for phase SNR (default)
75    PhaseSNR,
76    /// mag² * TE² - based on phase variance
77    PhaseVar,
78    /// Uniform weights
79    Average,
80    /// TE only
81    TEs,
82    /// Magnitude only
83    Mag,
84}
85
86impl B0WeightType {
87    pub fn from_str(s: &str) -> Self {
88        match s.to_lowercase().as_str() {
89            "phase_snr" | "phasesnr" => B0WeightType::PhaseSNR,
90            "phase_var" | "phasevar" => B0WeightType::PhaseVar,
91            "average" | "uniform" => B0WeightType::Average,
92            "tes" | "te" => B0WeightType::TEs,
93            "mag" | "magnitude" => B0WeightType::Mag,
94            _ => B0WeightType::PhaseSNR, // default
95        }
96    }
97}
98
99/// 3D Gaussian smoothing for phase data (handles phase wrapping)
100///
101/// Implements gaussiansmooth3d_phase from MriResearchTools.jl
102/// Uses separable Gaussian filtering with phase-aware averaging
103///
104/// # Arguments
105/// * `phase` - Input phase data (nx * ny * nz)
106/// * `sigma` - Smoothing sigma in voxels [sx, sy, sz]
107/// * `mask` - Binary mask (1 = include, 0 = exclude)
108/// * `nx`, `ny`, `nz` - Dimensions
109///
110/// # Returns
111/// Smoothed phase data
112pub fn gaussian_smooth_3d_phase(
113    phase: &[f64],
114    sigma: [f64; 3],
115    mask: &[u8],
116    nx: usize, ny: usize, nz: usize,
117) -> Vec<f64> {
118    let n_total = nx * ny * nz;
119
120    // For phase smoothing, we smooth the complex representation
121    // and extract the angle to handle wrapping correctly
122    let mut real = vec![0.0; n_total];
123    let mut imag = vec![0.0; n_total];
124
125    // Convert phase to complex (unit vectors)
126    for i in 0..n_total {
127        if mask[i] > 0 {
128            real[i] = phase[i].cos();
129            imag[i] = phase[i].sin();
130        }
131    }
132
133    // Apply separable Gaussian smoothing to real and imaginary parts
134    let real_smoothed = gaussian_smooth_3d_separable(&real, sigma, mask, nx, ny, nz);
135    let imag_smoothed = gaussian_smooth_3d_separable(&imag, sigma, mask, nx, ny, nz);
136
137    // Convert back to phase
138    let mut result = vec![0.0; n_total];
139    for i in 0..n_total {
140        if mask[i] > 0 {
141            result[i] = imag_smoothed[i].atan2(real_smoothed[i]);
142        }
143    }
144
145    result
146}
147
148/// Separable 3D Gaussian smoothing
149fn gaussian_smooth_3d_separable(
150    data: &[f64],
151    sigma: [f64; 3],
152    mask: &[u8],
153    nx: usize, ny: usize, nz: usize,
154) -> Vec<f64> {
155    let n_total = nx * ny * nz;
156    let mut result = data.to_vec();
157    let mut temp = vec![0.0; n_total];
158
159    // X direction
160    if sigma[0] > 0.0 {
161        let kernel = make_gaussian_kernel(sigma[0]);
162        let half = kernel.len() / 2;
163
164        for k in 0..nz {
165            for j in 0..ny {
166                for i in 0..nx {
167                    let idx = idx3d(i, j, k, nx, ny);
168                    if mask[idx] == 0 {
169                        temp[idx] = 0.0;
170                        continue;
171                    }
172
173                    let mut sum = 0.0;
174                    let mut weight_sum = 0.0;
175
176                    for (ki, &kv) in kernel.iter().enumerate() {
177                        let ii = i as isize + ki as isize - half as isize;
178                        if ii >= 0 && ii < nx as isize {
179                            let nidx = idx3d(ii as usize, j, k, nx, ny);
180                            if mask[nidx] > 0 {
181                                sum += result[nidx] * kv;
182                                weight_sum += kv;
183                            }
184                        }
185                    }
186
187                    temp[idx] = if weight_sum > 0.0 { sum / weight_sum } else { 0.0 };
188                }
189            }
190        }
191        std::mem::swap(&mut result, &mut temp);
192    }
193
194    // Y direction
195    if sigma[1] > 0.0 {
196        let kernel = make_gaussian_kernel(sigma[1]);
197        let half = kernel.len() / 2;
198
199        for k in 0..nz {
200            for j in 0..ny {
201                for i in 0..nx {
202                    let idx = idx3d(i, j, k, nx, ny);
203                    if mask[idx] == 0 {
204                        temp[idx] = 0.0;
205                        continue;
206                    }
207
208                    let mut sum = 0.0;
209                    let mut weight_sum = 0.0;
210
211                    for (ki, &kv) in kernel.iter().enumerate() {
212                        let jj = j as isize + ki as isize - half as isize;
213                        if jj >= 0 && jj < ny as isize {
214                            let nidx = idx3d(i, jj as usize, k, nx, ny);
215                            if mask[nidx] > 0 {
216                                sum += result[nidx] * kv;
217                                weight_sum += kv;
218                            }
219                        }
220                    }
221
222                    temp[idx] = if weight_sum > 0.0 { sum / weight_sum } else { 0.0 };
223                }
224            }
225        }
226        std::mem::swap(&mut result, &mut temp);
227    }
228
229    // Z direction
230    if sigma[2] > 0.0 {
231        let kernel = make_gaussian_kernel(sigma[2]);
232        let half = kernel.len() / 2;
233
234        for k in 0..nz {
235            for j in 0..ny {
236                for i in 0..nx {
237                    let idx = idx3d(i, j, k, nx, ny);
238                    if mask[idx] == 0 {
239                        temp[idx] = 0.0;
240                        continue;
241                    }
242
243                    let mut sum = 0.0;
244                    let mut weight_sum = 0.0;
245
246                    for (ki, &kv) in kernel.iter().enumerate() {
247                        let kk = k as isize + ki as isize - half as isize;
248                        if kk >= 0 && kk < nz as isize {
249                            let nidx = idx3d(i, j, kk as usize, nx, ny);
250                            if mask[nidx] > 0 {
251                                sum += result[nidx] * kv;
252                                weight_sum += kv;
253                            }
254                        }
255                    }
256
257                    temp[idx] = if weight_sum > 0.0 { sum / weight_sum } else { 0.0 };
258                }
259            }
260        }
261        std::mem::swap(&mut result, &mut temp);
262    }
263
264    result
265}
266
267/// Create 1D Gaussian kernel
268fn make_gaussian_kernel(sigma: f64) -> Vec<f64> {
269    let radius = (3.0 * sigma).ceil() as usize;
270    let size = 2 * radius + 1;
271    let mut kernel = vec![0.0; size];
272
273    let two_sigma_sq = 2.0 * sigma * sigma;
274    let mut sum = 0.0;
275
276    for i in 0..size {
277        let x = i as f64 - radius as f64;
278        kernel[i] = (-x * x / two_sigma_sq).exp();
279        sum += kernel[i];
280    }
281
282    // Normalize
283    for k in kernel.iter_mut() {
284        *k /= sum;
285    }
286
287    kernel
288}
289
290/// Compute Hermitian Inner Product (HIP) between two echoes
291///
292/// HIP = conj(echo1) * echo2 = mag1 * mag2 * exp(i * (phase2 - phase1))
293///
294/// Returns (hip_phase, hip_mag) where:
295/// - hip_phase = phase2 - phase1 (wrapped to [-π, π])
296/// - hip_mag = mag1 * mag2
297pub fn hermitian_inner_product(
298    phase1: &[f64], mag1: &[f64],
299    phase2: &[f64], mag2: &[f64],
300    mask: &[u8],
301    n: usize,
302) -> (Vec<f64>, Vec<f64>) {
303    let mut hip_phase = vec![0.0; n];
304    let mut hip_mag = vec![0.0; n];
305
306    for i in 0..n {
307        if mask[i] > 0 {
308            hip_phase[i] = wrap_to_pi(phase2[i] - phase1[i]);
309            hip_mag[i] = mag1[i] * mag2[i];
310        }
311    }
312
313    (hip_phase, hip_mag)
314}
315
316/// MCPC-3D-S phase offset estimation for single-coil multi-echo data
317///
318/// Implements the MCPC-3D-S algorithm from MriResearchTools.jl for single-coil data.
319/// This estimates and removes the phase offset (φ₀) from each echo.
320///
321/// # Arguments
322/// * `phases` - Phase data for all echoes, shape [n_echoes][nx*ny*nz]
323/// * `mags` - Magnitude data for all echoes, shape [n_echoes][nx*ny*nz]
324/// * `tes` - Echo times in ms
325/// * `mask` - Binary mask
326/// * `sigma` - Smoothing sigma in voxels [sx, sy, sz], default [10, 10, 5]
327/// * `echoes` - Which echoes to use for HIP calculation, default [0, 1] (first two)
328/// * `nx`, `ny`, `nz` - Dimensions
329///
330/// # Returns
331/// (corrected_phases, phase_offset) where:
332/// - corrected_phases: phases with offset removed
333/// - phase_offset: estimated phase offset
334pub fn mcpc3ds_single_coil(
335    phases: &[impl AsRef<[f64]>],
336    mags: &[impl AsRef<[f64]>],
337    tes: &[f64],
338    mask: &[u8],
339    sigma: [f64; 3],
340    echoes: [usize; 2],
341    nx: usize, ny: usize, nz: usize,
342) -> (Vec<Vec<f64>>, Vec<f64>) {
343    let n_echoes = phases.len();
344    let n_total = nx * ny * nz;
345
346    let e1 = echoes[0];
347    let e2 = echoes[1];
348
349    // ΔTE = TEs[echo2] - TEs[echo1]
350    let delta_te = tes[e2] - tes[e1];
351
352    // Compute HIP between the two echoes
353    // HIP = conj(echo1) * echo2, so hip_phase = phase2 - phase1
354    let (hip_phase, hip_mag) = hermitian_inner_product(
355        phases[e1].as_ref(), mags[e1].as_ref(),
356        phases[e2].as_ref(), mags[e2].as_ref(),
357        mask, n_total
358    );
359
360    // Weight for ROMEO = sqrt(|HIP|) - matches Julia: weight = sqrt.(abs.(hip))
361    let weight: Vec<f64> = hip_mag.iter().map(|&x| x.sqrt()).collect();
362    drop(hip_mag); // Free ~82 MB early
363
364    // Unwrap HIP phase using ROMEO (matching Julia line 48)
365    // Julia: phaseevolution = (TEs[echoes[1]] / ΔTE) .* romeo(angle.(hip); mag=weight, mask)
366    let unwrapped_hip = unwrap_with_romeo(&hip_phase, &weight, mask, nx, ny, nz);
367    drop(hip_phase); // Free ~82 MB early
368    drop(weight);    // Free ~82 MB early
369
370    // Phase evolution at TE1: (TE1 / ΔTE) * unwrapped_hip
371    // This gives the phase that would have evolved from TE=0 to TE=TE1
372    let scale = tes[e1] / delta_te;
373    let mut phase_offset = vec![0.0; n_total];
374    for i in 0..n_total {
375        if mask[i] > 0 {
376            // Phase offset = phase[echo1] - phase_evolution
377            // phase_evolution = scale * unwrapped_hip
378            // IMPORTANT: Do NOT wrap here! Julia line 49 does raw subtraction
379            phase_offset[i] = phases[e1].as_ref()[i] - scale * unwrapped_hip[i];
380        }
381    }
382    drop(unwrapped_hip); // Free ~82 MB early
383
384    // Smooth the phase offset (handles wrapping via complex representation)
385    // Julia line 51: po[:,:,:,icha] .= gaussiansmooth3d_phase(view(po,:,:,:,icha), sigma; mask)
386    let phase_offset_smoothed = gaussian_smooth_3d_phase(&phase_offset, sigma, mask, nx, ny, nz);
387    drop(phase_offset); // Free ~82 MB early
388
389    // Remove phase offset from all echoes
390    // Julia combinewithPO does: exp.(1im .* (phase - po)) then angle()
391    // This is equivalent to wrap_to_pi(phase - po)
392    let mut corrected_phases = Vec::with_capacity(n_echoes);
393    for e in 0..n_echoes {
394        let mut corrected = vec![0.0; n_total];
395        for i in 0..n_total {
396            if mask[i] > 0 {
397                corrected[i] = wrap_to_pi(phases[e].as_ref()[i] - phase_offset_smoothed[i]);
398            }
399        }
400        corrected_phases.push(corrected);
401    }
402
403    (corrected_phases, phase_offset_smoothed)
404}
405
406/// Unwrap phase using ROMEO algorithm
407fn unwrap_with_romeo(
408    phase: &[f64],
409    mag: &[f64],
410    mask: &[u8],
411    nx: usize, ny: usize, nz: usize,
412) -> Vec<f64> {
413    // Calculate ROMEO weights (no second echo)
414    let weights = calculate_weights_romeo(
415        phase, mag, None, // No second echo for single phase
416        0.0, 0.0, // TEs not used when phase2 is None
417        mask, nx, ny, nz
418    );
419
420    // Find seed point (center of mass of mask)
421    let (seed_i, seed_j, seed_k) = find_seed_point(mask, nx, ny, nz);
422
423    // Perform region growing unwrap
424    let mut unwrapped = phase.to_vec();
425    let mut work_mask = mask.to_vec();
426
427    grow_region_unwrap(
428        &mut unwrapped, &weights, &mut work_mask,
429        nx, ny, nz, seed_i, seed_j, seed_k
430    );
431
432    unwrapped
433}
434
435/// Find a good seed point (center of mass of the mask)
436fn find_seed_point(mask: &[u8], nx: usize, ny: usize, nz: usize) -> (usize, usize, usize) {
437    let mut sum_i = 0usize;
438    let mut sum_j = 0usize;
439    let mut sum_k = 0usize;
440    let mut count = 0usize;
441
442    for k in 0..nz {
443        for j in 0..ny {
444            for i in 0..nx {
445                let idx = idx3d(i, j, k, nx, ny);
446                if mask[idx] > 0 {
447                    sum_i += i;
448                    sum_j += j;
449                    sum_k += k;
450                    count += 1;
451                }
452            }
453        }
454    }
455
456    if count == 0 {
457        return (nx / 2, ny / 2, nz / 2);
458    }
459
460    (sum_i / count, sum_j / count, sum_k / count)
461}
462
463/// Calculate B0 field from unwrapped phase using weighted averaging
464///
465/// Implements calculateB0_unwrapped from MriResearchTools.jl
466///
467/// Formula: B0 = (1000 / 2π) * Σ(phase / TE * weight) / Σ(weight)
468///
469/// # Arguments
470/// * `unwrapped_phases` - Unwrapped phase for each echo [n_echoes][nx*ny*nz]
471/// * `mags` - Magnitude for each echo (used for some weighting types)
472/// * `tes` - Echo times in ms
473/// * `mask` - Binary mask
474/// * `weight_type` - Type of weighting to use
475/// * `n_total` - Total number of voxels
476///
477/// # Returns
478/// B0 field in Hz
479pub fn calculate_b0_weighted(
480    unwrapped_phases: &[impl AsRef<[f64]>],
481    mags: &[impl AsRef<[f64]>],
482    tes: &[f64],
483    mask: &[u8],
484    weight_type: B0WeightType,
485    n_total: usize,
486) -> Vec<f64> {
487    let n_echoes = unwrapped_phases.len();
488    let mut b0 = vec![0.0; n_total];
489
490    // Compute inline to avoid allocating per-echo weight arrays
491
492    // B0 = (1000 / 2π) * Σ(phase / TE * weight) / Σ(weight)
493    let scale = 1000.0 / TWO_PI;
494
495    for i in 0..n_total {
496        if mask[i] == 0 {
497            continue;
498        }
499
500        let mut weighted_sum = 0.0;
501        let mut weight_sum = 0.0;
502
503        for e in 0..n_echoes {
504            let te = tes[e];
505            let mag_val = mags[e].as_ref()[i];
506            let phase_over_te = unwrapped_phases[e].as_ref()[i] / te;
507
508            let w = match weight_type {
509                B0WeightType::PhaseSNR => mag_val * te,
510                B0WeightType::PhaseVar => mag_val * mag_val * te * te,
511                B0WeightType::Average => 1.0,
512                B0WeightType::TEs => te,
513                B0WeightType::Mag => mag_val,
514            };
515
516            weighted_sum += phase_over_te * w;
517            weight_sum += w;
518        }
519
520        if weight_sum > 1e-10 {
521            b0[i] = scale * weighted_sum / weight_sum;
522        }
523    }
524
525    b0
526}
527
528/// Full MCPC-3D-S + B0 calculation pipeline
529///
530/// This combines phase offset removal with weighted B0 calculation
531///
532/// # Arguments
533/// * `phases` - Wrapped phase for each echo
534/// * `mags` - Magnitude for each echo
535/// * `tes` - Echo times in ms
536/// * `mask` - Binary mask
537/// * `sigma` - Smoothing sigma for phase offset [sx, sy, sz]
538/// * `weight_type` - B0 weighting type
539/// * `nx`, `ny`, `nz` - Dimensions
540///
541/// # Returns
542/// (b0_hz, phase_offset, corrected_phases)
543pub fn mcpc3ds_b0_pipeline(
544    phases: &[impl AsRef<[f64]>],
545    mags: &[impl AsRef<[f64]>],
546    tes: &[f64],
547    mask: &[u8],
548    sigma: [f64; 3],
549    weight_type: B0WeightType,
550    nx: usize, ny: usize, nz: usize,
551) -> (Vec<f64>, Vec<f64>, Vec<Vec<f64>>) {
552    let n_total = nx * ny * nz;
553    let n_echoes = phases.len();
554
555    // Step 1: MCPC-3D-S to remove phase offset
556    let (corrected_phases, phase_offset) = mcpc3ds_single_coil(
557        phases, mags, tes, mask,
558        sigma, [0, 1], // use first two echoes
559        nx, ny, nz
560    );
561
562    // Step 2: Unwrap the corrected phases using ROMEO
563    // Each echo needs to be unwrapped independently
564    let mut unwrapped_phases = Vec::with_capacity(n_echoes);
565    for e in 0..n_echoes {
566        let unwrapped = unwrap_with_romeo(&corrected_phases[e], mags[e].as_ref(), mask, nx, ny, nz);
567        unwrapped_phases.push(unwrapped);
568    }
569
570    // Step 3: Align echoes to remove 2π ambiguities
571    // Use first echo as reference
572    for e in 1..n_echoes {
573        let te_ratio = tes[e] / tes[0];
574
575        // Calculate mean difference
576        let mut sum_diff = 0.0;
577        let mut count = 0;
578        for i in 0..n_total {
579            if mask[i] > 0 {
580                let expected = unwrapped_phases[0][i] * te_ratio;
581                sum_diff += unwrapped_phases[e][i] - expected;
582                count += 1;
583            }
584        }
585
586        if count > 0 {
587            let mean_diff = sum_diff / count as f64;
588            let correction = (mean_diff / TWO_PI).round() * TWO_PI;
589
590            if correction.abs() > 0.1 {
591                for i in 0..n_total {
592                    if mask[i] > 0 {
593                        unwrapped_phases[e][i] -= correction;
594                    }
595                }
596            }
597        }
598    }
599
600    // Step 4: Calculate B0 with weighted averaging
601    let b0 = calculate_b0_weighted(
602        &unwrapped_phases, mags, tes, mask,
603        weight_type, n_total
604    );
605
606    (b0, phase_offset, corrected_phases)
607}
608
609//=============================================================================
610// Multi-Echo Linear Fit
611//=============================================================================
612
613/// Result of multi-echo linear fit
614pub struct LinearFitResult {
615    /// Field map (slope) in rad/s (divide by 2π for Hz)
616    pub field: Vec<f64>,
617    /// Phase offset (intercept) in radians
618    pub phase_offset: Vec<f64>,
619    /// Fit residual (normalized by magnitude sum)
620    pub fit_residual: Vec<f64>,
621    /// Reliability mask (1 = reliable, 0 = unreliable)
622    pub reliability_mask: Vec<u8>,
623}
624
625/// Multi-echo linear fit with magnitude weighting
626///
627/// Fits a linear model: phase = slope * TE + intercept
628/// using weighted least squares with magnitude as weights.
629///
630/// Based on QSM.jl multi_echo_linear_fit and QSMART echofit.m
631///
632/// # Arguments
633/// * `unwrapped_phases` - Unwrapped phase for each echo [n_echoes][nx*ny*nz]
634/// * `mags` - Magnitude for each echo [n_echoes][nx*ny*nz]
635/// * `tes` - Echo times in seconds
636/// * `mask` - Binary mask
637/// * `estimate_offset` - If true, estimate phase offset (intercept)
638/// * `reliability_threshold_percentile` - Percentile for reliability masking (0-100, 0=disable)
639///
640/// # Returns
641/// LinearFitResult containing field, phase_offset, fit_residual, reliability_mask
642pub fn multi_echo_linear_fit(
643    unwrapped_phases: &[impl AsRef<[f64]>],
644    mags: &[impl AsRef<[f64]>],
645    tes: &[f64],
646    mask: &[u8],
647    estimate_offset: bool,
648    reliability_threshold_percentile: f64,
649) -> LinearFitResult {
650    let n_echoes = unwrapped_phases.len();
651    let n_total = unwrapped_phases[0].as_ref().len();
652
653    let mut field = vec![0.0; n_total];
654    let mut phase_offset = vec![0.0; n_total];
655    let mut fit_residual = vec![0.0; n_total];
656
657    if estimate_offset {
658        // Weighted linear fit with intercept: phase = α + β * TE
659        // Using centered data approach for numerical stability
660        //
661        // β = Σ w*(TE - TE_mean)*(phase - phase_mean) / Σ w*(TE - TE_mean)²
662        // α = phase_mean - β * TE_mean (weighted means)
663
664        // Precompute weighted TE mean and sum of squared deviations
665        // (These are per-voxel because weights vary)
666        for v in 0..n_total {
667            if mask[v] == 0 {
668                continue;
669            }
670
671            // Compute weighted means
672            let mut sum_w = 0.0;
673            let mut sum_w_te = 0.0;
674            let mut sum_w_phase = 0.0;
675
676            for e in 0..n_echoes {
677                let w = mags[e].as_ref()[v];
678                sum_w += w;
679                sum_w_te += w * tes[e];
680                sum_w_phase += w * unwrapped_phases[e].as_ref()[v];
681            }
682
683            if sum_w < 1e-10 {
684                continue;
685            }
686
687            let te_mean = sum_w_te / sum_w;
688            let phase_mean = sum_w_phase / sum_w;
689
690            // Compute slope using centered data
691            let mut sum_w_te_centered_sq = 0.0;
692            let mut sum_w_te_centered_phase_centered = 0.0;
693
694            for e in 0..n_echoes {
695                let w = mags[e].as_ref()[v];
696                let te_centered = tes[e] - te_mean;
697                let phase_centered = unwrapped_phases[e].as_ref()[v] - phase_mean;
698                sum_w_te_centered_sq += w * te_centered * te_centered;
699                sum_w_te_centered_phase_centered += w * te_centered * phase_centered;
700            }
701
702            if sum_w_te_centered_sq > 1e-10 {
703                let slope = sum_w_te_centered_phase_centered / sum_w_te_centered_sq;
704                let intercept = phase_mean - slope * te_mean;
705                field[v] = slope;
706                phase_offset[v] = intercept;
707
708                // Compute weighted residual
709                let mut sum_w_resid_sq = 0.0;
710                for e in 0..n_echoes {
711                    let w = mags[e].as_ref()[v];
712                    let predicted = intercept + slope * tes[e];
713                    let diff = unwrapped_phases[e].as_ref()[v] - predicted;
714                    sum_w_resid_sq += w * diff * diff;
715                }
716                // Normalize by sum of weights and number of echoes (matching echofit.m)
717                fit_residual[v] = sum_w_resid_sq / sum_w * n_echoes as f64;
718            }
719        }
720    } else {
721        // Weighted linear fit through origin: phase = β * TE
722        // β = Σ w*TE*phase / Σ w*TE²
723        // (matching echofit.m line 40)
724
725        for v in 0..n_total {
726            if mask[v] == 0 {
727                continue;
728            }
729
730            let mut sum_w_te_phase = 0.0;
731            let mut sum_w_te_sq = 0.0;
732            let mut sum_w = 0.0;
733
734            for e in 0..n_echoes {
735                let w = mags[e].as_ref()[v];
736                let te = tes[e];
737                let phase = unwrapped_phases[e].as_ref()[v];
738                sum_w_te_phase += w * te * phase;
739                sum_w_te_sq += w * te * te;
740                sum_w += w;
741            }
742
743            if sum_w_te_sq > 1e-10 {
744                let slope = sum_w_te_phase / sum_w_te_sq;
745                field[v] = slope;
746
747                // Compute weighted residual
748                let mut sum_w_resid_sq = 0.0;
749                for e in 0..n_echoes {
750                    let w = mags[e].as_ref()[v];
751                    let predicted = slope * tes[e];
752                    let diff = unwrapped_phases[e].as_ref()[v] - predicted;
753                    sum_w_resid_sq += w * diff * diff;
754                }
755                // Normalize by sum of weights and number of echoes
756                if sum_w > 1e-10 {
757                    fit_residual[v] = sum_w_resid_sq / sum_w * n_echoes as f64;
758                }
759            }
760        }
761    }
762
763    // Create reliability mask based on fit residuals
764    let reliability_mask = if reliability_threshold_percentile > 0.0 {
765        compute_reliability_mask(&fit_residual, mask, reliability_threshold_percentile)
766    } else {
767        // All masked voxels are reliable
768        mask.to_vec()
769    };
770
771    LinearFitResult {
772        field,
773        phase_offset,
774        fit_residual,
775        reliability_mask,
776    }
777}
778
779/// Compute reliability mask by thresholding fit residuals
780///
781/// Applies Gaussian smoothing to residuals before thresholding (matching echofit.m)
782fn compute_reliability_mask(
783    fit_residual: &[f64],
784    mask: &[u8],
785    threshold_percentile: f64,
786) -> Vec<u8> {
787    let n_total = fit_residual.len();
788
789    // Collect non-zero residuals for percentile calculation
790    let mut residuals: Vec<f64> = fit_residual.iter()
791        .enumerate()
792        .filter(|(i, &r)| mask[*i] > 0 && r > 0.0 && r.is_finite())
793        .map(|(_, &r)| r)
794        .collect();
795
796    if residuals.is_empty() {
797        return mask.to_vec();
798    }
799
800    // Sort and find threshold at given percentile
801    residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
802    let percentile_idx = ((threshold_percentile / 100.0) * residuals.len() as f64) as usize;
803    let threshold = residuals[percentile_idx.min(residuals.len() - 1)];
804
805    // Create reliability mask
806    let mut reliability = vec![0u8; n_total];
807    for i in 0..n_total {
808        if mask[i] > 0 && fit_residual[i] < threshold {
809            reliability[i] = 1;
810        }
811    }
812
813    reliability
814}
815
816/// Convert field from rad/s to Hz
817#[inline]
818pub fn field_to_hz(field: &[f64]) -> Vec<f64> {
819    field.iter().map(|&f| f / TWO_PI).collect()
820}
821
822#[cfg(test)]
823mod tests {
824    use super::*;
825
826    #[test]
827    fn test_wrap_to_pi() {
828        assert!((wrap_to_pi(0.0) - 0.0).abs() < 1e-10);
829        assert!((wrap_to_pi(PI) - PI).abs() < 1e-10);
830        assert!((wrap_to_pi(-PI) - (-PI)).abs() < 1e-10);
831        assert!((wrap_to_pi(3.0 * PI) - PI).abs() < 1e-10);
832        assert!((wrap_to_pi(-3.0 * PI) - (-PI)).abs() < 1e-10);
833    }
834
835    #[test]
836    fn test_gaussian_kernel() {
837        let kernel = make_gaussian_kernel(1.0);
838        let sum: f64 = kernel.iter().sum();
839        assert!((sum - 1.0).abs() < 1e-10);
840    }
841
842    #[test]
843    fn test_hip() {
844        let n = 8;
845        let phase1 = vec![0.1; n];
846        let phase2 = vec![0.3; n];
847        let mag1 = vec![1.0; n];
848        let mag2 = vec![1.0; n];
849        let mask = vec![1u8; n];
850
851        let (hip_phase, hip_mag) = hermitian_inner_product(&phase1, &mag1, &phase2, &mag2, &mask, n);
852
853        for i in 0..n {
854            assert!((hip_phase[i] - 0.2).abs() < 1e-10);
855            assert!((hip_mag[i] - 1.0).abs() < 1e-10);
856        }
857    }
858
859    // =========================================================================
860    // Helper to build synthetic multi-echo data on a small 3D grid
861    // =========================================================================
862
863    /// Build synthetic multi-echo phase/magnitude data.
864    ///
865    /// The phase at each voxel is: phase_offset + slope * TE
866    /// where `slope` is a spatially-varying linear ramp along x.
867    /// Magnitude is uniform (1.0) inside the mask.
868    fn make_synthetic_multi_echo(
869        nx: usize, ny: usize, nz: usize,
870        tes: &[f64],
871    ) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<u8>) {
872        let n = nx * ny * nz;
873        let n_echoes = tes.len();
874
875        // Constant phase offset (small, well within [-pi, pi])
876        let phase_offset_val = 0.3;
877        // Slope (rad/ms) as a function of x: gentle ramp so phases stay in [-pi, pi]
878        // Max slope = 0.05 rad/ms at x=nx-1, so max phase ~ 0.3 + 0.05*7*15 = 5.55
879        // which will wrap but that is fine.
880        let slope_scale = 0.05;
881
882        let mask = vec![1u8; n];
883        let mut phases: Vec<Vec<f64>> = Vec::with_capacity(n_echoes);
884        let mut mags: Vec<Vec<f64>> = Vec::with_capacity(n_echoes);
885
886        for e in 0..n_echoes {
887            let mut p = vec![0.0; n];
888            let m = vec![1.0; n]; // uniform magnitude
889            for k in 0..nz {
890                for j in 0..ny {
891                    for i in 0..nx {
892                        let idx = idx3d(i, j, k, nx, ny);
893                        let slope = slope_scale * i as f64;
894                        p[idx] = wrap_to_pi(phase_offset_val + slope * tes[e]);
895                    }
896                }
897            }
898            phases.push(p);
899            mags.push(m);
900        }
901
902        (phases, mags, mask)
903    }
904
905    // =========================================================================
906    // idx3d
907    // =========================================================================
908
909    #[test]
910    fn test_idx3d_basic() {
911        assert_eq!(idx3d(0, 0, 0, 4, 4), 0);
912        assert_eq!(idx3d(1, 0, 0, 4, 4), 1);
913        assert_eq!(idx3d(0, 1, 0, 4, 4), 4);
914        assert_eq!(idx3d(0, 0, 1, 4, 4), 16);
915        assert_eq!(idx3d(3, 3, 3, 4, 4), 63);
916    }
917
918    // =========================================================================
919    // B0WeightType::from_str
920    // =========================================================================
921
922    #[test]
923    fn test_b0_weight_type_from_str() {
924        assert_eq!(B0WeightType::from_str("phase_snr"), B0WeightType::PhaseSNR);
925        assert_eq!(B0WeightType::from_str("phasesnr"), B0WeightType::PhaseSNR);
926        assert_eq!(B0WeightType::from_str("PhaseSNR"), B0WeightType::PhaseSNR);
927        assert_eq!(B0WeightType::from_str("phase_var"), B0WeightType::PhaseVar);
928        assert_eq!(B0WeightType::from_str("phasevar"), B0WeightType::PhaseVar);
929        assert_eq!(B0WeightType::from_str("average"), B0WeightType::Average);
930        assert_eq!(B0WeightType::from_str("uniform"), B0WeightType::Average);
931        assert_eq!(B0WeightType::from_str("tes"), B0WeightType::TEs);
932        assert_eq!(B0WeightType::from_str("te"), B0WeightType::TEs);
933        assert_eq!(B0WeightType::from_str("mag"), B0WeightType::Mag);
934        assert_eq!(B0WeightType::from_str("magnitude"), B0WeightType::Mag);
935        // Unknown string should default to PhaseSNR
936        assert_eq!(B0WeightType::from_str("unknown"), B0WeightType::PhaseSNR);
937    }
938
939    // =========================================================================
940    // gaussian_smooth_3d_phase
941    // =========================================================================
942
943    #[test]
944    fn test_gaussian_smooth_3d_phase_uniform_input() {
945        let (nx, ny, nz) = (8, 8, 8);
946        let n = nx * ny * nz;
947        // Uniform phase should remain (approximately) constant after smoothing
948        let phase = vec![1.0; n];
949        let mask = vec![1u8; n];
950        let sigma = [1.0, 1.0, 1.0];
951
952        let smoothed = gaussian_smooth_3d_phase(&phase, sigma, &mask, nx, ny, nz);
953
954        assert_eq!(smoothed.len(), n);
955        for v in &smoothed {
956            assert!(v.is_finite(), "smoothed value must be finite");
957            assert!((v - 1.0).abs() < 0.05, "uniform phase should remain ~1.0, got {}", v);
958        }
959    }
960
961    #[test]
962    fn test_gaussian_smooth_3d_phase_zero_sigma() {
963        let (nx, ny, nz) = (4, 4, 4);
964        let n = nx * ny * nz;
965        let phase: Vec<f64> = (0..n).map(|i| wrap_to_pi(i as f64 * 0.1)).collect();
966        let mask = vec![1u8; n];
967        let sigma = [0.0, 0.0, 0.0];
968
969        let smoothed = gaussian_smooth_3d_phase(&phase, sigma, &mask, nx, ny, nz);
970
971        // With zero sigma, output should equal input (no smoothing applied)
972        assert_eq!(smoothed.len(), n);
973        for i in 0..n {
974            assert!((smoothed[i] - phase[i]).abs() < 1e-10,
975                "zero-sigma smoothing should be identity, voxel {}: got {} expected {}",
976                i, smoothed[i], phase[i]);
977        }
978    }
979
980    #[test]
981    fn test_gaussian_smooth_3d_phase_masked_zeros() {
982        let (nx, ny, nz) = (8, 8, 8);
983        let n = nx * ny * nz;
984        let phase = vec![0.5; n];
985        let mut mask = vec![1u8; n];
986        // Set half the voxels to 0
987        for i in 0..n / 2 {
988            mask[i] = 0;
989        }
990
991        let sigma = [1.0, 1.0, 1.0];
992        let smoothed = gaussian_smooth_3d_phase(&phase, sigma, &mask, nx, ny, nz);
993
994        assert_eq!(smoothed.len(), n);
995        // Masked-out voxels should remain 0
996        for i in 0..n / 2 {
997            assert_eq!(smoothed[i], 0.0, "masked-out voxel {} should be 0", i);
998        }
999        // Masked-in voxels should be finite
1000        for i in n / 2..n {
1001            assert!(smoothed[i].is_finite());
1002        }
1003    }
1004
1005    // =========================================================================
1006    // gaussian_smooth_3d_separable (tested indirectly through phase smoothing
1007    // but let's also exercise multi-axis sigma)
1008    // =========================================================================
1009
1010    #[test]
1011    fn test_gaussian_smooth_anisotropic_sigma() {
1012        let (nx, ny, nz) = (8, 8, 8);
1013        let n = nx * ny * nz;
1014        let phase: Vec<f64> = (0..n).map(|i| wrap_to_pi(0.3 * (i as f64))).collect();
1015        let mask = vec![1u8; n];
1016        let sigma = [2.0, 0.5, 1.0]; // anisotropic
1017
1018        let smoothed = gaussian_smooth_3d_phase(&phase, sigma, &mask, nx, ny, nz);
1019        assert_eq!(smoothed.len(), n);
1020        for v in &smoothed {
1021            assert!(v.is_finite());
1022            assert!(*v >= -PI && *v <= PI, "smoothed phase should be in [-pi, pi], got {}", v);
1023        }
1024    }
1025
1026    // =========================================================================
1027    // make_gaussian_kernel
1028    // =========================================================================
1029
1030    #[test]
1031    fn test_gaussian_kernel_symmetry() {
1032        let kernel = make_gaussian_kernel(2.0);
1033        let len = kernel.len();
1034        for i in 0..len / 2 {
1035            assert!((kernel[i] - kernel[len - 1 - i]).abs() < 1e-12,
1036                "kernel should be symmetric");
1037        }
1038    }
1039
1040    #[test]
1041    fn test_gaussian_kernel_peak_at_center() {
1042        let kernel = make_gaussian_kernel(1.5);
1043        let center = kernel.len() / 2;
1044        for (i, &v) in kernel.iter().enumerate() {
1045            if i != center {
1046                assert!(v <= kernel[center], "center should be peak");
1047            }
1048        }
1049    }
1050
1051    // =========================================================================
1052    // hermitian_inner_product (additional tests)
1053    // =========================================================================
1054
1055    #[test]
1056    fn test_hip_with_mask() {
1057        let n = 4;
1058        let phase1 = vec![0.5; n];
1059        let phase2 = vec![1.0; n];
1060        let mag1 = vec![2.0; n];
1061        let mag2 = vec![3.0; n];
1062        let mask = vec![1, 0, 1, 0];
1063
1064        let (hip_phase, hip_mag) = hermitian_inner_product(
1065            &phase1, &mag1, &phase2, &mag2, &mask, n
1066        );
1067
1068        // Masked-in voxels
1069        assert!((hip_phase[0] - 0.5).abs() < 1e-10);
1070        assert!((hip_mag[0] - 6.0).abs() < 1e-10);
1071        assert!((hip_phase[2] - 0.5).abs() < 1e-10);
1072        assert!((hip_mag[2] - 6.0).abs() < 1e-10);
1073
1074        // Masked-out voxels
1075        assert_eq!(hip_phase[1], 0.0);
1076        assert_eq!(hip_mag[1], 0.0);
1077        assert_eq!(hip_phase[3], 0.0);
1078        assert_eq!(hip_mag[3], 0.0);
1079    }
1080
1081    #[test]
1082    fn test_hip_wrapping() {
1083        // Test that phase difference wraps correctly
1084        let n = 1;
1085        let phase1 = vec![PI - 0.1];
1086        let phase2 = vec![-PI + 0.1];
1087        let mag1 = vec![1.0];
1088        let mag2 = vec![1.0];
1089        let mask = vec![1u8];
1090
1091        let (hip_phase, _) = hermitian_inner_product(
1092            &phase1, &mag1, &phase2, &mag2, &mask, n
1093        );
1094
1095        // phase2 - phase1 = (-PI + 0.1) - (PI - 0.1) = -2PI + 0.2 -> wraps to 0.2
1096        assert!((hip_phase[0] - 0.2).abs() < 1e-10,
1097            "HIP should wrap phase difference, got {}", hip_phase[0]);
1098    }
1099
1100    // =========================================================================
1101    // find_seed_point
1102    // =========================================================================
1103
1104    #[test]
1105    fn test_find_seed_point_full_mask() {
1106        let (nx, ny, nz) = (8, 8, 8);
1107        let mask = vec![1u8; nx * ny * nz];
1108        let (si, sj, sk) = find_seed_point(&mask, nx, ny, nz);
1109        // Center of mass of a fully-filled cube should be approximately center
1110        assert_eq!(si, 3); // mean of 0..7 = 3.5, integer division = 3
1111        assert_eq!(sj, 3);
1112        assert_eq!(sk, 3);
1113    }
1114
1115    #[test]
1116    fn test_find_seed_point_empty_mask() {
1117        let (nx, ny, nz) = (8, 8, 8);
1118        let mask = vec![0u8; nx * ny * nz];
1119        let (si, sj, sk) = find_seed_point(&mask, nx, ny, nz);
1120        // Fallback: center of volume
1121        assert_eq!(si, 4);
1122        assert_eq!(sj, 4);
1123        assert_eq!(sk, 4);
1124    }
1125
1126    #[test]
1127    fn test_find_seed_point_corner_mask() {
1128        let (nx, ny, nz) = (8, 8, 8);
1129        let mut mask = vec![0u8; nx * ny * nz];
1130        // Only set voxel (0,0,0)
1131        mask[idx3d(0, 0, 0, nx, ny)] = 1;
1132        let (si, sj, sk) = find_seed_point(&mask, nx, ny, nz);
1133        assert_eq!(si, 0);
1134        assert_eq!(sj, 0);
1135        assert_eq!(sk, 0);
1136    }
1137
1138    // =========================================================================
1139    // field_to_hz
1140    // =========================================================================
1141
1142    #[test]
1143    fn test_field_to_hz() {
1144        let field = vec![TWO_PI, -TWO_PI, 0.0, PI];
1145        let hz = field_to_hz(&field);
1146        assert!((hz[0] - 1.0).abs() < 1e-10);
1147        assert!((hz[1] - (-1.0)).abs() < 1e-10);
1148        assert!((hz[2] - 0.0).abs() < 1e-10);
1149        assert!((hz[3] - 0.5).abs() < 1e-10);
1150    }
1151
1152    // =========================================================================
1153    // calculate_b0_weighted
1154    // =========================================================================
1155
1156    #[test]
1157    fn test_calculate_b0_weighted_phase_snr() {
1158        // For a constant slope (rad/ms), all weight types should recover it.
1159        // phase[e] = slope * TE[e], so phase/TE = slope for each echo.
1160        // Weighted average of identical values = same value.
1161        // B0 = (1000 / 2pi) * slope (Hz)
1162        let n = 64;
1163        let tes = [5.0, 10.0, 15.0];
1164        let slope = 0.2; // rad/ms
1165        let mask = vec![1u8; n];
1166
1167        let phases: Vec<Vec<f64>> = tes.iter()
1168            .map(|&te| vec![slope * te; n])
1169            .collect();
1170        let mags: Vec<Vec<f64>> = tes.iter()
1171            .map(|_| vec![1.0; n])
1172            .collect();
1173
1174        let b0 = calculate_b0_weighted(&phases, &mags, &tes, &mask, B0WeightType::PhaseSNR, n);
1175
1176        let expected_hz = 1000.0 / TWO_PI * slope;
1177        assert_eq!(b0.len(), n);
1178        for v in &b0 {
1179            assert!(v.is_finite());
1180            assert!((v - expected_hz).abs() < 1e-8,
1181                "expected {} Hz, got {}", expected_hz, v);
1182        }
1183    }
1184
1185    #[test]
1186    fn test_calculate_b0_weighted_all_weight_types() {
1187        let n = 16;
1188        let tes = [5.0, 10.0, 15.0];
1189        let slope = 0.1;
1190        let mask = vec![1u8; n];
1191
1192        let phases: Vec<Vec<f64>> = tes.iter()
1193            .map(|&te| vec![slope * te; n])
1194            .collect();
1195        let mags: Vec<Vec<f64>> = tes.iter()
1196            .map(|_| vec![2.0; n])
1197            .collect();
1198
1199        let expected_hz = 1000.0 / TWO_PI * slope;
1200
1201        for wt in &[
1202            B0WeightType::PhaseSNR,
1203            B0WeightType::PhaseVar,
1204            B0WeightType::Average,
1205            B0WeightType::TEs,
1206            B0WeightType::Mag,
1207        ] {
1208            let b0 = calculate_b0_weighted(&phases, &mags, &tes, &mask, *wt, n);
1209            assert_eq!(b0.len(), n);
1210            for v in &b0 {
1211                assert!(v.is_finite(), "weight type {:?} produced non-finite", wt);
1212                assert!((v - expected_hz).abs() < 1e-8,
1213                    "weight type {:?}: expected {} Hz, got {}", wt, expected_hz, v);
1214            }
1215        }
1216    }
1217
1218    #[test]
1219    fn test_calculate_b0_weighted_masked_out() {
1220        let n = 8;
1221        let tes = [5.0, 10.0];
1222        let mask = vec![0u8; n]; // all masked out
1223
1224        let phases: Vec<Vec<f64>> = tes.iter()
1225            .map(|&te| vec![0.5 * te; n])
1226            .collect();
1227        let mags: Vec<Vec<f64>> = tes.iter()
1228            .map(|_| vec![1.0; n])
1229            .collect();
1230
1231        let b0 = calculate_b0_weighted(&phases, &mags, &tes, &mask, B0WeightType::PhaseSNR, n);
1232
1233        for v in &b0 {
1234            assert_eq!(*v, 0.0, "masked-out voxels should have B0=0");
1235        }
1236    }
1237
1238    #[test]
1239    fn test_calculate_b0_weighted_zero_magnitude() {
1240        // When magnitude is zero, weight is zero; result should be 0
1241        let n = 4;
1242        let tes = [5.0, 10.0, 15.0];
1243        let mask = vec![1u8; n];
1244
1245        let phases: Vec<Vec<f64>> = tes.iter()
1246            .map(|&te| vec![0.2 * te; n])
1247            .collect();
1248        let mags: Vec<Vec<f64>> = tes.iter()
1249            .map(|_| vec![0.0; n]) // zero magnitude
1250            .collect();
1251
1252        // PhaseSNR weight = mag * te = 0
1253        let b0 = calculate_b0_weighted(&phases, &mags, &tes, &mask, B0WeightType::PhaseSNR, n);
1254        for v in &b0 {
1255            assert_eq!(*v, 0.0, "zero-magnitude voxels should yield B0=0");
1256        }
1257
1258        // Average weight = 1.0, should still work
1259        let b0_avg = calculate_b0_weighted(&phases, &mags, &tes, &mask, B0WeightType::Average, n);
1260        let expected = 1000.0 / TWO_PI * 0.2;
1261        for v in &b0_avg {
1262            assert!((v - expected).abs() < 1e-8);
1263        }
1264    }
1265
1266    // =========================================================================
1267    // multi_echo_linear_fit
1268    // =========================================================================
1269
1270    #[test]
1271    fn test_multi_echo_linear_fit_no_offset() {
1272        // phase = slope * TE (no intercept)
1273        // Should recover the slope exactly.
1274        let n = 32;
1275        let tes = [0.005, 0.010, 0.015]; // in seconds
1276        let slope = 100.0; // rad/s
1277        let mask = vec![1u8; n];
1278
1279        let phases: Vec<Vec<f64>> = tes.iter()
1280            .map(|&te| vec![slope * te; n])
1281            .collect();
1282        let mags: Vec<Vec<f64>> = tes.iter()
1283            .map(|_| vec![1.0; n])
1284            .collect();
1285
1286        let result = multi_echo_linear_fit(
1287            &phases, &mags, &tes, &mask,
1288            false, // no offset estimation
1289            0.0,   // no reliability threshold
1290        );
1291
1292        assert_eq!(result.field.len(), n);
1293        assert_eq!(result.phase_offset.len(), n);
1294        assert_eq!(result.fit_residual.len(), n);
1295        assert_eq!(result.reliability_mask.len(), n);
1296
1297        for i in 0..n {
1298            assert!((result.field[i] - slope).abs() < 1e-6,
1299                "slope: expected {}, got {}", slope, result.field[i]);
1300            assert_eq!(result.phase_offset[i], 0.0,
1301                "offset should be 0 when estimate_offset=false");
1302            assert!(result.fit_residual[i] < 1e-10,
1303                "residual should be ~0 for perfect linear data");
1304            assert_eq!(result.reliability_mask[i], 1,
1305                "reliability should match mask when threshold=0");
1306        }
1307    }
1308
1309    #[test]
1310    fn test_multi_echo_linear_fit_with_offset() {
1311        // phase = intercept + slope * TE
1312        let n = 16;
1313        let tes = [0.005, 0.010, 0.015, 0.020];
1314        let slope = 200.0;     // rad/s
1315        let intercept = 0.5;   // rad
1316        let mask = vec![1u8; n];
1317
1318        let phases: Vec<Vec<f64>> = tes.iter()
1319            .map(|&te| vec![intercept + slope * te; n])
1320            .collect();
1321        let mags: Vec<Vec<f64>> = tes.iter()
1322            .map(|_| vec![1.0; n])
1323            .collect();
1324
1325        let result = multi_echo_linear_fit(
1326            &phases, &mags, &tes, &mask,
1327            true, // estimate offset
1328            0.0,
1329        );
1330
1331        for i in 0..n {
1332            assert!((result.field[i] - slope).abs() < 1e-4,
1333                "slope: expected {}, got {}", slope, result.field[i]);
1334            assert!((result.phase_offset[i] - intercept).abs() < 1e-4,
1335                "intercept: expected {}, got {}", intercept, result.phase_offset[i]);
1336            assert!(result.fit_residual[i] < 1e-8,
1337                "residual should be ~0 for perfect linear data, got {}", result.fit_residual[i]);
1338        }
1339    }
1340
1341    #[test]
1342    fn test_multi_echo_linear_fit_masked_out() {
1343        let n = 8;
1344        let tes = [0.005, 0.010, 0.015];
1345        let mask = vec![0u8; n];
1346
1347        let phases: Vec<Vec<f64>> = tes.iter()
1348            .map(|&te| vec![100.0 * te; n])
1349            .collect();
1350        let mags: Vec<Vec<f64>> = tes.iter()
1351            .map(|_| vec![1.0; n])
1352            .collect();
1353
1354        let result = multi_echo_linear_fit(&phases, &mags, &tes, &mask, true, 0.0);
1355
1356        for i in 0..n {
1357            assert_eq!(result.field[i], 0.0);
1358            assert_eq!(result.phase_offset[i], 0.0);
1359            assert_eq!(result.fit_residual[i], 0.0);
1360        }
1361    }
1362
1363    #[test]
1364    fn test_multi_echo_linear_fit_varying_slope() {
1365        // Each voxel has a different slope
1366        let n = 8;
1367        let tes = [0.005, 0.010, 0.015];
1368        let mask = vec![1u8; n];
1369
1370        let slopes: Vec<f64> = (0..n).map(|i| 50.0 * (i as f64 + 1.0)).collect();
1371
1372        let phases: Vec<Vec<f64>> = tes.iter()
1373            .map(|&te| {
1374                slopes.iter().map(|&s| s * te).collect()
1375            })
1376            .collect();
1377        let mags: Vec<Vec<f64>> = tes.iter()
1378            .map(|_| vec![1.0; n])
1379            .collect();
1380
1381        let result = multi_echo_linear_fit(&phases, &mags, &tes, &mask, false, 0.0);
1382
1383        for i in 0..n {
1384            assert!((result.field[i] - slopes[i]).abs() < 1e-6,
1385                "voxel {}: expected slope {}, got {}", i, slopes[i], result.field[i]);
1386        }
1387    }
1388
1389    #[test]
1390    fn test_multi_echo_linear_fit_with_reliability_threshold() {
1391        // Create data where some voxels have noisy fits
1392        let n = 100;
1393        let tes = [0.005, 0.010, 0.015];
1394        let mask = vec![1u8; n];
1395        let slope = 100.0;
1396
1397        let mut phases: Vec<Vec<f64>> = tes.iter()
1398            .map(|&te| vec![slope * te; n])
1399            .collect();
1400        let mags: Vec<Vec<f64>> = tes.iter()
1401            .map(|_| vec![1.0; n])
1402            .collect();
1403
1404        // Add large noise to last 10 voxels to increase their residuals
1405        for e in 0..tes.len() {
1406            for i in 90..100 {
1407                phases[e][i] += if e % 2 == 0 { 2.0 } else { -2.0 };
1408            }
1409        }
1410
1411        // Use 80th percentile threshold
1412        let result = multi_echo_linear_fit(&phases, &mags, &tes, &mask, false, 80.0);
1413
1414        assert_eq!(result.reliability_mask.len(), n);
1415
1416        // Clean voxels (0..90) have residual=0, noisy voxels (90..100) have residual>0.
1417        // The threshold is computed from non-zero residuals only.
1418        // Voxels with residual=0 satisfy 0 < threshold, but compute_reliability_mask
1419        // checks `fit_residual[i] < threshold` -- 0 < any positive threshold => reliable=1.
1420        // However, residual might not be exactly 0 due to floating point.
1421        // Just verify that the noisy voxels have higher residuals than clean ones.
1422        let max_clean_resid = result.fit_residual[0..90].iter()
1423            .cloned().fold(0.0f64, f64::max);
1424        let min_noisy_resid = result.fit_residual[90..100].iter()
1425            .cloned().fold(f64::INFINITY, f64::min);
1426        assert!(max_clean_resid < min_noisy_resid,
1427            "clean residuals ({}) should be less than noisy residuals ({})",
1428            max_clean_resid, min_noisy_resid);
1429
1430        // The reliability mask should exist and have valid values
1431        for &v in &result.reliability_mask {
1432            assert!(v == 0 || v == 1);
1433        }
1434    }
1435
1436    #[test]
1437    fn test_multi_echo_linear_fit_zero_magnitude() {
1438        let n = 4;
1439        let tes = [0.005, 0.010, 0.015];
1440        let mask = vec![1u8; n];
1441
1442        let phases: Vec<Vec<f64>> = tes.iter()
1443            .map(|&te| vec![100.0 * te; n])
1444            .collect();
1445        let mags: Vec<Vec<f64>> = tes.iter()
1446            .map(|_| vec![0.0; n]) // zero magnitude
1447            .collect();
1448
1449        // Should not crash; field should be 0 because sum_w_te_sq ~ 0
1450        let result = multi_echo_linear_fit(&phases, &mags, &tes, &mask, false, 0.0);
1451        for v in &result.field {
1452            assert!(v.is_finite());
1453        }
1454
1455        let result2 = multi_echo_linear_fit(&phases, &mags, &tes, &mask, true, 0.0);
1456        for v in &result2.field {
1457            assert!(v.is_finite());
1458        }
1459    }
1460
1461    // =========================================================================
1462    // compute_reliability_mask
1463    // =========================================================================
1464
1465    #[test]
1466    fn test_compute_reliability_mask_basic() {
1467        let n = 10;
1468        let mask = vec![1u8; n];
1469        // Residuals in ascending order: 0.1, 0.2, ..., 1.0
1470        let fit_residual: Vec<f64> = (1..=n).map(|i| i as f64 * 0.1).collect();
1471
1472        // 50th percentile: threshold ~ 0.5
1473        let reliability = compute_reliability_mask(&fit_residual, &mask, 50.0);
1474        assert_eq!(reliability.len(), n);
1475
1476        // Voxels with residual < threshold should be reliable
1477        let reliable_count: usize = reliability.iter().map(|&v| v as usize).sum();
1478        assert!(reliable_count > 0 && reliable_count < n,
1479            "some but not all should be reliable, got {}/{}", reliable_count, n);
1480    }
1481
1482    #[test]
1483    fn test_compute_reliability_mask_all_zero_residual() {
1484        let n = 5;
1485        let mask = vec![1u8; n];
1486        let fit_residual = vec![0.0; n];
1487
1488        // When all residuals are 0, the filter skips them (r > 0.0 check fails)
1489        // so residuals vec is empty and mask is returned as-is
1490        let reliability = compute_reliability_mask(&fit_residual, &mask, 50.0);
1491        assert_eq!(reliability, mask);
1492    }
1493
1494    // =========================================================================
1495    // mcpc3ds_single_coil
1496    // =========================================================================
1497
1498    #[test]
1499    fn test_mcpc3ds_single_coil_output_sizes() {
1500        let (nx, ny, nz) = (8, 8, 8);
1501        let n = nx * ny * nz;
1502        let tes = [5.0, 10.0, 15.0];
1503        let (phases, mags, mask) = make_synthetic_multi_echo(nx, ny, nz, &tes);
1504
1505        let sigma = [1.0, 1.0, 1.0];
1506        let (corrected, offset) = mcpc3ds_single_coil(
1507            &phases, &mags, &tes, &mask, sigma, [0, 1], nx, ny, nz,
1508        );
1509
1510        // Check sizes
1511        assert_eq!(corrected.len(), tes.len(), "should have one corrected phase per echo");
1512        for (e, cp) in corrected.iter().enumerate() {
1513            assert_eq!(cp.len(), n, "echo {} corrected phase should have {} voxels", e, n);
1514        }
1515        assert_eq!(offset.len(), n, "phase offset should have {} voxels", n);
1516    }
1517
1518    #[test]
1519    fn test_mcpc3ds_single_coil_finite_output() {
1520        let (nx, ny, nz) = (8, 8, 8);
1521        let tes = [5.0, 10.0, 15.0];
1522        let (phases, mags, mask) = make_synthetic_multi_echo(nx, ny, nz, &tes);
1523
1524        let sigma = [1.0, 1.0, 1.0];
1525        let (corrected, offset) = mcpc3ds_single_coil(
1526            &phases, &mags, &tes, &mask, sigma, [0, 1], nx, ny, nz,
1527        );
1528
1529        for v in &offset {
1530            assert!(v.is_finite(), "phase offset should be finite");
1531        }
1532        for cp in &corrected {
1533            for v in cp {
1534                assert!(v.is_finite(), "corrected phase should be finite");
1535            }
1536        }
1537    }
1538
1539    #[test]
1540    fn test_mcpc3ds_single_coil_corrected_in_range() {
1541        let (nx, ny, nz) = (8, 8, 8);
1542        let tes = [5.0, 10.0, 15.0];
1543        let (phases, mags, mask) = make_synthetic_multi_echo(nx, ny, nz, &tes);
1544
1545        let sigma = [1.0, 1.0, 1.0];
1546        let (corrected, _) = mcpc3ds_single_coil(
1547            &phases, &mags, &tes, &mask, sigma, [0, 1], nx, ny, nz,
1548        );
1549
1550        // Corrected phases should be in [-pi, pi] since wrap_to_pi is applied
1551        for cp in &corrected {
1552            for &v in cp {
1553                assert!(v >= -PI - 1e-10 && v <= PI + 1e-10,
1554                    "corrected phase should be in [-pi, pi], got {}", v);
1555            }
1556        }
1557    }
1558
1559    #[test]
1560    fn test_mcpc3ds_single_coil_uniform_phase() {
1561        // Uniform phase across all echoes => offset should be approximately that phase
1562        let (nx, ny, nz) = (8, 8, 8);
1563        let n = nx * ny * nz;
1564        let tes = [5.0, 10.0, 15.0];
1565
1566        // All echoes have constant phase 0.5 (no TE dependence)
1567        let phases: Vec<Vec<f64>> = (0..3).map(|_| vec![0.5; n]).collect();
1568        let mags: Vec<Vec<f64>> = (0..3).map(|_| vec![1.0; n]).collect();
1569        let mask = vec![1u8; n];
1570
1571        let sigma = [1.0, 1.0, 1.0];
1572        let (corrected, _offset) = mcpc3ds_single_coil(
1573            &phases, &mags, &tes, &mask, sigma, [0, 1], nx, ny, nz,
1574        );
1575
1576        // After removing offset, corrected phases should be close to 0
1577        for cp in &corrected {
1578            for &v in cp {
1579                assert!(v.abs() < 1.0,
1580                    "after offset removal of uniform phase, corrected should be ~0, got {}", v);
1581            }
1582        }
1583    }
1584
1585    // =========================================================================
1586    // mcpc3ds_b0_pipeline
1587    // =========================================================================
1588
1589    #[test]
1590    fn test_mcpc3ds_b0_pipeline_output_sizes() {
1591        let (nx, ny, nz) = (8, 8, 8);
1592        let n = nx * ny * nz;
1593        let tes = [5.0, 10.0, 15.0];
1594        let (phases, mags, mask) = make_synthetic_multi_echo(nx, ny, nz, &tes);
1595
1596        let sigma = [1.0, 1.0, 1.0];
1597        let (b0, offset, corrected) = mcpc3ds_b0_pipeline(
1598            &phases, &mags, &tes, &mask, sigma,
1599            B0WeightType::PhaseSNR, nx, ny, nz,
1600        );
1601
1602        assert_eq!(b0.len(), n, "B0 should have n voxels");
1603        assert_eq!(offset.len(), n, "offset should have n voxels");
1604        assert_eq!(corrected.len(), tes.len(), "corrected should have n_echoes entries");
1605        for cp in &corrected {
1606            assert_eq!(cp.len(), n);
1607        }
1608    }
1609
1610    #[test]
1611    fn test_mcpc3ds_b0_pipeline_finite_output() {
1612        let (nx, ny, nz) = (8, 8, 8);
1613        let tes = [5.0, 10.0, 15.0];
1614        let (phases, mags, mask) = make_synthetic_multi_echo(nx, ny, nz, &tes);
1615
1616        let sigma = [1.0, 1.0, 1.0];
1617        let (b0, offset, corrected) = mcpc3ds_b0_pipeline(
1618            &phases, &mags, &tes, &mask, sigma,
1619            B0WeightType::PhaseSNR, nx, ny, nz,
1620        );
1621
1622        for v in &b0 {
1623            assert!(v.is_finite(), "B0 should be finite");
1624        }
1625        for v in &offset {
1626            assert!(v.is_finite(), "offset should be finite");
1627        }
1628        for cp in &corrected {
1629            for v in cp {
1630                assert!(v.is_finite(), "corrected phase should be finite");
1631            }
1632        }
1633    }
1634
1635    #[test]
1636    fn test_mcpc3ds_b0_pipeline_different_weight_types() {
1637        let (nx, ny, nz) = (8, 8, 8);
1638        let tes = [5.0, 10.0, 15.0];
1639        let (phases, mags, mask) = make_synthetic_multi_echo(nx, ny, nz, &tes);
1640        let sigma = [1.0, 1.0, 1.0];
1641
1642        for wt in &[
1643            B0WeightType::PhaseSNR,
1644            B0WeightType::Average,
1645            B0WeightType::TEs,
1646            B0WeightType::Mag,
1647            B0WeightType::PhaseVar,
1648        ] {
1649            let (b0, _, _) = mcpc3ds_b0_pipeline(
1650                &phases, &mags, &tes, &mask, sigma, *wt, nx, ny, nz,
1651            );
1652            for v in &b0 {
1653                assert!(v.is_finite(),
1654                    "B0 with weight type {:?} should be finite", wt);
1655            }
1656        }
1657    }
1658
1659    // =========================================================================
1660    // wrap_to_pi edge cases
1661    // =========================================================================
1662
1663    #[test]
1664    fn test_wrap_to_pi_near_boundaries() {
1665        // Values just beyond PI and -PI
1666        let v1 = wrap_to_pi(PI + 0.001);
1667        assert!(v1 < PI && v1 > -PI, "should wrap back into range");
1668
1669        let v2 = wrap_to_pi(-PI - 0.001);
1670        assert!(v2 > -PI && v2 < PI, "should wrap back into range");
1671
1672        // Large positive and negative values
1673        let v3 = wrap_to_pi(100.0 * PI);
1674        assert!(v3 >= -PI && v3 <= PI, "should be in [-pi, pi], got {}", v3);
1675
1676        let v4 = wrap_to_pi(-100.0 * PI);
1677        assert!(v4 >= -PI && v4 <= PI, "should be in [-pi, pi], got {}", v4);
1678    }
1679
1680    // =========================================================================
1681    // unwrap_with_romeo (tested indirectly through mcpc3ds_single_coil
1682    // but let's also test directly)
1683    // =========================================================================
1684
1685    #[test]
1686    fn test_unwrap_with_romeo_smooth_data() {
1687        let (nx, ny, nz) = (8, 8, 8);
1688        let n = nx * ny * nz;
1689        // Smooth phase that doesn't need unwrapping
1690        let phase: Vec<f64> = (0..n).map(|i| {
1691            let x = (i % nx) as f64 / nx as f64;
1692            0.5 * x // small smooth phase
1693        }).collect();
1694        let mag = vec![1.0; n];
1695        let mask = vec![1u8; n];
1696
1697        let unwrapped = unwrap_with_romeo(&phase, &mag, &mask, nx, ny, nz);
1698
1699        assert_eq!(unwrapped.len(), n);
1700        for (i, &v) in unwrapped.iter().enumerate() {
1701            assert!(v.is_finite(), "unwrapped voxel {} should be finite", i);
1702        }
1703    }
1704
1705    // =========================================================================
1706    // Integration: linear fit on mcpc3ds output
1707    // =========================================================================
1708
1709    #[test]
1710    fn test_linear_fit_on_mcpc3ds_output() {
1711        let (nx, ny, nz) = (8, 8, 8);
1712        let n = nx * ny * nz;
1713        let tes = [5.0, 10.0, 15.0];
1714        let (phases, mags, mask) = make_synthetic_multi_echo(nx, ny, nz, &tes);
1715
1716        let sigma = [1.0, 1.0, 1.0];
1717        let (corrected, _offset) = mcpc3ds_single_coil(
1718            &phases, &mags, &tes, &mask, sigma, [0, 1], nx, ny, nz,
1719        );
1720
1721        // Run linear fit on corrected phases (tes in seconds for fit)
1722        let tes_s: Vec<f64> = tes.iter().map(|&t| t / 1000.0).collect();
1723        let result = multi_echo_linear_fit(
1724            &corrected, &mags, &tes_s, &mask, true, 0.0,
1725        );
1726
1727        assert_eq!(result.field.len(), n);
1728        assert_eq!(result.phase_offset.len(), n);
1729        assert_eq!(result.fit_residual.len(), n);
1730        assert_eq!(result.reliability_mask.len(), n);
1731
1732        for v in &result.field {
1733            assert!(v.is_finite(), "field should be finite");
1734        }
1735        for v in &result.phase_offset {
1736            assert!(v.is_finite(), "phase_offset should be finite");
1737        }
1738    }
1739}