Skip to main content

qsm_core/inversion/
tikhonov.rs

1//! Tikhonov regularization for QSM
2//!
3//! Tikhonov regularization adds an L2 penalty term to stabilize the inversion:
4//!
5//! χ = argmin_x ||Dx - f||₂² + λ||Γx||₂²
6//!
7//! This has a closed-form solution in k-space:
8//! χ̂ = D* · f̂ / (|D|² + λ|Γ|²)
9//!
10//! Reference:
11//! Bilgic, B., et al. (2014). "Fast image reconstruction with L2-regularization."
12//! Journal of Magnetic Resonance Imaging, 40(1):181-191. https://doi.org/10.1002/jmri.24365
13//!
14//! Reference implementation: https://github.com/kamesy/QSM.jl
15
16use num_complex::Complex64;
17use crate::fft::{fft3d, ifft3d};
18use crate::kernels::dipole::dipole_kernel;
19use crate::kernels::laplacian::laplacian_kernel;
20
21/// Regularization type for Tikhonov
22#[derive(Clone, Copy, Debug)]
23pub enum Regularization {
24    /// Identity: λ||x||₂²
25    Identity,
26    /// Gradient: λ||∇x||₂² (uses negative Laplacian)
27    Gradient,
28    /// Laplacian: λ||∆x||₂²
29    Laplacian,
30}
31
32/// Tikhonov algorithm parameters
33#[derive(Clone, Debug)]
34pub struct TikhonovParams {
35    /// Regularization weight
36    pub lambda: f64,
37    /// Regularization type
38    pub reg: Regularization,
39}
40
41impl Default for TikhonovParams {
42    fn default() -> Self {
43        Self { lambda: 0.01, reg: Regularization::Identity }
44    }
45}
46
47/// Tikhonov regularization for dipole inversion
48///
49/// # Arguments
50/// * `local_field` - Local field values
51/// * `mask` - Binary mask (1 = inside ROI, 0 = outside)
52/// * `nx`, `ny`, `nz` - Array dimensions
53/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
54/// * `bdir` - B0 field direction
55/// * `lambda` - Regularization parameter (typically 1e-2 to 1e-4)
56/// * `reg` - Type of regularization
57///
58/// # Returns
59/// Susceptibility map
60pub fn tikhonov(
61    local_field: &[f64],
62    mask: &[u8],
63    nx: usize, ny: usize, nz: usize,
64    vsx: f64, vsy: f64, vsz: f64,
65    bdir: (f64, f64, f64),
66    lambda: f64,
67    reg: Regularization,
68) -> Vec<f64> {
69    let n_total = nx * ny * nz;
70
71    // Generate dipole kernel
72    let d = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
73
74    // Generate regularization kernel and FFT it
75    let gamma: Vec<f64> = match reg {
76        Regularization::Identity => {
77            vec![1.0; n_total]
78        }
79        Regularization::Gradient => {
80            // Negative Laplacian for gradient regularization
81            let l = laplacian_kernel(nx, ny, nz, vsx, vsy, vsz, true);
82            // FFT to get frequency response
83            let mut l_complex: Vec<Complex64> = l.iter()
84                .map(|&x| Complex64::new(x, 0.0))
85                .collect();
86            fft3d(&mut l_complex, nx, ny, nz);
87            // Take real part (Laplacian FFT is real)
88            l_complex.iter().map(|c| c.re).collect()
89        }
90        Regularization::Laplacian => {
91            // Laplacian squared
92            let l = laplacian_kernel(nx, ny, nz, vsx, vsy, vsz, false);
93            let mut l_complex: Vec<Complex64> = l.iter()
94                .map(|&x| Complex64::new(x, 0.0))
95                .collect();
96            fft3d(&mut l_complex, nx, ny, nz);
97            // |Γ|²
98            l_complex.iter().map(|c| c.re * c.re).collect()
99        }
100    };
101
102    // Compute Tikhonov inverse: D / (D² + λΓ)
103    let inv_d: Vec<f64> = d.iter().zip(gamma.iter()).map(|(&dval, &gval)| {
104        let denom = dval * dval + lambda * gval;
105        if denom.abs() > 1e-20 {
106            dval / denom
107        } else {
108            0.0
109        }
110    }).collect();
111
112    // Convert local field to complex
113    let mut field_complex: Vec<Complex64> = local_field.iter()
114        .map(|&x| Complex64::new(x, 0.0))
115        .collect();
116
117    // FFT of local field
118    fft3d(&mut field_complex, nx, ny, nz);
119
120    // Multiply by Tikhonov inverse
121    for i in 0..n_total {
122        field_complex[i] *= inv_d[i];
123    }
124
125    // IFFT to get susceptibility
126    ifft3d(&mut field_complex, nx, ny, nz);
127
128    // Extract real part and apply mask
129    let mut chi: Vec<f64> = field_complex.iter()
130        .map(|c| c.re)
131        .collect();
132
133    // Apply mask
134    for i in 0..n_total {
135        if mask[i] == 0 {
136            chi[i] = 0.0;
137        }
138    }
139
140    chi
141}
142
143/// Tikhonov with default parameters (identity regularization, λ=0.01; matches QSM.jl)
144pub fn tikhonov_default(
145    local_field: &[f64],
146    mask: &[u8],
147    nx: usize, ny: usize, nz: usize,
148    vsx: f64, vsy: f64, vsz: f64,
149) -> Vec<f64> {
150    tikhonov(
151        local_field, mask, nx, ny, nz, vsx, vsy, vsz,
152        (0.0, 0.0, 1.0), 0.01, Regularization::Identity
153    )
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_tikhonov_zero_field() {
162        let n = 8;
163        let field = vec![0.0; n * n * n];
164        let mask = vec![1u8; n * n * n];
165
166        let chi = tikhonov_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
167
168        for val in chi.iter() {
169            assert!(val.abs() < 1e-10, "Zero field should give zero chi");
170        }
171    }
172
173    #[test]
174    fn test_tikhonov_finite() {
175        let n = 8;
176        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.01).collect();
177        let mask = vec![1u8; n * n * n];
178
179        let chi = tikhonov_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
180
181        for (i, val) in chi.iter().enumerate() {
182            assert!(val.is_finite(), "Chi should be finite at index {}", i);
183        }
184    }
185
186    #[test]
187    fn test_tikhonov_regularization_types() {
188        let n = 8;
189        let field: Vec<f64> = (0..n*n*n).map(|i| ((i as f64) * 0.1).sin()).collect();
190        let mask = vec![1u8; n * n * n];
191        let bdir = (0.0, 0.0, 1.0);
192
193        let chi_id = tikhonov(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
194                             bdir, 0.01, Regularization::Identity);
195        let chi_grad = tikhonov(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
196                               bdir, 0.01, Regularization::Gradient);
197        let chi_lap = tikhonov(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
198                              bdir, 0.01, Regularization::Laplacian);
199
200        // All should be different
201        let diff_ig: f64 = chi_id.iter().zip(chi_grad.iter())
202            .map(|(a, b)| (a - b).abs()).sum();
203        let diff_gl: f64 = chi_grad.iter().zip(chi_lap.iter())
204            .map(|(a, b)| (a - b).abs()).sum();
205
206        assert!(diff_ig > 1e-10, "Identity and Gradient should differ");
207        assert!(diff_gl > 1e-10, "Gradient and Laplacian should differ");
208    }
209}