Skip to main content

qsm_core/inversion/
tgv.rs

1//! TGV-QSM: Total Generalized Variation for Quantitative Susceptibility Mapping
2//!
3//! Single-step QSM reconstruction from wrapped phase data using TGV regularization.
4//!
5//! References:
6//! Langkammer, C., Bredies, K., Poser, B.A., et al. (2015).
7//! "Fast quantitative susceptibility mapping using 3D EPI and total generalized variation."
8//! NeuroImage, 111:622-630. https://doi.org/10.1016/j.neuroimage.2015.02.041
9//!
10//! Chatnuntawech, I., McDaniel, P., et al. (2017).
11//! "Single-step quantitative susceptibility mapping with variational penalties."
12//! NMR in Biomedicine, 30(4):e3570. https://doi.org/10.1002/nbm.3570
13//!
14//! Reference implementation: https://github.com/korbinian90/QuantitativeSusceptibilityMappingTGV.jl
15//!
16//! The algorithm solves:
17//!   min_χ ||∇²(phase) - D*χ||₂² + α₁||∇χ - w||₁ + α₀||ε(w)||₁
18//!
19//! where:
20//! - χ is the susceptibility map
21//! - w is an auxiliary vector field (velocity)
22//! - ε(w) is the symmetric gradient of w
23//! - D is the dipole kernel
24//! - α₀, α₁ are TGV regularization parameters
25//!
26//! Optimizations:
27//! - Bounding box reduction: only process the region containing the mask
28//! - Pre-allocated buffers: all temporary arrays allocated once outside the loop
29//! - Early termination: convergence check every 100 iterations
30
31use std::f32::consts::PI;
32
33/// TGV parameters
34#[derive(Clone, Debug)]
35pub struct TgvParams {
36    /// First-order TGV weight (gradient term)
37    pub alpha1: f32,
38    /// Second-order TGV weight (symmetric gradient term)
39    pub alpha0: f32,
40    /// Number of primal-dual iterations
41    pub iterations: usize,
42    /// Number of mask erosions
43    pub erosions: usize,
44    /// Primal step size multiplier (larger = faster but less stable)
45    pub step_size: f32,
46    /// Field strength in Tesla
47    pub fieldstrength: f32,
48    /// Echo time in seconds
49    pub te: f32,
50    /// Convergence tolerance (relative change in chi)
51    pub tol: f32,
52}
53
54impl Default for TgvParams {
55    fn default() -> Self {
56        Self {
57            alpha1: 0.003,
58            alpha0: 0.002,
59            iterations: 1000,
60            erosions: 3,
61            step_size: 3.0,
62            fieldstrength: 3.0,
63            te: 0.020, // 20ms
64            tol: 1e-5,
65        }
66    }
67}
68
69/// Get default alpha values based on regularization level (1-4)
70///
71/// Matches Julia QuantitativeSusceptibilityMappingTGV.jl reference:
72/// `alpha = [0.001, 0.001] + [0.001, 0.002] * (regularization - 1)`
73/// Level 2 (default) gives α₀=0.002, α₁=0.003
74pub fn get_default_alpha(regularization: u8) -> (f32, f32) {
75    let reg = regularization.clamp(1, 4) as f32;
76    let alpha0 = 0.001 + 0.001 * (reg - 1.0);
77    let alpha1 = 0.001 + 0.002 * (reg - 1.0);
78    (alpha0.max(0.0), alpha1.max(0.0))
79}
80
81/// Get default number of iterations based on voxel size and step size.
82///
83/// Matches the Julia reference: `max(1000, 3200 / prod(res)^0.42) / step_size^0.6`
84pub fn get_default_iterations(res: (f32, f32, f32), step_size: f32) -> usize {
85    let prod_res = res.0 * res.1 * res.2;
86    let it = (1000.0_f32).max(3200.0 / prod_res.powf(0.42)) / step_size.powf(0.6);
87    it.round() as usize
88}
89
90/// Bounding box for mask region
91#[derive(Clone, Debug)]
92struct BoundingBox {
93    i_min: usize,
94    i_max: usize,
95    j_min: usize,
96    j_max: usize,
97    k_min: usize,
98    k_max: usize,
99}
100
101impl BoundingBox {
102    /// Find minimal bounding box containing the mask with padding
103    fn from_mask(mask: &[u8], nx: usize, ny: usize, nz: usize, padding: usize) -> Self {
104        let mut i_min = nx;
105        let mut i_max = 0;
106        let mut j_min = ny;
107        let mut j_max = 0;
108        let mut k_min = nz;
109        let mut k_max = 0;
110
111        for k in 0..nz {
112            for j in 0..ny {
113                for i in 0..nx {
114                    if mask[i + j * nx + k * nx * ny] != 0 {
115                        i_min = i_min.min(i);
116                        i_max = i_max.max(i);
117                        j_min = j_min.min(j);
118                        j_max = j_max.max(j);
119                        k_min = k_min.min(k);
120                        k_max = k_max.max(k);
121                    }
122                }
123            }
124        }
125
126        // Add padding and clamp to bounds
127        let i_min = i_min.saturating_sub(padding);
128        let j_min = j_min.saturating_sub(padding);
129        let k_min = k_min.saturating_sub(padding);
130        let i_max = (i_max + padding + 1).min(nx);
131        let j_max = (j_max + padding + 1).min(ny);
132        let k_max = (k_max + padding + 1).min(nz);
133
134        Self { i_min, i_max, j_min, j_max, k_min, k_max }
135    }
136
137    fn dims(&self) -> (usize, usize, usize) {
138        (self.i_max - self.i_min, self.j_max - self.j_min, self.k_max - self.k_min)
139    }
140
141    fn total(&self) -> usize {
142        let (bx, by, bz) = self.dims();
143        bx * by * bz
144    }
145}
146
147/// Extract sub-volume from full volume
148fn extract_subvolume<T: Copy + Default>(
149    full: &[T],
150    bbox: &BoundingBox,
151    nx: usize, ny: usize, _nz: usize,
152) -> Vec<T> {
153    let (bx, by, bz) = bbox.dims();
154    let mut sub = vec![T::default(); bx * by * bz];
155
156    for k in 0..bz {
157        for j in 0..by {
158            for i in 0..bx {
159                let full_idx = (bbox.i_min + i) + (bbox.j_min + j) * nx + (bbox.k_min + k) * nx * ny;
160                let sub_idx = i + j * bx + k * bx * by;
161                sub[sub_idx] = full[full_idx];
162            }
163        }
164    }
165    sub
166}
167
168/// Insert sub-volume back into full volume
169fn insert_subvolume<T: Copy>(
170    full: &mut [T],
171    sub: &[T],
172    bbox: &BoundingBox,
173    nx: usize, ny: usize, _nz: usize,
174) {
175    let (bx, by, bz) = bbox.dims();
176
177    for k in 0..bz {
178        for j in 0..by {
179            for i in 0..bx {
180                let full_idx = (bbox.i_min + i) + (bbox.j_min + j) * nx + (bbox.k_min + k) * nx * ny;
181                let sub_idx = i + j * bx + k * bx * by;
182                full[full_idx] = sub[sub_idx];
183            }
184        }
185    }
186}
187
188/// Compute dipole stencil (27-point spatial kernel)
189pub fn compute_dipole_stencil(
190    res: (f32, f32, f32),
191    b0_dir: (f32, f32, f32),
192) -> [[[f32; 3]; 3]; 3] {
193    let (dx, dy, dz) = res;
194    let (bx, by, bz) = b0_dir;
195
196    // Normalize B0 direction
197    let b_norm = (bx * bx + by * by + bz * bz).sqrt();
198    let bx = bx / b_norm;
199    let by = by / b_norm;
200    let bz = bz / b_norm;
201
202    let mut stencil = [[[0.0f32; 3]; 3]; 3];
203
204    let hx2 = 1.0 / (dx * dx);
205    let hy2 = 1.0 / (dy * dy);
206    let hz2 = 1.0 / (dz * dz);
207    let factor = 1.0 / 3.0;
208
209    // Center (i=1, j=1, k=1)
210    stencil[1][1][1] = -2.0 * (hx2 + hy2 + hz2) * factor
211                     + 2.0 * (bx * bx * hx2 + by * by * hy2 + bz * bz * hz2);
212
213    // X neighbors
214    stencil[0][1][1] = hx2 * factor - bx * bx * hx2;
215    stencil[2][1][1] = hx2 * factor - bx * bx * hx2;
216
217    // Y neighbors
218    stencil[1][0][1] = hy2 * factor - by * by * hy2;
219    stencil[1][2][1] = hy2 * factor - by * by * hy2;
220
221    // Z neighbors
222    stencil[1][1][0] = hz2 * factor - bz * bz * hz2;
223    stencil[1][1][2] = hz2 * factor - bz * bz * hz2;
224
225    // Cross terms for oblique B0
226    let hxy = 1.0 / (dx * dy);
227    let hxz = 1.0 / (dx * dz);
228    let hyz = 1.0 / (dy * dz);
229
230    let xy_factor = -bx * by * hxy;
231    stencil[0][0][1] = xy_factor;
232    stencil[2][2][1] = xy_factor;
233    stencil[0][2][1] = -xy_factor;
234    stencil[2][0][1] = -xy_factor;
235
236    let xz_factor = -bx * bz * hxz;
237    stencil[0][1][0] = xz_factor;
238    stencil[2][1][2] = xz_factor;
239    stencil[0][1][2] = -xz_factor;
240    stencil[2][1][0] = -xz_factor;
241
242    let yz_factor = -by * bz * hyz;
243    stencil[1][0][0] = yz_factor;
244    stencil[1][2][2] = yz_factor;
245    stencil[1][0][2] = -yz_factor;
246    stencil[1][2][0] = -yz_factor;
247
248    stencil
249}
250
251/// Compute SVD-fitted oblique 27-point dipole stencil via DST-based Poisson solves.
252///
253/// Matches the Julia reference implementation from QuantitativeSusceptibilityMappingTGV.jl.
254/// For each unique centrosymmetric pair of the 26 off-center stencil positions, a Poisson
255/// equation is solved using DST-I to obtain Green's function values. These are then fitted
256/// to the analytical dipole field via least squares, and resolution-weighted to produce
257/// the final stencil.
258///
259/// Note: The delta function for offsets (di,dj,dk) and (-di,-dj,-dk) are identical
260/// (both place +1 at mid±I and -2 at mid), so we only compute 13 unique Poisson
261/// solutions and solve a full-rank 13×13 system. This avoids the rank deficiency
262/// that would occur with all 26 positions.
263pub fn compute_oblique_stencil(
264    res: (f32, f32, f32),
265    b0_dir: (f32, f32, f32),
266) -> [[[f32; 3]; 3]; 3] {
267    let n: usize = 64;
268    let singularity_cutout = 4.0_f64;
269    let mid = n / 2; // 32 (0-indexed center)
270    let n2 = n * n;
271    let n3 = n * n * n;
272
273    let (dx, dy, dz) = (res.0 as f64, res.1 as f64, res.2 as f64);
274
275    // Normalize B0 direction
276    let b_norm = ((b0_dir.0 * b0_dir.0 + b0_dir.1 * b0_dir.1 + b0_dir.2 * b0_dir.2) as f64).sqrt();
277    let bdir = (b0_dir.0 as f64 / b_norm, b0_dir.1 as f64 / b_norm, b0_dir.2 as f64 / b_norm);
278
279    // Compute dipole field on 64³ grid
280    let mut d = vec![f64::NAN; n3];
281    let mut d_mask = vec![false; n3];
282
283    for k in 0..n {
284        for j in 0..n {
285            for i in 0..n {
286                let x = i as f64 - mid as f64;
287                let y = j as f64 - mid as f64;
288                let z = k as f64 - mid as f64;
289                let r = (x * x + y * y + z * z).sqrt();
290
291                let idx = i + j * n + k * n2;
292                if r < singularity_cutout {
293                    // d[idx] remains NAN, d_mask[idx] remains false
294                } else {
295                    let xz = (bdir.0 * x + bdir.1 * y + bdir.2 * z) / r;
296                    let kappa = (3.0 * xz * xz - 1.0) / (4.0 * std::f64::consts::PI * r * r * r);
297                    d[idx] = kappa;
298                    d_mask[idx] = true;
299                }
300            }
301        }
302    }
303
304    // DST-I eigenvalue components: coord2[k] = 2 * sin(π*(k+1) / (2*(N+1)))
305    let coord2_sq: Vec<f64> = (0..n)
306        .map(|k| {
307            let v = 2.0 * (std::f64::consts::PI * (k as f64 + 1.0) / (2.0 * (n as f64 + 1.0))).sin();
308            v * v
309        })
310        .collect();
311
312    // 3D eigenvalue grid: coord2_grid[i,j,k] = coord2[i]² + coord2[j]² + coord2[k]²
313    let mut coord2_grid = vec![0.0_f64; n3];
314    for k in 0..n {
315        for j in 0..n {
316            for i in 0..n {
317                coord2_grid[i + j * n + k * n2] = coord2_sq[i] + coord2_sq[j] + coord2_sq[k];
318            }
319        }
320    }
321
322    // Pre-compute DST-I sin table
323    let sin_table = dst_sin_table(n);
324    // Inverse DST-I scale: our DST computes Σ x[n]*sin(...) (no factor of 2),
325    // so the inverse is (2/(N+1))^3 for 3D (unlike FFTW which uses 1/(2*(N+1))^3).
326    let idst_scale = (2.0 / (n as f64 + 1.0)).powi(3);
327
328    // Enumerate all 26 stencil positions (excluding center), column-major order
329    let mut stencil_positions: Vec<(i32, i32, i32)> = Vec::with_capacity(26);
330    for dk in -1..=1_i32 {
331        for dj in -1..=1_i32 {
332            for di in -1..=1_i32 {
333                if di == 0 && dj == 0 && dk == 0 {
334                    continue;
335                }
336                stencil_positions.push((di, dj, dk));
337            }
338        }
339    }
340
341    // Collect valid dipole point indices
342    let valid_indices: Vec<usize> = d_mask
343        .iter()
344        .enumerate()
345        .filter(|(_, &m)| m)
346        .map(|(idx, _)| idx)
347        .collect();
348
349    // Exploit centrosymmetry: delta for (di,dj,dk) is identical to (-di,-dj,-dk),
350    // so only compute 13 unique Poisson solutions (one per centrosymmetric pair).
351    // Pair p corresponds to stencil_positions[p] and stencil_positions[25-p].
352    let num_pairs = 13;
353    let mut a_rows: Vec<Vec<f64>> = Vec::with_capacity(num_pairs);
354
355    for p in 0..num_pairs {
356        let (di, dj, dk) = stencil_positions[p];
357
358        // Create delta function: delta[mid+I] = 1, delta[mid-I] = 1, delta[mid] = -2
359        let mut delta = vec![0.0_f64; n3];
360        let pi = (mid as i32 + di) as usize;
361        let pj = (mid as i32 + dj) as usize;
362        let pk = (mid as i32 + dk) as usize;
363        let mi = (mid as i32 - di) as usize;
364        let mj = (mid as i32 - dj) as usize;
365        let mk = (mid as i32 - dk) as usize;
366
367        delta[pi + pj * n + pk * n2] = 1.0;
368        delta[mi + mj * n + mk * n2] += 1.0; // += handles coincident positions
369        delta[mid + mid * n + mid * n2] = -2.0;
370
371        // Forward 3D DST-I
372        let mut fdelta = dst3d(&delta, n, &sin_table);
373
374        // Poisson solve: divide by -coord2_grid
375        for idx in 0..n3 {
376            fdelta[idx] /= -coord2_grid[idx];
377        }
378
379        // Inverse 3D DST-I (= forward DST-I * scale)
380        let vdelta = dst3d(&fdelta, n, &sin_table);
381
382        // Extract values at valid dipole points, applying inverse scale
383        let row: Vec<f64> = valid_indices
384            .iter()
385            .map(|&idx| vdelta[idx] * idst_scale)
386            .collect();
387        a_rows.push(row);
388    }
389
390    // Extract dipole values at valid points
391    let d_valid: Vec<f64> = valid_indices.iter().map(|&idx| d[idx]).collect();
392
393    // Solve least squares: B^T * y ≈ d_valid via normal equations (13 unknowns, full rank)
394    // G = B * B^T (13×13), h = B * d_valid (13×1)
395    let mut g = vec![vec![0.0_f64; num_pairs]; num_pairs];
396    for i in 0..num_pairs {
397        for j in 0..=i {
398            let dot: f64 = a_rows[i]
399                .iter()
400                .zip(a_rows[j].iter())
401                .map(|(&a, &b)| a * b)
402                .sum();
403            g[i][j] = dot;
404            g[j][i] = dot;
405        }
406    }
407
408    let mut h = vec![0.0_f64; num_pairs];
409    for i in 0..num_pairs {
410        h[i] = a_rows[i]
411            .iter()
412            .zip(d_valid.iter())
413            .map(|(&a, &b)| a * b)
414            .sum();
415    }
416
417    // Solve G * y = h via eigendecomposition with thresholding (handles rank deficiency
418    // from centrosymmetry and other spatial symmetries like x-y equivalence)
419    let y = solve_symmetric_pseudoinverse(&g, &h);
420
421    // Assemble stencil: assign each pair's coefficient to both positions
422    // y[p] = x[p] + x[25-p] = 2*x[p] (by symmetry), so stencil[I] = 2*x[I] = y[p]
423    let mut result = [[[0.0f32; 3]; 3]; 3];
424    for p in 0..num_pairs {
425        let coeff = y[p] as f32;
426
427        // First member of pair
428        let (di, dj, dk) = stencil_positions[p];
429        result[(di + 1) as usize][(dj + 1) as usize][(dk + 1) as usize] = coeff;
430
431        // Second member (centrosymmetric partner)
432        let (di2, dj2, dk2) = stencil_positions[25 - p];
433        result[(di2 + 1) as usize][(dj2 + 1) as usize][(dk2 + 1) as usize] = coeff;
434    }
435
436    // Apply resolution weights: w = (i²/dx² + j²/dy² + k²/dz²) / (i² + j² + k²)
437    for dk in -1..=1_i32 {
438        for dj in -1..=1_i32 {
439            for di in -1..=1_i32 {
440                if di == 0 && dj == 0 && dk == 0 {
441                    continue;
442                }
443                let si = (di + 1) as usize;
444                let sj = (dj + 1) as usize;
445                let sk = (dk + 1) as usize;
446
447                let i2 = (di * di) as f64;
448                let j2 = (dj * dj) as f64;
449                let k2 = (dk * dk) as f64;
450                let weight = (i2 / (dx * dx) + j2 / (dy * dy) + k2 / (dz * dz)) / (i2 + j2 + k2);
451                result[si][sj][sk] *= weight as f32;
452            }
453        }
454    }
455
456    // Center = -sum(all others)
457    let mut total = 0.0f32;
458    for dk in 0..3 {
459        for dj in 0..3 {
460            for di in 0..3 {
461                if !(di == 1 && dj == 1 && dk == 1) {
462                    total += result[di][dj][dk];
463                }
464            }
465        }
466    }
467    result[1][1][1] = -total;
468
469    result
470}
471
472/// Build sin table for DST-I of given length.
473/// sin_table[n][k] = sin(π*(n+1)*(k+1)/(N+1))
474fn dst_sin_table(n: usize) -> Vec<Vec<f64>> {
475    let scale = std::f64::consts::PI / (n as f64 + 1.0);
476    let mut table = vec![vec![0.0_f64; n]; n];
477    for j in 0..n {
478        for k in 0..n {
479            table[j][k] = ((j as f64 + 1.0) * (k as f64 + 1.0) * scale).sin();
480        }
481    }
482    table
483}
484
485/// 3D separable DST-I (Type I Discrete Sine Transform).
486fn dst3d(input: &[f64], n: usize, sin_table: &[Vec<f64>]) -> Vec<f64> {
487    let n2 = n * n;
488    let mut data = input.to_vec();
489    let mut buf_in = vec![0.0_f64; n];
490    let mut buf_out = vec![0.0_f64; n];
491
492    // Transform along x (contiguous dimension)
493    for k in 0..n {
494        for j in 0..n {
495            let base = j * n + k * n2;
496            buf_in.copy_from_slice(&data[base..base + n]);
497            dst1(&buf_in, sin_table, &mut buf_out);
498            data[base..base + n].copy_from_slice(&buf_out);
499        }
500    }
501
502    // Transform along y (stride = n)
503    for k in 0..n {
504        for i in 0..n {
505            for j in 0..n {
506                buf_in[j] = data[i + j * n + k * n2];
507            }
508            dst1(&buf_in, sin_table, &mut buf_out);
509            for j in 0..n {
510                data[i + j * n + k * n2] = buf_out[j];
511            }
512        }
513    }
514
515    // Transform along z (stride = n*n)
516    for j in 0..n {
517        for i in 0..n {
518            for k in 0..n {
519                buf_in[k] = data[i + j * n + k * n2];
520            }
521            dst1(&buf_in, sin_table, &mut buf_out);
522            for k in 0..n {
523                data[i + j * n + k * n2] = buf_out[k];
524            }
525        }
526    }
527
528    data
529}
530
531/// 1D DST-I: X[k] = Σ x[n] * sin(π*(n+1)*(k+1)/(N+1))
532fn dst1(input: &[f64], sin_table: &[Vec<f64>], output: &mut [f64]) {
533    let n = input.len();
534    for k in 0..n {
535        let mut sum = 0.0_f64;
536        for j in 0..n {
537            sum += input[j] * sin_table[j][k];
538        }
539        output[k] = sum;
540    }
541}
542
543/// Solve a symmetric positive semi-definite system via eigendecomposition with thresholding.
544///
545/// Computes the pseudo-inverse solution: x = V * diag(1/λ, thresholded) * V^T * h
546/// where G = V Λ V^T is the eigendecomposition. Small eigenvalues (< threshold * max_eigenvalue)
547/// are zeroed, matching Julia's SVD-based approach for handling rank-deficient systems.
548fn solve_symmetric_pseudoinverse(g: &[Vec<f64>], h: &[f64]) -> Vec<f64> {
549    let n = h.len();
550
551    // Copy g for eigendecomposition (Jacobi method modifies in place)
552    let mut a: Vec<Vec<f64>> = g.to_vec();
553    let mut v = vec![vec![0.0_f64; n]; n];
554    for i in 0..n {
555        v[i][i] = 1.0;
556    }
557
558    // Jacobi eigendecomposition for symmetric matrices
559    let max_sweeps = 100;
560    let tol = 1e-15;
561
562    for _ in 0..max_sweeps {
563        // Find largest off-diagonal element
564        let mut max_off = 0.0_f64;
565        for i in 0..n {
566            for j in (i + 1)..n {
567                max_off = max_off.max(a[i][j].abs());
568            }
569        }
570        if max_off < tol {
571            break;
572        }
573
574        // Sweep all off-diagonal pairs
575        for p in 0..n {
576            for q in (p + 1)..n {
577                if a[p][q].abs() < tol {
578                    continue;
579                }
580
581                // Compute Givens rotation angle to zero a[p][q]
582                let app = a[p][p];
583                let aqq = a[q][q];
584                let apq = a[p][q];
585                let tau = (aqq - app) / (2.0 * apq);
586                let t = if tau >= 0.0 {
587                    1.0 / (tau + (1.0 + tau * tau).sqrt())
588                } else {
589                    -1.0 / (-tau + (1.0 + tau * tau).sqrt())
590                };
591                let c = 1.0 / (1.0 + t * t).sqrt();
592                let s = t * c;
593
594                // Update matrix A' = G^T A G (only rows/cols p, q change)
595                // First update off-diagonal rows
596                for i in 0..n {
597                    if i == p || i == q {
598                        continue;
599                    }
600                    let aip = a[i][p];
601                    let aiq = a[i][q];
602                    a[i][p] = c * aip - s * aiq;
603                    a[p][i] = a[i][p];
604                    a[i][q] = s * aip + c * aiq;
605                    a[q][i] = a[i][q];
606                }
607
608                // Update 2×2 block
609                a[p][p] = c * c * app - 2.0 * c * s * apq + s * s * aqq;
610                a[q][q] = s * s * app + 2.0 * c * s * apq + c * c * aqq;
611                a[p][q] = 0.0;
612                a[q][p] = 0.0;
613
614                // Accumulate eigenvectors: V' = V * G
615                for i in 0..n {
616                    let vip = v[i][p];
617                    let viq = v[i][q];
618                    v[i][p] = c * vip - s * viq;
619                    v[i][q] = s * vip + c * viq;
620                }
621            }
622        }
623    }
624
625    // Eigenvalues are on diagonal
626    let eigenvalues: Vec<f64> = (0..n).map(|i| a[i][i]).collect();
627    let max_eigen = eigenvalues.iter().cloned().fold(0.0_f64, |a, b| a.max(b.abs()));
628    let threshold = 1e-10 * max_eigen;
629
630    // Solve: x = V * diag(1/λ_thresholded) * V^T * h
631    let mut vt_h = vec![0.0_f64; n];
632    for i in 0..n {
633        for j in 0..n {
634            vt_h[i] += v[j][i] * h[j];
635        }
636    }
637
638    for i in 0..n {
639        if eigenvalues[i].abs() > threshold {
640            vt_h[i] /= eigenvalues[i];
641        } else {
642            vt_h[i] = 0.0;
643        }
644    }
645
646    let mut x = vec![0.0_f64; n];
647    for i in 0..n {
648        for j in 0..n {
649            x[i] += v[i][j] * vt_h[j];
650        }
651    }
652
653    x
654}
655
656/// Apply dipole stencil to a 3D volume
657/// Uses Neumann BC at boundaries (matching Julia's wave_local)
658fn apply_stencil(
659    output: &mut [f32],
660    input: &[f32],
661    stencil: &[[[f32; 3]; 3]; 3],
662    mask: &[u8],
663    nx: usize, ny: usize, nz: usize,
664) {
665    for k in 0..nz {
666        for j in 0..ny {
667            for i in 0..nx {
668                let idx = i + j * nx + k * nx * ny;
669
670                if mask[idx] == 0 {
671                    output[idx] = 0.0;
672                    continue;
673                }
674
675                // Julia's wave_local only computes if not at boundary
676                // If at boundary, result is 0
677                if i == 0 || j == 0 || k == 0 || i + 1 >= nx || j + 1 >= ny || k + 1 >= nz {
678                    output[idx] = 0.0;
679                    continue;
680                }
681
682                let mut sum = 0.0f32;
683
684                for dk in 0..3i32 {
685                    for dj in 0..3i32 {
686                        for di in 0..3i32 {
687                            let ni = (i as i32 + di - 1) as usize;
688                            let nj = (j as i32 + dj - 1) as usize;
689                            let nk = (k as i32 + dk - 1) as usize;
690
691                            let nidx = ni + nj * nx + nk * nx * ny;
692                            sum += stencil[di as usize][dj as usize][dk as usize] * input[nidx];
693                        }
694                    }
695                }
696
697                output[idx] = sum;
698            }
699        }
700    }
701}
702
703/// Compute Laplacian of wrapped phase using the DEL method
704pub fn compute_phase_laplacian(
705    phase: &[f32],
706    mask: &[u8],
707    nx: usize, ny: usize, nz: usize,
708    vsx: f32, vsy: f32, vsz: f32,
709) -> Vec<f32> {
710    let n_total = nx * ny * nz;
711
712    let sin_phase: Vec<f32> = phase.iter().map(|&p| p.sin()).collect();
713    let cos_phase: Vec<f32> = phase.iter().map(|&p| p.cos()).collect();
714
715    let lap_sin = compute_laplacian(&sin_phase, nx, ny, nz, vsx, vsy, vsz);
716    let lap_cos = compute_laplacian(&cos_phase, nx, ny, nz, vsx, vsy, vsz);
717
718    let mut laplacian = vec![0.0f32; n_total];
719    for i in 0..n_total {
720        if mask[i] != 0 {
721            laplacian[i] = lap_sin[i] * cos_phase[i] - lap_cos[i] * sin_phase[i];
722        }
723    }
724
725    laplacian
726}
727
728/// Compute discrete Laplacian of a 3D array
729fn compute_laplacian(
730    input: &[f32],
731    nx: usize, ny: usize, nz: usize,
732    vsx: f32, vsy: f32, vsz: f32,
733) -> Vec<f32> {
734    let n_total = nx * ny * nz;
735    let mut output = vec![0.0f32; n_total];
736
737    let hx2 = 1.0 / (vsx * vsx);
738    let hy2 = 1.0 / (vsy * vsy);
739    let hz2 = 1.0 / (vsz * vsz);
740    let center = -2.0 * (hx2 + hy2 + hz2);
741
742    for k in 0..nz {
743        let km1 = if k == 0 { 0 } else { k - 1 };
744        let kp1 = if k + 1 >= nz { nz - 1 } else { k + 1 };
745
746        for j in 0..ny {
747            let jm1 = if j == 0 { 0 } else { j - 1 };
748            let jp1 = if j + 1 >= ny { ny - 1 } else { j + 1 };
749
750            for i in 0..nx {
751                let im1 = if i == 0 { 0 } else { i - 1 };
752                let ip1 = if i + 1 >= nx { nx - 1 } else { i + 1 };
753
754                let idx = i + j * nx + k * nx * ny;
755
756                output[idx] = center * input[idx]
757                    + hx2 * (input[im1 + j * nx + k * nx * ny] + input[ip1 + j * nx + k * nx * ny])
758                    + hy2 * (input[i + jm1 * nx + k * nx * ny] + input[i + jp1 * nx + k * nx * ny])
759                    + hz2 * (input[i + j * nx + km1 * nx * ny] + input[i + j * nx + kp1 * nx * ny]);
760            }
761        }
762    }
763
764    output
765}
766
767/// Apply Laplacian with mask
768fn apply_laplacian_inplace(
769    output: &mut [f32],
770    input: &[f32],
771    mask: &[u8],
772    nx: usize, ny: usize, nz: usize,
773    vsx: f32, vsy: f32, vsz: f32,
774) {
775    let hx2 = 1.0 / (vsx * vsx);
776    let hy2 = 1.0 / (vsy * vsy);
777    let hz2 = 1.0 / (vsz * vsz);
778
779    for k in 0..nz {
780        let k_offset = k * nx * ny;
781
782        for j in 0..ny {
783            let j_offset = j * nx;
784
785            for i in 0..nx {
786                let idx = i + j_offset + k_offset;
787
788                if mask[idx] == 0 {
789                    output[idx] = 0.0;
790                    continue;
791                }
792
793                let a0 = input[idx];
794
795                // Neumann BC: use center value at boundary (matching Julia)
796                let a_xm = if i > 0 { input[(i - 1) + j_offset + k_offset] } else { a0 };
797                let a_xp = if i + 1 < nx { input[(i + 1) + j_offset + k_offset] } else { a0 };
798                let a_ym = if j > 0 { input[i + (j - 1) * nx + k_offset] } else { a0 };
799                let a_yp = if j + 1 < ny { input[i + (j + 1) * nx + k_offset] } else { a0 };
800                let a_zm = if k > 0 { input[i + j_offset + (k - 1) * nx * ny] } else { a0 };
801                let a_zp = if k + 1 < nz { input[i + j_offset + (k + 1) * nx * ny] } else { a0 };
802
803                // Laplacian: sum of second derivatives
804                output[idx] = hx2 * (a_xm - 2.0 * a0 + a_xp)
805                            + hy2 * (a_ym - 2.0 * a0 + a_yp)
806                            + hz2 * (a_zm - 2.0 * a0 + a_zp);
807            }
808        }
809    }
810}
811
812/// Erode mask by one voxel (6-connected)
813pub fn erode_mask(mask: &[u8], nx: usize, ny: usize, nz: usize) -> Vec<u8> {
814    let n_total = nx * ny * nz;
815    let mut eroded = vec![0u8; n_total];
816
817    for k in 0..nz {
818        for j in 0..ny {
819            for i in 0..nx {
820                let idx = i + j * nx + k * nx * ny;
821
822                if mask[idx] == 0 {
823                    continue;
824                }
825
826                // Voxels at the boundary of the volume are always eroded
827                if i == 0 || i == nx - 1 || j == 0 || j == ny - 1 || k == 0 || k == nz - 1 {
828                    continue;
829                }
830
831                // Check 6-connected neighbors
832                let all_neighbors = mask[idx - 1] != 0
833                    && mask[idx + 1] != 0
834                    && mask[idx - nx] != 0
835                    && mask[idx + nx] != 0
836                    && mask[idx - nx * ny] != 0
837                    && mask[idx + nx * ny] != 0;
838
839                eroded[idx] = if all_neighbors { 1 } else { 0 };
840            }
841        }
842    }
843
844    eroded
845}
846
847/// Compute gradient norm squared
848#[inline]
849fn grad_norm_sq(res: (f32, f32, f32)) -> f32 {
850    let (dx, dy, dz) = res;
851    4.0 * (1.0 / (dx * dx) + 1.0 / (dy * dy) + 1.0 / (dz * dz))
852}
853
854/// Compute the squared spectral norm of the TGV operator matrix via power iteration
855///
856/// The operator matrix M is:
857/// [0,   g,  1 ]
858/// [0,   0,  g ]
859/// [g², w,  0 ]
860///
861/// We compute the largest eigenvalue of M^T * M using power iteration.
862fn compute_operator_norm_sqr(g: f32, g2: f32, w: f32) -> f32 {
863    // M^T * M = [g⁴,   g²w,    0    ]
864    //           [g²w,  g²+w²,  g    ]
865    //           [0,    g,      g²+1 ]
866    let g4 = g2 * g2;
867    let g2w = g2 * w;
868    let g2_w2 = g2 + w * w;
869    let g2_1 = g2 + 1.0;
870
871    // Power iteration to find largest eigenvalue
872    let mut v = [1.0f32, 1.0, 1.0];
873
874    for _ in 0..20 {
875        // Matrix-vector multiply: y = (M^T * M) * v
876        let y0 = g4 * v[0] + g2w * v[1];
877        let y1 = g2w * v[0] + g2_w2 * v[1] + g * v[2];
878        let y2 = g * v[1] + g2_1 * v[2];
879
880        // Compute norm
881        let norm = (y0 * y0 + y1 * y1 + y2 * y2).sqrt();
882        if norm < 1e-10 {
883            break;
884        }
885
886        // Normalize
887        v[0] = y0 / norm;
888        v[1] = y1 / norm;
889        v[2] = y2 / norm;
890    }
891
892    // Rayleigh quotient: eigenvalue = v^T * (M^T * M) * v
893    let y0 = g4 * v[0] + g2w * v[1];
894    let y1 = g2w * v[0] + g2_w2 * v[1] + g * v[2];
895    let y2 = g * v[1] + g2_1 * v[2];
896
897    v[0] * y0 + v[1] * y1 + v[2] * y2
898}
899
900/// L2 norm of 3-component vector
901#[inline]
902fn norm3(x: f32, y: f32, z: f32) -> f32 {
903    (x * x + y * y + z * z).sqrt()
904}
905
906/// Frobenius norm of symmetric 3x3 tensor (6 components)
907#[inline]
908fn frobenius_norm(sxx: f32, sxy: f32, sxz: f32, syy: f32, syz: f32, szz: f32) -> f32 {
909    (sxx * sxx + syy * syy + szz * szz + 2.0 * (sxy * sxy + sxz * sxz + syz * syz)).sqrt()
910}
911
912/// L-infinity projection for 3-component vector
913#[inline]
914fn project_linf3(px: &mut f32, py: &mut f32, pz: &mut f32, threshold: f32) {
915    let norm = norm3(*px, *py, *pz);
916    if norm > threshold {
917        let scale = threshold / norm;
918        *px *= scale;
919        *py *= scale;
920        *pz *= scale;
921    }
922}
923
924/// L-infinity projection for 6-component symmetric tensor
925#[inline]
926fn project_linf6(
927    qxx: &mut f32, qxy: &mut f32, qxz: &mut f32,
928    qyy: &mut f32, qyz: &mut f32, qzz: &mut f32,
929    threshold: f32,
930) {
931    let norm = frobenius_norm(*qxx, *qxy, *qxz, *qyy, *qyz, *qzz);
932    if norm > threshold {
933        let scale = threshold / norm;
934        *qxx *= scale;
935        *qxy *= scale;
936        *qxz *= scale;
937        *qyy *= scale;
938        *qyz *= scale;
939        *qzz *= scale;
940    }
941}
942
943/// Compute relative change for convergence check
944fn compute_relative_change(chi: &[f32], chi_prev: &[f32], mask: &[u8]) -> f32 {
945    let mut diff_sq = 0.0f32;
946    let mut norm_sq = 0.0f32;
947
948    for i in 0..chi.len() {
949        if mask[i] != 0 {
950            let d = chi[i] - chi_prev[i];
951            diff_sq += d * d;
952            norm_sq += chi[i] * chi[i];
953        }
954    }
955
956    if norm_sq > 1e-10 {
957        (diff_sq / norm_sq).sqrt()
958    } else {
959        1.0
960    }
961}
962
963/// Pre-allocated workspace for TGV iteration
964struct TgvWorkspace {
965    // Primal variables
966    chi: Vec<f32>,
967    chi_: Vec<f32>,
968    chi_prev: Vec<f32>,  // For convergence check
969    phi: Vec<f32>,
970    phi_: Vec<f32>,
971    wx: Vec<f32>,
972    wy: Vec<f32>,
973    wz: Vec<f32>,
974    wx_: Vec<f32>,
975    wy_: Vec<f32>,
976    wz_: Vec<f32>,
977
978    // Dual variables
979    eta: Vec<f32>,
980    px: Vec<f32>,
981    py: Vec<f32>,
982    pz: Vec<f32>,
983    qxx: Vec<f32>,
984    qxy: Vec<f32>,
985    qxz: Vec<f32>,
986    qyy: Vec<f32>,
987    qyz: Vec<f32>,
988    qzz: Vec<f32>,
989
990    // Temporary buffers
991    temp1: Vec<f32>,
992    temp2: Vec<f32>,
993    gx: Vec<f32>,
994    gy: Vec<f32>,
995    gz: Vec<f32>,
996
997    // Symmetric gradient buffers (reused)
998    sxx: Vec<f32>,
999    sxy: Vec<f32>,
1000    sxz: Vec<f32>,
1001    syy: Vec<f32>,
1002    syz: Vec<f32>,
1003    szz: Vec<f32>,
1004
1005    // Divergence buffers (reused)
1006    divqx: Vec<f32>,
1007    divqy: Vec<f32>,
1008    divqz: Vec<f32>,
1009}
1010
1011impl TgvWorkspace {
1012    fn new(n: usize) -> Self {
1013        Self {
1014            chi: vec![0.0; n],
1015            chi_: vec![0.0; n],
1016            chi_prev: vec![0.0; n],
1017            phi: vec![0.0; n],
1018            phi_: vec![0.0; n],
1019            wx: vec![0.0; n],
1020            wy: vec![0.0; n],
1021            wz: vec![0.0; n],
1022            wx_: vec![0.0; n],
1023            wy_: vec![0.0; n],
1024            wz_: vec![0.0; n],
1025            eta: vec![0.0; n],
1026            px: vec![0.0; n],
1027            py: vec![0.0; n],
1028            pz: vec![0.0; n],
1029            qxx: vec![0.0; n],
1030            qxy: vec![0.0; n],
1031            qxz: vec![0.0; n],
1032            qyy: vec![0.0; n],
1033            qyz: vec![0.0; n],
1034            qzz: vec![0.0; n],
1035            temp1: vec![0.0; n],
1036            temp2: vec![0.0; n],
1037            gx: vec![0.0; n],
1038            gy: vec![0.0; n],
1039            gz: vec![0.0; n],
1040            sxx: vec![0.0; n],
1041            sxy: vec![0.0; n],
1042            sxz: vec![0.0; n],
1043            syy: vec![0.0; n],
1044            syz: vec![0.0; n],
1045            szz: vec![0.0; n],
1046            divqx: vec![0.0; n],
1047            divqy: vec![0.0; n],
1048            divqz: vec![0.0; n],
1049        }
1050    }
1051}
1052
1053/// Main TGV-QSM reconstruction
1054pub fn tgv_qsm(
1055    phase: &[f32],
1056    mask: &[u8],
1057    nx: usize, ny: usize, nz: usize,
1058    vsx: f32, vsy: f32, vsz: f32,
1059    params: &TgvParams,
1060    b0_dir: (f32, f32, f32),
1061) -> Vec<f32> {
1062    tgv_qsm_with_progress(phase, mask, nx, ny, nz, vsx, vsy, vsz, params, b0_dir, |_, _| {})
1063}
1064
1065/// TGV-QSM with progress callback (optimized version)
1066pub fn tgv_qsm_with_progress<F>(
1067    phase: &[f32],
1068    mask: &[u8],
1069    nx: usize, ny: usize, nz: usize,
1070    vsx: f32, vsy: f32, vsz: f32,
1071    params: &TgvParams,
1072    b0_dir: (f32, f32, f32),
1073    progress: F,
1074) -> Vec<f32>
1075where
1076    F: Fn(usize, usize),
1077{
1078    let n_total = nx * ny * nz;
1079    let res = (vsx, vsy, vsz);
1080
1081    // Erode mask
1082    let mut mask_eroded = mask.to_vec();
1083    for _ in 0..params.erosions {
1084        mask_eroded = erode_mask(&mask_eroded, nx, ny, nz);
1085    }
1086
1087    // Create mask0 (one more erosion for internal computations)
1088    let mask0 = erode_mask(&mask_eroded, nx, ny, nz);
1089
1090    // Find bounding box (with padding of 2 voxels)
1091    let bbox = BoundingBox::from_mask(&mask0, nx, ny, nz, 2);
1092    let (bx, by, bz) = bbox.dims();
1093    let b_total = bbox.total();
1094
1095    // Extract sub-volumes for the bounding box region
1096    let phase_sub = extract_subvolume(phase, &bbox, nx, ny, nz);
1097    let mask0_sub = extract_subvolume(&mask0, &bbox, nx, ny, nz);
1098    let mask_eroded_sub = extract_subvolume(&mask_eroded, &bbox, nx, ny, nz);
1099
1100    // Compute phase Laplacian on sub-volume
1101    let mut laplace_phi0 = compute_phase_laplacian(&phase_sub, &mask0_sub, bx, by, bz, vsx, vsy, vsz);
1102
1103    // Subtract mean within mask
1104    let (sum, count): (f32, usize) = laplace_phi0.iter().zip(mask0_sub.iter())
1105        .filter(|(_, &m)| m != 0)
1106        .fold((0.0, 0), |(s, c), (&v, _)| (s + v, c + 1));
1107    if count > 0 {
1108        let mean = sum / count as f32;
1109        for (v, &m) in laplace_phi0.iter_mut().zip(mask0_sub.iter()) {
1110            if m != 0 {
1111                *v -= mean;
1112            }
1113        }
1114    }
1115
1116    // Compute SVD-fitted oblique dipole stencil (matching Julia reference)
1117    let stencil = compute_oblique_stencil(res, b0_dir);
1118
1119    // Compute step sizes for convergence (matching Julia implementation)
1120    let grad_norm_squared = grad_norm_sq(res);
1121    let grad_norm = grad_norm_squared.sqrt();
1122    let wave_norm: f32 = stencil.iter().flatten().flatten().map(|x| x.abs()).sum();
1123    let norm_sqr = compute_operator_norm_sqr(grad_norm, grad_norm_squared, wave_norm);
1124
1125    let tau = 1.0 / norm_sqr.sqrt();
1126    let sigma = tau;
1127
1128    // step_size is applied selectively in the updates:
1129    // - eta, phi: use base sigma/tau
1130    // - p, q, chi, w: use sigma * step_size / tau * step_size
1131    let sigma_step = sigma * params.step_size;
1132    let tau_step = tau * params.step_size;
1133
1134    // Projection thresholds are alpha values (NOT 1/alpha!)
1135    // Julia: projects p to ||p|| <= alpha1, q to ||q|| <= alpha0
1136    let alpha = (params.alpha0, params.alpha1);
1137
1138    // Pre-allocate all workspace buffers
1139    let mut ws = TgvWorkspace::new(b_total);
1140
1141    let mut _converged = false;
1142    let mut final_iter = params.iterations;
1143
1144    // Main iteration loop
1145    for iter in 0..params.iterations {
1146        progress(iter, params.iterations);
1147
1148        // Convergence check every 100 iterations
1149        if iter > 0 && iter % 100 == 0 {
1150            let rel_change = compute_relative_change(&ws.chi, &ws.chi_prev, &mask0_sub);
1151            if rel_change < params.tol {
1152                _converged = true;
1153                final_iter = iter;
1154                break;
1155            }
1156            // Save current chi for next convergence check
1157            ws.chi_prev.copy_from_slice(&ws.chi);
1158        }
1159
1160        // === DUAL UPDATE ===
1161
1162        // 1. Update eta (data term dual)
1163        apply_laplacian_inplace(&mut ws.temp1, &ws.phi_, &mask0_sub, bx, by, bz, vsx, vsy, vsz);
1164        apply_stencil(&mut ws.temp2, &ws.chi_, &stencil, &mask0_sub, bx, by, bz);
1165
1166        for i in 0..b_total {
1167            if mask0_sub[i] != 0 {
1168                ws.eta[i] += sigma * (-ws.temp1[i] + ws.temp2[i] - laplace_phi0[i]);
1169            }
1170        }
1171
1172        // 2. Update p (gradient dual)
1173        // Julia: p += mask0 * sigma * grad(chi) - mask * sigma * w
1174        // Compute unmasked gradient first
1175        crate::utils::gradient::fgrad_inplace_f32(
1176            &mut ws.gx, &mut ws.gy, &mut ws.gz, &ws.chi_, bx, by, bz, vsx, vsy, vsz
1177        );
1178
1179        for i in 0..b_total {
1180            let in_mask0 = mask0_sub[i] != 0;
1181            let in_mask = mask_eroded_sub[i] != 0;
1182
1183            if in_mask0 || in_mask {
1184                // gradient term scaled by mask0, w term scaled by mask
1185                let sigmaw0 = if in_mask0 { sigma_step } else { 0.0 };
1186                let sigmaw = if in_mask { sigma_step } else { 0.0 };
1187
1188                ws.px[i] += sigmaw0 * ws.gx[i] - sigmaw * ws.wx_[i];
1189                ws.py[i] += sigmaw0 * ws.gy[i] - sigmaw * ws.wy_[i];
1190                ws.pz[i] += sigmaw0 * ws.gz[i] - sigmaw * ws.wz_[i];
1191
1192                project_linf3(&mut ws.px[i], &mut ws.py[i], &mut ws.pz[i], alpha.1);
1193            }
1194        }
1195
1196        // 3. Update q (symmetric gradient dual)
1197        crate::utils::gradient::symgrad_inplace_f32(
1198            &mut ws.sxx, &mut ws.sxy, &mut ws.sxz, &mut ws.syy, &mut ws.syz, &mut ws.szz,
1199            &ws.wx_, &ws.wy_, &ws.wz_, bx, by, bz, vsx, vsy, vsz
1200        );
1201
1202        for i in 0..b_total {
1203            if mask0_sub[i] != 0 {
1204                ws.qxx[i] += sigma_step * ws.sxx[i];
1205                ws.qxy[i] += sigma_step * ws.sxy[i];
1206                ws.qxz[i] += sigma_step * ws.sxz[i];
1207                ws.qyy[i] += sigma_step * ws.syy[i];
1208                ws.qyz[i] += sigma_step * ws.syz[i];
1209                ws.qzz[i] += sigma_step * ws.szz[i];
1210
1211                project_linf6(
1212                    &mut ws.qxx[i], &mut ws.qxy[i], &mut ws.qxz[i],
1213                    &mut ws.qyy[i], &mut ws.qyz[i], &mut ws.qzz[i],
1214                    alpha.0
1215                );
1216            }
1217        }
1218
1219        // === VARIABLE SWAP ===
1220        std::mem::swap(&mut ws.phi, &mut ws.phi_);
1221        std::mem::swap(&mut ws.chi, &mut ws.chi_);
1222        std::mem::swap(&mut ws.wx, &mut ws.wx_);
1223        std::mem::swap(&mut ws.wy, &mut ws.wy_);
1224        std::mem::swap(&mut ws.wz, &mut ws.wz_);
1225
1226        // === PRIMAL UPDATE ===
1227
1228        // 1. Update phi
1229        for i in 0..b_total {
1230            ws.temp1[i] = if mask0_sub[i] != 0 { ws.eta[i] } else { 0.0 };
1231        }
1232        apply_laplacian_inplace(&mut ws.temp2, &ws.temp1, &mask0_sub, bx, by, bz, vsx, vsy, vsz);
1233
1234        for i in 0..b_total {
1235            let denom = 1.0 + if mask_eroded_sub[i] != 0 { tau } else { 0.0 };
1236            ws.phi[i] = (ws.phi_[i] + tau * ws.temp2[i]) / denom;
1237        }
1238
1239        // 2. Update chi
1240        crate::utils::gradient::bdiv_masked_inplace_f32(
1241            &mut ws.temp1, &ws.px, &ws.py, &ws.pz, &mask0_sub, bx, by, bz, vsx, vsy, vsz
1242        );
1243
1244        for i in 0..b_total {
1245            ws.gx[i] = if mask0_sub[i] != 0 { ws.eta[i] } else { 0.0 };
1246        }
1247        apply_stencil(&mut ws.temp2, &ws.gx, &stencil, &mask0_sub, bx, by, bz);
1248
1249        for i in 0..b_total {
1250            ws.chi[i] = ws.chi_[i] + tau_step * (ws.temp1[i] - ws.temp2[i]);
1251        }
1252
1253        // 3. Update w
1254        for i in 0..b_total {
1255            let m = if mask0_sub[i] != 0 { 1.0 } else { 0.0 };
1256            ws.sxx[i] = ws.qxx[i] * m;
1257            ws.sxy[i] = ws.qxy[i] * m;
1258            ws.sxz[i] = ws.qxz[i] * m;
1259            ws.syy[i] = ws.qyy[i] * m;
1260            ws.syz[i] = ws.qyz[i] * m;
1261            ws.szz[i] = ws.qzz[i] * m;
1262        }
1263
1264        crate::utils::gradient::symdiv_inplace_f32(
1265            &mut ws.divqx, &mut ws.divqy, &mut ws.divqz,
1266            &ws.sxx, &ws.sxy, &ws.sxz, &ws.syy, &ws.syz, &ws.szz,
1267            bx, by, bz, vsx, vsy, vsz
1268        );
1269
1270        // Julia: w_dest = w; if mask: w_dest += tau*(p + div(mask0*q))
1271        for i in 0..b_total {
1272            ws.wx[i] = ws.wx_[i];
1273            ws.wy[i] = ws.wy_[i];
1274            ws.wz[i] = ws.wz_[i];
1275            if mask_eroded_sub[i] != 0 {
1276                ws.wx[i] += tau_step * (ws.px[i] + ws.divqx[i]);
1277                ws.wy[i] += tau_step * (ws.py[i] + ws.divqy[i]);
1278                ws.wz[i] += tau_step * (ws.pz[i] + ws.divqz[i]);
1279            }
1280        }
1281
1282        // === EXTRAGRADIENT UPDATE ===
1283        for i in 0..b_total {
1284            ws.phi_[i] = 2.0 * ws.phi[i] - ws.phi_[i];
1285            ws.chi_[i] = 2.0 * ws.chi[i] - ws.chi_[i];
1286            ws.wx_[i] = 2.0 * ws.wx[i] - ws.wx_[i];
1287            ws.wy_[i] = 2.0 * ws.wy[i] - ws.wy_[i];
1288            ws.wz_[i] = 2.0 * ws.wz[i] - ws.wz_[i];
1289        }
1290    }
1291
1292    progress(final_iter, params.iterations);
1293
1294    // Scale to susceptibility (ppm)
1295    let gamma = 42.5781f32;  // Hz/T
1296    let scale = 1.0 / (2.0 * PI * params.te * params.fieldstrength * gamma);
1297
1298    // Create full-size result and insert sub-volume
1299    let mut result = vec![0.0f32; n_total];
1300
1301    // Scale chi in sub-volume and apply mask
1302    let mut chi_scaled = vec![0.0f32; b_total];
1303    for i in 0..b_total {
1304        if mask_eroded_sub[i] != 0 {
1305            chi_scaled[i] = ws.chi[i] * scale;
1306        }
1307    }
1308
1309    // Insert back into full volume
1310    insert_subvolume(&mut result, &chi_scaled, &bbox, nx, ny, nz);
1311
1312    result
1313}
1314
1315#[cfg(test)]
1316mod tests {
1317    use super::*;
1318
1319    #[test]
1320    fn test_dipole_stencil() {
1321        let stencil = compute_dipole_stencil((1.0, 1.0, 1.0), (0.0, 0.0, 1.0));
1322
1323        let mut sum = 0.0f32;
1324        for k in 0..3 {
1325            for j in 0..3 {
1326                for i in 0..3 {
1327                    sum += stencil[i][j][k];
1328                }
1329            }
1330        }
1331        assert!(sum.abs() < 1e-6, "Stencil sum should be ~0, got {}", sum);
1332    }
1333
1334    #[test]
1335    fn test_phase_laplacian() {
1336        let nx = 4;
1337        let ny = 4;
1338        let nz = 4;
1339        let n = nx * ny * nz;
1340
1341        let phase = vec![1.0f32; n];
1342        let mask = vec![1u8; n];
1343
1344        let lap = compute_phase_laplacian(&phase, &mask, nx, ny, nz, 1.0, 1.0, 1.0);
1345
1346        let max_val = lap.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
1347        assert!(max_val < 1e-5, "Laplacian of constant should be ~0, got max {}", max_val);
1348    }
1349
1350    #[test]
1351    fn test_erode_mask() {
1352        let nx = 5;
1353        let ny = 5;
1354        let nz = 5;
1355
1356        let mask = vec![1u8; nx * ny * nz];
1357        let eroded = erode_mask(&mask, nx, ny, nz);
1358
1359        let center = 2 + 2 * nx + 2 * nx * ny;
1360        assert_eq!(eroded[center], 1);
1361        assert_eq!(eroded[0], 0);
1362    }
1363
1364    #[test]
1365    fn test_default_alpha() {
1366        let (a0, a1) = get_default_alpha(2);
1367        assert!((a0 - 0.002).abs() < 1e-6);
1368        assert!((a1 - 0.003).abs() < 1e-6);
1369    }
1370
1371    #[test]
1372    fn test_bounding_box() {
1373        let nx = 10;
1374        let ny = 10;
1375        let nz = 10;
1376        let mut mask = vec![0u8; nx * ny * nz];
1377
1378        // Set a small region in the center
1379        for k in 3..7 {
1380            for j in 3..7 {
1381                for i in 3..7 {
1382                    mask[i + j * nx + k * nx * ny] = 1;
1383                }
1384            }
1385        }
1386
1387        let bbox = BoundingBox::from_mask(&mask, nx, ny, nz, 1);
1388
1389        // Should be 3-1=2 to 6+1+1=8 (with padding 1)
1390        assert_eq!(bbox.i_min, 2);
1391        assert_eq!(bbox.i_max, 8);
1392        assert_eq!(bbox.j_min, 2);
1393        assert_eq!(bbox.j_max, 8);
1394    }
1395
1396    #[test]
1397    fn test_tgv_qsm_small() {
1398        let n = 12;
1399        let n_total = n * n * n;
1400        let center = 6.0_f32;
1401        let radius = 4.0_f32;
1402
1403        // Build a sphere mask
1404        let mut mask = vec![0u8; n_total];
1405        for k in 0..n {
1406            for j in 0..n {
1407                for i in 0..n {
1408                    let dx = i as f32 - center;
1409                    let dy = j as f32 - center;
1410                    let dz = k as f32 - center;
1411                    if (dx * dx + dy * dy + dz * dz).sqrt() < radius {
1412                        mask[i + j * n + k * n * n] = 1;
1413                    }
1414                }
1415            }
1416        }
1417
1418        // Fill phase with a linear ramp (z direction) masked
1419        let mut phase = vec![0.0f32; n_total];
1420        for k in 0..n {
1421            for j in 0..n {
1422                for i in 0..n {
1423                    let idx = i + j * n + k * n * n;
1424                    if mask[idx] != 0 {
1425                        phase[idx] = 0.1 * k as f32;
1426                    }
1427                }
1428            }
1429        }
1430
1431        let params = TgvParams {
1432            iterations: 10,
1433            erosions: 1,
1434            ..TgvParams::default()
1435        };
1436
1437        let result = tgv_qsm(&phase, &mask, n, n, n, 1.0, 1.0, 1.0, &params, (0.0, 0.0, 1.0));
1438
1439        assert_eq!(result.len(), n_total);
1440
1441        // All values should be finite
1442        for &v in &result {
1443            assert!(v.is_finite(), "Result contains non-finite value: {}", v);
1444        }
1445
1446        // There should be at least some non-zero values within the (eroded) mask
1447        let has_nonzero = result.iter().any(|&v| v.abs() > 1e-20);
1448        assert!(has_nonzero, "Result is entirely zero; expected non-zero values within mask");
1449    }
1450
1451    #[test]
1452    fn test_oblique_stencil() {
1453        let stencil = compute_oblique_stencil((1.0, 1.0, 1.0), (0.0, 0.0, 1.0));
1454
1455        // Sum of all elements should be ~0 (Laplacian-like operator)
1456        let mut sum = 0.0f32;
1457        for k in 0..3 {
1458            for j in 0..3 {
1459                for i in 0..3 {
1460                    sum += stencil[i][j][k];
1461                }
1462            }
1463        }
1464        assert!(
1465            sum.abs() < 1e-4,
1466            "Oblique stencil sum should be ~0, got {}",
1467            sum
1468        );
1469
1470        // Center element should be approximately -sum(others), so verify it matches
1471        let mut off_sum = 0.0f32;
1472        for k in 0..3 {
1473            for j in 0..3 {
1474                for i in 0..3 {
1475                    if !(i == 1 && j == 1 && k == 1) {
1476                        off_sum += stencil[i][j][k];
1477                    }
1478                }
1479            }
1480        }
1481        assert!(
1482            (stencil[1][1][1] + off_sum).abs() < 1e-6,
1483            "Center should be -sum(others): center={}, off_sum={}",
1484            stencil[1][1][1], off_sum
1485        );
1486    }
1487
1488    #[test]
1489    fn test_oblique_stencil_aniso() {
1490        let stencil = compute_oblique_stencil((1.0, 1.0, 2.0), (0.2, 0.3, 0.9));
1491
1492        // Sum of all elements should be ~0
1493        let mut sum = 0.0f32;
1494        for k in 0..3 {
1495            for j in 0..3 {
1496                for i in 0..3 {
1497                    sum += stencil[i][j][k];
1498                }
1499            }
1500        }
1501        assert!(
1502            sum.abs() < 1e-4,
1503            "Anisotropic oblique stencil sum should be ~0, got {}",
1504            sum
1505        );
1506    }
1507
1508    #[test]
1509    fn test_get_default_iterations() {
1510        // With isotropic 1mm voxels and step_size=1.0
1511        let it = get_default_iterations((1.0, 1.0, 1.0), 1.0);
1512        assert!(it >= 1000, "Iterations should be >= 1000 for 1mm iso, got {}", it);
1513
1514        // With larger voxels the count should decrease (prod_res is larger)
1515        let it_large = get_default_iterations((2.0, 2.0, 2.0), 1.0);
1516        assert!(it_large >= 1000, "Iterations should still be >= 1000 for 2mm iso");
1517
1518        // With very small voxels the count should be larger than 1mm iso
1519        let it_small = get_default_iterations((0.5, 0.5, 0.5), 1.0);
1520        assert!(it_small > it, "Smaller voxels should need more iterations: {} vs {}", it_small, it);
1521
1522        // Higher step_size should reduce iterations
1523        let it_fast = get_default_iterations((1.0, 1.0, 1.0), 3.0);
1524        assert!(it_fast < it, "Higher step_size should give fewer iterations: {} vs {}", it_fast, it);
1525    }
1526
1527    #[test]
1528    fn test_compute_relative_change() {
1529        // Two identical arrays -> relative change = 0
1530        let chi = vec![1.0f32, 2.0, 3.0, 4.0];
1531        let chi_prev = vec![1.0f32, 2.0, 3.0, 4.0];
1532        let mask = vec![1u8, 1, 1, 1];
1533        let rc = compute_relative_change(&chi, &chi_prev, &mask);
1534        assert!(rc.abs() < 1e-10, "Identical arrays should give 0 change, got {}", rc);
1535
1536        // One element changed
1537        let chi2 = vec![1.1f32, 2.0, 3.0, 4.0];
1538        let rc2 = compute_relative_change(&chi2, &chi_prev, &mask);
1539        assert!(rc2 > 0.0, "Different arrays should give positive change");
1540        // Expected: sqrt(0.01 / (1.21 + 4 + 9 + 16)) = sqrt(0.01 / 30.21)
1541        let expected = (0.01f32 / 30.21).sqrt();
1542        assert!(
1543            (rc2 - expected).abs() < 1e-5,
1544            "Expected relative change ~{}, got {}",
1545            expected,
1546            rc2
1547        );
1548
1549        // Masked elements should be ignored
1550        let mask_partial = vec![1u8, 0, 0, 0];
1551        let chi3 = vec![2.0f32, 999.0, 999.0, 999.0];
1552        let chi_prev3 = vec![1.0f32, 0.0, 0.0, 0.0];
1553        let rc3 = compute_relative_change(&chi3, &chi_prev3, &mask_partial);
1554        // diff_sq = (2-1)^2 = 1, norm_sq = 4 => sqrt(1/4) = 0.5
1555        let expected3 = (1.0f32 / 4.0).sqrt();
1556        assert!(
1557            (rc3 - expected3).abs() < 1e-6,
1558            "Masked relative change expected {}, got {}",
1559            expected3,
1560            rc3
1561        );
1562
1563        // All zeros -> should return 1.0 (norm_sq < 1e-10)
1564        let zeros = vec![0.0f32; 4];
1565        let rc4 = compute_relative_change(&zeros, &zeros, &mask);
1566        assert!(
1567            (rc4 - 1.0).abs() < 1e-6,
1568            "Zero norm should return 1.0, got {}",
1569            rc4
1570        );
1571    }
1572
1573    #[test]
1574    fn test_tgv_convergence() {
1575        // Zero phase => chi should be ~0. Set tol=1.0 so convergence triggers at iter 100.
1576        let n = 12;
1577        let n_total = n * n * n;
1578        let center = 6.0_f32;
1579        let radius = 4.0_f32;
1580
1581        let mut mask = vec![0u8; n_total];
1582        for k in 0..n {
1583            for j in 0..n {
1584                for i in 0..n {
1585                    let dx = i as f32 - center;
1586                    let dy = j as f32 - center;
1587                    let dz = k as f32 - center;
1588                    if (dx * dx + dy * dy + dz * dz).sqrt() < radius {
1589                        mask[i + j * n + k * n * n] = 1;
1590                    }
1591                }
1592            }
1593        }
1594
1595        let phase = vec![0.0f32; n_total];
1596
1597        let params = TgvParams {
1598            iterations: 1000,
1599            erosions: 1,
1600            tol: 1.1, // Very loose tolerance so convergence triggers at first check (iter 100)
1601            ..TgvParams::default()
1602        };
1603
1604        let progress_iters = std::cell::RefCell::new(Vec::new());
1605        let result = tgv_qsm_with_progress(
1606            &phase, &mask, n, n, n, 1.0, 1.0, 1.0, &params, (0.0, 0.0, 1.0),
1607            |iter, _total| { progress_iters.borrow_mut().push(iter); }
1608        );
1609
1610        assert_eq!(result.len(), n_total);
1611
1612        // With zero phase, all output values should be very close to zero
1613        let max_abs = result.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
1614        assert!(
1615            max_abs < 1e-3,
1616            "Zero-phase TGV result should be ~0, got max abs {}",
1617            max_abs
1618        );
1619
1620        // Early convergence: the last progress call should report iter <= 100
1621        // (it converges at the iter-100 check since chi and chi_prev are both ~0)
1622        let iters = progress_iters.borrow();
1623        let &last_iter = iters.last().unwrap();
1624        assert!(
1625            last_iter <= 100,
1626            "Expected early convergence by iter 100, but last progress was at iter {}",
1627            last_iter
1628        );
1629    }
1630}