Skip to main content

qsm_core/separation/
chi_sep_medi.rs

1//! Chi-separation using MEDI-based coupled optimization
2//!
3//! Separates total susceptibility into paramagnetic (chi+) and diamagnetic (chi-)
4//! components using chi_pos + chi_neg formulation:
5//!   chi_pos >= 0 (paramagnetic, iron), in Hz internally
6//!   chi_neg <= 0 (diamagnetic, myelin), in Hz internally
7//!   chi_total = chi_pos + chi_neg
8//!
9//! Forward model (all in Hz):
10//!   field = D * (chi_pos + chi_neg)
11//!   R2'(Hz) = dr_p_eff * chi_pos + dr_q_eff * (-chi_neg)
12//!           = dr_p_eff * |chi_pos| + dr_q_eff * |chi_neg|
13//!   where dr_eff = ppm_factor * Dr (dimensionless effective relaxivity)
14//!
15//! The constraints (chi_pos >= 0, chi_neg <= 0) naturally break the gauge
16//! freedom of the chi_pos + chi_neg formulation. In most voxels, either
17//! chi_pos = 0 or chi_neg = 0, which pins one variable to its constraint
18//! boundary and prevents correlated drift.
19//!
20//! Reference:
21//! Shin, H., et al. (2021). "chi-separation: Magnetic susceptibility source
22//! separation toward iron and myelin mapping in the brain." NeuroImage, 240:118371.
23
24use num_complex::Complex32;
25use crate::fft::Fft3dWorkspaceF32;
26use crate::kernels::dipole::dipole_kernel_f32;
27use crate::inversion::medi::{
28    gradient_mask_f32,
29    fgrad_periodic_inplace_f32,
30    bdiv_periodic_inplace_f32,
31};
32use crate::utils::simd_ops::{
33    dot_product_f32, norm_squared_f32, axpy_f32, xpby_f32,
34    apply_gradient_weights_f32, compute_p_weights_f32,
35};
36
37/// Workspace for chi-separation — holds all reusable buffers (f32).
38struct ChiSepWorkspace {
39    n: usize,
40    nx: usize, ny: usize, nz: usize,
41    vsx: f32, vsy: f32, vsz: f32,
42
43    fft_ws: Fft3dWorkspaceF32,
44
45    gx: Vec<f32>,
46    gy: Vec<f32>,
47    gz: Vec<f32>,
48
49    reg_x: Vec<f32>,
50    reg_y: Vec<f32>,
51    reg_z: Vec<f32>,
52
53    div_buf: Vec<f32>,
54
55    complex_buf: Vec<Complex32>,
56    dipole_buf: Vec<f32>,
57
58    tmp: Vec<f32>,
59}
60
61impl ChiSepWorkspace {
62    fn new(nx: usize, ny: usize, nz: usize, vsx: f32, vsy: f32, vsz: f32) -> Self {
63        let n = nx * ny * nz;
64        Self {
65            n, nx, ny, nz, vsx, vsy, vsz,
66            fft_ws: Fft3dWorkspaceF32::new(nx, ny, nz),
67            gx: vec![0.0; n],
68            gy: vec![0.0; n],
69            gz: vec![0.0; n],
70            reg_x: vec![0.0; n],
71            reg_y: vec![0.0; n],
72            reg_z: vec![0.0; n],
73            div_buf: vec![0.0; n],
74            complex_buf: vec![Complex32::new(0.0, 0.0); n],
75            dipole_buf: vec![0.0; n],
76            tmp: vec![0.0; n],
77        }
78    }
79}
80
81/// Chi-separation using MEDI-based coupled optimization.
82///
83/// # Arguments
84/// * `local_field` - Local field map in Hz
85/// * `r2prime` - R2' map in Hz
86/// * `magnitude` - Magnitude image for edge weighting
87/// * `mask` - Binary brain mask, 1 = brain
88/// * `cf` - Central frequency in Hz (e.g. 123.2e6 for 3T)
89/// * `dr_pos` - Paramagnetic relaxivity in Hz/ppm (default: 114.0)
90/// * `dr_neg` - Diamagnetic relaxivity in Hz/ppm (default: 30.0)
91///
92/// # Returns
93/// `(chi_pos, chi_neg, chi_total)` — susceptibility maps in ppm
94#[allow(clippy::too_many_arguments)]
95pub fn chi_sep_medi(
96    local_field: &[f64],
97    r2prime: &[f64],
98    magnitude: &[f64],
99    mask: &[u8],
100    nx: usize, ny: usize, nz: usize,
101    vsx: f64, vsy: f64, vsz: f64,
102    bdir: (f64, f64, f64),
103    cf: f64,
104    lambda_para: f64,
105    lambda_dia: f64,
106    lambda_cpl: f64,
107    dr_pos: f64,
108    dr_neg: f64,
109    percentage: f64,
110    cg_tol: f64,
111    cg_max_iter: usize,
112    max_iter: usize,
113    tol: f64,
114) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
115    chi_sep_medi_with_progress(
116        local_field, r2prime, magnitude, mask,
117        nx, ny, nz, vsx, vsy, vsz, bdir, cf,
118        lambda_para, lambda_dia, lambda_cpl,
119        dr_pos, dr_neg, percentage,
120        cg_tol, cg_max_iter, max_iter, tol,
121        |_, _| {},
122    )
123}
124
125/// Chi-separation with progress callback.
126#[allow(clippy::too_many_arguments)]
127pub fn chi_sep_medi_with_progress<F>(
128    local_field: &[f64],
129    r2prime: &[f64],
130    magnitude: &[f64],
131    mask: &[u8],
132    nx: usize, ny: usize, nz: usize,
133    vsx: f64, vsy: f64, vsz: f64,
134    bdir: (f64, f64, f64),
135    cf: f64,
136    lambda_para: f64,
137    lambda_dia: f64,
138    lambda_cpl: f64,
139    dr_pos: f64,
140    dr_neg: f64,
141    percentage: f64,
142    cg_tol: f64,
143    cg_max_iter: usize,
144    max_iter: usize,
145    tol: f64,
146    mut progress_callback: F,
147) -> (Vec<f64>, Vec<f64>, Vec<f64>)
148where
149    F: FnMut(usize, usize),
150{
151    let n = nx * ny * nz;
152    let ppm_factor = (1.0e6 / cf) as f32;
153
154    let vsx_f32 = vsx as f32;
155    let vsy_f32 = vsy as f32;
156    let vsz_f32 = vsz as f32;
157    let bdir_f32 = (bdir.0 as f32, bdir.1 as f32, bdir.2 as f32);
158    let lambda_para_f32 = lambda_para as f32;
159    let lambda_dia_f32 = lambda_dia as f32;
160    let lambda_cpl_f32 = lambda_cpl as f32;
161    let cg_tol_f32 = cg_tol as f32;
162    let tol_f32 = tol as f32;
163
164    // Effective relaxivities: Dr(Hz/ppm) * ppm_factor(ppm/Hz) = dimensionless
165    // R2'(Hz) = Dr_pos * chi_ppm = Dr_pos * (chi_Hz * ppm_factor) = dr_p_eff * chi_Hz
166    let dr_p_eff = ppm_factor * dr_pos as f32;
167    let dr_q_eff = ppm_factor * dr_neg as f32;
168
169    // Auto-tune R2' normalization for field-strength-independent gauge breaking.
170    //
171    // The chi-sep gauge mode (χ+ grows, χ- shrinks equally) lives in the null space
172    // of the field Hessian and is ONLY constrained by R2'. The gauge eigenvalue is
173    // λ_cpl * (dr_p + dr_q)² which at 7T is only ~23 vs ~1000 for TV modes.
174    //
175    // We scale the R2' equation (both data and relaxivities) by r2_scale so that the
176    // gauge eigenvalue = max(λ_para, λ_dia), matching the TV regularization strength.
177    // This doesn't change the solution (R2' residual is still 0 at truth).
178    //
179    // r2_scale = sqrt(target / (λ_cpl * dr_sum²))
180    // → eigenvalue = λ_cpl * (r2_scale * dr_sum)² = target
181    let dr_sum = dr_p_eff + dr_q_eff;
182    let target_eig = 10.0 * lambda_para_f32.max(lambda_dia_f32);
183    let r2_scale = (target_eig / (lambda_cpl_f32 * dr_sum * dr_sum)).sqrt();
184    let dr_p_use = dr_p_eff * r2_scale;
185    let dr_q_use = dr_q_eff * r2_scale;
186
187    let field_f32: Vec<f32> = local_field.iter()
188        .zip(mask.iter())
189        .map(|(&v, &m)| if m != 0 { v as f32 } else { 0.0 })
190        .collect();
191    let r2p_f32: Vec<f32> = r2prime.iter()
192        .zip(mask.iter())
193        .map(|(&v, &m)| if m != 0 { (v as f32) * r2_scale } else { 0.0 })
194        .collect();
195    let mag_f32: Vec<f32> = magnitude.iter().map(|&v| v as f32).collect();
196
197    let mut ws = ChiSepWorkspace::new(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
198    let d_kernel = dipole_kernel_f32(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, bdir_f32);
199    let (mx, my, mz) = gradient_mask_f32(
200        &mag_f32, mask, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, percentage as f32,
201    );
202
203    // chi_pos >= 0 (paramagnetic), chi_neg <= 0 (diamagnetic), both in Hz
204    let mut chi_pos = vec![0.0f32; n];
205    let mut chi_neg = vec![0.0f32; n];
206
207    let mut vr_pos = vec![0.0f32; n];
208    let mut vr_neg = vec![0.0f32; n];
209    let mut vr_sum = vec![0.0f32; n];
210    let n2 = 2 * n;
211    let mut dx = vec![0.0f32; n2];
212    let mut rhs = vec![0.0f32; n2];
213    let mut chi_sum_buf = vec![0.0f32; n];
214
215    // TV weight on (chi+ + chi-) sum — paper uses 2*lambda for sum term
216    // Disabled for now (0.0) pending parameter tuning; the sum TV can over-couple components
217    let lambda_sum_f32 = 0.0_f32;
218
219    let eps = 1.0e-6_f32;
220
221    for iter in 0..max_iter {
222        progress_callback(iter + 1, max_iter);
223
224        // --- TV reweighting ---
225        fgrad_periodic_inplace_f32(&mut ws.gx, &mut ws.gy, &mut ws.gz,
226            &chi_pos, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
227        compute_p_weights_f32(&mut vr_pos, &mx, &my, &mz, &ws.gx, &ws.gy, &ws.gz, eps);
228
229        fgrad_periodic_inplace_f32(&mut ws.gx, &mut ws.gy, &mut ws.gz,
230            &chi_neg, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
231        compute_p_weights_f32(&mut vr_neg, &mx, &my, &mz, &ws.gx, &ws.gy, &ws.gz, eps);
232
233        // TV weights for the sum (chi+ + chi-)
234        for i in 0..n {
235            chi_sum_buf[i] = chi_pos[i] + chi_neg[i];
236        }
237        fgrad_periodic_inplace_f32(&mut ws.gx, &mut ws.gy, &mut ws.gz,
238            &chi_sum_buf, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
239        compute_p_weights_f32(&mut vr_sum, &mx, &my, &mz, &ws.gx, &ws.gy, &ws.gz, eps);
240
241        // --- Residuals ---
242        // field_residual = field - D*(chi_pos + chi_neg)
243        for i in 0..n {
244            ws.tmp[i] = chi_pos[i] + chi_neg[i];
245        }
246        let chi_sum = ws.tmp.clone();
247        ws.fft_ws.apply_dipole_inplace(&chi_sum, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
248        let field_residual: Vec<f32> = field_f32.iter()
249            .zip(ws.dipole_buf.iter())
250            .map(|(&f, &d)| f - d)
251            .collect();
252
253        // R2' residual (normalized): r2_norm - dr_p_use * chi_pos + dr_q_use * chi_neg
254        // where r2_norm = R2' / (dr_p_eff + dr_q_eff), dr_p_use = dr_p_eff / dr_sum
255        let r2_residual: Vec<f32> = (0..n).map(|i| {
256            r2p_f32[i] - dr_p_use * chi_pos[i] + dr_q_use * chi_neg[i]
257        }).collect();
258
259        // --- Build gradient (MEDI convention: b_orig = gradient, then negate) ---
260        //
261        // Gradient derived from first principles:
262        //   J = λ_para*TV(χ+) + λ_dia*TV(χ-) + λ_cpl/2*||field_res||² + λ_cpl/2*||r2_res||²
263        //
264        // ∂J/∂χ+ = λ_para*TV_grad(χ+) - λ_cpl*D(field_res) - λ_cpl*r2_res*dr_p_eff
265        // ∂J/∂χ- = λ_dia*TV_grad(χ-)  - λ_cpl*D(field_res) + λ_cpl*r2_res*dr_q_eff
266        //
267        // TV_grad(χ) = bdiv(wG*Vr*wG*fgrad(χ)) [MEDI convention, IS the gradient]
268        //   because bdiv = -div, so bdiv(wG*Vr*wG*∇χ) = -div(wG*Vr*wG*∇χ) = ∂TV/∂χ
269        //
270        // Field: ∂||field_res||²/∂χ+ = -2*D(field_res), with 1/2 → -D(field_res)
271        //   Same for χ- since chi_total = χ+ + χ-, ∂/∂χ- has same sign
272        //
273        // R2': r2_res = R2' - dr_p*χ+ + dr_q*χ- (since |χ-| = -χ-)
274        //   ∂r2_res/∂χ+ = -dr_p → ∂||r2_res||²/∂χ+ = -2*r2_res*dr_p, with 1/2 → -r2_res*dr_p
275        //   ∂r2_res/∂χ- = +dr_q → ∂||r2_res||²/∂χ- = +2*r2_res*dr_q, with 1/2 → +r2_res*dr_q
276
277        // TV gradient for chi_pos
278        fgrad_periodic_inplace_f32(&mut ws.gx, &mut ws.gy, &mut ws.gz,
279            &chi_pos, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
280        apply_gradient_weights_f32(&mut ws.reg_x, &mut ws.reg_y, &mut ws.reg_z,
281            &mx, &my, &mz, &vr_pos, &ws.gx, &ws.gy, &ws.gz);
282        bdiv_periodic_inplace_f32(&mut ws.div_buf,
283            &ws.reg_x, &ws.reg_y, &ws.reg_z, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
284        for i in 0..n {
285            rhs[i] = lambda_para_f32 * ws.div_buf[i];
286        }
287
288        // TV gradient for chi_neg
289        fgrad_periodic_inplace_f32(&mut ws.gx, &mut ws.gy, &mut ws.gz,
290            &chi_neg, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
291        apply_gradient_weights_f32(&mut ws.reg_x, &mut ws.reg_y, &mut ws.reg_z,
292            &mx, &my, &mz, &vr_neg, &ws.gx, &ws.gy, &ws.gz);
293        bdiv_periodic_inplace_f32(&mut ws.div_buf,
294            &ws.reg_x, &ws.reg_y, &ws.reg_z, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
295        for i in 0..n {
296            rhs[n + i] = lambda_dia_f32 * ws.div_buf[i];
297        }
298
299        // TV gradient for (chi+ + chi-) sum — applied equally to both components
300        for i in 0..n {
301            chi_sum_buf[i] = chi_pos[i] + chi_neg[i];
302        }
303        fgrad_periodic_inplace_f32(&mut ws.gx, &mut ws.gy, &mut ws.gz,
304            &chi_sum_buf, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
305        apply_gradient_weights_f32(&mut ws.reg_x, &mut ws.reg_y, &mut ws.reg_z,
306            &mx, &my, &mz, &vr_sum, &ws.gx, &ws.gy, &ws.gz);
307        bdiv_periodic_inplace_f32(&mut ws.div_buf,
308            &ws.reg_x, &ws.reg_y, &ws.reg_z, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
309        for i in 0..n {
310            let tv_sum = lambda_sum_f32 * ws.div_buf[i];
311            rhs[i] += tv_sum;
312            rhs[n + i] += tv_sum;
313        }
314
315        // Field fidelity gradient: -λ_cpl * D(field_res), SAME sign for both
316        ws.fft_ws.apply_dipole_inplace(&field_residual, &d_kernel,
317            &mut ws.dipole_buf, &mut ws.complex_buf);
318        for i in 0..n {
319            let fg = lambda_cpl_f32 * ws.dipole_buf[i];
320            rhs[i] -= fg;
321            rhs[n + i] -= fg;
322        }
323
324        // R2' fidelity gradient (normalized):
325        //   ∂J/∂χ+ contribution: -λ_cpl * r2_res * dr_p_use
326        //   ∂J/∂χ- contribution: +λ_cpl * r2_res * dr_q_use
327        for i in 0..n {
328            if mask[i] == 0 { continue; }
329            rhs[i] -= lambda_cpl_f32 * r2_residual[i] * dr_p_use;
330            rhs[n + i] += lambda_cpl_f32 * r2_residual[i] * dr_q_use;
331        }
332
333        // Negate for CG: b = -gradient
334        for v in rhs.iter_mut() {
335            *v = -*v;
336        }
337
338        // --- CG solve ---
339        cg_solve_chisep(
340            &mut ws, &d_kernel,
341            &mx, &my, &mz,
342            &vr_pos, &vr_neg, &vr_sum,
343            lambda_para_f32, lambda_dia_f32, lambda_sum_f32, lambda_cpl_f32,
344            dr_p_use, dr_q_use,
345            mask,
346            &rhs, &mut dx,
347            cg_tol_f32, cg_max_iter,
348        );
349
350        // --- Update (half Newton step for stability with sign constraints) ---
351        for i in 0..n {
352            chi_pos[i] += 0.5 * dx[i];
353            chi_neg[i] += 0.5 * dx[n + i];
354        }
355
356        // --- Enforce constraints: chi_pos >= 0, chi_neg <= 0 ---
357        for i in 0..n {
358            if mask[i] == 0 {
359                chi_pos[i] = 0.0;
360                chi_neg[i] = 0.0;
361            } else {
362                chi_pos[i] = chi_pos[i].max(0.0);
363                chi_neg[i] = chi_neg[i].min(0.0);
364            }
365        }
366
367        // --- Convergence check ---
368        let update_norm = norm_squared_f32(&dx).sqrt();
369        let sol_norm = (norm_squared_f32(&chi_pos) + norm_squared_f32(&chi_neg)).sqrt();
370        let ratio = update_norm / (sol_norm + 1e-6);
371
372        if ratio < tol_f32 {
373            break;
374        }
375    }
376
377    // Convert Hz -> ppm
378    let chi_pos_out: Vec<f64> = chi_pos.iter()
379        .zip(mask.iter())
380        .map(|(&v, &m)| if m == 0 { 0.0 } else { (v * ppm_factor) as f64 })
381        .collect();
382    let chi_neg_out: Vec<f64> = chi_neg.iter()
383        .zip(mask.iter())
384        .map(|(&v, &m)| if m == 0 { 0.0 } else { (v * ppm_factor) as f64 })
385        .collect();
386    let chi_total: Vec<f64> = chi_pos_out.iter()
387        .zip(chi_neg_out.iter())
388        .map(|(&p, &n)| p + n)
389        .collect();
390
391    (chi_pos_out, chi_neg_out, chi_total)
392}
393
394/// Apply chi-sep Hessian operator A to doubled vector dx = [d_pos; d_neg].
395///
396/// A_pos = λ_para * TV_hess(d_pos) + λ_sum * TV_hess_sum(d_pos+d_neg)
397///       + λ_cpl * D²(d_pos + d_neg) + λ_cpl * dr_p * (dr_p * d_pos - dr_q * d_neg)
398/// A_neg = λ_dia * TV_hess(d_neg) + λ_sum * TV_hess_sum(d_pos+d_neg)
399///       + λ_cpl * D²(d_pos + d_neg) - λ_cpl * dr_q * (dr_p * d_pos - dr_q * d_neg)
400///
401/// TV_hessian uses IRLS: bdiv(wG*Vr*wG*fgrad(.)), positive semi-definite.
402/// Field fidelity: D², same for both, positive semi-definite.
403/// R2': rank-1 structure [dr_p, -dr_q]^T * [dr_p, -dr_q], positive semi-definite.
404#[allow(clippy::too_many_arguments)]
405fn apply_chisep_operator(
406    ws: &mut ChiSepWorkspace,
407    d_kernel: &[f32],
408    mx: &[f32], my: &[f32], mz: &[f32],
409    vr_pos: &[f32], vr_neg: &[f32], vr_sum: &[f32],
410    lambda_para: f32, lambda_dia: f32, lambda_sum: f32, lambda_cpl: f32,
411    dr_p: f32, dr_q: f32,
412    mask: &[u8],
413    dx: &[f32],
414    out: &mut [f32],
415) {
416    let n = ws.n;
417    let (nx, ny, nz) = (ws.nx, ws.ny, ws.nz);
418    let (vsx, vsy, vsz) = (ws.vsx, ws.vsy, ws.vsz);
419
420    let d_pos = &dx[..n];
421    let d_neg = &dx[n..];
422
423    // TV for chi_pos: λ_para * bdiv(wG*Vr_pos*wG*fgrad(d_pos))
424    fgrad_periodic_inplace_f32(&mut ws.gx, &mut ws.gy, &mut ws.gz,
425        d_pos, nx, ny, nz, vsx, vsy, vsz);
426    apply_gradient_weights_f32(&mut ws.reg_x, &mut ws.reg_y, &mut ws.reg_z,
427        mx, my, mz, vr_pos, &ws.gx, &ws.gy, &ws.gz);
428    bdiv_periodic_inplace_f32(&mut ws.div_buf,
429        &ws.reg_x, &ws.reg_y, &ws.reg_z, nx, ny, nz, vsx, vsy, vsz);
430    for i in 0..n {
431        out[i] = lambda_para * ws.div_buf[i];
432    }
433
434    // TV for chi_neg: λ_dia * bdiv(wG*Vr_neg*wG*fgrad(d_neg))
435    fgrad_periodic_inplace_f32(&mut ws.gx, &mut ws.gy, &mut ws.gz,
436        d_neg, nx, ny, nz, vsx, vsy, vsz);
437    apply_gradient_weights_f32(&mut ws.reg_x, &mut ws.reg_y, &mut ws.reg_z,
438        mx, my, mz, vr_neg, &ws.gx, &ws.gy, &ws.gz);
439    bdiv_periodic_inplace_f32(&mut ws.div_buf,
440        &ws.reg_x, &ws.reg_y, &ws.reg_z, nx, ny, nz, vsx, vsy, vsz);
441    for i in 0..n {
442        out[n + i] = lambda_dia * ws.div_buf[i];
443    }
444
445    // TV for sum (d_pos + d_neg): λ_sum * bdiv(wG*Vr_sum*wG*fgrad(d_pos+d_neg))
446    // Applied equally to both components
447    for i in 0..n {
448        ws.tmp[i] = d_pos[i] + d_neg[i];
449    }
450    let sum_for_tv = ws.tmp.clone();
451    fgrad_periodic_inplace_f32(&mut ws.gx, &mut ws.gy, &mut ws.gz,
452        &sum_for_tv, nx, ny, nz, vsx, vsy, vsz);
453    apply_gradient_weights_f32(&mut ws.reg_x, &mut ws.reg_y, &mut ws.reg_z,
454        mx, my, mz, vr_sum, &ws.gx, &ws.gy, &ws.gz);
455    bdiv_periodic_inplace_f32(&mut ws.div_buf,
456        &ws.reg_x, &ws.reg_y, &ws.reg_z, nx, ny, nz, vsx, vsy, vsz);
457    for i in 0..n {
458        let tv_s = lambda_sum * ws.div_buf[i];
459        out[i] += tv_s;
460        out[n + i] += tv_s;
461    }
462
463    // Field fidelity: λ_cpl * D²(d_pos + d_neg), SAME for both
464    for i in 0..n {
465        ws.tmp[i] = d_pos[i] + d_neg[i];
466    }
467    let sum_copy = ws.tmp.clone();
468    ws.fft_ws.apply_dipole_inplace(&sum_copy, d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
469    let d1_copy = ws.dipole_buf.clone();
470    ws.fft_ws.apply_dipole_inplace(&d1_copy, d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
471
472    for i in 0..n {
473        let ff = lambda_cpl * ws.dipole_buf[i];
474        out[i] += ff;
475        out[n + i] += ff;
476    }
477
478    // R2' fidelity: rank-1 Hessian [dr_p, -dr_q]^T * [dr_p, -dr_q]
479    // r2_lin = dr_p * d_pos - dr_q * d_neg
480    // out_pos += λ_cpl * dr_p * r2_lin
481    // out_neg -= λ_cpl * dr_q * r2_lin
482    for i in 0..n {
483        if mask[i] == 0 { continue; }
484        let r2_lin = dr_p * d_pos[i] - dr_q * d_neg[i];
485        out[i] += lambda_cpl * dr_p * r2_lin;
486        out[n + i] -= lambda_cpl * dr_q * r2_lin;
487    }
488}
489
490/// CG solver for the doubled chi-sep system.
491#[allow(clippy::too_many_arguments)]
492fn cg_solve_chisep(
493    ws: &mut ChiSepWorkspace,
494    d_kernel: &[f32],
495    mx: &[f32], my: &[f32], mz: &[f32],
496    vr_pos: &[f32], vr_neg: &[f32], vr_sum: &[f32],
497    lambda_para: f32, lambda_dia: f32, lambda_sum: f32, lambda_cpl: f32,
498    dr_p: f32, dr_q: f32,
499    mask: &[u8],
500    b: &[f32],
501    x: &mut [f32],
502    tol: f32,
503    max_iter: usize,
504) {
505    let n2 = 2 * ws.n;
506    x.fill(0.0);
507
508    let mut cg_r = vec![0.0f32; n2];
509    let mut cg_p = vec![0.0f32; n2];
510    let mut cg_ap = vec![0.0f32; n2];
511
512    cg_r.copy_from_slice(&b[..n2]);
513    cg_p.copy_from_slice(&cg_r);
514
515    let mut rsold = dot_product_f32(&cg_r, &cg_r);
516    let b_norm = dot_product_f32(b, b).sqrt();
517
518    if b_norm < 1e-10 {
519        return;
520    }
521
522    for _cg_iter in 0..max_iter {
523        apply_chisep_operator(
524            ws, d_kernel, mx, my, mz,
525            vr_pos, vr_neg, vr_sum,
526            lambda_para, lambda_dia, lambda_sum, lambda_cpl,
527            dr_p, dr_q,
528            mask,
529            &cg_p, &mut cg_ap,
530        );
531
532        let pap = dot_product_f32(&cg_p, &cg_ap);
533        if pap.abs() < 1e-15 {
534            break;
535        }
536
537        let alpha = rsold / pap;
538        axpy_f32(x, alpha, &cg_p);
539        axpy_f32(&mut cg_r, -alpha, &cg_ap);
540
541        let rsnew = dot_product_f32(&cg_r, &cg_r);
542        if rsnew.sqrt() < tol * b_norm {
543            break;
544        }
545
546        let beta_cg = rsnew / rsold;
547        xpby_f32(&mut cg_p, &cg_r, beta_cg);
548        rsold = rsnew;
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use crate::kernels::dipole::dipole_kernel;
556    use crate::fft::{fft3d_real, ifft3d_real};
557
558    fn make_sphere(nx: usize, ny: usize, nz: usize, cx: f64, cy: f64, cz: f64, r: f64) -> Vec<f64> {
559        let mut vol = vec![0.0; nx * ny * nz];
560        for k in 0..nz {
561            for j in 0..ny {
562                for i in 0..nx {
563                    let dx = i as f64 - cx;
564                    let dy = j as f64 - cy;
565                    let dz = k as f64 - cz;
566                    if dx * dx + dy * dy + dz * dz <= r * r {
567                        vol[i + j * nx + k * nx * ny] = 1.0;
568                    }
569                }
570            }
571        }
572        vol
573    }
574
575    #[test]
576    fn test_chi_sep_medi_basic() {
577        let (nx, ny, nz) = (32, 32, 32);
578        let n = nx * ny * nz;
579        let (vsx, vsy, vsz) = (1.0, 1.0, 1.0);
580        let bdir = (0.0, 0.0, 1.0);
581        let cf: f64 = 123.2e6; // 3T
582
583        let chi_pos_true_ppm = 0.05;
584        let chi_neg_true_ppm = -0.03;
585
586        let sphere_inner = make_sphere(nx, ny, nz, 16.0, 16.0, 16.0, 4.0);
587        let sphere_outer = make_sphere(nx, ny, nz, 16.0, 16.0, 16.0, 8.0);
588        let brain_mask = make_sphere(nx, ny, nz, 16.0, 16.0, 16.0, 12.0);
589
590        let hz_per_ppm = cf / 1.0e6;
591        let mut chi_pos_ppm = vec![0.0f64; n];
592        let mut chi_neg_ppm = vec![0.0f64; n];
593        for i in 0..n {
594            if sphere_inner[i] > 0.5 {
595                chi_pos_ppm[i] = chi_pos_true_ppm;
596            }
597            if sphere_outer[i] > 0.5 && sphere_inner[i] < 0.5 {
598                chi_neg_ppm[i] = chi_neg_true_ppm;
599            }
600        }
601
602        // Forward model: field_Hz = D * chi_total_Hz
603        let chi_total_hz: Vec<f64> = chi_pos_ppm.iter()
604            .zip(chi_neg_ppm.iter())
605            .map(|(&p, &n)| (p + n) * hz_per_ppm)
606            .collect();
607        let d = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
608        let chi_fft = fft3d_real(&chi_total_hz, nx, ny, nz);
609        let field_fft: Vec<_> = chi_fft.iter()
610            .zip(d.iter())
611            .map(|(&c, &dk)| c * dk)
612            .collect();
613        let local_field = ifft3d_real(&field_fft, nx, ny, nz);
614
615        // R2'(Hz) = Dr_pos * |chi+_ppm| + Dr_neg * |chi-_ppm|
616        let dr_pos: f64 = 114.0;
617        let dr_neg: f64 = 30.0;
618        let r2prime: Vec<f64> = (0..n).map(|i| {
619            dr_pos * chi_pos_ppm[i].abs() + dr_neg * chi_neg_ppm[i].abs()
620        }).collect();
621
622        let mask: Vec<u8> = brain_mask.iter()
623            .map(|&v| if v > 0.5 { 1 } else { 0 })
624            .collect();
625
626        let magnitude: Vec<f64> = (0..n).map(|i| {
627            if mask[i] == 0 { return 0.0; }
628            let base = 100.0;
629            if sphere_inner[i] > 0.5 {
630                base * 1.5
631            } else if sphere_outer[i] > 0.5 {
632                base * 0.7
633            } else {
634                base
635            }
636        }).collect();
637
638        let (chi_pos_out, chi_neg_out, chi_total_out) = chi_sep_medi(
639            &local_field, &r2prime, &magnitude, &mask,
640            nx, ny, nz, vsx, vsy, vsz, bdir, cf,
641            1000.0, 1000.0, 100.0,
642            dr_pos, dr_neg,
643            0.3, 0.01, 100, 10, 0.1,
644        );
645
646        // chi+ should be non-negative, chi- non-positive
647        for i in 0..n {
648            if mask[i] != 0 {
649                assert!(chi_pos_out[i] >= -1e-10,
650                    "chi+ should be non-negative at voxel {}, got {}", i, chi_pos_out[i]);
651                assert!(chi_neg_out[i] <= 1e-10,
652                    "chi- should be non-positive at voxel {}, got {}", i, chi_neg_out[i]);
653            }
654        }
655
656        for i in 0..n {
657            let diff = (chi_total_out[i] - chi_pos_out[i] - chi_neg_out[i]).abs();
658            assert!(diff < 1e-10, "chi_total != chi+ + chi- at voxel {}", i);
659        }
660
661        let pos_max = chi_pos_out.iter().cloned().fold(0.0_f64, f64::max);
662        let neg_min = chi_neg_out.iter().cloned().fold(0.0_f64, f64::min);
663        assert!(pos_max > 0.0, "chi+ should have positive values, max={}", pos_max);
664        assert!(neg_min < 0.0, "chi- should have negative values, min={}", neg_min);
665    }
666}