Skip to main content

qsm_core/inversion/
medi.rs

1//! MEDI (Morphology Enabled Dipole Inversion) L1 regularization
2//!
3//! Gauss-Newton optimization with L1 TV regularization and
4//! morphology-based gradient weighting from magnitude images.
5//!
6//! Features:
7//! - **Per-direction gradient masks** (mx, my, mz) matching the original MEDI formulation
8//! - Adaptive edge detection with configurable percentage threshold (default: 30% edges)
9//! - SNR-based data weighting using noise standard deviation maps
10//! - Optional SMV (Spherical Mean Value) preprocessing
11//! - Optional merit-based outlier adjustment (MERIT)
12//! - **Optimized with f32 single precision for WASM performance**
13//! - **Buffer reuse to minimize allocations**
14//! - **Standard CG convergence (relative tolerance)**
15//! - **Linear extrapolation boundary conditions** matching MATLAB's gradf
16//!
17//! Reference:
18//! Liu, T., Liu, J., de Rochefort, L., Spincemaille, P., Khalidov, I., Ledoux, J.R.,
19//! Wang, Y. (2011). "Morphology enabled dipole inversion (MEDI) from a single-angle
20//! acquisition: comparison with COSMOS in human brain imaging."
21//! Magnetic Resonance in Medicine, 66(3):777-783. https://doi.org/10.1002/mrm.22816
22//!
23//! Liu, J., Liu, T., de Rochefort, L., Ledoux, J., Khalidov, I., Chen, W., Tsiouris, A.J.,
24//! Wisnieff, C., Spincemaille, P., Prince, M.R., Wang, Y. (2012).
25//! "Morphology enabled dipole inversion for quantitative susceptibility mapping using
26//! structural consistency between the magnitude image and the susceptibility map."
27//! NeuroImage, 59(3):2560-2568.
28//!
29//! Reference implementation: https://github.com/huawu02/MEDI_toolbox
30
31use num_complex::Complex32;
32use crate::fft::Fft3dWorkspaceF32;
33use crate::kernels::dipole::dipole_kernel_f32;
34use crate::kernels::smv::smv_kernel_f32;
35use crate::utils::simd_ops::{
36    dot_product_f32, norm_squared_f32, axpy_f32, xpby_f32,
37    apply_gradient_weights_f32, compute_p_weights_f32, combine_terms_f32, negate_f32,
38};
39// Note: Uses fgrad_periodic_inplace_f32 / bdiv_periodic_inplace_f32 (periodic BCs)
40// for the MEDI inner loop matching MATLAB's gradfp_mex / gradfp_adj_mex.
41// Uses fgrad_linext_inplace_f32 (linear extrapolation BCs) only for gradient mask
42// computation, matching MATLAB's gradf_mex.
43
44/// MEDI algorithm parameters
45#[derive(Clone, Debug)]
46pub struct MediParams {
47    /// Regularization weight
48    pub lambda: f64,
49    /// Enable MERIT (outlier adjustment)
50    pub merit: bool,
51    /// Enable SMV preprocessing
52    pub smv: bool,
53    /// SMV radius in mm
54    pub smv_radius: f64,
55    /// Data weighting mode (1 = SNR)
56    pub data_weighting: i32,
57    /// Fraction of voxels considered edges (0.0-1.0)
58    pub percentage: f64,
59    /// CG convergence tolerance
60    pub cg_tol: f64,
61    /// Maximum CG iterations
62    pub cg_max_iter: usize,
63    /// Maximum outer iterations
64    pub max_iter: usize,
65    /// Outer convergence tolerance
66    pub tol: f64,
67}
68
69impl Default for MediParams {
70    fn default() -> Self {
71        Self {
72            lambda: 7.5e-5,
73            merit: false,
74            smv: true,
75            smv_radius: 5.0,
76            data_weighting: 1,
77            percentage: 0.3,
78            cg_tol: 0.01,
79            cg_max_iter: 10,
80            max_iter: 30,
81            tol: 0.1,
82        }
83    }
84}
85
86/// Workspace for MEDI operations - holds all reusable buffers (f32 version)
87/// Uses single precision for ~2x speedup on WASM
88pub struct MediWorkspace {
89    pub n_total: usize,
90    pub nx: usize,
91    pub ny: usize,
92    pub nz: usize,
93    pub vsx: f32,
94    pub vsy: f32,
95    pub vsz: f32,
96
97    // FFT workspace with cached plans (f32)
98    pub fft_ws: Fft3dWorkspaceF32,
99
100    // Gradient buffers (3 components)
101    pub gx: Vec<f32>,
102    pub gy: Vec<f32>,
103    pub gz: Vec<f32>,
104
105    // Weighted gradient buffers
106    pub reg_x: Vec<f32>,
107    pub reg_y: Vec<f32>,
108    pub reg_z: Vec<f32>,
109
110    // Divergence buffer
111    pub div_buf: Vec<f32>,
112
113    // Complex buffer for FFT operations
114    pub complex_buf: Vec<Complex32>,
115    pub complex_buf2: Vec<Complex32>,
116
117    // Real buffer for dipole result
118    pub dipole_buf: Vec<f32>,
119
120    // CG solver buffers
121    pub cg_r: Vec<f32>,
122    pub cg_p: Vec<f32>,
123    pub cg_ap: Vec<f32>,
124}
125
126impl MediWorkspace {
127    /// Create a new MEDI workspace for the given dimensions
128    pub fn new(nx: usize, ny: usize, nz: usize, vsx: f32, vsy: f32, vsz: f32) -> Self {
129        let n_total = nx * ny * nz;
130
131        Self {
132            n_total,
133            nx, ny, nz,
134            vsx, vsy, vsz,
135            fft_ws: Fft3dWorkspaceF32::new(nx, ny, nz),
136            gx: vec![0.0; n_total],
137            gy: vec![0.0; n_total],
138            gz: vec![0.0; n_total],
139            reg_x: vec![0.0; n_total],
140            reg_y: vec![0.0; n_total],
141            reg_z: vec![0.0; n_total],
142            div_buf: vec![0.0; n_total],
143            complex_buf: vec![Complex32::new(0.0, 0.0); n_total],
144            complex_buf2: vec![Complex32::new(0.0, 0.0); n_total],
145            dipole_buf: vec![0.0; n_total],
146            cg_r: vec![0.0; n_total],
147            cg_p: vec![0.0; n_total],
148            cg_ap: vec![0.0; n_total],
149        }
150    }
151}
152
153/// Apply dipole convolution: out = real(ifft(D * fft(x)))
154#[inline]
155fn apply_dipole_conv(
156    fft_ws: &mut Fft3dWorkspaceF32,
157    x: &[f32],
158    d_kernel: &[f32],
159    out: &mut [f32],
160    complex_buf: &mut [Complex32],
161) {
162    fft_ws.apply_dipole_inplace(x, d_kernel, out, complex_buf);
163}
164
165/// MEDI operator buffers - separate struct to allow split borrowing
166struct MediOpBuffers<'a> {
167    gx: &'a mut [f32],
168    gy: &'a mut [f32],
169    gz: &'a mut [f32],
170    reg_x: &'a mut [f32],
171    reg_y: &'a mut [f32],
172    reg_z: &'a mut [f32],
173    div_buf: &'a mut [f32],
174    dipole_buf: &'a mut [f32],
175    complex_buf: &'a mut [Complex32],
176    complex_buf2: &'a mut [Complex32],
177}
178
179/// Apply MEDI operator in-place: out = fidelity(dx) + lambda*reg(dx)
180/// This is the hot path - called many times per Gauss-Newton iteration
181/// Uses per-direction gradient masks (mx, my, mz) matching MATLAB MEDI
182/// SIMD-accelerated for element-wise operations
183#[inline]
184fn apply_medi_operator_core(
185    fft_ws: &mut Fft3dWorkspaceF32,
186    bufs: &mut MediOpBuffers,
187    n: usize,
188    nx: usize, ny: usize, nz: usize,
189    vsx: f32, vsy: f32, vsz: f32,
190    dx: &[f32],
191    w: &[Complex32],
192    d_kernel: &[f32],
193    mx: &[f32],  // Per-direction gradient mask for x
194    my: &[f32],  // Per-direction gradient mask for y
195    mz: &[f32],  // Per-direction gradient mask for z
196    vr: &[f32],
197    lambda: f32,
198    out: &mut [f32],
199) {
200    // 1. Compute gradient of dx (in-place into gx, gy, gz) - periodic BCs matching MATLAB gradfp_mex
201    fgrad_periodic_inplace_f32(bufs.gx, bufs.gy, bufs.gz, dx, nx, ny, nz, vsx, vsy, vsz);
202
203    // 2. Apply per-direction weights: reg_i = m_i * P * m_i * g_i (SIMD accelerated)
204    // MATLAB: ux = mx .* P .* mx .* ux; uy = my .* P .* my .* uy; uz = mz .* P .* mz .* uz;
205    apply_gradient_weights_f32(
206        bufs.reg_x, bufs.reg_y, bufs.reg_z,
207        mx, my, mz, vr,
208        bufs.gx, bufs.gy, bufs.gz,
209    );
210
211    // 3. Compute divergence (in-place into div_buf) - periodic BCs matching MATLAB gradfp_adj_mex
212    bdiv_periodic_inplace_f32(bufs.div_buf, bufs.reg_x, bufs.reg_y, bufs.reg_z, nx, ny, nz, vsx, vsy, vsz);
213
214    // 4. Fidelity term: D^T(|w|^2 * D(dx))
215    apply_dipole_conv(fft_ws, dx, d_kernel, bufs.dipole_buf, bufs.complex_buf);
216
217    // Multiply by |w|^2 and convert to complex
218    for i in 0..n {
219        let w_mag_sq = w[i].norm_sqr();
220        bufs.complex_buf2[i] = Complex32::new(bufs.dipole_buf[i] * w_mag_sq, 0.0);
221    }
222
223    // Apply D^T (which is D for real symmetric kernel)
224    fft_ws.fft3d(bufs.complex_buf2);
225    for i in 0..n {
226        bufs.complex_buf2[i] *= d_kernel[i];
227    }
228    fft_ws.ifft3d(bufs.complex_buf2);
229
230    // 5. Combine: out = lambda*div_buf + real(complex_buf2) (matching MATLAB: y = D + R)
231    // Extract real parts for SIMD operation
232    for i in 0..n {
233        bufs.dipole_buf[i] = bufs.complex_buf2[i].re;
234    }
235    combine_terms_f32(out, bufs.div_buf, bufs.dipole_buf, lambda);
236}
237
238/// Conjugate gradient solver with buffer reuse
239/// Solves Ax = b where A is the MEDI operator
240///
241/// The optional progress callback receives (cg_iter, max_iter) for each CG iteration.
242/// Uses per-direction gradient masks (mx, my, mz) matching MATLAB MEDI.
243#[inline]
244fn cg_solve_medi<F>(
245    ws: &mut MediWorkspace,
246    w: &[Complex32],
247    d_kernel: &[f32],
248    mx: &[f32],  // Per-direction gradient mask for x
249    my: &[f32],  // Per-direction gradient mask for y
250    mz: &[f32],  // Per-direction gradient mask for z
251    vr: &[f32],
252    lambda: f32,
253    b: &[f32],
254    x: &mut [f32],
255    tol: f32,
256    max_iter: usize,
257    mut progress_callback: F,
258) where
259    F: FnMut(usize, usize),
260{
261    let n = ws.n_total;
262    let (nx, ny, nz) = (ws.nx, ws.ny, ws.nz);
263    let (vsx, vsy, vsz) = (ws.vsx, ws.vsy, ws.vsz);
264
265    // Initialize x to zero
266    x.fill(0.0);
267
268    // r = b - A*x = b (since x=0)
269    ws.cg_r.copy_from_slice(b);
270
271    // p = r
272    ws.cg_p.copy_from_slice(&ws.cg_r);
273
274    // rsold = r·r (SIMD accelerated)
275    let mut rsold: f32 = norm_squared_f32(&ws.cg_r);
276
277    // b_norm for relative tolerance (SIMD accelerated)
278    let b_norm: f32 = norm_squared_f32(b).sqrt();
279    if b_norm < 1e-10 {
280        return; // b is zero, x=0 is the solution
281    }
282
283    // Buffer for p (to avoid borrow conflict)
284    let mut p_copy = vec![0.0f32; n];
285
286    for cg_iter in 0..max_iter {
287        // Report CG progress
288        progress_callback(cg_iter + 1, max_iter);
289
290        // Copy p to avoid borrow conflict
291        p_copy.copy_from_slice(&ws.cg_p);
292
293        // ap = A*p - use split borrowing
294        {
295            let mut bufs = MediOpBuffers {
296                gx: &mut ws.gx,
297                gy: &mut ws.gy,
298                gz: &mut ws.gz,
299                reg_x: &mut ws.reg_x,
300                reg_y: &mut ws.reg_y,
301                reg_z: &mut ws.reg_z,
302                div_buf: &mut ws.div_buf,
303                dipole_buf: &mut ws.dipole_buf,
304                complex_buf: &mut ws.complex_buf,
305                complex_buf2: &mut ws.complex_buf2,
306            };
307            apply_medi_operator_core(
308                &mut ws.fft_ws, &mut bufs, n, nx, ny, nz, vsx, vsy, vsz,
309                &p_copy, w, d_kernel, mx, my, mz, vr, lambda, &mut ws.cg_ap
310            );
311        }
312
313        // pap = p·ap (SIMD accelerated)
314        let pap: f32 = dot_product_f32(&ws.cg_p, &ws.cg_ap);
315
316        if pap.abs() < 1e-15 {
317            break;
318        }
319
320        let alpha = rsold / pap;
321
322        // x = x + alpha*p (SIMD accelerated)
323        axpy_f32(x, alpha, &ws.cg_p);
324
325        // r = r - alpha*ap (SIMD accelerated)
326        axpy_f32(&mut ws.cg_r, -alpha, &ws.cg_ap);
327
328        // rsnew = r·r (SIMD accelerated)
329        let rsnew: f32 = norm_squared_f32(&ws.cg_r);
330        let residual = rsnew.sqrt();
331
332        // Check convergence
333        if residual < tol * b_norm {
334            break;
335        }
336
337        let beta = rsnew / rsold;
338
339        // p = r + beta*p (SIMD accelerated)
340        xpby_f32(&mut ws.cg_p, &ws.cg_r, beta);
341
342        rsold = rsnew;
343    }
344}
345
346/// MEDI L1 dipole inversion with full options (OPTIMIZED f32 VERSION)
347///
348/// # Arguments
349/// * `local_field` - Local field/phase (RDF) in radians (nx * ny * nz)
350/// * `n_std` - Noise standard deviation map (same size as local_field)
351/// * `magnitude` - Magnitude image for gradient weighting (nx * ny * nz)
352/// * `mask` - Binary mask (nx * ny * nz), 1 = brain
353/// * `nx`, `ny`, `nz` - Array dimensions
354/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
355/// * `lambda` - Regularization parameter (default: 7.5e-5, matching MATLAB MEDI)
356/// * `bdir` - B0 field direction (default: (0, 0, 1))
357/// * `merit` - Enable iterative merit-based outlier adjustment (default: false)
358/// * `smv` - Enable SMV preprocessing within MEDI (default: false)
359/// * `smv_radius` - SMV radius in mm (default: 5.0)
360/// * `data_weighting` - Data weighting mode: 0=uniform, 1=SNR (default: 1)
361/// * `percentage` - Fraction of voxels considered edges (default: 0.3 = 30%, matching MATLAB gpct=30)
362/// * `cg_tol` - CG solver tolerance (default: 0.01)
363/// * `cg_max_iter` - CG maximum iterations (default: 10, matching MATLAB)
364/// * `max_iter` - Maximum Gauss-Newton iterations (default: 30, matching MATLAB)
365/// * `tol` - Convergence tolerance (default: 0.1)
366///
367/// # Returns
368/// Susceptibility map (in same units as input field)
369#[allow(clippy::too_many_arguments)]
370pub fn medi_l1(
371    local_field: &[f64],
372    n_std: &[f64],
373    magnitude: &[f64],
374    mask: &[u8],
375    nx: usize, ny: usize, nz: usize,
376    vsx: f64, vsy: f64, vsz: f64,
377    lambda: f64,
378    bdir: (f64, f64, f64),
379    merit: bool,
380    smv: bool,
381    smv_radius: f64,
382    data_weighting: i32,
383    percentage: f64,
384    cg_tol: f64,
385    cg_max_iter: usize,
386    max_iter: usize,
387    tol: f64,
388) -> Vec<f64> {
389    let n_total = nx * ny * nz;
390
391    // Convert to f32 for internal computation (much faster on WASM)
392    let vsx_f32 = vsx as f32;
393    let vsy_f32 = vsy as f32;
394    let vsz_f32 = vsz as f32;
395    let lambda_f32 = lambda as f32;
396    let bdir_f32 = (bdir.0 as f32, bdir.1 as f32, bdir.2 as f32);
397    let smv_radius_f32 = smv_radius as f32;
398    let percentage_f32 = percentage as f32;
399    let cg_tol_f32 = cg_tol as f32;
400    let tol_f32 = tol as f32;
401
402    // Convert input arrays to f32
403    let local_field_f32: Vec<f32> = local_field.iter().map(|&v| v as f32).collect();
404    let n_std_f32: Vec<f32> = n_std.iter().map(|&v| v as f32).collect();
405    let magnitude_f32: Vec<f32> = magnitude.iter().map(|&v| v as f32).collect();
406
407    // Create workspace - this allocates all buffers ONCE
408    let mut ws = MediWorkspace::new(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
409
410    // Working copies that may be modified by SMV preprocessing
411    let mut rdf: Vec<f32> = local_field_f32.clone();
412    let mut work_mask: Vec<u8> = mask.to_vec();
413    let mut tempn: Vec<f32> = n_std_f32.clone();
414
415    // Apply mask to N_std
416    for i in 0..n_total {
417        if mask[i] == 0 {
418            tempn[i] = 0.0;
419        }
420    }
421
422    // Generate dipole kernel
423    let mut d_kernel = dipole_kernel_f32(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, bdir_f32);
424
425    // SMV preprocessing (optional)
426    let sphere_k = if smv {
427        let sk = smv_kernel_f32(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, smv_radius_f32);
428
429        // FFT of sphere kernel for convolution
430        let mut sk_fft: Vec<Complex32> = sk.iter()
431            .map(|&v| Complex32::new(v, 0.0))
432            .collect();
433        ws.fft_ws.fft3d(&mut sk_fft);
434
435        // Erode mask: SMV(mask) > 0.999
436        let mask_f32: Vec<f32> = work_mask.iter().map(|&m| m as f32).collect();
437        let smv_mask = apply_smv_kernel_ws(&mask_f32, &sk_fft, &mut ws);
438        for i in 0..n_total {
439            work_mask[i] = if smv_mask[i] > 0.999 { 1 } else { 0 };
440        }
441
442        // Modify dipole kernel: D = (1 - SphereK) * D
443        for i in 0..n_total {
444            d_kernel[i] *= 1.0 - sk[i];
445        }
446
447        // Modify RDF: RDF = RDF - SMV(RDF)
448        let smv_rdf = apply_smv_kernel_ws(&rdf, &sk_fft, &mut ws);
449        for i in 0..n_total {
450            rdf[i] -= smv_rdf[i];
451            if work_mask[i] == 0 {
452                rdf[i] = 0.0;
453            }
454        }
455
456        // Modify noise: tempn = sqrt(SMV(tempn^2) + tempn^2)
457        let tempn_sq: Vec<f32> = tempn.iter().map(|&t| t * t).collect();
458        let smv_tempn_sq = apply_smv_kernel_ws(&tempn_sq, &sk_fft, &mut ws);
459        for i in 0..n_total {
460            tempn[i] = (smv_tempn_sq[i] + tempn_sq[i]).sqrt();
461        }
462
463        Some(sk_fft)
464    } else {
465        None
466    };
467
468    // Compute data weighting
469    let mut m = dataterm_mask_f32(data_weighting, &tempn, &work_mask);
470
471    // b0 = m * exp(i * RDF)
472    let mut b0: Vec<Complex32> = rdf.iter()
473        .zip(m.iter())
474        .map(|(&f, &mi)| {
475            let phase = Complex32::new(0.0, f);
476            mi * phase.exp()
477        })
478        .collect();
479
480    // Compute per-direction gradient weighting masks from magnitude edges
481    // Returns (mx, my, mz) - separate masks for each gradient direction (matching MATLAB MEDI)
482    let (w_gx, w_gy, w_gz) = gradient_mask_f32(&magnitude_f32, &work_mask, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, percentage_f32);
483
484    // Fallback: if any mask is all zeros, use magnitude image (matching MATLAB)
485    let w_gx = if w_gx.iter().any(|&v| v != 0.0) { w_gx } else { magnitude_f32.clone() };
486    let w_gy = if w_gy.iter().any(|&v| v != 0.0) { w_gy } else { magnitude_f32.clone() };
487    let w_gz = if w_gz.iter().any(|&v| v != 0.0) { w_gz } else { magnitude_f32.clone() };
488
489    // Initialize susceptibility
490    let mut chi = vec![0.0f32; n_total];
491    let mut dx = vec![0.0f32; n_total];  // Reusable buffer for CG solution
492    let mut rhs = vec![0.0f32; n_total]; // Reusable buffer for RHS
493    let mut vr = vec![0.0f32; n_total];  // Reusable buffer for Vr (P in MATLAB)
494    let mut w: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); n_total]; // Reusable buffer for w
495    let mut chi_prev = vec![0.0f32; n_total]; // Reusable buffer for convergence check
496    let mut badpoint = vec![0.0f32; n_total];
497    let mut n_std_work: Vec<f32> = n_std_f32.clone();
498
499    // MATLAB: beta = sqrt(eps(class(f))) where eps for f64 ≈ 2.22e-16, so sqrt(eps) ≈ 1.49e-8.
500    // This is a regularization parameter for the P weight denominator, not a precision limit.
501    // Using the same value as MATLAB (1.49e-8) is fine in f32 (representable, well above f32 eps).
502    let beta = 1.49e-8_f32;
503
504    // Gauss-Newton iterations
505    for _iter in 0..max_iter {
506        // Save chi_prev for convergence check
507        chi_prev.copy_from_slice(&chi);
508
509        // Compute P = 1 / sqrt(|m * grad(chi)|^2 + beta) using per-direction masks (SIMD accelerated)
510        // MATLAB: P = 1 ./ sqrt(ux.*ux + uy.*uy + uz.*uz + beta);
511        // where ux = mx .* grad_x(chi), uy = my .* grad_y(chi), uz = mz .* grad_z(chi)
512        // Uses periodic BCs matching MATLAB's grad_ (which calls gradfp_mex)
513        fgrad_periodic_inplace_f32(
514            &mut ws.gx, &mut ws.gy, &mut ws.gz,
515            &chi, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32,
516        );
517
518        compute_p_weights_f32(&mut vr, &w_gx, &w_gy, &w_gz, &ws.gx, &ws.gy, &ws.gz, beta);
519
520        // Compute w = m * exp(i * D*chi) using workspace
521        apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
522        for i in 0..n_total {
523            let phase = Complex32::new(0.0, ws.dipole_buf[i]);
524            w[i] = m[i] * phase.exp();
525        }
526
527        // Compute right-hand side using workspace
528        compute_rhs_inplace(&chi, &w, &b0, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda_f32, &mut rhs, &mut ws);
529
530        // Negate for CG (solving A*dx = -b) (SIMD accelerated)
531        negate_f32(&mut rhs);
532
533        // Solve A*dx = rhs using optimized CG with buffer reuse (no progress reporting)
534        cg_solve_medi(&mut ws, &w, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda_f32, &rhs, &mut dx, cg_tol_f32, cg_max_iter, |_, _| {});
535
536        // Update: chi = chi + dx (SIMD accelerated)
537        axpy_f32(&mut chi, 1.0, &dx);
538
539        // Check convergence (SIMD accelerated)
540        let norm_dx_sq = norm_squared_f32(&dx);
541        let norm_chi_sq = norm_squared_f32(&chi_prev);
542        let rel_change = norm_dx_sq.sqrt() / (norm_chi_sq.sqrt() + 1e-6);
543
544        // Merit adjustment (optional)
545        if merit {
546            // Compute residual: wres = m * exp(i * D*chi) - b0
547            apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
548            let mut wres: Vec<Complex32> = ws.dipole_buf.iter()
549                .zip(m.iter())
550                .zip(b0.iter())
551                .map(|((&dc, &mi), &b0i)| {
552                    let phase = Complex32::new(0.0, dc);
553                    mi * phase.exp() - b0i
554                })
555                .collect();
556
557            // Subtract mean over mask
558            let mask_count = work_mask.iter().filter(|&&m| m != 0).count() as f32;
559            if mask_count > 0.0 {
560                let mean_wres: Complex32 = wres.iter()
561                    .zip(work_mask.iter())
562                    .filter(|(_, &m)| m != 0)
563                    .map(|(w, _)| w)
564                    .sum::<Complex32>() / mask_count;
565
566                for i in 0..n_total {
567                    if work_mask[i] != 0 {
568                        wres[i] -= mean_wres;
569                    }
570                }
571            }
572
573            // Compute factor = std(abs(wres[mask])) * 6
574            let abs_wres: Vec<f32> = wres.iter()
575                .zip(work_mask.iter())
576                .filter(|(_, &m)| m != 0)
577                .map(|(w, _)| w.norm())
578                .collect();
579
580            if !abs_wres.is_empty() {
581                let mean_abs: f32 = abs_wres.iter().sum::<f32>() / abs_wres.len() as f32;
582                let var: f32 = abs_wres.iter()
583                    .map(|&v| (v - mean_abs).powi(2))
584                    .sum::<f32>() / abs_wres.len() as f32;
585                let factor = var.sqrt() * 6.0;
586
587                if factor > 1e-10 {
588                    // Normalize wres by factor
589                    let mut wres_norm: Vec<f32> = wres.iter()
590                        .map(|w| w.norm() / factor)
591                        .collect();
592
593                    // Clamp values < 1 to 1
594                    for v in wres_norm.iter_mut() {
595                        if *v < 1.0 {
596                            *v = 1.0;
597                        }
598                    }
599
600                    // Mark bad points and update noise
601                    for i in 0..n_total {
602                        if wres_norm[i] > 1.0 {
603                            badpoint[i] = 1.0;
604                        }
605                        if work_mask[i] != 0 {
606                            n_std_work[i] *= wres_norm[i].powi(2);
607                        }
608                    }
609
610                    // Recompute tempn
611                    tempn = n_std_work.clone();
612                    if let Some(ref sk_fft) = sphere_k {
613                        let tempn_sq: Vec<f32> = tempn.iter().map(|&t| t * t).collect();
614                        let smv_tempn_sq = apply_smv_kernel_ws(&tempn_sq, sk_fft, &mut ws);
615                        for i in 0..n_total {
616                            tempn[i] = (smv_tempn_sq[i] + tempn_sq[i]).sqrt();
617                        }
618                    }
619
620                    // Recompute data weighting and b0
621                    m = dataterm_mask_f32(data_weighting, &tempn, &work_mask);
622                    b0 = rdf.iter()
623                        .zip(m.iter())
624                        .map(|(&f, &mi)| {
625                            let phase = Complex32::new(0.0, f);
626                            mi * phase.exp()
627                        })
628                        .collect();
629                }
630            }
631        }
632
633        if rel_change < tol_f32 {
634            break;
635        }
636    }
637
638    // Suppress unused variable warning
639    let _ = badpoint;
640
641    // Apply mask and convert back to f64
642    chi.iter()
643        .zip(mask.iter())
644        .map(|(&c, &m)| if m == 0 { 0.0 } else { c as f64 })
645        .collect()
646}
647
648/// Apply SMV kernel using workspace buffers (f32)
649fn apply_smv_kernel_ws(
650    x: &[f32],
651    sk_fft: &[Complex32],
652    ws: &mut MediWorkspace,
653) -> Vec<f32> {
654    let n_total = ws.n_total;
655
656    // Copy to complex buffer
657    for (c, &r) in ws.complex_buf.iter_mut().zip(x.iter()) {
658        *c = Complex32::new(r, 0.0);
659    }
660
661    ws.fft_ws.fft3d(&mut ws.complex_buf);
662
663    for i in 0..n_total {
664        ws.complex_buf[i] *= sk_fft[i];
665    }
666
667    ws.fft_ws.ifft3d(&mut ws.complex_buf);
668
669    ws.complex_buf.iter().map(|c| c.re).collect()
670}
671
672/// Compute RHS in-place using workspace buffers (f32)
673/// Uses per-direction gradient masks (mx, my, mz) matching MATLAB MEDI
674/// SIMD-accelerated for element-wise operations
675fn compute_rhs_inplace(
676    chi: &[f32],
677    w: &[Complex32],
678    b0: &[Complex32],
679    d_kernel: &[f32],
680    mx: &[f32],  // Per-direction gradient mask for x
681    my: &[f32],  // Per-direction gradient mask for y
682    mz: &[f32],  // Per-direction gradient mask for z
683    vr: &[f32],
684    lambda: f32,
685    rhs: &mut [f32],
686    ws: &mut MediWorkspace,
687) {
688    let n = ws.n_total;
689
690    // Regularization term: div(m * P * m * grad(chi)) for each direction
691    // MATLAB: b = lam .* gradAdj_(ux, uy, uz, vsz);
692    // where ux = mx .* P .* mx .* grad_x(chi), etc.
693    // Uses periodic BCs matching MATLAB's gradfp_mex / gradfp_adj_mex
694    fgrad_periodic_inplace_f32(
695        &mut ws.gx, &mut ws.gy, &mut ws.gz,
696        chi, ws.nx, ws.ny, ws.nz,
697        ws.vsx, ws.vsy, ws.vsz,
698    );
699
700    // Apply per-direction weights: ux = mx * P * mx * gx (SIMD accelerated)
701    apply_gradient_weights_f32(
702        &mut ws.reg_x, &mut ws.reg_y, &mut ws.reg_z,
703        mx, my, mz, vr,
704        &ws.gx, &ws.gy, &ws.gz,
705    );
706
707    bdiv_periodic_inplace_f32(
708        &mut ws.div_buf,
709        &ws.reg_x, &ws.reg_y, &ws.reg_z,
710        ws.nx, ws.ny, ws.nz,
711        ws.vsx, ws.vsy, ws.vsz,
712    );
713
714    // Data term: D^T(conj(w) * (-i) * (w - b0))
715    // MATLAB: b = b + real(ifft3(conj(D) .* fft3(1i .* w2 .* (exp(1i.*(f - Dx)) - 1))));
716    for i in 0..n {
717        let diff = w[i] - b0[i];
718        let conj_w = w[i].conj();
719        let neg_i = Complex32::new(0.0, -1.0);
720        ws.complex_buf2[i] = conj_w * neg_i * diff;
721    }
722
723    // Apply D^T (which is D for real symmetric kernel)
724    ws.fft_ws.fft3d(&mut ws.complex_buf2);
725
726    for i in 0..n {
727        ws.complex_buf2[i] *= d_kernel[i];
728    }
729
730    ws.fft_ws.ifft3d(&mut ws.complex_buf2);
731
732    // Extract real parts for SIMD combine operation
733    for i in 0..n {
734        ws.dipole_buf[i] = ws.complex_buf2[i].re;
735    }
736
737    // Combine terms: rhs = lambda * reg_term + data_term (SIMD accelerated, matching MATLAB)
738    combine_terms_f32(rhs, &ws.div_buf, &ws.dipole_buf, lambda);
739}
740
741/// Generate data weighting mask (f32)
742///
743/// # Arguments
744/// * `mode` - 0 for uniform weighting, 1 for SNR weighting
745/// * `n_std` - Noise standard deviation
746/// * `mask` - Binary mask
747fn dataterm_mask_f32(mode: i32, n_std: &[f32], mask: &[u8]) -> Vec<f32> {
748    let n = n_std.len();
749
750    if mode == 0 {
751        // Uniform weighting
752        mask.iter().map(|&m| if m != 0 { 1.0 } else { 0.0 }).collect()
753    } else {
754        // SNR weighting: w = mask / N_std, normalized so mean over ROI = 1
755        let mut w: Vec<f32> = n_std.iter()
756            .zip(mask.iter())
757            .map(|(&n, &m)| {
758                if m != 0 && n > 1e-10 {
759                    1.0 / n
760                } else {
761                    0.0
762                }
763            })
764            .collect();
765
766        // Compute mean over ROI
767        let mask_count = mask.iter().filter(|&&m| m != 0).count() as f32;
768        if mask_count > 0.0 {
769            let sum: f32 = w.iter()
770                .zip(mask.iter())
771                .filter(|(_, &m)| m != 0)
772                .map(|(&wi, _)| wi)
773                .sum();
774            let mean = sum / mask_count;
775
776            if mean > 1e-10 {
777                // Normalize so mean = 1
778                for i in 0..n {
779                    w[i] /= mean;
780                }
781            }
782        }
783
784        // Ensure zeros outside mask
785        for i in 0..n {
786            if mask[i] == 0 {
787                w[i] = 0.0;
788            }
789        }
790
791        w
792    }
793}
794
795/// Generate per-direction gradient weighting masks (f32)
796///
797/// Computes separate edge masks for each gradient direction from magnitude image,
798/// matching the MATLAB MEDI implementation (gradientMaskMedi.m).
799/// Returns (mx, my, mz) where each mask is 1 (regularize) for non-edges, 0 for edges.
800///
801/// # Arguments
802/// * `magnitude` - Magnitude image
803/// * `mask` - Binary mask
804/// * `nx`, `ny`, `nz` - Array dimensions
805/// * `vsx`, `vsy`, `vsz` - Voxel sizes
806/// * `percentage` - Percentage of voxels considered to be edges (0.0-1.0, e.g., 0.3 = 30% edges)
807///
808/// # Returns
809/// Tuple of (mx, my, mz) per-direction binary gradient masks
810pub(crate) fn gradient_mask_f32(
811    magnitude: &[f32],
812    mask: &[u8],
813    nx: usize, ny: usize, nz: usize,
814    vsx: f32, vsy: f32, vsz: f32,
815    percentage: f32,
816) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
817    let n_total = nx * ny * nz;
818
819    // Normalize magnitude by max value within mask (matching MATLAB)
820    let mag_max = magnitude.iter()
821        .zip(mask.iter())
822        .filter(|(_, &m)| m != 0)
823        .map(|(&v, _)| v.abs())
824        .fold(0.0_f32, f32::max);
825
826    let mag_normalized: Vec<f32> = magnitude.iter()
827        .zip(mask.iter())
828        .map(|(&m, &msk)| {
829            if msk != 0 && mag_max > 1e-10 {
830                m / mag_max
831            } else {
832                0.0
833            }
834        })
835        .collect();
836
837    // Compute gradient of normalized magnitude (using linear extrapolation BCs)
838    let (gx, gy, gz) = fgrad_linext_f32(&mag_normalized, nx, ny, nz, vsx, vsy, vsz);
839
840    // Take absolute values of each gradient direction
841    let abs_gx: Vec<f32> = gx.iter().map(|&v| v.abs()).collect();
842    let abs_gy: Vec<f32> = gy.iter().map(|&v| v.abs()).collect();
843    let abs_gz: Vec<f32> = gz.iter().map(|&v| v.abs()).collect();
844
845    // Collect all gradient values within mask for threshold computation
846    let mut all_grads: Vec<f32> = Vec::with_capacity(3 * n_total);
847    for i in 0..n_total {
848        if mask[i] != 0 {
849            all_grads.push(abs_gx[i]);
850            all_grads.push(abs_gy[i]);
851            all_grads.push(abs_gz[i]);
852        }
853    }
854
855    if all_grads.is_empty() {
856        return (vec![1.0; n_total], vec![1.0; n_total], vec![1.0; n_total]);
857    }
858
859    // Sort to find percentile threshold (100 - percentage)
860    // MATLAB: thr = prctile([mx(mask); my(mask); mz(mask)], 100 - p);
861    all_grads.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
862    let percentile_idx = ((1.0 - percentage) * (all_grads.len() - 1) as f32) as usize;
863    let threshold = all_grads[percentile_idx.min(all_grads.len() - 1)];
864
865    // Create per-direction masks: 1 where gradient < threshold (non-edges), 0 at edges
866    // MATLAB: mx = mx < thr; my = my < thr; mz = mz < thr;
867    let mx: Vec<f32> = abs_gx.iter()
868        .zip(mask.iter())
869        .map(|(&g, &m)| if m != 0 && g < threshold { 1.0 } else { 0.0 })
870        .collect();
871
872    let my: Vec<f32> = abs_gy.iter()
873        .zip(mask.iter())
874        .map(|(&g, &m)| if m != 0 && g < threshold { 1.0 } else { 0.0 })
875        .collect();
876
877    let mz: Vec<f32> = abs_gz.iter()
878        .zip(mask.iter())
879        .map(|(&g, &m)| if m != 0 && g < threshold { 1.0 } else { 0.0 })
880        .collect();
881
882    (mx, my, mz)
883}
884
885/// Forward difference gradient with linear extrapolation boundary conditions (f32)
886/// Matches MATLAB's gradf behavior: dx(end) = dx(end-1)
887pub(crate) fn fgrad_linext_f32(
888    x: &[f32],
889    nx: usize, ny: usize, nz: usize,
890    vsx: f32, vsy: f32, vsz: f32,
891) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
892    let n_total = nx * ny * nz;
893    let mut gx = vec![0.0f32; n_total];
894    let mut gy = vec![0.0f32; n_total];
895    let mut gz = vec![0.0f32; n_total];
896    fgrad_linext_inplace_f32(&mut gx, &mut gy, &mut gz, x, nx, ny, nz, vsx, vsy, vsz);
897    (gx, gy, gz)
898}
899
900/// Forward difference gradient with linear extrapolation boundary conditions (f32, in-place)
901/// Matches MATLAB's gradf behavior: dx(end) = dx(end-1)
902#[inline]
903pub(crate) fn fgrad_linext_inplace_f32(
904    gx: &mut [f32], gy: &mut [f32], gz: &mut [f32],
905    x: &[f32],
906    nx: usize, ny: usize, nz: usize,
907    vsx: f32, vsy: f32, vsz: f32,
908) {
909    let hx = 1.0 / vsx;
910    let hy = 1.0 / vsy;
911    let hz = 1.0 / vsz;
912
913    for k in 0..nz {
914        let k_offset = k * nx * ny;
915
916        for j in 0..ny {
917            let j_offset = j * nx;
918
919            for i in 0..nx {
920                let idx = i + j_offset + k_offset;
921                let x_val = x[idx];
922
923                // Forward difference with linear extrapolation at boundary
924                // MATLAB: dx(end,:,:) = dx(end-1,:,:)
925                if i + 1 < nx {
926                    gx[idx] = (x[idx + 1] - x_val) * hx;
927                } else if i > 0 {
928                    // Copy from previous (linear extrapolation)
929                    gx[idx] = gx[idx - 1];
930                } else {
931                    gx[idx] = 0.0;
932                }
933
934                if j + 1 < ny {
935                    gy[idx] = (x[i + (j + 1) * nx + k_offset] - x_val) * hy;
936                } else if j > 0 {
937                    gy[idx] = gy[i + (j - 1) * nx + k_offset];
938                } else {
939                    gy[idx] = 0.0;
940                }
941
942                if k + 1 < nz {
943                    gz[idx] = (x[i + j_offset + (k + 1) * nx * ny] - x_val) * hz;
944                } else if k > 0 {
945                    gz[idx] = gz[i + j_offset + (k - 1) * nx * ny];
946                } else {
947                    gz[idx] = 0.0;
948                }
949            }
950        }
951    }
952}
953
954
955/// Forward difference gradient with periodic boundary conditions (f32, in-place)
956/// Matches MATLAB's gradfp_mex used inside MEDI iterations.
957/// At boundaries, wraps around: dx(end) = (x(1) - x(end)) / h
958#[inline]
959pub(crate) fn fgrad_periodic_inplace_f32(
960    gx: &mut [f32], gy: &mut [f32], gz: &mut [f32],
961    x: &[f32],
962    nx: usize, ny: usize, nz: usize,
963    vsx: f32, vsy: f32, vsz: f32,
964) {
965    let hx = 1.0 / vsx;
966    let hy = 1.0 / vsy;
967    let hz = 1.0 / vsz;
968    let nxny = nx * ny;
969
970    for k in 0..nz {
971        let k_offset = k * nxny;
972
973        for j in 0..ny {
974            let j_offset = j * nx;
975
976            for i in 0..nx {
977                let idx = i + j_offset + k_offset;
978                let x_val = x[idx];
979
980                // x-direction: periodic wrap at i = nx-1
981                let x_next = if i + 1 < nx { x[idx + 1] } else { x[j_offset + k_offset] };
982                gx[idx] = (x_next - x_val) * hx;
983
984                // y-direction: periodic wrap at j = ny-1
985                let y_next = if j + 1 < ny { x[i + (j + 1) * nx + k_offset] } else { x[i + k_offset] };
986                gy[idx] = (y_next - x_val) * hy;
987
988                // z-direction: periodic wrap at k = nz-1
989                let z_next = if k + 1 < nz { x[i + j_offset + (k + 1) * nxny] } else { x[i + j_offset] };
990                gz[idx] = (z_next - x_val) * hz;
991            }
992        }
993    }
994}
995
996/// Backward divergence with periodic boundary conditions (f32, in-place)
997/// Adjoint of fgrad_periodic_inplace_f32, matching MATLAB's gradfp_adj_mex.
998/// At boundaries, wraps around: at i=0, uses gx(end) instead of zero.
999#[inline]
1000pub(crate) fn bdiv_periodic_inplace_f32(
1001    div: &mut [f32],
1002    gx: &[f32], gy: &[f32], gz: &[f32],
1003    nx: usize, ny: usize, nz: usize,
1004    vsx: f32, vsy: f32, vsz: f32,
1005) {
1006    let hx = -1.0 / vsx;  // Negative for adjoint
1007    let hy = -1.0 / vsy;
1008    let hz = -1.0 / vsz;
1009    let nxny = nx * ny;
1010
1011    for k in 0..nz {
1012        let k_offset = k * nxny;
1013
1014        for j in 0..ny {
1015            let j_offset = j * nx;
1016
1017            for i in 0..nx {
1018                let idx = i + j_offset + k_offset;
1019
1020                // x-direction: at i=0, wrap to gx[nx-1,j,k]
1021                let gx_prev = if i > 0 { gx[idx - 1] } else { gx[(nx - 1) + j_offset + k_offset] };
1022                let gx_term = (gx[idx] - gx_prev) * hx;
1023
1024                // y-direction: at j=0, wrap to gy[i,ny-1,k]
1025                let gy_prev = if j > 0 { gy[i + (j - 1) * nx + k_offset] } else { gy[i + (ny - 1) * nx + k_offset] };
1026                let gy_term = (gy[idx] - gy_prev) * hy;
1027
1028                // z-direction: at k=0, wrap to gz[i,j,nz-1]
1029                let gz_prev = if k > 0 { gz[i + j_offset + (k - 1) * nxny] } else { gz[i + j_offset + (nz - 1) * nxny] };
1030                let gz_term = (gz[idx] - gz_prev) * hz;
1031
1032                div[idx] = gx_term + gy_term + gz_term;
1033            }
1034        }
1035    }
1036}
1037
1038/// MEDI L1 with progress callback (OPTIMIZED f32 VERSION)
1039///
1040/// Same as `medi_l1` but calls `progress_callback(iteration, max_iter)` each iteration.
1041#[allow(clippy::too_many_arguments)]
1042pub fn medi_l1_with_progress<F>(
1043    local_field: &[f64],
1044    n_std: &[f64],
1045    magnitude: &[f64],
1046    mask: &[u8],
1047    nx: usize, ny: usize, nz: usize,
1048    vsx: f64, vsy: f64, vsz: f64,
1049    lambda: f64,
1050    bdir: (f64, f64, f64),
1051    merit: bool,
1052    smv: bool,
1053    smv_radius: f64,
1054    data_weighting: i32,
1055    percentage: f64,
1056    cg_tol: f64,
1057    cg_max_iter: usize,
1058    max_iter: usize,
1059    tol: f64,
1060    mut progress_callback: F,
1061) -> Vec<f64>
1062where
1063    F: FnMut(usize, usize),
1064{
1065    let n_total = nx * ny * nz;
1066
1067    // Convert to f32 for internal computation
1068    let vsx_f32 = vsx as f32;
1069    let vsy_f32 = vsy as f32;
1070    let vsz_f32 = vsz as f32;
1071    let lambda_f32 = lambda as f32;
1072    let bdir_f32 = (bdir.0 as f32, bdir.1 as f32, bdir.2 as f32);
1073    let smv_radius_f32 = smv_radius as f32;
1074    let percentage_f32 = percentage as f32;
1075    let cg_tol_f32 = cg_tol as f32;
1076    let tol_f32 = tol as f32;
1077
1078    // Convert input arrays to f32
1079    let local_field_f32: Vec<f32> = local_field.iter().map(|&v| v as f32).collect();
1080    let n_std_f32: Vec<f32> = n_std.iter().map(|&v| v as f32).collect();
1081    let magnitude_f32: Vec<f32> = magnitude.iter().map(|&v| v as f32).collect();
1082
1083    // Create workspace - allocates all buffers ONCE
1084    let mut ws = MediWorkspace::new(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
1085
1086    // Working copies that may be modified by SMV preprocessing
1087    let mut rdf: Vec<f32> = local_field_f32.clone();
1088    let mut work_mask: Vec<u8> = mask.to_vec();
1089    let mut tempn: Vec<f32> = n_std_f32.clone();
1090
1091    // Apply mask to N_std
1092    for i in 0..n_total {
1093        if mask[i] == 0 {
1094            tempn[i] = 0.0;
1095        }
1096    }
1097
1098    // Generate dipole kernel
1099    let mut d_kernel = dipole_kernel_f32(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, bdir_f32);
1100
1101    // SMV preprocessing (optional)
1102    let sphere_k = if smv {
1103        let sk = smv_kernel_f32(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, smv_radius_f32);
1104
1105        // FFT of sphere kernel for convolution
1106        let mut sk_fft: Vec<Complex32> = sk.iter()
1107            .map(|&v| Complex32::new(v, 0.0))
1108            .collect();
1109        ws.fft_ws.fft3d(&mut sk_fft);
1110
1111        // Erode mask: SMV(mask) > 0.999
1112        let mask_f32: Vec<f32> = work_mask.iter().map(|&m| m as f32).collect();
1113        let smv_mask = apply_smv_kernel_ws(&mask_f32, &sk_fft, &mut ws);
1114        for i in 0..n_total {
1115            work_mask[i] = if smv_mask[i] > 0.999 { 1 } else { 0 };
1116        }
1117
1118        // Modify dipole kernel: D = (1 - SphereK) * D
1119        for i in 0..n_total {
1120            d_kernel[i] *= 1.0 - sk[i];
1121        }
1122
1123        // Modify RDF: RDF = RDF - SMV(RDF)
1124        let smv_rdf = apply_smv_kernel_ws(&rdf, &sk_fft, &mut ws);
1125        for i in 0..n_total {
1126            rdf[i] -= smv_rdf[i];
1127            if work_mask[i] == 0 {
1128                rdf[i] = 0.0;
1129            }
1130        }
1131
1132        // Modify noise: tempn = sqrt(SMV(tempn^2) + tempn^2)
1133        let tempn_sq: Vec<f32> = tempn.iter().map(|&t| t * t).collect();
1134        let smv_tempn_sq = apply_smv_kernel_ws(&tempn_sq, &sk_fft, &mut ws);
1135        for i in 0..n_total {
1136            tempn[i] = (smv_tempn_sq[i] + tempn_sq[i]).sqrt();
1137        }
1138
1139        Some(sk_fft)
1140    } else {
1141        None
1142    };
1143
1144    // Compute data weighting
1145    let mut m = dataterm_mask_f32(data_weighting, &tempn, &work_mask);
1146
1147    // b0 = m * exp(i * RDF)
1148    let mut b0: Vec<Complex32> = rdf.iter()
1149        .zip(m.iter())
1150        .map(|(&f, &mi)| {
1151            let phase = Complex32::new(0.0, f);
1152            mi * phase.exp()
1153        })
1154        .collect();
1155
1156    // Compute per-direction gradient weighting masks from magnitude edges
1157    // Returns (mx, my, mz) - separate masks for each gradient direction (matching MATLAB MEDI)
1158    let (w_gx, w_gy, w_gz) = gradient_mask_f32(&magnitude_f32, &work_mask, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, percentage_f32);
1159
1160    // Fallback: if any mask is all zeros, use magnitude image (matching MATLAB)
1161    let w_gx = if w_gx.iter().any(|&v| v != 0.0) { w_gx } else { magnitude_f32.clone() };
1162    let w_gy = if w_gy.iter().any(|&v| v != 0.0) { w_gy } else { magnitude_f32.clone() };
1163    let w_gz = if w_gz.iter().any(|&v| v != 0.0) { w_gz } else { magnitude_f32.clone() };
1164
1165    // Initialize susceptibility and reusable buffers
1166    let mut chi = vec![0.0f32; n_total];
1167    let mut dx = vec![0.0f32; n_total];
1168    let mut rhs = vec![0.0f32; n_total];
1169    let mut vr = vec![0.0f32; n_total];
1170    let mut w: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); n_total];
1171    let mut chi_prev = vec![0.0f32; n_total];
1172    let mut badpoint = vec![0.0f32; n_total];
1173    let mut n_std_work: Vec<f32> = n_std_f32.clone();
1174
1175    // MATLAB: beta = sqrt(eps(class(f))) where eps for f64 ≈ 2.22e-16, so sqrt(eps) ≈ 1.49e-8.
1176    // This is a regularization parameter for the P weight denominator, not a precision limit.
1177    // Using the same value as MATLAB (1.49e-8) is fine in f32 (representable, well above f32 eps).
1178    let beta = 1.49e-8_f32;
1179
1180    // Total progress = GN iterations * CG iterations per GN
1181    let total_steps = max_iter * cg_max_iter;
1182
1183    // Gauss-Newton iterations
1184    for iter in 0..max_iter {
1185        chi_prev.copy_from_slice(&chi);
1186
1187        // Compute P = 1 / sqrt(|m * grad(chi)|^2 + beta) using per-direction masks (SIMD accelerated)
1188        // MATLAB: P = 1 ./ sqrt(ux.*ux + uy.*uy + uz.*uz + beta);
1189        // Uses periodic BCs matching MATLAB's grad_ (which calls gradfp_mex)
1190        fgrad_periodic_inplace_f32(
1191            &mut ws.gx, &mut ws.gy, &mut ws.gz,
1192            &chi, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32,
1193        );
1194
1195        compute_p_weights_f32(&mut vr, &w_gx, &w_gy, &w_gz, &ws.gx, &ws.gy, &ws.gz, beta);
1196
1197        // Compute w = m * exp(i * D*chi)
1198        apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
1199        for i in 0..n_total {
1200            let phase = Complex32::new(0.0, ws.dipole_buf[i]);
1201            w[i] = m[i] * phase.exp();
1202        }
1203
1204        // Compute right-hand side
1205        compute_rhs_inplace(&chi, &w, &b0, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda_f32, &mut rhs, &mut ws);
1206
1207        // Negate for CG (solving A*dx = -b) (SIMD accelerated)
1208        negate_f32(&mut rhs);
1209
1210        // Solve A*dx = rhs using optimized CG with combined progress reporting
1211        // Progress = (gn_iter * cg_max_iter + cg_iter) / (max_iter * cg_max_iter)
1212        let gn_iter = iter;
1213        cg_solve_medi(
1214            &mut ws, &w, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda_f32, &rhs, &mut dx, cg_tol_f32, cg_max_iter,
1215            |cg_iter, cg_total| {
1216                let current = gn_iter * cg_total + cg_iter;
1217                progress_callback(current, total_steps);
1218            }
1219        );
1220
1221        // Update: chi = chi + dx (SIMD accelerated)
1222        axpy_f32(&mut chi, 1.0, &dx);
1223
1224        // Check convergence (SIMD accelerated)
1225        let norm_dx_sq = norm_squared_f32(&dx);
1226        let norm_chi_sq = norm_squared_f32(&chi_prev);
1227        let rel_change = norm_dx_sq.sqrt() / (norm_chi_sq.sqrt() + 1e-6);
1228
1229        // Merit adjustment (optional)
1230        if merit {
1231            // Compute residual: wres = m * exp(i * D*chi) - b0
1232            apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
1233            let mut wres: Vec<Complex32> = ws.dipole_buf.iter()
1234                .zip(m.iter())
1235                .zip(b0.iter())
1236                .map(|((&dc, &mi), &b0i)| {
1237                    let phase = Complex32::new(0.0, dc);
1238                    mi * phase.exp() - b0i
1239                })
1240                .collect();
1241
1242            // Subtract mean over mask
1243            let mask_count = work_mask.iter().filter(|&&m| m != 0).count() as f32;
1244            if mask_count > 0.0 {
1245                let mean_wres: Complex32 = wres.iter()
1246                    .zip(work_mask.iter())
1247                    .filter(|(_, &m)| m != 0)
1248                    .map(|(w, _)| w)
1249                    .sum::<Complex32>() / mask_count;
1250
1251                for i in 0..n_total {
1252                    if work_mask[i] != 0 {
1253                        wres[i] -= mean_wres;
1254                    }
1255                }
1256            }
1257
1258            // Compute factor = std(abs(wres[mask])) * 6
1259            let abs_wres: Vec<f32> = wres.iter()
1260                .zip(work_mask.iter())
1261                .filter(|(_, &m)| m != 0)
1262                .map(|(w, _)| w.norm())
1263                .collect();
1264
1265            if !abs_wres.is_empty() {
1266                let mean_abs: f32 = abs_wres.iter().sum::<f32>() / abs_wres.len() as f32;
1267                let var: f32 = abs_wres.iter()
1268                    .map(|&v| (v - mean_abs).powi(2))
1269                    .sum::<f32>() / abs_wres.len() as f32;
1270                let factor = var.sqrt() * 6.0;
1271
1272                if factor > 1e-10 {
1273                    // Normalize wres by factor
1274                    let mut wres_norm: Vec<f32> = wres.iter()
1275                        .map(|w| w.norm() / factor)
1276                        .collect();
1277
1278                    // Clamp values < 1 to 1
1279                    for v in wres_norm.iter_mut() {
1280                        if *v < 1.0 {
1281                            *v = 1.0;
1282                        }
1283                    }
1284
1285                    // Mark bad points and update noise
1286                    for i in 0..n_total {
1287                        if wres_norm[i] > 1.0 {
1288                            badpoint[i] = 1.0;
1289                        }
1290                        if work_mask[i] != 0 {
1291                            n_std_work[i] *= wres_norm[i].powi(2);
1292                        }
1293                    }
1294
1295                    // Recompute tempn
1296                    tempn = n_std_work.clone();
1297                    if let Some(ref sk_fft) = sphere_k {
1298                        let tempn_sq: Vec<f32> = tempn.iter().map(|&t| t * t).collect();
1299                        let smv_tempn_sq = apply_smv_kernel_ws(&tempn_sq, sk_fft, &mut ws);
1300                        for i in 0..n_total {
1301                            tempn[i] = (smv_tempn_sq[i] + tempn_sq[i]).sqrt();
1302                        }
1303                    }
1304
1305                    // Recompute data weighting and b0
1306                    m = dataterm_mask_f32(data_weighting, &tempn, &work_mask);
1307                    b0 = rdf.iter()
1308                        .zip(m.iter())
1309                        .map(|(&f, &mi)| {
1310                            let phase = Complex32::new(0.0, f);
1311                            mi * phase.exp()
1312                        })
1313                        .collect();
1314                }
1315            }
1316        }
1317
1318        if rel_change < tol_f32 {
1319            // Report completion on early convergence
1320            progress_callback(total_steps, total_steps);
1321            break;
1322        }
1323    }
1324
1325    // Suppress unused variable warning
1326    let _ = badpoint;
1327
1328    // Apply mask and convert back to f64
1329    chi.iter()
1330        .zip(mask.iter())
1331        .map(|(&c, &m)| if m == 0 { 0.0 } else { c as f64 })
1332        .collect()
1333}
1334
1335/// MEDI with default parameters (backward compatible)
1336pub fn medi_l1_default(
1337    local_field: &[f64],
1338    mask: &[u8],
1339    magnitude: &[f64],
1340    nx: usize, ny: usize, nz: usize,
1341    vsx: f64, vsy: f64, vsz: f64,
1342) -> Vec<f64> {
1343    // Create uniform noise std (no SNR weighting)
1344    let n_std = vec![1.0; local_field.len()];
1345
1346    medi_l1(
1347        local_field,
1348        &n_std,
1349        magnitude,
1350        mask,
1351        nx, ny, nz,
1352        vsx, vsy, vsz,
1353        7.5e-5,            // lambda (matching MATLAB default)
1354        (0.0, 0.0, 1.0),   // bdir
1355        false,             // merit
1356        false,             // smv
1357        5.0,               // smv_radius
1358        1,                 // data_weighting (SNR mode)
1359        0.3,               // percentage (30% edges, matching MATLAB gpct=30)
1360        0.01,              // cg_tol
1361        10,                // cg_max_iter (matching MATLAB default)
1362        30,                // max_iter (matching MATLAB default)
1363        0.1,               // tol
1364    )
1365}
1366
1367#[cfg(test)]
1368mod tests {
1369    use super::*;
1370
1371    #[test]
1372    fn test_dataterm_mask_uniform() {
1373        let n_std = vec![1.0f32; 27];
1374        let mask = vec![1u8; 27];
1375
1376        let w = dataterm_mask_f32(0, &n_std, &mask);
1377
1378        for &wi in w.iter() {
1379            assert!((wi - 1.0).abs() < 1e-10);
1380        }
1381    }
1382
1383    #[test]
1384    fn test_dataterm_mask_snr() {
1385        let n_std = vec![2.0f32; 27];
1386        let mask = vec![1u8; 27];
1387
1388        let w = dataterm_mask_f32(1, &n_std, &mask);
1389
1390        // Mean should be 1
1391        let mean: f32 = w.iter().sum::<f32>() / 27.0;
1392        assert!((mean - 1.0).abs() < 1e-5);
1393    }
1394
1395    #[test]
1396    fn test_gradient_mask_constant() {
1397        // Constant magnitude should have no edges (all gradients are zero)
1398        let mag = vec![1.0f32; 8 * 8 * 8];
1399        let mask = vec![1u8; 8 * 8 * 8];
1400
1401        let (mx, my, mz) = gradient_mask_f32(&mag, &mask, 8, 8, 8, 1.0, 1.0, 1.0, 0.3);
1402
1403        // All should be binary masks (0 or 1)
1404        for i in 0..(8 * 8 * 8) {
1405            assert!(mx[i] == 0.0 || mx[i] == 1.0, "mx should be binary, got {}", mx[i]);
1406            assert!(my[i] == 0.0 || my[i] == 1.0, "my should be binary, got {}", my[i]);
1407            assert!(mz[i] == 0.0 || mz[i] == 1.0, "mz should be binary, got {}", mz[i]);
1408        }
1409    }
1410
1411    #[test]
1412    fn test_medi_zero_field() {
1413        let n = 8;
1414        let field = vec![0.0; n * n * n];
1415        let mask = vec![1u8; n * n * n];
1416        let mag = vec![1.0; n * n * n];
1417        let n_std = vec![1.0; n * n * n];
1418
1419        let chi = medi_l1(
1420            &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1421            1000.0, (0.0, 0.0, 1.0), false, false, 5.0, 1, 0.9, 0.1, 10, 3, 0.1
1422        );
1423
1424        for &val in chi.iter() {
1425            assert!(val.abs() < 1e-4, "Zero field should give near-zero chi, got {}", val);
1426        }
1427    }
1428
1429    #[test]
1430    fn test_medi_finite() {
1431        let n = 8;
1432        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
1433        let mask = vec![1u8; n * n * n];
1434        let mag = vec![1.0; n * n * n];
1435        let n_std = vec![1.0; n * n * n];
1436
1437        let chi = medi_l1(
1438            &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1439            1000.0, (0.0, 0.0, 1.0), false, false, 5.0, 1, 0.9, 0.1, 10, 3, 0.1
1440        );
1441
1442        for (i, &val) in chi.iter().enumerate() {
1443            assert!(val.is_finite(), "Chi should be finite at index {}", i);
1444        }
1445    }
1446
1447    #[test]
1448    fn test_medi_with_smv() {
1449        let n = 8;
1450        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
1451        let mask = vec![1u8; n * n * n];
1452        let mag = vec![1.0; n * n * n];
1453        let n_std = vec![1.0; n * n * n];
1454
1455        // Test with SMV enabled
1456        let chi = medi_l1(
1457            &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1458            1000.0, (0.0, 0.0, 1.0), false, true, 2.0, 1, 0.9, 0.1, 10, 3, 0.1
1459        );
1460
1461        for (i, &val) in chi.iter().enumerate() {
1462            assert!(val.is_finite(), "Chi with SMV should be finite at index {}", i);
1463        }
1464    }
1465
1466    #[test]
1467    fn test_medi_mask() {
1468        let n = 8;
1469        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
1470        let mut mask = vec![1u8; n * n * n];
1471        let mag = vec![1.0; n * n * n];
1472        let n_std = vec![1.0; n * n * n];
1473        mask[0] = 0;
1474        mask[10] = 0;
1475
1476        let chi = medi_l1(
1477            &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1478            1000.0, (0.0, 0.0, 1.0), false, false, 5.0, 1, 0.9, 0.1, 10, 3, 0.1
1479        );
1480
1481        assert_eq!(chi[0], 0.0, "Masked voxel should be zero");
1482        assert_eq!(chi[10], 0.0, "Masked voxel should be zero");
1483    }
1484
1485    /// Debug test: run MEDI step-by-step on real data, saving intermediates
1486    /// for comparison with Octave reference (other/medi_debug_octave.m).
1487    /// Run with: cargo test --release test_medi_debug -- --ignored --nocapture
1488    #[test]
1489    #[ignore]
1490    fn test_medi_debug() {
1491        let data_path = "/home/ashley/OUT/2bgRemoved.nii";
1492        if !std::path::Path::new(data_path).exists() {
1493            eprintln!("Skipping: {} not found", data_path);
1494            return;
1495        }
1496
1497        let outdir = "/home/ashley/OUT/debug";
1498        std::fs::create_dir_all(outdir).ok();
1499
1500        // Read NIfTI
1501        let bytes = std::fs::read(data_path).unwrap();
1502        let nifti_data = crate::nifti_io::load_nifti(&bytes).unwrap();
1503        let (nx, ny, nz) = nifti_data.dims;
1504        let (vsx, vsy, vsz) = nifti_data.voxel_size;
1505
1506        let n_total = nx * ny * nz;
1507        let vsx_f32 = vsx as f32;
1508        let vsy_f32 = vsy as f32;
1509        let vsz_f32 = vsz as f32;
1510
1511        eprintln!("Data: {}x{}x{}, voxel: {}x{}x{}", nx, ny, nz, vsx, vsy, vsz);
1512
1513        // Convert to f32
1514        let local_field: Vec<f32> = nifti_data.data.iter().map(|&v| v as f32).collect();
1515
1516        // Create mask from non-zero voxels
1517        let mask: Vec<u8> = local_field.iter()
1518            .map(|&v| if v.abs() > 1e-10 { 1 } else { 0 })
1519            .collect();
1520        let mask_count: usize = mask.iter().filter(|&&m| m != 0).count();
1521        eprintln!("Mask voxels: {} / {}", mask_count, n_total);
1522
1523        // Save inputs
1524        save_f32_raw(&local_field, &format!("{}/f_rust.raw", outdir));
1525        let mask_f32: Vec<f32> = mask.iter().map(|&m| m as f32).collect();
1526        save_f32_raw(&mask_f32, &format!("{}/mask_rust.raw", outdir));
1527
1528        // Parameters (matching MATLAB defaults)
1529        let lambda: f32 = 7.5e-5;
1530        let beta: f32 = 1.49e-8;
1531        let bdir = (0.0f32, 0.0f32, 1.0f32);
1532        let cg_tol: f32 = 0.1;  // Match MATLAB tolcg default
1533        let cg_max_iter: usize = 10;
1534
1535        // Dipole kernel
1536        let d_kernel = crate::kernels::dipole::dipole_kernel_f32(
1537            nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, bdir,
1538        );
1539        save_f32_raw(&d_kernel, &format!("{}/D_rust.raw", outdir));
1540        eprintln!("D: min={} max={} D[0]={}", fmin(&d_kernel), fmax(&d_kernel), d_kernel[0]);
1541
1542        // Workspace
1543        let mut ws = MediWorkspace::new(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
1544
1545        // Data weighting: uniform m = mask (matching w=ones in Octave)
1546        let m: Vec<f32> = mask.iter().map(|&m| if m != 0 { 1.0 } else { 0.0 }).collect();
1547
1548        // b0 = m * exp(i * f)
1549        let b0: Vec<Complex32> = local_field.iter().zip(m.iter())
1550            .map(|(&f, &mi)| {
1551                let phase = Complex32::new(0.0, f);
1552                mi * phase.exp()
1553            })
1554            .collect();
1555
1556        // Gradient mask: uniform magnitude -> mx=my=mz=mask
1557        let w_gx: Vec<f32> = m.clone();
1558        let w_gy: Vec<f32> = m.clone();
1559        let w_gz: Vec<f32> = m.clone();
1560
1561        // ===== ITERATION 1 (chi = 0) =====
1562        eprintln!("\n=== Iteration 1 ===");
1563        let mut chi = vec![0.0f32; n_total];
1564        let mut dx = vec![0.0f32; n_total];
1565        let mut rhs = vec![0.0f32; n_total];
1566        let mut vr = vec![0.0f32; n_total];
1567        let mut w: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); n_total];
1568
1569        // P weights (chi=0 -> gradient=0 -> P = 1/sqrt(beta))
1570        fgrad_periodic_inplace_f32(
1571            &mut ws.gx, &mut ws.gy, &mut ws.gz,
1572            &chi, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32,
1573        );
1574        compute_p_weights_f32(&mut vr, &w_gx, &w_gy, &w_gz, &ws.gx, &ws.gy, &ws.gz, beta);
1575        save_f32_raw(&vr, &format!("{}/P1_rust.raw", outdir));
1576        eprintln!("P1: min={} max={} mean={}", fmin(&vr), fmax(&vr), fmean(&vr));
1577
1578        // w = m * exp(i * D*chi) = m (since chi=0)
1579        apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
1580        for i in 0..n_total {
1581            let phase = Complex32::new(0.0, ws.dipole_buf[i]);
1582            w[i] = m[i] * phase.exp();
1583        }
1584
1585        // RHS
1586        compute_rhs_inplace(
1587            &chi, &w, &b0, &d_kernel,
1588            &w_gx, &w_gy, &w_gz, &vr, lambda,
1589            &mut rhs, &mut ws,
1590        );
1591        save_f32_raw(&rhs, &format!("{}/rhs1_rust.raw", outdir));
1592        eprintln!("RHS1: min={} max={} norm={}", fmin(&rhs), fmax(&rhs), fnorm(&rhs));
1593
1594        // Negate for CG
1595        negate_f32(&mut rhs);
1596
1597        // CG solve (with iteration-level residual logging)
1598        let mut cg_residuals: Vec<f32> = Vec::new();
1599        {
1600            // Manual CG to capture residuals (matching cg_solve_medi but with logging)
1601            let n = ws.n_total;
1602            let (nx, ny, nz) = (ws.nx, ws.ny, ws.nz);
1603            let (vsx, vsy, vsz) = (ws.vsx, ws.vsy, ws.vsz);
1604
1605            dx.fill(0.0);
1606            ws.cg_r.copy_from_slice(&rhs);
1607            ws.cg_p.copy_from_slice(&ws.cg_r);
1608            let mut rsold: f32 = norm_squared_f32(&ws.cg_r);
1609            let b_norm: f32 = norm_squared_f32(&rhs).sqrt();
1610
1611            let mut p_copy = vec![0.0f32; n];
1612            let mut prev_residual = rsold.sqrt();
1613
1614            for cg_iter in 0..cg_max_iter {
1615                let residual_before = rsold.sqrt();
1616                cg_residuals.push(residual_before);
1617
1618                p_copy.copy_from_slice(&ws.cg_p);
1619                {
1620                    let mut bufs = MediOpBuffers {
1621                        gx: &mut ws.gx, gy: &mut ws.gy, gz: &mut ws.gz,
1622                        reg_x: &mut ws.reg_x, reg_y: &mut ws.reg_y, reg_z: &mut ws.reg_z,
1623                        div_buf: &mut ws.div_buf, dipole_buf: &mut ws.dipole_buf,
1624                        complex_buf: &mut ws.complex_buf, complex_buf2: &mut ws.complex_buf2,
1625                    };
1626                    apply_medi_operator_core(
1627                        &mut ws.fft_ws, &mut bufs, n, nx, ny, nz, vsx, vsy, vsz,
1628                        &p_copy, &w, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda, &mut ws.cg_ap,
1629                    );
1630                }
1631
1632                let pap: f32 = dot_product_f32(&ws.cg_p, &ws.cg_ap);
1633                if pap.abs() < 1e-15 { break; }
1634                let alpha = rsold / pap;
1635
1636                axpy_f32(&mut dx, alpha, &ws.cg_p);
1637                axpy_f32(&mut ws.cg_r, -alpha, &ws.cg_ap);
1638
1639                let rsnew: f32 = norm_squared_f32(&ws.cg_r);
1640                let residual = rsnew.sqrt();
1641
1642                eprintln!("  CG iter {}: res={:.6e}, alpha={:.6e}, pap={:.6e}",
1643                    cg_iter + 1, residual, alpha, pap);
1644
1645                if residual < cg_tol * b_norm { break; }
1646
1647                // No stall detection in this debug version (matching MATLAB)
1648
1649                let beta_cg = rsnew / rsold;
1650                xpby_f32(&mut ws.cg_p, &ws.cg_r, beta_cg);
1651                rsold = rsnew;
1652                prev_residual = residual;
1653            }
1654        }
1655        save_f32_raw(&dx, &format!("{}/dx1_rust.raw", outdir));
1656        eprintln!("dx1: min={} max={} norm={}", fmin(&dx), fmax(&dx), fnorm(&dx));
1657
1658        // Update chi
1659        axpy_f32(&mut chi, 1.0, &dx);
1660        save_f32_raw(&chi, &format!("{}/chi1_rust.raw", outdir));
1661        eprintln!("chi1: min={} max={} norm={}", fmin(&chi), fmax(&chi), fnorm(&chi));
1662
1663        // ===== ITERATION 2 =====
1664        eprintln!("\n=== Iteration 2 ===");
1665
1666        // P weights
1667        fgrad_periodic_inplace_f32(
1668            &mut ws.gx, &mut ws.gy, &mut ws.gz,
1669            &chi, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32,
1670        );
1671        compute_p_weights_f32(&mut vr, &w_gx, &w_gy, &w_gz, &ws.gx, &ws.gy, &ws.gz, beta);
1672        save_f32_raw(&vr, &format!("{}/P2_rust.raw", outdir));
1673        eprintln!("P2: min={} max={} mean={}", fmin(&vr), fmax(&vr), fmean(&vr));
1674
1675        // w = m * exp(i * D*chi)
1676        apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
1677        for i in 0..n_total {
1678            let phase = Complex32::new(0.0, ws.dipole_buf[i]);
1679            w[i] = m[i] * phase.exp();
1680        }
1681
1682        // RHS
1683        compute_rhs_inplace(
1684            &chi, &w, &b0, &d_kernel,
1685            &w_gx, &w_gy, &w_gz, &vr, lambda,
1686            &mut rhs, &mut ws,
1687        );
1688        save_f32_raw(&rhs, &format!("{}/rhs2_rust.raw", outdir));
1689        eprintln!("RHS2: min={} max={} norm={}", fmin(&rhs), fmax(&rhs), fnorm(&rhs));
1690
1691        negate_f32(&mut rhs);
1692
1693        // CG solve (no stall detection)
1694        {
1695            let n = ws.n_total;
1696            let (nx, ny, nz) = (ws.nx, ws.ny, ws.nz);
1697            let (vsx, vsy, vsz) = (ws.vsx, ws.vsy, ws.vsz);
1698
1699            dx.fill(0.0);
1700            ws.cg_r.copy_from_slice(&rhs);
1701            ws.cg_p.copy_from_slice(&ws.cg_r);
1702            let mut rsold: f32 = norm_squared_f32(&ws.cg_r);
1703            let b_norm: f32 = norm_squared_f32(&rhs).sqrt();
1704            let mut p_copy = vec![0.0f32; n];
1705
1706            for cg_iter in 0..cg_max_iter {
1707                p_copy.copy_from_slice(&ws.cg_p);
1708                {
1709                    let mut bufs = MediOpBuffers {
1710                        gx: &mut ws.gx, gy: &mut ws.gy, gz: &mut ws.gz,
1711                        reg_x: &mut ws.reg_x, reg_y: &mut ws.reg_y, reg_z: &mut ws.reg_z,
1712                        div_buf: &mut ws.div_buf, dipole_buf: &mut ws.dipole_buf,
1713                        complex_buf: &mut ws.complex_buf, complex_buf2: &mut ws.complex_buf2,
1714                    };
1715                    apply_medi_operator_core(
1716                        &mut ws.fft_ws, &mut bufs, n, nx, ny, nz, vsx, vsy, vsz,
1717                        &p_copy, &w, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda, &mut ws.cg_ap,
1718                    );
1719                }
1720                let pap: f32 = dot_product_f32(&ws.cg_p, &ws.cg_ap);
1721                if pap.abs() < 1e-15 { break; }
1722                let alpha = rsold / pap;
1723                axpy_f32(&mut dx, alpha, &ws.cg_p);
1724                axpy_f32(&mut ws.cg_r, -alpha, &ws.cg_ap);
1725                let rsnew: f32 = norm_squared_f32(&ws.cg_r);
1726                let residual = rsnew.sqrt();
1727                eprintln!("  CG iter {}: res={:.6e}", cg_iter + 1, residual);
1728                if residual < cg_tol * b_norm { break; }
1729                let beta_cg = rsnew / rsold;
1730                xpby_f32(&mut ws.cg_p, &ws.cg_r, beta_cg);
1731                rsold = rsnew;
1732            }
1733        }
1734        save_f32_raw(&dx, &format!("{}/dx2_rust.raw", outdir));
1735        eprintln!("dx2: min={} max={} norm={}", fmin(&dx), fmax(&dx), fnorm(&dx));
1736
1737        axpy_f32(&mut chi, 1.0, &dx);
1738        save_f32_raw(&chi, &format!("{}/chi2_rust.raw", outdir));
1739        eprintln!("chi2: min={} max={} norm={}", fmin(&chi), fmax(&chi), fnorm(&chi));
1740
1741        eprintln!("\nDone. Intermediates saved to {}", outdir);
1742    }
1743
1744    fn save_f32_raw(data: &[f32], path: &str) {
1745        use std::io::Write;
1746        let mut file = std::fs::File::create(path).unwrap();
1747        for &val in data {
1748            file.write_all(&val.to_le_bytes()).unwrap();
1749        }
1750    }
1751
1752    fn fmin(data: &[f32]) -> f32 { data.iter().cloned().fold(f32::MAX, f32::min) }
1753    fn fmax(data: &[f32]) -> f32 { data.iter().cloned().fold(f32::MIN, f32::max) }
1754    fn fmean(data: &[f32]) -> f32 { data.iter().sum::<f32>() / data.len() as f32 }
1755    fn fnorm(data: &[f32]) -> f32 { data.iter().map(|&v| v * v).sum::<f32>().sqrt() }
1756
1757    #[test]
1758    fn test_medi_l1_small() {
1759        // 8x8x8 volume with synthetic local field
1760        let n = 8;
1761        let n_total = n * n * n;
1762
1763        // Create a synthetic local field from a dipole-like source
1764        let mut field = vec![0.0f64; n_total];
1765        let center = n / 2;
1766        for z in 0..n {
1767            for y in 0..n {
1768                for x in 0..n {
1769                    let idx = x + y * n + z * n * n;
1770                    let dx = (x as f64) - (center as f64);
1771                    let dy = (y as f64) - (center as f64);
1772                    let dz = (z as f64) - (center as f64);
1773                    let r2 = dx*dx + dy*dy + dz*dz;
1774                    if r2 > 1.0 {
1775                        // Dipole field pattern: (3*cos^2(theta) - 1) / r^3
1776                        let r = r2.sqrt();
1777                        let cos_theta = dz / r;
1778                        field[idx] = (3.0 * cos_theta * cos_theta - 1.0) / (r * r * r) * 0.01;
1779                    }
1780                }
1781            }
1782        }
1783
1784        let mask = vec![1u8; n_total];
1785        let mag = vec![1.0f64; n_total];
1786        let n_std = vec![1.0f64; n_total];
1787
1788        // Run MEDI with few Gauss-Newton iterations to exercise the main loop
1789        let chi = medi_l1(
1790            &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1791            1e-4,              // lambda
1792            (0.0, 0.0, 1.0),   // bdir
1793            false,             // merit
1794            false,             // smv
1795            5.0,               // smv_radius
1796            1,                 // data_weighting (SNR)
1797            0.3,               // percentage
1798            0.01,              // cg_tol
1799            10,                // cg_max_iter
1800            5,                 // max_iter (enough to exercise the loop)
1801            0.1,               // tol
1802        );
1803
1804        // Result should be finite and same size
1805        assert_eq!(chi.len(), n_total);
1806        for (i, &val) in chi.iter().enumerate() {
1807            assert!(val.is_finite(), "MEDI L1 chi should be finite at index {}", i);
1808        }
1809
1810        // Chi should be non-trivial (not all zero) for a dipole field input
1811        let chi_norm: f64 = chi.iter().map(|&v| v * v).sum::<f64>().sqrt();
1812        assert!(chi_norm > 1e-10, "MEDI L1 should produce non-zero susceptibility for dipole field, got norm={}", chi_norm);
1813    }
1814
1815    #[test]
1816    fn test_medi_weight_types() {
1817        let n = 8;
1818        let n_total = n * n * n;
1819
1820        let field: Vec<f64> = (0..n_total).map(|i| (i as f64) * 0.001).collect();
1821        let mask = vec![1u8; n_total];
1822        let mag = vec![1.0f64; n_total];
1823        let n_std = vec![1.0f64; n_total];
1824
1825        // Test uniform data weighting (mode 0)
1826        let chi_uniform = medi_l1(
1827            &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1828            1000.0, (0.0, 0.0, 1.0), false, false, 5.0,
1829            0,     // uniform data weighting
1830            0.9, 0.1, 10, 3, 0.1
1831        );
1832
1833        // Test SNR data weighting (mode 1) with varying noise
1834        let n_std_varying: Vec<f64> = (0..n_total).map(|i| 0.5 + (i as f64) * 0.01).collect();
1835        let chi_snr = medi_l1(
1836            &field, &n_std_varying, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1837            1000.0, (0.0, 0.0, 1.0), false, false, 5.0,
1838            1,     // SNR data weighting
1839            0.9, 0.1, 10, 3, 0.1
1840        );
1841
1842        // Both should be finite
1843        for (i, &val) in chi_uniform.iter().enumerate() {
1844            assert!(val.is_finite(), "Uniform weighting chi should be finite at {}", i);
1845        }
1846        for (i, &val) in chi_snr.iter().enumerate() {
1847            assert!(val.is_finite(), "SNR weighting chi should be finite at {}", i);
1848        }
1849
1850        // Uniform and SNR weighting should give different results
1851        let diff_norm: f64 = chi_uniform.iter()
1852            .zip(chi_snr.iter())
1853            .map(|(&a, &b)| (a - b).powi(2))
1854            .sum::<f64>()
1855            .sqrt();
1856        // They may or may not differ depending on input, so just check finiteness
1857        assert!(diff_norm.is_finite(), "Difference between weight modes should be finite");
1858    }
1859
1860    #[test]
1861    fn test_medi_l1_with_progress_small() {
1862        let n = 8;
1863        let n_total = n * n * n;
1864
1865        let field: Vec<f64> = (0..n_total).map(|i| (i as f64) * 0.001).collect();
1866        let mask = vec![1u8; n_total];
1867        let mag = vec![1.0f64; n_total];
1868        let n_std = vec![1.0f64; n_total];
1869
1870        let mut progress_calls = 0usize;
1871        let chi = medi_l1_with_progress(
1872            &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1873            1000.0, (0.0, 0.0, 1.0), false, false, 5.0,
1874            1, 0.9, 0.1, 10, 3, 0.1,
1875            |_iter, _max| { progress_calls += 1; },
1876        );
1877
1878        assert_eq!(chi.len(), n_total);
1879        for &val in &chi {
1880            assert!(val.is_finite(), "medi_l1_with_progress output should be finite");
1881        }
1882        assert!(progress_calls > 0, "progress callback should be called at least once");
1883    }
1884
1885    #[test]
1886    fn test_medi_l1_with_merit() {
1887        let n = 8;
1888        let n_total = n * n * n;
1889
1890        let mut field = vec![0.0f64; n_total];
1891        let center = n / 2;
1892        for z in 0..n {
1893            for y in 0..n {
1894                for x in 0..n {
1895                    let idx = x + y * n + z * n * n;
1896                    let dx = (x as f64) - (center as f64);
1897                    let dy = (y as f64) - (center as f64);
1898                    let dz = (z as f64) - (center as f64);
1899                    let r2 = dx * dx + dy * dy + dz * dz;
1900                    if r2 > 1.0 {
1901                        let r = r2.sqrt();
1902                        let cos_theta = dz / r;
1903                        field[idx] = (3.0 * cos_theta * cos_theta - 1.0) / (r * r * r) * 0.01;
1904                    }
1905                }
1906            }
1907        }
1908
1909        let mask = vec![1u8; n_total];
1910        let mag = vec![1.0f64; n_total];
1911        let n_std = vec![1.0f64; n_total];
1912
1913        // Test with merit=true to cover the merit adjustment code path
1914        let chi = medi_l1(
1915            &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1916            1e-4, (0.0, 0.0, 1.0),
1917            true,  // merit enabled
1918            false, 5.0, 1, 0.3, 0.01, 10, 5, 0.1,
1919        );
1920
1921        assert_eq!(chi.len(), n_total);
1922        for &val in &chi {
1923            assert!(val.is_finite(), "MEDI with merit should produce finite results");
1924        }
1925    }
1926
1927    #[test]
1928    fn test_medi_l1_with_smv() {
1929        let n = 8;
1930        let n_total = n * n * n;
1931
1932        let field: Vec<f64> = (0..n_total).map(|i| (i as f64) * 0.001).collect();
1933        let mask = vec![1u8; n_total];
1934        let mag = vec![1.0f64; n_total];
1935        let n_std = vec![1.0f64; n_total];
1936
1937        // Test with smv=true to cover the SMV-weighted code path
1938        let chi = medi_l1(
1939            &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1940            1000.0, (0.0, 0.0, 1.0),
1941            false,
1942            true,  // smv enabled
1943            3.0,   // smv_radius
1944            1, 0.3, 0.01, 10, 3, 0.1,
1945        );
1946
1947        assert_eq!(chi.len(), n_total);
1948        for &val in &chi {
1949            assert!(val.is_finite(), "MEDI with SMV should produce finite results");
1950        }
1951    }
1952}