Skip to main content

qsm_core/inversion/
rts.rs

1//! Rapid Two-Step (RTS) dipole inversion
2//!
3//! Two-step approach that combines:
4//! 1. LSMR for well-conditioned k-space regions
5//! 2. TV regularization for ill-conditioned regions
6//!
7//! Reference:
8//! Kames, C., Wiggermann, V., Rauscher, A. (2018).
9//! "Rapid two-step dipole inversion for susceptibility mapping with sparsity priors."
10//! NeuroImage, 167:276-283. https://doi.org/10.1016/j.neuroimage.2017.11.018
11//!
12//! Reference implementation: https://github.com/kamesy/QSM.jl
13
14use num_complex::Complex64;
15use crate::fft::Fft3dWorkspace;
16use crate::kernels::dipole::dipole_kernel;
17use crate::kernels::laplacian::laplacian_kernel;
18use crate::utils::gradient::{fgrad_inplace, bdiv_inplace};
19
20#[cfg(feature = "parallel")]
21use crate::par::*;
22
23/// Soft thresholding (shrinkage) operator
24#[inline]
25fn shrink(x: f64, threshold: f64) -> f64 {
26    if x > threshold {
27        x - threshold
28    } else if x < -threshold {
29        x + threshold
30    } else {
31        0.0
32    }
33}
34
35/// RTS algorithm parameters
36#[derive(Clone, Debug)]
37pub struct RtsParams {
38    /// Threshold for ill-conditioned region (typically 0.15)
39    pub delta: f64,
40    /// Regularization parameter for well-conditioned region (typically 1e5)
41    pub mu: f64,
42    /// ADMM penalty parameter (typically 10)
43    pub rho: f64,
44    /// Convergence tolerance
45    pub tol: f64,
46    /// Maximum ADMM iterations
47    pub max_iter: usize,
48    /// LSMR iterations for step 1 (typically 4)
49    pub lsmr_iter: usize,
50}
51
52impl Default for RtsParams {
53    fn default() -> Self {
54        Self {
55            delta: 0.15,
56            mu: 1e5,
57            rho: 10.0,
58            tol: 1e-2,
59            max_iter: 20,
60            lsmr_iter: 4,
61        }
62    }
63}
64
65/// RTS dipole inversion (optimized)
66///
67/// # Arguments
68/// * `local_field` - Local field values (nx * ny * nz)
69/// * `mask` - Binary mask (nx * ny * nz), 1 = inside ROI
70/// * `nx`, `ny`, `nz` - Array dimensions
71/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
72/// * `bdir` - B0 field direction
73/// * `delta` - Threshold for ill-conditioned region (typically 0.15)
74/// * `mu` - Regularization parameter for well-conditioned region (typically 1e5)
75/// * `rho` - ADMM penalty parameter (typically 10)
76/// * `tol` - Convergence tolerance
77/// * `max_iter` - Maximum ADMM iterations
78/// * `lsmr_iter` - LSMR iterations for step 1 (typically 4)
79///
80/// # Returns
81/// Susceptibility map
82pub fn rts(
83    local_field: &[f64],
84    mask: &[u8],
85    nx: usize, ny: usize, nz: usize,
86    vsx: f64, vsy: f64, vsz: f64,
87    bdir: (f64, f64, f64),
88    delta: f64,
89    mu: f64,
90    rho: f64,
91    tol: f64,
92    max_iter: usize,
93    lsmr_iter: usize,
94) -> Vec<f64> {
95    rts_with_progress(
96        local_field, mask, nx, ny, nz, vsx, vsy, vsz,
97        bdir, delta, mu, rho, tol, max_iter, lsmr_iter,
98        |_, _| {} // no-op progress callback
99    )
100}
101
102/// RTS with progress callback (optimized)
103///
104/// Optimized implementation with:
105/// - Pre-allocated buffers (zero allocations per iteration)
106/// - In-place gradient/divergence operations
107/// - Buffer swapping instead of cloning
108/// - Fused z-subproblem and u-update
109///
110/// Same as `rts` but calls `progress_callback(iteration, max_iter)` each iteration.
111pub fn rts_with_progress<F>(
112    local_field: &[f64],
113    mask: &[u8],
114    nx: usize, ny: usize, nz: usize,
115    vsx: f64, vsy: f64, vsz: f64,
116    bdir: (f64, f64, f64),
117    delta: f64,
118    mu: f64,
119    rho: f64,
120    tol: f64,
121    max_iter: usize,
122    lsmr_iter: usize,
123    mut progress_callback: F,
124) -> Vec<f64>
125where
126    F: FnMut(usize, usize),
127{
128    let n_total = nx * ny * nz;
129
130    // ========================================================================
131    // Pre-compute kernels (done once)
132    // ========================================================================
133
134    // Create FFT workspace (caches plans and scratch buffers for reuse)
135    let mut fft_ws = Fft3dWorkspace::new(nx, ny, nz);
136
137    // Generate dipole kernel D
138    let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
139
140    // Generate negative Laplacian kernel
141    let l_kernel = laplacian_kernel(nx, ny, nz, vsx, vsy, vsz, true);
142
143    // FFT of Laplacian kernel (reuse buffer for other purposes later)
144    let mut work_complex: Vec<Complex64> = l_kernel.iter()
145        .map(|&x| Complex64::new(x, 0.0))
146        .collect();
147    fft_ws.fft3d(&mut work_complex);
148
149    // Compute well-conditioned mask M and inverse operator iA
150    let mut m_mask: Vec<f64> = vec![0.0; n_total];
151    let mut inv_a: Vec<f64> = vec![0.0; n_total];
152
153    for i in 0..n_total {
154        let l_fft_i = work_complex[i].re;
155        if d_kernel[i].abs() > delta {
156            m_mask[i] = mu;
157        }
158        let a = m_mask[i] + rho * l_fft_i;
159        if a.abs() > 1e-20 {
160            inv_a[i] = rho / a;
161        }
162    }
163
164    // ========================================================================
165    // Step 1: Well-conditioned k-space (simplified LSMR)
166    // ========================================================================
167
168    // FFT of field (reuse work_complex)
169    for i in 0..n_total {
170        work_complex[i] = Complex64::new(local_field[i], 0.0);
171    }
172    fft_ws.fft3d(&mut work_complex);
173
174    // Store field_fft for LSMR iterations
175    let field_fft: Vec<Complex64> = work_complex.clone();
176
177    // Initial estimate: chi = D * f / (D^2 + epsilon) for well-conditioned
178    // Stored in work_complex
179    for i in 0..n_total {
180        let d = d_kernel[i];
181        if d.abs() > delta {
182            work_complex[i] = field_fft[i] * d / (d * d + 1e-6);
183        } else {
184            work_complex[i] = Complex64::new(0.0, 0.0);
185        }
186    }
187
188    // Simple iterative refinement for well-conditioned region
189    // Use a temporary buffer for residual
190    let mut residual = vec![Complex64::new(0.0, 0.0); n_total];
191    for _ in 0..lsmr_iter {
192        // residual = f - D * chi
193        for i in 0..n_total {
194            residual[i] = field_fft[i] - work_complex[i] * d_kernel[i];
195        }
196
197        // update chi for well-conditioned region
198        for i in 0..n_total {
199            let d = d_kernel[i];
200            if d.abs() > delta {
201                work_complex[i] += residual[i] * d / (d * d + 1e-6);
202            }
203        }
204    }
205
206    // Transform to spatial domain
207    fft_ws.ifft3d(&mut work_complex);
208
209    // Initialize x and apply mask
210    let mut x = vec![0.0; n_total];
211    for i in 0..n_total {
212        x[i] = if mask[i] != 0 { work_complex[i].re } else { 0.0 };
213    }
214
215    // ========================================================================
216    // Pre-compute constant part of RHS for ADMM
217    // ========================================================================
218
219    // F_hat = inv_a * M * FFT(x) / rho
220    for i in 0..n_total {
221        work_complex[i] = Complex64::new(x[i], 0.0);
222    }
223    fft_ws.fft3d(&mut work_complex);
224
225    let mut f_hat: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); n_total];
226    for i in 0..n_total {
227        if m_mask[i].abs() > 1e-20 && inv_a[i].abs() > 1e-20 {
228            f_hat[i] = work_complex[i] * (m_mask[i] / rho) * inv_a[i];
229        }
230    }
231
232    // ========================================================================
233    // Pre-allocate ALL working buffers for ADMM (zero allocations in loop)
234    // ========================================================================
235
236    let mut x_prev = vec![0.0; n_total];
237
238    // Dual variables (scaled Lagrange multipliers)
239    let mut ux = vec![0.0; n_total];
240    let mut uy = vec![0.0; n_total];
241    let mut uz = vec![0.0; n_total];
242
243    // Gradient buffers (reused for z-u computation)
244    let mut gx = vec![0.0; n_total];
245    let mut gy = vec![0.0; n_total];
246    let mut gz = vec![0.0; n_total];
247
248    // Divergence buffer
249    let mut div_v = vec![0.0; n_total];
250
251    let inv_rho = 1.0 / rho;
252
253    // ========================================================================
254    // Step 2: ADMM iterations (zero allocations per iteration)
255    // ========================================================================
256
257    for iter in 0..max_iter {
258        progress_callback(iter + 1, max_iter);
259
260        // Swap x and x_prev (no allocation)
261        std::mem::swap(&mut x, &mut x_prev);
262
263        // ====================================================================
264        // x-subproblem: (M + ρL)x = F + ρ∇ᵀ(z - u)
265        // ====================================================================
266
267        // gx/gy/gz hold (z - u) from previous iteration (or zero initially)
268        bdiv_inplace(&mut div_v, &gx, &gy, &gz, nx, ny, nz, vsx, vsy, vsz);
269
270        // Prepare FFT
271        for i in 0..n_total {
272            work_complex[i] = Complex64::new(div_v[i], 0.0);
273        }
274        fft_ws.fft3d(&mut work_complex);
275
276        // x_hat = f_hat - inv_a * div_hat
277        // Note: bdiv computes positive divergence ∇·, but the adjoint ∇ᵀ = -∇·,
278        // so we subtract (see Eq. [7] in Kames et al. 2018).
279        for i in 0..n_total {
280            work_complex[i] = f_hat[i] - work_complex[i] * inv_a[i];
281        }
282
283        // IFFT to get x
284        fft_ws.ifft3d(&mut work_complex);
285        for i in 0..n_total {
286            x[i] = work_complex[i].re;
287        }
288
289        // ====================================================================
290        // Convergence check
291        // ====================================================================
292        let mut norm_diff_sq = 0.0;
293        let mut norm_x_sq = 0.0;
294        for i in 0..n_total {
295            let diff = x[i] - x_prev[i];
296            norm_diff_sq += diff * diff;
297            norm_x_sq += x[i] * x[i];
298        }
299
300        let rel_change = norm_diff_sq.sqrt() / (norm_x_sq.sqrt() + 1e-20);
301        if rel_change < tol {
302            progress_callback(iter + 1, iter + 1);
303            break;
304        }
305
306        // ====================================================================
307        // Fused: z-subproblem + u-update + prepare (z-u) for next iteration
308        // ====================================================================
309
310        fgrad_inplace(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, vsy, vsz);
311
312        // Fused z-subproblem + u-update
313        for i in 0..n_total {
314            let vx = gx[i] + ux[i];
315            let vy = gy[i] + uy[i];
316            let vz = gz[i] + uz[i];
317
318            let zx_i = shrink(vx, inv_rho);
319            let zy_i = shrink(vy, inv_rho);
320            let zz_i = shrink(vz, inv_rho);
321
322            ux[i] = vx - zx_i;
323            uy[i] = vy - zy_i;
324            uz[i] = vz - zz_i;
325
326            gx[i] = 2.0 * zx_i - vx;
327            gy[i] = 2.0 * zy_i - vy;
328            gz[i] = 2.0 * zz_i - vz;
329        }
330    }
331
332    // Apply mask
333    for i in 0..n_total {
334        if mask[i] == 0 { x[i] = 0.0; }
335    }
336
337    x
338}
339
340/// RTS with default parameters
341pub fn rts_default(
342    local_field: &[f64],
343    mask: &[u8],
344    nx: usize, ny: usize, nz: usize,
345    vsx: f64, vsy: f64, vsz: f64,
346) -> Vec<f64> {
347    let p = RtsParams::default();
348    rts(
349        local_field, mask, nx, ny, nz, vsx, vsy, vsz,
350        (0.0, 0.0, 1.0),
351        p.delta, p.mu, p.rho, p.tol, p.max_iter, p.lsmr_iter,
352    )
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_rts_zero_field() {
361        let n = 8;
362        let field = vec![0.0; n * n * n];
363        let mask = vec![1u8; n * n * n];
364
365        let chi = rts(
366            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
367            (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-2, 5, 2
368        );
369
370        for &val in chi.iter() {
371            assert!(val.abs() < 1e-6, "Zero field should give near-zero chi");
372        }
373    }
374
375    #[test]
376    fn test_rts_finite() {
377        let n = 8;
378        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
379        let mask = vec![1u8; n * n * n];
380
381        let chi = rts(
382            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
383            (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-2, 5, 2
384        );
385
386        for (i, &val) in chi.iter().enumerate() {
387            assert!(val.is_finite(), "Chi should be finite at index {}", i);
388        }
389    }
390
391    #[test]
392    fn test_rts_mask() {
393        let n = 8;
394        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
395        let mut mask = vec![1u8; n * n * n];
396        // Zero out some mask values
397        mask[0] = 0;
398        mask[10] = 0;
399
400        let chi = rts(
401            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
402            (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-2, 5, 2
403        );
404
405        assert_eq!(chi[0], 0.0, "Masked voxel should be zero");
406        assert_eq!(chi[10], 0.0, "Masked voxel should be zero");
407    }
408
409    /// Verify parallel and sequential RTS produce identical results.
410    #[cfg(feature = "parallel")]
411    #[test]
412    fn test_rts_parallel_matches_sequential() {
413        let n = 16;
414        let field: Vec<f64> = (0..n*n*n).map(|i| ((i as f64) * 0.7).sin() * 0.01).collect();
415        let mask = vec![1u8; n * n * n];
416
417        // Sequential (1 thread)
418        let pool_1 = rayon::ThreadPoolBuilder::new().num_threads(1).build().unwrap();
419        let chi_seq = pool_1.install(|| {
420            rts(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
421                (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-4, 20, 4)
422        });
423
424        // Parallel (default threads)
425        let chi_par = rts(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
426            (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-4, 20, 4);
427
428        // Compare
429        for (i, (s, p)) in chi_seq.iter().zip(chi_par.iter()).enumerate() {
430            assert!(
431                (s - p).abs() < 1e-10,
432                "RTS mismatch at voxel {}: seq={} par={} diff={}",
433                i, s, p, (s - p).abs()
434            );
435        }
436    }
437}