Skip to main content

qsm_core/inversion/
ilsqr.rs

1//! iLSQR: Iterative LSQR for QSM with streaking artifact removal
2//!
3//! Reference:
4//! Li, W., Wang, N., Yu, F., Han, H., Cao, W., Romero, R., Tantiwongkosi, B.,
5//! Duong, T.Q., Liu, C. (2015). "A method for estimating and removing streaking
6//! artifacts in quantitative susceptibility mapping."
7//! NeuroImage, 108:111-122. https://doi.org/10.1016/j.neuroimage.2014.12.043
8//!
9//! Reference implementation: https://github.com/kamesy/QSM.m
10//!
11//! The algorithm consists of 4 steps:
12//! 1. Initial LSQR solution with Laplacian-based weights
13//! 2. FastQSM estimate using sign(D) approximation
14//! 3. Streaking artifact estimation using LSMR
15//! 4. Artifact subtraction
16
17/// Parameters for the iLSQR algorithm.
18#[derive(Clone, Debug)]
19pub struct IlsqrParams {
20    /// Convergence tolerance (default: 0.01)
21    pub tol: f64,
22    /// Maximum iterations (default: 50)
23    pub max_iter: usize,
24}
25
26impl Default for IlsqrParams {
27    fn default() -> Self {
28        Self {
29            tol: 0.01,
30            max_iter: 50,
31        }
32    }
33}
34
35use std::cell::RefCell;
36use num_complex::Complex64;
37use crate::fft::Fft3dWorkspace;
38use crate::kernels::dipole::dipole_kernel;
39use crate::kernels::smv::smv_kernel;
40use crate::utils::gradient::{fgrad, bdiv};
41
42// ============================================================================
43// LSQR Solver
44// ============================================================================
45
46/// LSQR iterative solver for Ax = b
47///
48/// Solves the least squares problem min ||Ax - b||² using the LSQR algorithm.
49/// Based on Paige & Saunders (1982).
50///
51/// # Arguments
52/// * `apply_a` - Function that computes A*x
53/// * `apply_at` - Function that computes A^T*x
54/// * `b` - Right-hand side vector
55/// * `tol` - Convergence tolerance
56/// * `max_iter` - Maximum iterations
57///
58/// # Returns
59/// Solution vector x
60pub fn lsqr<F, G>(
61    apply_a: F,
62    apply_at: G,
63    b: &[f64],
64    tol: f64,
65    max_iter: usize,
66) -> Vec<f64>
67where
68    F: Fn(&[f64]) -> Vec<f64>,
69    G: Fn(&[f64]) -> Vec<f64>,
70{
71    // Initialize
72    let mut u = b.to_vec();
73    let mut beta = norm(&u);
74
75    if beta > 0.0 {
76        scale_inplace(&mut u, 1.0 / beta);
77    }
78
79    let mut v = apply_at(&u);
80    let n = v.len();
81    let mut alpha = norm(&v);
82
83    if alpha > 0.0 {
84        scale_inplace(&mut v, 1.0 / alpha);
85    }
86
87    let mut w = v.clone();
88    let mut x = vec![0.0; n];
89
90    let mut phi_bar = beta;
91    let mut rho_bar = alpha;
92
93    let bnorm = beta;
94
95    for _iter in 0..max_iter {
96        // Bidiagonalization
97        let mut u_new = apply_a(&v);
98        axpy(&mut u_new, -alpha, &u);
99        beta = norm(&u_new);
100
101        if beta > 0.0 {
102            scale_inplace(&mut u_new, 1.0 / beta);
103        }
104        u = u_new;
105
106        let mut v_new = apply_at(&u);
107        axpy(&mut v_new, -beta, &v);
108        alpha = norm(&v_new);
109
110        if alpha > 0.0 {
111            scale_inplace(&mut v_new, 1.0 / alpha);
112        }
113        v = v_new;
114
115        // Construct and apply rotation
116        let rho = (rho_bar * rho_bar + beta * beta).sqrt();
117        let c = rho_bar / rho;
118        let s = beta / rho;
119        let theta = s * alpha;
120        rho_bar = -c * alpha;
121        let phi = c * phi_bar;
122        phi_bar = s * phi_bar;
123
124        // Update x and w
125        let t1 = phi / rho;
126        let t2 = -theta / rho;
127
128        for i in 0..n {
129            x[i] += t1 * w[i];
130            w[i] = v[i] + t2 * w[i];
131        }
132
133        // Check convergence
134        let rel_residual = phi_bar / (bnorm + 1e-20);
135
136        if rel_residual < tol {
137            break;
138        }
139    }
140
141    x
142}
143
144// ============================================================================
145// LSQR Solver (Complex)
146// ============================================================================
147
148/// Complex norm
149fn norm_complex(x: &[Complex64]) -> f64 {
150    x.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt()
151}
152
153/// Complex scale in place
154fn scale_complex_inplace(x: &mut [Complex64], s: f64) {
155    for v in x.iter_mut() {
156        *v *= s;
157    }
158}
159
160/// Complex axpy: y += a * x
161fn axpy_complex(y: &mut [Complex64], a: f64, x: &[Complex64]) {
162    for (yi, xi) in y.iter_mut().zip(x.iter()) {
163        *yi += a * xi;
164    }
165}
166
167/// LSQR iterative solver for Ax = b (complex version)
168///
169/// Solves the least squares problem min ||Ax - b||² using the LSQR algorithm.
170/// Based on Paige & Saunders (1982), with convergence tests matching MATLAB's lsqr.
171///
172/// Convergence tests (matching MATLAB):
173/// 1. ||r|| / ||b|| <= btol + atol * ||A|| * ||x|| / ||b||  (residual test)
174/// 2. ||A'r|| / (||A|| * ||r||) <= atol  (normal equations test)
175pub fn lsqr_complex<F, G>(
176    apply_a: F,
177    apply_ah: G,
178    b: &[Complex64],
179    tol: f64,
180    max_iter: usize,
181    verbose: bool,
182) -> Vec<Complex64>
183where
184    F: Fn(&[Complex64]) -> Vec<Complex64>,
185    G: Fn(&[Complex64]) -> Vec<Complex64>,
186{
187    // Initialize: beta_1 * u_1 = b
188    let mut u = b.to_vec();
189    let mut beta = norm_complex(&u);
190
191    if beta > 0.0 {
192        scale_complex_inplace(&mut u, 1.0 / beta);
193    }
194
195    // alpha_1 * v_1 = A^H * u_1
196    let mut v = apply_ah(&u);
197    let n = v.len();
198    let mut alpha = norm_complex(&v);
199
200    if alpha > 0.0 {
201        scale_complex_inplace(&mut v, 1.0 / alpha);
202    }
203
204    let mut w = v.clone();
205    let mut x = vec![Complex64::new(0.0, 0.0); n];
206
207    let mut phi_bar = beta;
208    let mut rho_bar = alpha;
209
210    let bnorm = beta;
211    let atol = tol;
212    let btol = tol;
213
214    // Track ||A|| estimate
215    let mut norm_a2 = alpha * alpha;
216
217    // ||x|| estimate using plane rotations (matches MATLAB's built-in lsqr xxnorm)
218    // Verified: produces identical values to exact norm_complex(&x)
219    let mut xxnorm = 0.0;
220    let mut z_sol = 0.0;
221    let mut cs2 = -1.0;
222    let mut sn2 = 0.0;
223
224    if alpha * beta == 0.0 {
225        return x;
226    }
227
228    for _iter in 0..max_iter {
229        // Bidiagonalization step
230        let mut u_new = apply_a(&v);
231        axpy_complex(&mut u_new, -alpha, &u);
232        beta = norm_complex(&u_new);
233
234        if beta > 0.0 {
235            scale_complex_inplace(&mut u_new, 1.0 / beta);
236        }
237        u = u_new;
238
239        let mut v_new = apply_ah(&u);
240        axpy_complex(&mut v_new, -beta, &v);
241        alpha = norm_complex(&v_new);
242
243        if alpha > 0.0 {
244            scale_complex_inplace(&mut v_new, 1.0 / alpha);
245        }
246        v = v_new;
247
248        // Construct and apply Givens rotation
249        let rho = (rho_bar * rho_bar + beta * beta).sqrt();
250        let c = rho_bar / rho;
251        let s = beta / rho;
252        let theta = s * alpha;
253        rho_bar = -c * alpha;
254        let phi = c * phi_bar;
255        phi_bar = s * phi_bar;
256
257        // ||x|| estimation via plane rotations (MATLAB's xxnorm approach)
258        let delta = sn2 * rho;
259        let gambar = -cs2 * rho;
260        let rhs = phi - delta * z_sol;
261        let zbar = rhs / gambar;
262        let xnorm = (xxnorm + zbar * zbar).sqrt();
263        let gamma = (gambar * gambar + theta * theta).sqrt();
264        cs2 = gambar / gamma;
265        sn2 = theta / gamma;
266        z_sol = rhs / gamma;
267        xxnorm += z_sol * z_sol;
268
269        // Update x and w
270        let t1 = phi / rho;
271        let t2 = -theta / rho;
272        for i in 0..n {
273            x[i] += t1 * w[i];
274            w[i] = v[i] + t2 * w[i];
275        }
276
277        // Estimate norms for convergence tests
278        let normr = phi_bar;
279        let norm_ar = alpha * (c * phi_bar).abs();
280
281        norm_a2 += beta * beta + alpha * alpha;
282        let norm_a = norm_a2.sqrt();
283
284        // Convergence tests (matching MATLAB's lsqr)
285        let test1 = normr / (bnorm + 1e-20);
286        let test2 = norm_ar / ((norm_a * normr) + 1e-20);
287        let rtol = btol + atol * norm_a * xnorm / (bnorm + 1e-20);
288
289        if verbose {
290            eprintln!("  LSQR iter {:>3}: ||r||/||b||={:.6e}  ||A'r||/(||A||·||r||)={:.6e}  rtol={:.6e}",
291                _iter + 1, test1, test2, rtol);
292        }
293
294        if test2 <= atol || test1 <= rtol {
295            if verbose {
296                eprintln!("  LSQR converged at iteration {} (test1={:.4e}, test2={:.4e})",
297                    _iter + 1, test1, test2);
298            }
299            break;
300        }
301    }
302
303    x
304}
305
306// ============================================================================
307// LSMR Solver
308// ============================================================================
309
310/// LSMR iterative solver for Ax = b
311///
312/// Solves the least squares problem min ||Ax - b||² using the LSMR algorithm.
313/// Based on Fong & Saunders (2011). More stable than LSQR for ill-conditioned problems.
314///
315/// # Arguments
316/// * `apply_a` - Function that computes A*x
317/// * `apply_at` - Function that computes A^T*x
318/// * `b` - Right-hand side vector
319/// * `n` - Size of solution vector
320/// * `atol` - Absolute tolerance
321/// * `btol` - Relative tolerance
322/// * `max_iter` - Maximum iterations
323/// * `verbose` - Print progress
324///
325/// # Returns
326/// Solution vector x
327pub fn lsmr<F, G>(
328    apply_a: F,
329    apply_at: G,
330    b: &[f64],
331    n: usize,
332    atol: f64,
333    btol: f64,
334    max_iter: usize,
335    _verbose: bool,
336) -> Vec<f64>
337where
338    F: Fn(&[f64]) -> Vec<f64>,
339    G: Fn(&[f64]) -> Vec<f64>,
340{
341    // Reference: Fong & Saunders (2011), "LSMR: An iterative algorithm for
342    // sparse least-squares problems", SIAM J. Sci. Comput.
343    // Based on the official MATLAB implementation by Fong & Saunders.
344
345    // Initialize: beta*u = b, alpha*v = A'*u
346    let mut u = b.to_vec();
347    let mut beta = norm(&u);
348
349    if beta > 0.0 {
350        scale_inplace(&mut u, 1.0 / beta);
351    }
352
353    let mut v = apply_at(&u);
354    let mut alpha = norm(&v);
355
356    if alpha > 0.0 {
357        scale_inplace(&mut v, 1.0 / alpha);
358    }
359
360    // Initialize variables (matching MATLAB reference variable names)
361    let mut alpha_bar = alpha;
362    let mut zeta_bar = alpha * beta;
363    let mut rho = 1.0;
364    let mut rho_bar = 1.0;
365    let mut c_bar = 1.0;
366    let mut s_bar = 0.0;
367
368    let mut h = v.clone();
369    let mut h_bar = vec![0.0; n];
370    let mut x = vec![0.0; n];
371
372    // Variables for ||r|| estimation
373    let normb = beta;
374    let mut betadd = beta;
375    let mut betad = 0.0;
376    let mut rhodold = 1.0;
377    let mut tautildeold = 0.0;
378    let mut thetatilde = 0.0;
379    let mut zeta = 0.0;
380    let d = 0.0;
381
382    // Variables for ||A|| and cond(A) estimation
383    let mut norm_a2 = alpha * alpha;
384    let mut maxrbar = 0.0f64;
385    let mut minrbar = 1e100f64;
386    let conlim = 1e8;
387    let ctol = if conlim > 0.0 { 1.0 / conlim } else { 0.0 };
388
389    // Early exit if A'b = 0
390    if alpha * beta == 0.0 {
391        return x;
392    }
393
394    for _iter in 0..max_iter {
395        // Bidiagonalization
396        let mut u_new = apply_a(&v);
397        axpy(&mut u_new, -alpha, &u);
398        beta = norm(&u_new);
399
400        if beta > 0.0 {
401            scale_inplace(&mut u_new, 1.0 / beta);
402        }
403        u = u_new;
404
405        let mut v_new = apply_at(&u);
406        axpy(&mut v_new, -beta, &v);
407        alpha = norm(&v_new);
408
409        if alpha > 0.0 {
410            scale_inplace(&mut v_new, 1.0 / alpha);
411        }
412        v = v_new;
413
414        // Construct rotation Q_i (undamped: alphahat = alphabar)
415        let rho_old = rho;
416        rho = (alpha_bar * alpha_bar + beta * beta).sqrt();
417        let c = alpha_bar / rho;
418        let s = beta / rho;
419        let theta_new = s * alpha;
420        alpha_bar = c * alpha;
421
422        // Construct rotation Qbar_i
423        let rho_bar_old = rho_bar;
424        let zeta_old = zeta;
425        let theta_bar = s_bar * rho;
426        let rho_temp = c_bar * rho;
427        rho_bar = (rho_temp * rho_temp + theta_new * theta_new).sqrt();
428        c_bar = rho_temp / rho_bar;
429        s_bar = theta_new / rho_bar;
430        zeta = c_bar * zeta_bar;
431        zeta_bar = -s_bar * zeta_bar;
432
433        // Update h_bar, x, h
434        for i in 0..n {
435            h_bar[i] = h[i] - (theta_bar * rho / (rho_old * rho_bar_old)) * h_bar[i];
436            x[i] += (zeta / (rho * rho_bar)) * h_bar[i];
437            h[i] = v[i] - (theta_new / rho) * h[i];
438        }
439
440        // Estimate ||r|| (from reference implementation)
441        // For undamped case: chat=1, shat=0, so betaacute=betadd, betacheck=0
442        let betaacute = betadd;      // chat * betadd (chat=1 for undamped)
443        // betacheck = 0 for undamped (shat=0), so d += 0
444        let betahat = c * betaacute;
445        betadd = -s * betaacute;
446
447        let thetatildeold = thetatilde;
448        let rhotildeold = (rhodold * rhodold + theta_bar * theta_bar).sqrt();
449        let ctildeold = rhodold / rhotildeold;
450        let stildeold = theta_bar / rhotildeold;
451        thetatilde = stildeold * rho_bar;
452        rhodold = ctildeold * rho_bar;
453        betad = -stildeold * betad + ctildeold * betahat;
454
455        tautildeold = (zeta_old - thetatildeold * tautildeold) / rhotildeold;
456        let taud = (zeta - thetatilde * tautildeold) / rhodold;
457        // d += betacheck^2 = 0 for undamped case
458        let normr = (d + (betad - taud).powi(2) + betadd * betadd).sqrt();
459
460        // Estimate ||A||
461        norm_a2 += beta * beta;
462        let norm_a = norm_a2.sqrt();
463        norm_a2 += alpha * alpha;
464
465        // Estimate cond(A) (matching MATLAB reference)
466        maxrbar = maxrbar.max(rho_bar_old);
467        if _iter > 0 {
468            minrbar = minrbar.min(rho_bar_old);
469        }
470        let cond_a = maxrbar.max(rho_temp) / minrbar.min(rho_temp);
471
472        // Convergence tests (matching reference implementation)
473        let norm_ar = zeta_bar.abs();
474        let normx = norm(&x);
475
476        let test1 = normr / (normb + 1e-20);
477        let test2 = norm_ar / ((norm_a * normr) + 1e-20);
478        let test3 = 1.0 / (cond_a + 1e-20);
479        let rtol = btol + atol * norm_a * normx / (normb + 1e-20);
480
481        if _verbose {
482            eprintln!("  LSMR iter {:>3}: ||r||/||b||={:.6e}  ||A'r||/(||A||·||r||)={:.6e}  1/condA={:.6e}  rtol={:.6e}",
483                _iter + 1, test1, test2, test3, rtol);
484        }
485
486        // Test3 (condition number) checked first, then test2, then test1
487        // matching MATLAB priority where later tests override earlier istop
488        if test3 <= ctol || test2 <= atol || test1 <= rtol {
489            if _verbose {
490                let reason = if test1 <= rtol { "test1 (residual)"
491                } else if test2 <= atol { "test2 (||A'r||)"
492                } else { "test3 (cond(A))" };
493                eprintln!("  LSMR converged at iteration {} via {}", _iter + 1, reason);
494            }
495            break;
496        }
497    }
498
499    x
500}
501
502// ============================================================================
503// Weight Functions
504// ============================================================================
505
506/// Compute Laplacian of a 3D field using mask-adaptive finite differences
507///
508/// Matches MATLAB's lap1_mex.c: uses central differences where both neighbors
509/// are in the mask, forward/backward one-sided stencils near mask boundaries,
510/// and zero contribution where neither neighbor is in the mask.
511fn compute_laplacian(
512    f: &[f64],
513    mask: &[u8],
514    nx: usize, ny: usize, nz: usize,
515    vsx: f64, vsy: f64, vsz: f64,
516) -> Vec<f64> {
517    let n_total = nx * ny * nz;
518    let mut lap = vec![0.0; n_total];
519
520    let hx = 1.0 / (vsx * vsx);
521    let hy = 1.0 / (vsy * vsy);
522    let hz = 1.0 / (vsz * vsz);
523
524    let nxny = nx * ny;
525
526    for k in 0..nz {
527        for j in 0..ny {
528            let jk_offset = j * nx + k * nxny;
529            for i in 0..nx {
530                let l = i + jk_offset;
531
532                if mask[l] == 0 {
533                    continue;
534                }
535
536                // X-axis contribution
537                lap[l] += hx * lap1_axis(f, mask, l, 1, nx, i, nx);
538
539                // Y-axis contribution
540                lap[l] += hy * lap1_axis(f, mask, l, nx, nxny, j * nx, nxny);
541
542                // Z-axis contribution
543                lap[l] += hz * lap1_axis(f, mask, l, nxny, n_total, k * nxny, n_total);
544            }
545        }
546    }
547
548    lap
549}
550
551/// Compute second derivative along one axis using mask-adaptive stencil.
552///
553/// Matches MATLAB's lap1_mex.c logic:
554/// - `idx = 2*G[l+a] + G[l-a]` selects the stencil type:
555///   3 = central, 2 = forward, 1 = backward, 0 = zero
556/// - At domain boundaries: i=0 → forward, i=N-1 → backward
557///
558/// # Arguments
559/// * `f` - field values
560/// * `mask` - binary mask
561/// * `l` - linear index of current voxel
562/// * `a` - stride for this axis (1 for x, nx for y, nx*ny for z)
563/// * `n_axis` - total extent for this axis (nx for x, nx*ny for y, nx*ny*nz for z)
564/// * `coord` - axis coordinate as linear offset (i for x, j*nx for y, k*nx*ny for z)
565/// * `n_total` - total number of voxels (only used for z boundary detection)
566#[inline]
567fn lap1_axis(
568    f: &[f64],
569    mask: &[u8],
570    l: usize,
571    a: usize,
572    n_axis: usize,
573    coord: usize,
574    n_total: usize,
575) -> f64 {
576    // Determine the stencil type based on mask of neighbors and boundary
577    // MATLAB: (i-1) < NXX ? 2*G[l+a]+G[l-a] : (i==0)*2 + (i==NX)
578    // where NXX = N_axis_size - 2 (using size_t underflow trick for boundary detection)
579    let n_end = n_axis - a; // corresponds to NX, NY, NZ in MATLAB (last element coord)
580    let n_interior = n_axis - 2 * a; // corresponds to NXX, NYY, NZZ
581
582    let stencil = if coord.wrapping_sub(a) < n_interior {
583        // Interior: check mask neighbors
584        2 * (mask[l + a] as u8) + (mask[l - a] as u8)
585    } else {
586        // Boundary: first → forward(2), last → backward(1)
587        if coord == 0 { 2 } else if coord == n_end { 1 } else { 0 }
588    };
589
590    match stencil {
591        3 => {
592            // Central: u[l-a] - 2u[l] + u[l+a]
593            f[l - a] - 2.0 * f[l] + f[l + a]
594        }
595        2 => {
596            // Forward one-sided
597            lap1_forward(f, mask, l, a, n_axis, coord, n_total)
598        }
599        1 => {
600            // Backward one-sided
601            lap1_backward(f, mask, l, a, n_axis, coord, n_total)
602        }
603        _ => 0.0, // Neither neighbor in mask
604    }
605}
606
607/// Forward one-sided second derivative (matching MATLAB's fd/ff functions)
608#[inline]
609fn lap1_forward(
610    f: &[f64],
611    mask: &[u8],
612    l: usize,
613    a: usize,
614    n_axis: usize,
615    coord: usize,
616    _n_total: usize,
617) -> f64 {
618    // 4th order: 2u - 5u[+a] + 4u[+2a] - u[+3a]
619    if coord + 3 * a < n_axis && mask[l + 2 * a] != 0 && mask[l + 3 * a] != 0 {
620        2.0 * f[l] - 5.0 * f[l + a] + 4.0 * f[l + 2 * a] - f[l + 3 * a]
621    }
622    // 2nd order: u - 2u[+a] + u[+2a]
623    else if coord + 2 * a < n_axis && mask[l + 2 * a] != 0 {
624        f[l] - 2.0 * f[l + a] + f[l + 2 * a]
625    }
626    // 1st order: u[+a] - u
627    else {
628        f[l + a] - f[l]
629    }
630}
631
632/// Backward one-sided second derivative (matching MATLAB's bd/bf functions)
633#[inline]
634fn lap1_backward(
635    f: &[f64],
636    mask: &[u8],
637    l: usize,
638    a: usize,
639    n_axis: usize,
640    coord: usize,
641    _n_total: usize,
642) -> f64 {
643    // 4th order: -u[-3a] + 4u[-2a] - 5u[-a] + 2u
644    if coord.wrapping_sub(3 * a) < n_axis && mask[l - 3 * a] != 0 && mask[l - 2 * a] != 0 {
645        -f[l - 3 * a] + 4.0 * f[l - 2 * a] - 5.0 * f[l - a] + 2.0 * f[l]
646    }
647    // 2nd order: u[-2a] - 2u[-a] + u
648    else if coord.wrapping_sub(2 * a) < n_axis && mask[l - 2 * a] != 0 {
649        f[l - 2 * a] - 2.0 * f[l - a] + f[l]
650    }
651    // 1st order: u[-a] - u
652    else {
653        f[l - a] - f[l]
654    }
655}
656
657/// Laplacian weights for iLSQR (Equation 7)
658///
659/// Weights based on Laplacian magnitude with percentile-based thresholding.
660fn laplacian_weights_ilsqr(
661    f: &[f64],
662    mask: &[u8],
663    nx: usize, ny: usize, nz: usize,
664    vsx: f64, vsy: f64, vsz: f64,
665    pmin: f64,
666    pmax: f64,
667) -> Vec<f64> {
668    let n_total = nx * ny * nz;
669    let mut w = vec![0.0; n_total];
670
671    // Compute Laplacian
672    let lap = compute_laplacian(f, mask, nx, ny, nz, vsx, vsy, vsz);
673
674    // Collect masked Laplacian values for percentile calculation
675    let mut masked_lap: Vec<f64> = lap.iter()
676        .zip(mask.iter())
677        .filter(|(_, &m)| m > 0)
678        .map(|(&l, _)| l)
679        .collect();
680
681    if masked_lap.is_empty() {
682        return w;
683    }
684
685    // Sort for percentile calculation (MATLAB: prctile with linear interpolation)
686    masked_lap.sort_by(|a, b| a.partial_cmp(b).unwrap());
687
688    let thr_min = prctile(&masked_lap, pmin);
689    let thr_max = prctile(&masked_lap, pmax);
690
691    let range = thr_max - thr_min;
692
693    // Apply weights (Equation 7)
694    for i in 0..n_total {
695        if mask[i] == 0 {
696            continue;
697        }
698
699        let l = lap[i];
700
701        if l < thr_min {
702            w[i] = 1.0;
703        } else if l > thr_max {
704            w[i] = 0.0;
705        } else if range > 1e-10 {
706            w[i] = (thr_max - l) / range;
707        }
708    }
709
710    w
711}
712
713/// K-space weights for FastQSM (Equation 10)
714///
715/// Weights based on |D|^n with percentile normalization.
716fn dipole_kspace_weights_ilsqr(
717    d: &[f64],
718    n_exp: f64,
719    pa: f64,
720    pb: f64,
721) -> Vec<f64> {
722    let len = d.len();
723    let mut w = vec![0.0; len];
724
725    // Compute |D|^n
726    for i in 0..len {
727        w[i] = d[i].abs().powf(n_exp);
728    }
729
730    // Percentile on ALL values (matching MATLAB: prctile(vec(w), [pa, pb]))
731    let mut vals: Vec<f64> = w.to_vec();
732    vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
733
734    if vals.is_empty() {
735        return vec![0.0; len];
736    }
737
738    let ab_min = prctile(&vals, pa);
739    let ab_max = prctile(&vals, pb);
740
741    let range = ab_max - ab_min;
742
743    // Normalize to [0, 1]
744    for i in 0..len {
745        if range > 1e-20 {
746            w[i] = (w[i] - ab_min) / range;
747        }
748        w[i] = w[i].max(0.0).min(1.0);
749    }
750
751    w
752}
753
754/// Mask-adaptive forward gradient (matching MATLAB's gradfm_mex)
755///
756/// For masked voxels: uses forward difference where forward neighbor is in mask,
757/// falls back to backward difference, or 0 if neither neighbor is in mask.
758/// Outside mask: gradient is 0.
759fn fgrad_masked(
760    f: &[f64],
761    mask: &[u8],
762    nx: usize, ny: usize, nz: usize,
763    vsx: f64, vsy: f64, vsz: f64,
764) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
765    let n_total = nx * ny * nz;
766    let mut dx = vec![0.0; n_total];
767    let mut dy = vec![0.0; n_total];
768    let mut dz = vec![0.0; n_total];
769
770    let hx = 1.0 / vsx;
771    let hy = 1.0 / vsy;
772    let hz = 1.0 / vsz;
773
774    let nxny = nx * ny;
775
776    for k in 0..nz {
777        for j in 0..ny {
778            let jk = j * nx + k * nxny;
779            for i in 0..nx {
780                let l = i + jk;
781                if mask[l] == 0 { continue; }
782
783                // X-axis: forward if possible, else backward, else 0
784                dx[l] = if i < nx - 1 && mask[l + 1] != 0 {
785                    hx * (f[l + 1] - f[l])
786                } else if i > 0 && mask[l - 1] != 0 {
787                    hx * (f[l] - f[l - 1])
788                } else {
789                    0.0
790                };
791
792                // Y-axis
793                dy[l] = if j < ny - 1 && mask[l + nx] != 0 {
794                    hy * (f[l + nx] - f[l])
795                } else if j > 0 && mask[l - nx] != 0 {
796                    hy * (f[l] - f[l - nx])
797                } else {
798                    0.0
799                };
800
801                // Z-axis
802                dz[l] = if k < nz - 1 && mask[l + nxny] != 0 {
803                    hz * (f[l + nxny] - f[l])
804                } else if k > 0 && mask[l - nxny] != 0 {
805                    hz * (f[l] - f[l - nxny])
806                } else {
807                    0.0
808                };
809            }
810        }
811    }
812
813    (dx, dy, dz)
814}
815
816/// Gradient weights for streaking artifact estimation (Equation 15)
817fn gradient_weights_ilsqr(
818    x: &[f64],
819    mask: &[u8],
820    nx: usize, ny: usize, nz: usize,
821    vsx: f64, vsy: f64, vsz: f64,
822    pmin: f64,
823    pmax: f64,
824) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
825    // MATLAB uses gradf(x, mask, vsz) — mask-adaptive forward differences
826    let (gx, gy, gz) = fgrad_masked(x, mask, nx, ny, nz, vsx, vsy, vsz);
827
828    // Apply percentile-based weights to each component
829    let wx = gradient_weights_component(&gx, mask, pmin, pmax);
830    let wy = gradient_weights_component(&gy, mask, pmin, pmax);
831    let wz = gradient_weights_component(&gz, mask, pmin, pmax);
832
833    (wx, wy, wz)
834}
835
836fn gradient_weights_component(
837    g: &[f64],
838    mask: &[u8],
839    pmin: f64,
840    pmax: f64,
841) -> Vec<f64> {
842    let len = g.len();
843    let mut w = vec![0.0; len];
844
845    // Collect masked gradient values
846    let mut masked_g: Vec<f64> = g.iter()
847        .zip(mask.iter())
848        .filter(|(_, &m)| m > 0)
849        .map(|(&v, _)| v)
850        .collect();
851
852    if masked_g.is_empty() {
853        return w;
854    }
855
856    masked_g.sort_by(|a, b| a.partial_cmp(b).unwrap());
857
858    let thr_min = prctile(&masked_g, pmin);
859    let thr_max = prctile(&masked_g, pmax);
860
861    let range = thr_max - thr_min;
862
863    for i in 0..len {
864        if mask[i] == 0 {
865            continue;
866        }
867
868        let v = g[i];
869
870        if v < thr_min {
871            w[i] = 1.0;
872        } else if v > thr_max {
873            w[i] = 0.0;
874        } else if range > 1e-10 {
875            w[i] = (thr_max - v) / range;
876        }
877
878        // Apply mask
879        w[i] *= mask[i] as f64;
880    }
881
882    w
883}
884
885// ============================================================================
886// Helper Functions
887// ============================================================================
888
889fn norm(x: &[f64]) -> f64 {
890    x.iter().map(|&v| v * v).sum::<f64>().sqrt()
891}
892
893fn scale_inplace(x: &mut [f64], s: f64) {
894    for v in x.iter_mut() {
895        *v *= s;
896    }
897}
898
899fn axpy(y: &mut [f64], a: f64, x: &[f64]) {
900    for (yi, &xi) in y.iter_mut().zip(x.iter()) {
901        *yi += a * xi;
902    }
903}
904
905fn multiply_elementwise(a: &[f64], b: &[f64]) -> Vec<f64> {
906    a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).collect()
907}
908
909fn sign_array(x: &[f64]) -> Vec<f64> {
910    x.iter().map(|&v| {
911        if v > 0.0 { 1.0 }
912        else if v < 0.0 { -1.0 }
913        else { 0.0 }
914    }).collect()
915}
916
917/// Percentile with linear interpolation (matching MATLAB's prctile)
918///
919/// Input must be a sorted slice. Returns the p-th percentile (p in [0, 100]).
920fn prctile(sorted: &[f64], p: f64) -> f64 {
921    let n = sorted.len();
922    if n == 0 { return 0.0; }
923    if n == 1 { return sorted[0]; }
924    let h = (p / 100.0) * (n - 1) as f64;
925    let lo = h.floor() as usize;
926    let hi = (lo + 1).min(n - 1);
927    let frac = h - lo as f64;
928    sorted[lo] + frac * (sorted[hi] - sorted[lo])
929}
930
931// ============================================================================
932// Step 1: Initial LSQR Solution
933// ============================================================================
934
935/// Step 1: Initial LSQR solution with Laplacian weights
936fn lsqr_step(
937    f: &[f64],
938    mask: &[u8],
939    d: &[f64],
940    nx: usize, ny: usize, nz: usize,
941    vsx: f64, vsy: f64, vsz: f64,
942    workspace: &mut Fft3dWorkspace,
943) -> Vec<f64> {
944
945    // Laplacian weight parameters (from QSM.m)
946    let pmin = 60.0;
947    let pmax = 99.9;
948    let tol_lsqr = 0.01;
949    let maxit_lsqr = 50;
950
951    // Compute Laplacian weights (Equation 7)
952    let w = laplacian_weights_ilsqr(f, mask, nx, ny, nz, vsx, vsy, vsz, pmin, pmax);
953
954    // Compute b = D * FFT(w .* f) - b is COMPLEX
955    let wf: Vec<Complex64> = w.iter().zip(f.iter())
956        .map(|(&wi, &fi)| Complex64::new(wi * fi, 0.0))
957        .collect();
958
959    let mut wf_fft = wf.clone();
960    workspace.fft3d(&mut wf_fft);
961
962    // b = D .* FFT(w .* f) - keep as complex!
963    let b: Vec<Complex64> = wf_fft.iter().zip(d.iter())
964        .map(|(wfi, &di)| wfi * di)
965        .collect();
966
967    // Define A*x operator: D * FFT(w .* real(IFFT(D .* x)))
968    // Works with complex vectors throughout
969    // Reuse a single workspace across all LSQR iterations to avoid repeated allocation
970    let lsqr_ws = RefCell::new(Fft3dWorkspace::new(nx, ny, nz));
971    let apply_a = |x: &[Complex64]| -> Vec<Complex64> {
972        // D .* x (in k-space) - x is complex, D is real
973        let dx: Vec<Complex64> = x.iter().zip(d.iter())
974            .map(|(xi, &di)| xi * di)
975            .collect();
976
977        // IFFT(D .* x)
978        let mut dx_ifft = dx.clone();
979        let mut ws = lsqr_ws.borrow_mut();
980        ws.ifft3d(&mut dx_ifft);
981
982        // w .* real(IFFT(D .* x)) - take real part here as per MATLAB reference
983        let wdx: Vec<Complex64> = w.iter().zip(dx_ifft.iter())
984            .map(|(&wi, dxi)| Complex64::new(wi * dxi.re, 0.0))
985            .collect();
986
987        // FFT(w .* ...)
988        let mut wdx_fft = wdx.clone();
989        ws.fft3d(&mut wdx_fft);
990
991        // D .* FFT(...)
992        wdx_fft.iter().zip(d.iter())
993            .map(|(wdxi, &di)| wdxi * di)
994            .collect()
995    };
996
997    // A^H is same as A for this Hermitian operator (D is real, w is real)
998    let apply_ah = |x: &[Complex64]| -> Vec<Complex64> {
999        apply_a(x)
1000    };
1001
1002    // Solve with complex LSQR
1003    let x_lsqr = lsqr_complex(apply_a, apply_ah, &b, tol_lsqr, maxit_lsqr, false);
1004
1005    // IFFT to get result in image space
1006    let mut x_ifft = x_lsqr;
1007    workspace.ifft3d(&mut x_ifft);
1008
1009    // Apply mask and take real part
1010    x_ifft.iter().zip(mask.iter())
1011        .map(|(xi, &mi)| if mi > 0 { xi.re } else { 0.0 })
1012        .collect()
1013}
1014
1015// ============================================================================
1016// Step 2: FastQSM
1017// ============================================================================
1018
1019/// Step 2: FastQSM estimate
1020fn fastqsm_step(
1021    f: &[f64],
1022    mask: &[u8],
1023    d: &[f64],
1024    nx: usize, ny: usize, nz: usize,
1025    vsx: f64, vsy: f64, vsz: f64,
1026    workspace: &mut Fft3dWorkspace,
1027) -> Vec<f64> {
1028    let n_total = nx * ny * nz;
1029
1030    // FFT of field
1031    let f_complex: Vec<Complex64> = f.iter()
1032        .map(|&v| Complex64::new(v, 0.0))
1033        .collect();
1034
1035    let mut f_fft = f_complex;
1036    workspace.fft3d(&mut f_fft);
1037
1038    // Equation (8): x = sign(D) .* F
1039    let sign_d = sign_array(d);
1040    let x: Vec<Complex64> = f_fft.iter().zip(sign_d.iter())
1041        .map(|(fi, &si)| fi * si)
1042        .collect();
1043
1044    // K-space weights (Equation 10)
1045    let pa = 1.0;
1046    let pb = 30.0;
1047    let n_exp = 0.001;
1048    let wfs = dipole_kspace_weights_ilsqr(d, n_exp, pa, pb);
1049
1050    // SMV kernel for smoothing (Equation 9)
1051    let r_smv = 3.0;
1052    let h = smv_kernel(nx, ny, nz, vsx, vsy, vsz, r_smv);
1053
1054    // FFT of SMV kernel — take real part to match MATLAB: real(fft3(ifftshift(h)))
1055    let h_complex: Vec<Complex64> = h.iter()
1056        .map(|&v| Complex64::new(v, 0.0))
1057        .collect();
1058    let mut h_fft_complex = h_complex;
1059    workspace.fft3d(&mut h_fft_complex);
1060    let h_fft: Vec<f64> = h_fft_complex.iter().map(|c| c.re).collect();
1061
1062    // Equation (9): Apply weighted combination
1063    // x = FFT(mask .* IFFT(wfs .* x + (1-wfs) .* (h .* x)))
1064    let mut x_filtered: Vec<Complex64> = x.iter()
1065        .zip(wfs.iter())
1066        .zip(h_fft.iter())
1067        .map(|((xi, &wi), &hi)| {
1068            xi * wi + xi * hi * (1.0 - wi)
1069        })
1070        .collect();
1071
1072    workspace.ifft3d(&mut x_filtered);
1073
1074    // Apply mask
1075    for (xi, &mi) in x_filtered.iter_mut().zip(mask.iter()) {
1076        if mi == 0 {
1077            *xi = Complex64::new(0.0, 0.0);
1078        } else {
1079            *xi = Complex64::new(xi.re, 0.0);
1080        }
1081    }
1082
1083    workspace.fft3d(&mut x_filtered);
1084
1085    // Equation (11): Apply again
1086    let mut x_filtered2: Vec<Complex64> = x_filtered.iter()
1087        .zip(wfs.iter())
1088        .zip(h_fft.iter())
1089        .map(|((xi, &wi), &hi)| {
1090            xi * wi + xi * hi * (1.0 - wi)
1091        })
1092        .collect();
1093
1094    workspace.ifft3d(&mut x_filtered2);
1095
1096    let x_fs: Vec<f64> = x_filtered2.iter().zip(mask.iter())
1097        .map(|(xi, &mi)| if mi > 0 { xi.re } else { 0.0 })
1098        .collect();
1099
1100    // Equation (12): TKD for comparison
1101    let t0 = 1.0 / 8.0;
1102    let mut inv_d = vec![0.0; n_total];
1103    for i in 0..n_total {
1104        if d[i].abs() < t0 {
1105            inv_d[i] = d[i].signum() / t0;
1106        } else {
1107            inv_d[i] = 1.0 / d[i];
1108        }
1109    }
1110
1111    let x_tkd_fft: Vec<Complex64> = f_fft.iter().zip(inv_d.iter())
1112        .map(|(fi, &idi)| fi * idi)
1113        .collect();
1114
1115    let mut x_tkd_complex = x_tkd_fft;
1116    workspace.ifft3d(&mut x_tkd_complex);
1117
1118    let x_tkd: Vec<f64> = x_tkd_complex.iter().zip(mask.iter())
1119        .map(|(xi, &mi)| if mi > 0 { xi.re } else { 0.0 })
1120        .collect();
1121
1122    // Equations (13-14): Linear regression to scale FastQSM
1123    // Solve: xtkd ≈ a * xfs + b
1124    // MATLAB reference uses ALL voxels (including zeros outside mask) for the regression
1125    let sum_xfs: f64 = x_fs.iter().map(|&v| v).sum();
1126    let sum_xtkd: f64 = x_tkd.iter().map(|&v| v).sum();
1127    let sum_xfs2: f64 = x_fs.iter().map(|&v| v * v).sum();
1128    let sum_xfs_xtkd: f64 = x_fs.iter().zip(x_tkd.iter())
1129        .map(|(&xf, &xt)| xf * xt)
1130        .sum();
1131
1132    let n_all: f64 = n_total as f64;
1133
1134    // Solve 2x2 system: [sum_xfs2, sum_xfs; sum_xfs, n] * [a; b] = [sum_xfs_xtkd; sum_xtkd]
1135    let det = sum_xfs2 * n_all - sum_xfs * sum_xfs;
1136
1137    let (a, b) = if det.abs() > 1e-20 {
1138        let a = (n_all * sum_xfs_xtkd - sum_xfs * sum_xtkd) / det;
1139        let b = (sum_xfs2 * sum_xtkd - sum_xfs * sum_xfs_xtkd) / det;
1140        (a, b)
1141    } else {
1142        (1.0, 0.0)
1143    };
1144
1145    // Equation (14): x = a * xfs + b
1146    x_fs.iter().zip(mask.iter())
1147        .map(|(&xf, &mi)| if mi > 0 { a * xf + b } else { 0.0 })
1148        .collect()
1149}
1150
1151// ============================================================================
1152// Step 3: Streaking Artifact Estimation
1153// ============================================================================
1154
1155/// Step 3: Estimate streaking artifacts using LSMR
1156fn susceptibility_artifacts_step(
1157    x0: &[f64],
1158    xfs: &[f64],
1159    mask: &[u8],
1160    d: &[f64],
1161    nx: usize, ny: usize, nz: usize,
1162    vsx: f64, vsy: f64, vsz: f64,
1163    tol: f64,
1164    maxit: usize,
1165    _workspace: &mut Fft3dWorkspace,
1166) -> Vec<f64> {
1167    let n_total = nx * ny * nz;
1168
1169    // Gradient weights (Equation 15)
1170    let pmin = 50.0;
1171    let pmax = 70.0;
1172    let (wx, wy, wz) = gradient_weights_ilsqr(xfs, mask, nx, ny, nz, vsx, vsy, vsz, pmin, pmax);
1173
1174    // Ill-conditioned mask (Equation 4)
1175    let thr = 0.1;
1176    let mic: Vec<f64> = d.iter().map(|&di| if di.abs() < thr { 1.0 } else { 0.0 }).collect();
1177
1178    // Compute gradient of x0 (Equation 3)
1179    let (dx, dy, dz) = fgrad(x0, nx, ny, nz, vsx, vsy, vsz);
1180
1181    // b = [wx .* dx; wy .* dy; wz .* dz] (concatenated)
1182    let bx = multiply_elementwise(&wx, &dx);
1183    let by = multiply_elementwise(&wy, &dy);
1184    let bz = multiply_elementwise(&wz, &dz);
1185
1186    let mut b = Vec::with_capacity(3 * n_total);
1187    b.extend_from_slice(&bx);
1188    b.extend_from_slice(&by);
1189    b.extend_from_slice(&bz);
1190
1191    // Define forward operator A and adjoint A^T
1192    // Reuse a single workspace across all LSMR iterations to avoid repeated allocation
1193    let lsmr_ws = RefCell::new(Fft3dWorkspace::new(nx, ny, nz));
1194    let apply_a = |x_in: &[f64]| -> Vec<f64> {
1195        // x_in is in image space
1196        // Apply Mic in k-space
1197        let x_complex: Vec<Complex64> = x_in.iter()
1198            .map(|&v| Complex64::new(v, 0.0))
1199            .collect();
1200
1201        let mut x_fft = x_complex;
1202        let mut ws = lsmr_ws.borrow_mut();
1203        ws.fft3d(&mut x_fft);
1204
1205        // Apply ill-conditioned mask
1206        let x_mic: Vec<Complex64> = x_fft.iter().zip(mic.iter())
1207            .map(|(xi, &mi)| xi * mi)
1208            .collect();
1209
1210        let mut x_ifft = x_mic;
1211        ws.ifft3d(&mut x_ifft);
1212
1213        let x_filtered: Vec<f64> = x_ifft.iter().map(|xi| xi.re).collect();
1214
1215        // Compute gradient
1216        let (gx, gy, gz) = fgrad(&x_filtered, nx, ny, nz, vsx, vsy, vsz);
1217
1218        // Apply weights and concatenate
1219        let mut result = Vec::with_capacity(3 * n_total);
1220        result.extend(wx.iter().zip(gx.iter()).map(|(&w, &g)| w * g));
1221        result.extend(wy.iter().zip(gy.iter()).map(|(&w, &g)| w * g));
1222        result.extend(wz.iter().zip(gz.iter()).map(|(&w, &g)| w * g));
1223
1224        result
1225    };
1226
1227    // Define adjoint operator A^T
1228    let apply_at = |y_in: &[f64]| -> Vec<f64> {
1229        // y_in is [yx; yy; yz] concatenated (3 * n_total)
1230        let yx = &y_in[0..n_total];
1231        let yy = &y_in[n_total..2*n_total];
1232        let yz = &y_in[2*n_total..3*n_total];
1233
1234        // Apply weights
1235        let wyx: Vec<f64> = wx.iter().zip(yx.iter()).map(|(&w, &y)| w * y).collect();
1236        let wyy: Vec<f64> = wy.iter().zip(yy.iter()).map(|(&w, &y)| w * y).collect();
1237        let wyz: Vec<f64> = wz.iter().zip(yz.iter()).map(|(&w, &y)| w * y).collect();
1238
1239        // Adjoint of forward gradient = -div (bdiv returns +div, so negate)
1240        // MATLAB's gradfp_adj_mex uses h = -1/voxel_size, including the negation.
1241        // Rust's bdiv uses h = +1/voxel_size, so we negate here.
1242        let div = bdiv(&wyx, &wyy, &wyz, nx, ny, nz, vsx, vsy, vsz);
1243
1244        // Apply Mic in k-space
1245        let div_complex: Vec<Complex64> = div.iter()
1246            .map(|&v| Complex64::new(-v, 0.0))
1247            .collect();
1248
1249        let mut div_fft = div_complex;
1250        let mut ws = lsmr_ws.borrow_mut();
1251        ws.fft3d(&mut div_fft);
1252
1253        let div_mic: Vec<Complex64> = div_fft.iter().zip(mic.iter())
1254            .map(|(di, &mi)| di * mi)
1255            .collect();
1256
1257        let mut div_ifft = div_mic;
1258        ws.ifft3d(&mut div_ifft);
1259
1260        div_ifft.iter().map(|di| di.re).collect()
1261    };
1262
1263    // Solve with LSMR
1264    let xsa = lsmr(apply_a, apply_at, &b, n_total, tol, tol, maxit, false);
1265
1266    // Apply mask
1267    xsa.iter().zip(mask.iter())
1268        .map(|(&x, &m)| if m > 0 { x } else { 0.0 })
1269        .collect()
1270}
1271
1272// ============================================================================
1273// Main iLSQR Algorithm
1274// ============================================================================
1275
1276/// iLSQR: A method for estimating and removing streaking artifacts in QSM
1277///
1278/// # Arguments
1279/// * `field` - Unwrapped local field/tissue phase (nx * ny * nz)
1280/// * `mask` - Binary mask of region of interest
1281/// * `nx`, `ny`, `nz` - Array dimensions
1282/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
1283/// * `bdir` - B0 field direction (bx, by, bz)
1284/// * `tol` - Stopping tolerance for LSMR solver
1285/// * `maxit` - Maximum iterations for LSMR
1286///
1287/// # Returns
1288/// Tuple of (susceptibility, streaking_artifacts, fast_qsm, initial_lsqr)
1289pub fn ilsqr(
1290    field: &[f64],
1291    mask: &[u8],
1292    nx: usize, ny: usize, nz: usize,
1293    vsx: f64, vsy: f64, vsz: f64,
1294    bdir: (f64, f64, f64),
1295    tol: f64,
1296    maxit: usize,
1297) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1298    // Generate dipole kernel
1299    let d = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
1300
1301    // Create FFT workspace
1302    let mut workspace = Fft3dWorkspace::new(nx, ny, nz);
1303
1304    // Step 1: Initial LSQR solution
1305    let xlsqr = lsqr_step(field, mask, &d, nx, ny, nz, vsx, vsy, vsz, &mut workspace);
1306
1307    // Step 2: FastQSM estimate
1308    let xfs = fastqsm_step(field, mask, &d, nx, ny, nz, vsx, vsy, vsz, &mut workspace);
1309
1310    // Step 3: Estimate streaking artifacts
1311    let xsa = susceptibility_artifacts_step(
1312        &xlsqr, &xfs, mask, &d,
1313        nx, ny, nz, vsx, vsy, vsz,
1314        tol, maxit, &mut workspace
1315    );
1316
1317    // Step 4: Subtract artifacts
1318    let chi: Vec<f64> = xlsqr.iter().zip(xsa.iter()).zip(mask.iter())
1319        .map(|((&xl, &xs), &m)| if m > 0 { xl - xs } else { 0.0 })
1320        .collect();
1321
1322    (chi, xsa, xfs, xlsqr)
1323}
1324
1325/// Simplified iLSQR returning only the final susceptibility map
1326pub fn ilsqr_simple(
1327    field: &[f64],
1328    mask: &[u8],
1329    nx: usize, ny: usize, nz: usize,
1330    vsx: f64, vsy: f64, vsz: f64,
1331    bdir: (f64, f64, f64),
1332    tol: f64,
1333    maxit: usize,
1334) -> Vec<f64> {
1335    let (chi, _, _, _) = ilsqr(field, mask, nx, ny, nz, vsx, vsy, vsz, bdir, tol, maxit);
1336    chi
1337}
1338
1339/// iLSQR with progress callback
1340pub fn ilsqr_with_progress<F>(
1341    field: &[f64],
1342    mask: &[u8],
1343    nx: usize, ny: usize, nz: usize,
1344    vsx: f64, vsy: f64, vsz: f64,
1345    bdir: (f64, f64, f64),
1346    tol: f64,
1347    maxit: usize,
1348    mut progress_callback: F,
1349) -> Vec<f64>
1350where
1351    F: FnMut(usize, usize),
1352{
1353    // Generate dipole kernel
1354    let d = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
1355
1356    // Create FFT workspace
1357    let mut workspace = Fft3dWorkspace::new(nx, ny, nz);
1358
1359    progress_callback(1, 4);
1360
1361    // Step 1: Initial LSQR solution
1362    let xlsqr = lsqr_step(field, mask, &d, nx, ny, nz, vsx, vsy, vsz, &mut workspace);
1363
1364    progress_callback(2, 4);
1365
1366    // Step 2: FastQSM estimate
1367    let xfs = fastqsm_step(field, mask, &d, nx, ny, nz, vsx, vsy, vsz, &mut workspace);
1368
1369    progress_callback(3, 4);
1370
1371    // Step 3: Estimate streaking artifacts
1372    let xsa = susceptibility_artifacts_step(
1373        &xlsqr, &xfs, mask, &d,
1374        nx, ny, nz, vsx, vsy, vsz,
1375        tol, maxit, &mut workspace
1376    );
1377
1378    progress_callback(4, 4);
1379
1380    // Step 4: Subtract artifacts
1381    xlsqr.iter().zip(xsa.iter()).zip(mask.iter())
1382        .map(|((&xl, &xs), &m)| if m > 0 { xl - xs } else { 0.0 })
1383        .collect()
1384}
1385
1386#[cfg(test)]
1387mod tests {
1388    use super::*;
1389
1390    #[test]
1391    fn test_lsqr_simple() {
1392        // Test LSQR on a simple diagonal system
1393        let n = 10;
1394        let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
1395        let b: Vec<f64> = diag.iter().map(|&d| d * 2.0).collect();  // x = [2, 2, 2, ...]
1396
1397        let apply_a = |x: &[f64]| -> Vec<f64> {
1398            x.iter().zip(diag.iter()).map(|(&xi, &di)| xi * di).collect()
1399        };
1400
1401        let x = lsqr(apply_a, apply_a, &b, 1e-10, 100);
1402
1403        for (i, &xi) in x.iter().enumerate() {
1404            assert!((xi - 2.0).abs() < 1e-6, "x[{}] = {}, expected 2.0", i, xi);
1405        }
1406    }
1407
1408    #[test]
1409    fn test_norm() {
1410        let x = vec![3.0, 4.0];
1411        assert!((norm(&x) - 5.0).abs() < 1e-10);
1412    }
1413
1414    #[test]
1415    fn test_sign_array() {
1416        let x = vec![-2.0, 0.0, 3.0];
1417        let s = sign_array(&x);
1418        assert_eq!(s, vec![-1.0, 0.0, 1.0]);
1419    }
1420
1421    #[test]
1422    fn test_lsqr_complex_diagonal() {
1423        // Test complex LSQR on a diagonal system: A = diag(1, 2, 3), b = [1+i, 4+2i, 9+3i]
1424        // Expected solution: x = [1+i, 2+i, 3+i]
1425        let diag = vec![1.0, 2.0, 3.0];
1426        let expected = vec![
1427            Complex64::new(1.0, 1.0),
1428            Complex64::new(2.0, 1.0),
1429            Complex64::new(3.0, 1.0),
1430        ];
1431        let b: Vec<Complex64> = expected.iter().zip(diag.iter())
1432            .map(|(&xi, &di)| xi * di)
1433            .collect();
1434
1435        let diag_a = diag.clone();
1436        let diag_ah = diag.clone();
1437        let apply_a = move |x: &[Complex64]| -> Vec<Complex64> {
1438            x.iter().zip(diag_a.iter()).map(|(&xi, &di)| xi * di).collect()
1439        };
1440        let apply_ah = move |x: &[Complex64]| -> Vec<Complex64> {
1441            x.iter().zip(diag_ah.iter()).map(|(&xi, &di)| xi * di).collect()
1442        };
1443
1444        let x = lsqr_complex(apply_a, apply_ah, &b, 1e-10, 100, false);
1445
1446        for (i, (xi, ei)) in x.iter().zip(expected.iter()).enumerate() {
1447            assert!((xi.re - ei.re).abs() < 1e-6,
1448                "x[{}].re = {}, expected {}", i, xi.re, ei.re);
1449            assert!((xi.im - ei.im).abs() < 1e-6,
1450                "x[{}].im = {}, expected {}", i, xi.im, ei.im);
1451        }
1452    }
1453
1454    #[test]
1455    fn test_lsmr_diagonal() {
1456        // Test the LSMR solver (inside ilsqr.rs) exercises all code paths.
1457        // Use a well-conditioned diagonal system: A = diag(1, 1, 1) (identity)
1458        // b = [3, 5, 7], expected x = [3, 5, 7]
1459        let b = vec![3.0, 5.0, 7.0];
1460
1461        let apply_a = |x: &[f64]| -> Vec<f64> { x.to_vec() };
1462        let apply_at = |x: &[f64]| -> Vec<f64> { x.to_vec() };
1463
1464        let x = lsmr(apply_a, apply_at, &b, 3, 1e-6, 1e-6, 200, false);
1465
1466        // Verify that the solver returns finite values and the output has correct length
1467        assert_eq!(x.len(), 3);
1468        for (i, &xi) in x.iter().enumerate() {
1469            assert!(xi.is_finite(), "x[{}] = {} is not finite", i, xi);
1470        }
1471
1472        // Compute residual: ||Ax - b|| should be reduced from ||b||
1473        let residual: f64 = x.iter().zip(b.iter())
1474            .map(|(&xi, &bi)| (xi - bi).powi(2))
1475            .sum::<f64>()
1476            .sqrt();
1477        let bnorm: f64 = b.iter().map(|&bi| bi * bi).sum::<f64>().sqrt();
1478        assert!(residual < bnorm,
1479            "residual {} should be less than ||b|| = {}", residual, bnorm);
1480    }
1481
1482    #[test]
1483    fn test_laplacian_weights() {
1484        // Test laplacian_weights_ilsqr on a small 4x4x4 volume with a uniform field
1485        // inside a mask. A constant field has zero Laplacian, so weights should be 1.0.
1486        let (nx, ny, nz) = (4, 4, 4);
1487        let n_total = nx * ny * nz;
1488        let mut mask = vec![0u8; n_total];
1489        let mut field = vec![0.0; n_total];
1490
1491        // Create a sphere mask and constant field inside
1492        for k in 0..nz {
1493            for j in 0..ny {
1494                for i in 0..nx {
1495                    let idx = i + j * nx + k * nx * ny;
1496                    let ci = i as f64 - 1.5;
1497                    let cj = j as f64 - 1.5;
1498                    let ck = k as f64 - 1.5;
1499                    let r2 = ci * ci + cj * cj + ck * ck;
1500                    if r2 < 2.5 {
1501                        mask[idx] = 1;
1502                        field[idx] = 5.0; // constant field => Laplacian is 0
1503                    }
1504                }
1505            }
1506        }
1507
1508        let w = laplacian_weights_ilsqr(&field, &mask, nx, ny, nz, 1.0, 1.0, 1.0, 10.0, 90.0);
1509
1510        // All weights should be finite and in [0, 1]
1511        for (i, &wi) in w.iter().enumerate() {
1512            assert!(wi.is_finite(), "weight[{}] is not finite", i);
1513            assert!(wi >= 0.0 && wi <= 1.0, "weight[{}] = {} out of [0,1]", i, wi);
1514        }
1515
1516        // Masked-out voxels should have weight 0
1517        for i in 0..n_total {
1518            if mask[i] == 0 {
1519                assert_eq!(w[i], 0.0, "weight outside mask should be 0 at index {}", i);
1520            }
1521        }
1522    }
1523
1524    #[test]
1525    fn test_dipole_kspace_weights() {
1526        // Test dipole_kspace_weights_ilsqr with synthetic dipole values
1527        let d = vec![0.0, 0.01, 0.1, 0.3, 0.5, 0.7, 1.0, -0.5, -1.0, 0.0];
1528
1529        let w = dipole_kspace_weights_ilsqr(&d, 1.0, 1.0, 90.0);
1530
1531        // All weights should be in [0, 1]
1532        for (i, &wi) in w.iter().enumerate() {
1533            assert!(wi >= 0.0 && wi <= 1.0,
1534                "weight[{}] = {} out of [0,1]", i, wi);
1535        }
1536
1537        // Zero dipole values should produce weight 0 (or very small) since |0|^n = 0
1538        assert!(w[0] <= 1e-10, "weight at D=0 should be ~0, got {}", w[0]);
1539
1540        // The largest |D| values should have weight near 1.0
1541        // d[6]=1.0 and d[8]=-1.0 have the largest |D|
1542        assert!(w[6] > 0.5, "weight at |D|=1.0 should be large, got {}", w[6]);
1543        assert!(w[8] > 0.5, "weight at |D|=1.0 should be large, got {}", w[8]);
1544    }
1545
1546    #[test]
1547    fn test_gradient_weights() {
1548        // Test gradient_weights_ilsqr on a small 4x4x4 volume
1549        let (nx, ny, nz) = (4, 4, 4);
1550        let n_total = nx * ny * nz;
1551
1552        // Create a mask (all ones for simplicity)
1553        let mask = vec![1u8; n_total];
1554
1555        // Create a field with a linear gradient in x
1556        let mut field = vec![0.0; n_total];
1557        for k in 0..nz {
1558            for j in 0..ny {
1559                for i in 0..nx {
1560                    let idx = i + j * nx + k * nx * ny;
1561                    field[idx] = i as f64; // linear in x
1562                }
1563            }
1564        }
1565
1566        let (wx, wy, wz) = gradient_weights_ilsqr(
1567            &field, &mask, nx, ny, nz, 1.0, 1.0, 1.0, 10.0, 90.0
1568        );
1569
1570        // All weights should be finite and in [0, 1]
1571        for i in 0..n_total {
1572            assert!(wx[i].is_finite(), "wx[{}] is not finite", i);
1573            assert!(wy[i].is_finite(), "wy[{}] is not finite", i);
1574            assert!(wz[i].is_finite(), "wz[{}] is not finite", i);
1575            assert!(wx[i] >= 0.0 && wx[i] <= 1.0, "wx[{}] = {} out of [0,1]", i, wx[i]);
1576            assert!(wy[i] >= 0.0 && wy[i] <= 1.0, "wy[{}] = {} out of [0,1]", i, wy[i]);
1577            assert!(wz[i] >= 0.0 && wz[i] <= 1.0, "wz[{}] = {} out of [0,1]", i, wz[i]);
1578        }
1579
1580        // The y and z gradients are zero for this field, so wy and wz should reflect
1581        // that all gradient values are identical (zero). Check they are well-defined.
1582        let wy_sum: f64 = wy.iter().sum();
1583        let wz_sum: f64 = wz.iter().sum();
1584        assert!(wy_sum.is_finite(), "wy sum is not finite");
1585        assert!(wz_sum.is_finite(), "wz sum is not finite");
1586    }
1587
1588    #[test]
1589    fn test_ilsqr_small() {
1590        // Run ilsqr_simple on a small 8x8x8 volume with a sphere mask
1591        // and synthetic local field data. This exercises the full pipeline:
1592        // lsqr_step, fastqsm_step, susceptibility_artifacts_step.
1593        let (nx, ny, nz) = (8, 8, 8);
1594        let n_total = nx * ny * nz;
1595        let vsx = 1.0;
1596        let vsy = 1.0;
1597        let vsz = 1.0;
1598        let bdir = (0.0, 0.0, 1.0);
1599
1600        // Create a sphere mask centered in the volume
1601        let mut mask = vec![0u8; n_total];
1602        let cx = (nx as f64 - 1.0) / 2.0;
1603        let cy = (ny as f64 - 1.0) / 2.0;
1604        let cz = (nz as f64 - 1.0) / 2.0;
1605        let radius = 3.0;
1606
1607        for k in 0..nz {
1608            for j in 0..ny {
1609                for i in 0..nx {
1610                    let idx = i + j * nx + k * nx * ny;
1611                    let di = i as f64 - cx;
1612                    let dj = j as f64 - cy;
1613                    let dk = k as f64 - cz;
1614                    if di * di + dj * dj + dk * dk < radius * radius {
1615                        mask[idx] = 1;
1616                    }
1617                }
1618            }
1619        }
1620
1621        // Create synthetic local field: a simple dipole-like pattern
1622        // Use a small susceptibility source and forward-model through the dipole kernel
1623        let mut field = vec![0.0; n_total];
1624        for k in 0..nz {
1625            for j in 0..ny {
1626                for i in 0..nx {
1627                    let idx = i + j * nx + k * nx * ny;
1628                    if mask[idx] > 0 {
1629                        let di = i as f64 - cx;
1630                        let dj = j as f64 - cy;
1631                        let dk = k as f64 - cz;
1632                        // Simulate a simple field variation
1633                        field[idx] = 0.01 * (dk * dk - di * di - dj * dj)
1634                            / (di * di + dj * dj + dk * dk + 1.0);
1635                    }
1636                }
1637            }
1638        }
1639
1640        let tol = 0.1;
1641        let maxit = 5; // Few iterations for speed
1642
1643        let chi = ilsqr_simple(&field, &mask, nx, ny, nz, vsx, vsy, vsz, bdir, tol, maxit);
1644
1645        // Check output dimensions
1646        assert_eq!(chi.len(), n_total, "output size mismatch");
1647
1648        // Check all values are finite
1649        for (i, &v) in chi.iter().enumerate() {
1650            assert!(v.is_finite(), "chi[{}] = {} is not finite", i, v);
1651        }
1652
1653        // Check mask is respected: outside mask should be zero
1654        for i in 0..n_total {
1655            if mask[i] == 0 {
1656                assert_eq!(chi[i], 0.0, "chi outside mask should be 0 at index {}", i);
1657            }
1658        }
1659
1660        // Check that the result is not all zeros inside the mask
1661        let inside_sum: f64 = chi.iter().zip(mask.iter())
1662            .filter(|(_, &m)| m > 0)
1663            .map(|(&v, _)| v.abs())
1664            .sum();
1665        assert!(inside_sum > 0.0, "chi should not be all zeros inside the mask");
1666    }
1667}