Skip to main content

qsm_core/inversion/
tv.rs

1//! Total Variation (TV) regularized dipole inversion using ADMM
2//!
3//! Solves the L1-regularized inverse problem:
4//! min_x ||Dx - f||₂² + λ||∇x||₁
5//!
6//! using Alternating Direction Method of Multipliers (ADMM).
7//!
8//! Reference:
9//! Bilgic, B., Fan, A.P., Polimeni, J.R., et al. (2014).
10//! "Fast quantitative susceptibility mapping with L1-regularization and automatic
11//! parameter selection." Magnetic Resonance in Medicine, 72(5):1444-1459.
12//! https://doi.org/10.1002/mrm.25029
13//!
14//! Reference implementation: https://github.com/kamesy/QSM.jl
15
16use num_complex::Complex64;
17use crate::fft::Fft3dWorkspace;
18use crate::kernels::dipole::dipole_kernel;
19use crate::kernels::laplacian::laplacian_kernel;
20use crate::utils::gradient::{bdiv_inplace, fgrad_inplace};
21
22#[cfg(feature = "parallel")]
23use crate::par::*;
24
25/// Soft thresholding (shrinkage) operator for L1 regularization
26/// shrink(x, t) = sign(x) * max(|x| - t, 0)
27#[inline]
28fn shrink(x: f64, threshold: f64) -> f64 {
29    if x > threshold {
30        x - threshold
31    } else if x < -threshold {
32        x + threshold
33    } else {
34        0.0
35    }
36}
37
38/// TV-ADMM algorithm parameters
39#[derive(Clone, Debug)]
40pub struct TvParams {
41    /// Regularization parameter (typically 1e-3 to 1e-4)
42    pub lambda: f64,
43    /// ADMM penalty parameter (typically 100*lambda)
44    pub rho: f64,
45    /// Convergence tolerance
46    pub tol: f64,
47    /// Maximum iterations
48    pub max_iter: usize,
49}
50
51impl Default for TvParams {
52    fn default() -> Self {
53        Self {
54            lambda: 2e-4,
55            rho: 2e-2,
56            tol: 1e-3,
57            max_iter: 250,
58        }
59    }
60}
61
62/// TV-ADMM dipole inversion (optimized)
63///
64/// # Arguments
65/// * `local_field` - Local field values (nx * ny * nz)
66/// * `mask` - Binary mask (nx * ny * nz), 1 = inside ROI
67/// * `nx`, `ny`, `nz` - Array dimensions
68/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
69/// * `bdir` - B0 field direction
70/// * `lambda` - Regularization parameter (typically 1e-3 to 1e-4)
71/// * `rho` - ADMM penalty parameter (typically 100*lambda)
72/// * `tol` - Convergence tolerance
73/// * `max_iter` - Maximum iterations
74///
75/// # Returns
76/// Susceptibility map
77pub fn tv_admm(
78    local_field: &[f64],
79    mask: &[u8],
80    nx: usize, ny: usize, nz: usize,
81    vsx: f64, vsy: f64, vsz: f64,
82    bdir: (f64, f64, f64),
83    lambda: f64,
84    rho: f64,
85    tol: f64,
86    max_iter: usize,
87) -> Vec<f64> {
88    tv_admm_with_progress(
89        local_field, mask, nx, ny, nz, vsx, vsy, vsz,
90        bdir, lambda, rho, tol, max_iter,
91        |_, _| {} // no-op progress callback
92    )
93}
94
95/// TV-ADMM with progress callback (optimized)
96///
97/// Optimized implementation with:
98/// - Pre-allocated buffers (zero allocations per iteration)
99/// - In-place gradient/divergence operations
100/// - Buffer swapping instead of cloning
101/// - Fused z-subproblem and u-update
102///
103/// Same as `tv_admm` but calls `progress_callback(iteration, max_iter)` each iteration.
104pub fn tv_admm_with_progress<F>(
105    local_field: &[f64],
106    mask: &[u8],
107    nx: usize, ny: usize, nz: usize,
108    vsx: f64, vsy: f64, vsz: f64,
109    bdir: (f64, f64, f64),
110    lambda: f64,
111    rho: f64,
112    tol: f64,
113    max_iter: usize,
114    mut progress_callback: F,
115) -> Vec<f64>
116where
117    F: FnMut(usize, usize),
118{
119    let n_total = nx * ny * nz;
120
121    // ========================================================================
122    // Pre-compute kernels (done once)
123    // ========================================================================
124
125    // Create FFT workspace (caches plans and scratch buffers for reuse)
126    let mut fft_ws = Fft3dWorkspace::new(nx, ny, nz);
127
128    // Generate dipole kernel D
129    let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
130
131    // Generate negative Laplacian kernel (for -Δ = ∇ᵀ∇)
132    let l_kernel = laplacian_kernel(nx, ny, nz, vsx, vsy, vsz, true);
133
134    // FFT of Laplacian kernel
135    let mut l_complex: Vec<Complex64> = l_kernel.iter()
136        .map(|&x| Complex64::new(x, 0.0))
137        .collect();
138    fft_ws.fft3d(&mut l_complex);
139
140    // Pre-compute inverse of (D^H D + ρ L) for x-subproblem
141    let mut inv_a: Vec<f64> = vec![0.0; n_total];
142    for i in 0..n_total {
143        let a = d_kernel[i] * d_kernel[i] + rho * l_complex[i].re;
144        inv_a[i] = if a.abs() > 1e-20 { 1.0 / a } else { 0.0 };
145    }
146
147    // Pre-compute D^H * f for constant part of RHS (reuse l_complex as work buffer)
148    let f_hat = &mut l_complex; // Reuse buffer
149    for i in 0..n_total {
150        f_hat[i] = Complex64::new(local_field[i], 0.0);
151    }
152    fft_ws.fft3d(f_hat);
153
154    // f_hat = D^H * FFT(f) * inv_a
155    for i in 0..n_total {
156        f_hat[i] = f_hat[i] * d_kernel[i] * inv_a[i];
157    }
158
159    // ========================================================================
160    // Pre-allocate ALL working buffers (zero allocations in iteration loop)
161    // ========================================================================
162
163    // Solution and previous (for convergence check)
164    let mut x = vec![0.0; n_total];
165    let mut x_prev = vec![0.0; n_total];
166
167    // Dual variables (scaled Lagrange multipliers)
168    let mut ux = vec![0.0; n_total];
169    let mut uy = vec![0.0; n_total];
170    let mut uz = vec![0.0; n_total];
171
172    // Gradient buffers (reused for z-u computation)
173    let mut gx = vec![0.0; n_total];
174    let mut gy = vec![0.0; n_total];
175    let mut gz = vec![0.0; n_total];
176
177    // Divergence buffer
178    let mut div_d = vec![0.0; n_total];
179
180    // Complex FFT buffer (reused each iteration)
181    let mut work_complex = vec![Complex64::new(0.0, 0.0); n_total];
182
183    let lambda_over_rho = lambda / rho;
184
185    // ========================================================================
186    // ADMM iterations (zero allocations per iteration)
187    // ========================================================================
188    for iter in 0..max_iter {
189        // Report progress
190        progress_callback(iter + 1, max_iter);
191
192        // Swap x and x_prev (no allocation, just pointer swap)
193        std::mem::swap(&mut x, &mut x_prev);
194
195        // ====================================================================
196        // x-subproblem: solve (D^H D + ρ L) x = D^H f + ρ div(z - u)
197        // ====================================================================
198
199        // gx/gy/gz currently hold ∇x from previous iteration (or zero initially)
200        // After z-subproblem, they hold z - u (we compute this at end of loop)
201
202        // Compute div(z - u) into div_d
203        // On first iteration, gx/gy/gz are zero, so div is zero
204        bdiv_inplace(&mut div_d, &gx, &gy, &gz, nx, ny, nz, vsx, vsy, vsz);
205
206        // Prepare FFT: work_complex = div_d
207        for i in 0..n_total {
208            work_complex[i] = Complex64::new(div_d[i], 0.0);
209        }
210        fft_ws.fft3d(&mut work_complex);
211
212        // x_hat = f_hat - rho * FFT(div) * inv_a
213        // Note: bdiv computes positive divergence ∇·, but the adjoint ∇ᵀ = -∇·,
214        // so we subtract (see Eq. [7] in Kames et al. 2018).
215        for i in 0..n_total {
216            work_complex[i] = f_hat[i] - rho * work_complex[i] * inv_a[i];
217        }
218
219        // IFFT to get x
220        fft_ws.ifft3d(&mut work_complex);
221        for i in 0..n_total {
222            x[i] = work_complex[i].re;
223        }
224
225        // ====================================================================
226        // Convergence check (before z/u update for efficiency)
227        // ====================================================================
228        let mut norm_diff_sq = 0.0;
229        let mut norm_x_sq = 0.0;
230        for i in 0..n_total {
231            let diff = x[i] - x_prev[i];
232            norm_diff_sq += diff * diff;
233            norm_x_sq += x[i] * x[i];
234        }
235
236        let rel_change = norm_diff_sq.sqrt() / (norm_x_sq.sqrt() + 1e-20);
237        if rel_change < tol {
238            progress_callback(iter + 1, iter + 1);
239            break;
240        }
241
242        // ====================================================================
243        // Fused: z-subproblem + u-update + prepare (z-u) for next iteration
244        // ====================================================================
245
246        // Compute gradient of x into gx/gy/gz
247        fgrad_inplace(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, vsy, vsz);
248
249        // Fused z-subproblem + u-update
250        for i in 0..n_total {
251            let vx = gx[i] + ux[i];
252            let vy = gy[i] + uy[i];
253            let vz = gz[i] + uz[i];
254
255            let zx_i = shrink(vx, lambda_over_rho);
256            let zy_i = shrink(vy, lambda_over_rho);
257            let zz_i = shrink(vz, lambda_over_rho);
258
259            ux[i] = vx - zx_i;
260            uy[i] = vy - zy_i;
261            uz[i] = vz - zz_i;
262
263            gx[i] = 2.0 * zx_i - vx;
264            gy[i] = 2.0 * zy_i - vy;
265            gz[i] = 2.0 * zz_i - vz;
266        }
267    }
268
269    // Apply mask
270    for i in 0..n_total {
271        if mask[i] == 0 { x[i] = 0.0; }
272    }
273
274    x
275}
276
277/// TV-ADMM with default parameters
278pub fn tv_admm_default(
279    local_field: &[f64],
280    mask: &[u8],
281    nx: usize, ny: usize, nz: usize,
282    vsx: f64, vsy: f64, vsz: f64,
283) -> Vec<f64> {
284    let p = TvParams::default();
285    tv_admm(
286        local_field, mask, nx, ny, nz, vsx, vsy, vsz,
287        (0.0, 0.0, 1.0),
288        p.lambda, p.rho, p.tol, p.max_iter,
289    )
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::utils::gradient::fgrad;
296
297    #[test]
298    fn test_shrink() {
299        assert!((shrink(1.0, 0.5) - 0.5).abs() < 1e-10);
300        assert!((shrink(-1.0, 0.5) - (-0.5)).abs() < 1e-10);
301        assert!((shrink(0.3, 0.5) - 0.0).abs() < 1e-10);
302        assert!((shrink(-0.3, 0.5) - 0.0).abs() < 1e-10);
303    }
304
305    #[test]
306    fn test_tv_admm_zero_field() {
307        // Zero field should give zero susceptibility
308        let n = 8;
309        let field = vec![0.0; n * n * n];
310        let mask = vec![1u8; n * n * n];
311
312        let chi = tv_admm(
313            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
314            (0.0, 0.0, 1.0), 1e-3, 0.1, 1e-2, 10
315        );
316
317        for &val in chi.iter() {
318            assert!(val.abs() < 1e-8, "Zero field should give zero chi, got {}", val);
319        }
320    }
321
322    #[test]
323    fn test_tv_admm_finite() {
324        // Result should be finite
325        let n = 8;
326        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
327        let mask = vec![1u8; n * n * n];
328
329        let chi = tv_admm(
330            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
331            (0.0, 0.0, 1.0), 1e-3, 0.1, 1e-2, 10
332        );
333
334        for (i, &val) in chi.iter().enumerate() {
335            assert!(val.is_finite(), "Chi should be finite at index {}", i);
336        }
337    }
338
339    #[test]
340    fn test_tv_admm_smoother_than_tkd() {
341        // TV should produce smoother results than TKD
342        let n = 8;
343        // Create noisy field
344        let mut field = vec![0.0; n * n * n];
345        for i in 0..n*n*n {
346            field[i] = if i % 2 == 0 { 0.01 } else { -0.01 };  // Alternating
347        }
348        let mask = vec![1u8; n * n * n];
349
350        let chi_tv = tv_admm(
351            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
352            (0.0, 0.0, 1.0), 1e-2, 1.0, 1e-2, 50  // Strong regularization
353        );
354
355        // Compute total variation (L1 norm of gradient)
356        let (gx, gy, gz) = fgrad(&chi_tv, n, n, n, 1.0, 1.0, 1.0);
357        let tv: f64 = gx.iter().chain(gy.iter()).chain(gz.iter())
358            .map(|&g| g.abs())
359            .sum();
360
361        // TV result should have small total variation
362        // (exact value depends on parameters, but should be bounded)
363        assert!(tv.is_finite(), "TV should be finite");
364    }
365
366    /// Verify parallel and sequential TV-ADMM produce identical results.
367    #[cfg(feature = "parallel")]
368    #[test]
369    fn test_tv_parallel_matches_sequential() {
370        let n = 16;
371        let field: Vec<f64> = (0..n*n*n).map(|i| ((i as f64) * 0.7).sin() * 0.01).collect();
372        let mask = vec![1u8; n * n * n];
373
374        // Sequential (1 thread)
375        let pool_1 = rayon::ThreadPoolBuilder::new().num_threads(1).build().unwrap();
376        let chi_seq = pool_1.install(|| {
377            tv_admm(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
378                (0.0, 0.0, 1.0), 1e-3, 0.1, 1e-3, 50)
379        });
380
381        // Parallel (default threads)
382        let chi_par = tv_admm(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
383            (0.0, 0.0, 1.0), 1e-3, 0.1, 1e-3, 50);
384
385        // Compare
386        for (i, (s, p)) in chi_seq.iter().zip(chi_par.iter()).enumerate() {
387            assert!(
388                (s - p).abs() < 1e-10,
389                "TV mismatch at voxel {}: seq={} par={} diff={}",
390                i, s, p, (s - p).abs()
391            );
392        }
393    }
394}