Skip to main content

qsm_core/inversion/
nltv.rs

1//! Nonlinear Total Variation (NLTV) regularized dipole inversion
2//!
3//! NLTV extends standard TV by using iteratively reweighted minimization,
4//! which produces sharper edges and better preserves fine details.
5//!
6//! The method solves:
7//! min_x ||Dx - f||₂² + λ Σ w_i |∇x|_i
8//!
9//! where weights w_i are iteratively updated based on the current solution.
10//!
11//! Reference:
12//! Kames, C., Wiggermann, V., Rauscher, A. (2018).
13//! "Rapid two-step dipole inversion for susceptibility mapping with sparsity priors."
14//! NeuroImage, 167:276-283. https://doi.org/10.1016/j.neuroimage.2017.11.018
15//!
16//! Reference implementation: https://github.com/kamesy/QSM.jl
17
18use num_complex::Complex64;
19use crate::fft::Fft3dWorkspace;
20use crate::kernels::dipole::dipole_kernel;
21use crate::kernels::laplacian::laplacian_kernel;
22use crate::utils::gradient::{bdiv_inplace, fgrad_inplace};
23
24/// Weighted soft thresholding operator
25#[inline]
26fn weighted_shrink(x: f64, threshold: f64, weight: f64) -> f64 {
27    let t = threshold * weight;
28    if x > t {
29        x - t
30    } else if x < -t {
31        x + t
32    } else {
33        0.0
34    }
35}
36
37/// NLTV algorithm parameters
38#[derive(Clone, Debug)]
39pub struct NltvParams {
40    /// Regularization parameter
41    pub lambda: f64,
42    /// Penalty parameter
43    pub mu: f64,
44    /// Convergence tolerance
45    pub tol: f64,
46    /// Maximum ADMM iterations
47    pub max_iter: usize,
48    /// Newton iterations for weight update
49    pub newton_iter: usize,
50}
51
52impl Default for NltvParams {
53    fn default() -> Self {
54        Self {
55            lambda: 1e-3,
56            mu: 1.0,
57            tol: 1e-3,
58            max_iter: 250,
59            newton_iter: 10,
60        }
61    }
62}
63
64/// NLTV dipole inversion using iteratively reweighted ADMM
65///
66/// # Arguments
67/// * `local_field` - Local field values (nx * ny * nz)
68/// * `mask` - Binary mask (nx * ny * nz), 1 = inside ROI
69/// * `nx`, `ny`, `nz` - Array dimensions
70/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
71/// * `bdir` - B0 field direction
72/// * `lambda` - Regularization parameter (typically 1e-3)
73/// * `mu` - Reweighting parameter for nonlinearity (typically 1.0)
74/// * `tol` - Convergence tolerance
75/// * `max_iter` - Maximum ADMM iterations
76/// * `newton_iter` - Reweighting updates (inner Newton-like iterations)
77///
78/// # Returns
79/// Susceptibility map
80pub fn nltv(
81    local_field: &[f64],
82    mask: &[u8],
83    nx: usize, ny: usize, nz: usize,
84    vsx: f64, vsy: f64, vsz: f64,
85    bdir: (f64, f64, f64),
86    lambda: f64,
87    mu: f64,
88    tol: f64,
89    max_iter: usize,
90    newton_iter: usize,
91) -> Vec<f64> {
92    nltv_with_progress(
93        local_field, mask, nx, ny, nz, vsx, vsy, vsz,
94        bdir, lambda, mu, tol, max_iter, newton_iter,
95        |_, _| {} // no-op progress callback
96    )
97}
98
99/// NLTV with progress callback
100pub fn nltv_with_progress<F>(
101    local_field: &[f64],
102    mask: &[u8],
103    nx: usize, ny: usize, nz: usize,
104    vsx: f64, vsy: f64, vsz: f64,
105    bdir: (f64, f64, f64),
106    lambda: f64,
107    mu: f64,
108    tol: f64,
109    max_iter: usize,
110    newton_iter: usize,
111    mut progress_callback: F,
112) -> Vec<f64>
113where
114    F: FnMut(usize, usize),
115{
116    let n_total = nx * ny * nz;
117    let eps = 1e-6; // Small constant to avoid division by zero
118
119    // ========================================================================
120    // Pre-compute kernels (done once)
121    // ========================================================================
122
123    // Create FFT workspace (caches plans and scratch buffers for reuse)
124    let mut fft_ws = Fft3dWorkspace::new(nx, ny, nz);
125
126    let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
127    let l_kernel = laplacian_kernel(nx, ny, nz, vsx, vsy, vsz, true);
128
129    // FFT of Laplacian kernel
130    let mut l_complex: Vec<Complex64> = l_kernel.iter()
131        .map(|&x| Complex64::new(x, 0.0))
132        .collect();
133    fft_ws.fft3d(&mut l_complex);
134
135    // Compute rho adaptively (for ADMM)
136    let rho = 100.0 * lambda;
137
138    // Pre-compute inverse of (D^H D + ρ L)
139    let mut inv_a: Vec<f64> = vec![0.0; n_total];
140    for i in 0..n_total {
141        let a = d_kernel[i] * d_kernel[i] + rho * l_complex[i].re;
142        inv_a[i] = if a.abs() > 1e-20 { 1.0 / a } else { 0.0 };
143    }
144
145    // Pre-compute D^H * FFT(f)
146    let f_hat = &mut l_complex;
147    for i in 0..n_total {
148        f_hat[i] = Complex64::new(local_field[i], 0.0);
149    }
150    fft_ws.fft3d(f_hat);
151    for i in 0..n_total {
152        f_hat[i] = f_hat[i] * d_kernel[i] * inv_a[i];
153    }
154
155    // ========================================================================
156    // Pre-allocate working buffers
157    // ========================================================================
158
159    let mut x = vec![0.0; n_total];
160    let mut x_prev = vec![0.0; n_total];
161
162    // Dual variables
163    let mut ux = vec![0.0; n_total];
164    let mut uy = vec![0.0; n_total];
165    let mut uz = vec![0.0; n_total];
166
167    // Gradient buffers
168    let mut gx = vec![0.0; n_total];
169    let mut gy = vec![0.0; n_total];
170    let mut gz = vec![0.0; n_total];
171
172    // Divergence buffer
173    let mut div_d = vec![0.0; n_total];
174
175    // Complex FFT buffer
176    let mut work_complex = vec![Complex64::new(0.0, 0.0); n_total];
177
178    // Adaptive weights for nonlinear term
179    let mut weights = vec![1.0; n_total];
180
181    let total_iter = max_iter * newton_iter;
182    let mut current_iter = 0;
183
184    // ========================================================================
185    // Outer loop: Newton-like reweighting
186    // ========================================================================
187    for _newton in 0..newton_iter {
188        let lambda_over_rho = lambda / rho;
189
190        // ====================================================================
191        // Inner loop: ADMM with current weights
192        // ====================================================================
193        for _iter in 0..max_iter {
194            current_iter += 1;
195            progress_callback(current_iter, total_iter);
196
197            // Swap x and x_prev
198            std::mem::swap(&mut x, &mut x_prev);
199
200            // ================================================================
201            // x-subproblem
202            // ================================================================
203            bdiv_inplace(&mut div_d, &gx, &gy, &gz, nx, ny, nz, vsx, vsy, vsz);
204
205            for i in 0..n_total {
206                work_complex[i] = Complex64::new(div_d[i], 0.0);
207            }
208            fft_ws.fft3d(&mut work_complex);
209
210            // x_hat = f_hat - rho * FFT(div) * inv_a
211            // Note: bdiv computes positive divergence ∇·, but the adjoint ∇ᵀ = -∇·,
212            // so we subtract (see Eq. [7] in Kames et al. 2018).
213            for i in 0..n_total {
214                work_complex[i] = f_hat[i] - rho * work_complex[i] * inv_a[i];
215            }
216
217            fft_ws.ifft3d(&mut work_complex);
218            for i in 0..n_total {
219                x[i] = work_complex[i].re;
220            }
221
222            // ================================================================
223            // Convergence check
224            // ================================================================
225            let mut norm_diff_sq = 0.0;
226            let mut norm_x_sq = 0.0;
227            for i in 0..n_total {
228                let diff = x[i] - x_prev[i];
229                norm_diff_sq += diff * diff;
230                norm_x_sq += x[i] * x[i];
231            }
232
233            let rel_change = norm_diff_sq.sqrt() / (norm_x_sq.sqrt() + 1e-20);
234            if rel_change < tol {
235                break;
236            }
237
238            // ================================================================
239            // z-subproblem + u-update with adaptive weights
240            // ================================================================
241            fgrad_inplace(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, vsy, vsz);
242
243            for i in 0..n_total {
244                let grad_x = gx[i];
245                let grad_y = gy[i];
246                let grad_z = gz[i];
247
248                let vx = grad_x + ux[i];
249                let vy = grad_y + uy[i];
250                let vz = grad_z + uz[i];
251
252                // Weighted soft thresholding
253                let zx_i = weighted_shrink(vx, lambda_over_rho, weights[i]);
254                let zy_i = weighted_shrink(vy, lambda_over_rho, weights[i]);
255                let zz_i = weighted_shrink(vz, lambda_over_rho, weights[i]);
256
257                // u update
258                ux[i] = vx - zx_i;
259                uy[i] = vy - zy_i;
260                uz[i] = vz - zz_i;
261
262                // Store (z - u_new) for next iteration's div
263                gx[i] = 2.0 * zx_i - vx;
264                gy[i] = 2.0 * zy_i - vy;
265                gz[i] = 2.0 * zz_i - vz;
266            }
267        }
268
269        // ====================================================================
270        // Update weights based on current gradient magnitude (Newton update)
271        // ====================================================================
272        fgrad_inplace(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, vsy, vsz);
273
274        for i in 0..n_total {
275            // Gradient magnitude
276            let grad_mag = (gx[i] * gx[i] + gy[i] * gy[i] + gz[i] * gz[i]).sqrt();
277
278            // Reweighting: w = 1 / (|∇x| + eps)^(1-q) where q is close to 1 for L1
279            // Using mu to control nonlinearity: w = 1 / (|∇x| + mu*eps)
280            weights[i] = 1.0 / (grad_mag + mu * eps);
281        }
282
283        // Normalize weights to prevent explosion
284        let max_weight: f64 = weights.iter().cloned().fold(0.0, f64::max);
285        if max_weight > 1.0 {
286            for w in weights.iter_mut() {
287                *w /= max_weight;
288            }
289        }
290    }
291
292    // Apply mask
293    for i in 0..n_total {
294        if mask[i] == 0 {
295            x[i] = 0.0;
296        }
297    }
298
299    x
300}
301
302/// NLTV with default parameters (matches QSM.jl nltv.jl defaults)
303pub fn nltv_default(
304    local_field: &[f64],
305    mask: &[u8],
306    nx: usize, ny: usize, nz: usize,
307    vsx: f64, vsy: f64, vsz: f64,
308) -> Vec<f64> {
309    nltv(
310        local_field, mask, nx, ny, nz, vsx, vsy, vsz,
311        (0.0, 0.0, 1.0),  // bdir
312        1e-3,             // lambda (QSM.jl default)
313        1.0,              // mu
314        1e-3,             // tol
315        250,              // max_iter
316        10                // newton_iter
317    )
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_nltv_zero_field() {
326        let n = 8;
327        let field = vec![0.0; n * n * n];
328        let mask = vec![1u8; n * n * n];
329
330        let chi = nltv(
331            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
332            (0.0, 0.0, 1.0), 1e-3, 1.0, 1e-2, 10, 2
333        );
334
335        for &val in chi.iter() {
336            assert!(val.abs() < 1e-8, "Zero field should give zero chi");
337        }
338    }
339
340    #[test]
341    fn test_nltv_finite() {
342        let n = 8;
343        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
344        let mask = vec![1u8; n * n * n];
345
346        let chi = nltv(
347            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
348            (0.0, 0.0, 1.0), 1e-3, 1.0, 1e-2, 10, 2
349        );
350
351        for (i, &val) in chi.iter().enumerate() {
352            assert!(val.is_finite(), "Chi should be finite at index {}", i);
353        }
354    }
355
356    #[test]
357    fn test_weighted_shrink() {
358        // w=1 should behave like regular shrink
359        assert!((weighted_shrink(1.0, 0.5, 1.0) - 0.5).abs() < 1e-10);
360        assert!((weighted_shrink(-1.0, 0.5, 1.0) - (-0.5)).abs() < 1e-10);
361        assert!((weighted_shrink(0.3, 0.5, 1.0) - 0.0).abs() < 1e-10);
362
363        // w=0.5 should have half the threshold
364        assert!((weighted_shrink(1.0, 0.5, 0.5) - 0.75).abs() < 1e-10);
365        assert!((weighted_shrink(0.3, 0.5, 0.5) - 0.05).abs() < 1e-10);
366    }
367}