1use std::f32::consts::PI;
32
33#[derive(Clone, Debug)]
35pub struct TgvParams {
36 pub alpha1: f32,
38 pub alpha0: f32,
40 pub iterations: usize,
42 pub erosions: usize,
44 pub step_size: f32,
46 pub fieldstrength: f32,
48 pub te: f32,
50 pub tol: f32,
52}
53
54impl Default for TgvParams {
55 fn default() -> Self {
56 Self {
57 alpha1: 0.003,
58 alpha0: 0.002,
59 iterations: 1000,
60 erosions: 3,
61 step_size: 3.0,
62 fieldstrength: 3.0,
63 te: 0.020, tol: 1e-5,
65 }
66 }
67}
68
69pub fn get_default_alpha(regularization: u8) -> (f32, f32) {
75 let reg = regularization.clamp(1, 4) as f32;
76 let alpha0 = 0.001 + 0.001 * (reg - 1.0);
77 let alpha1 = 0.001 + 0.002 * (reg - 1.0);
78 (alpha0.max(0.0), alpha1.max(0.0))
79}
80
81pub fn get_default_iterations(res: (f32, f32, f32), step_size: f32) -> usize {
85 let prod_res = res.0 * res.1 * res.2;
86 let it = (1000.0_f32).max(3200.0 / prod_res.powf(0.42)) / step_size.powf(0.6);
87 it.round() as usize
88}
89
90#[derive(Clone, Debug)]
92struct BoundingBox {
93 i_min: usize,
94 i_max: usize,
95 j_min: usize,
96 j_max: usize,
97 k_min: usize,
98 k_max: usize,
99}
100
101impl BoundingBox {
102 fn from_mask(mask: &[u8], nx: usize, ny: usize, nz: usize, padding: usize) -> Self {
104 let mut i_min = nx;
105 let mut i_max = 0;
106 let mut j_min = ny;
107 let mut j_max = 0;
108 let mut k_min = nz;
109 let mut k_max = 0;
110
111 for k in 0..nz {
112 for j in 0..ny {
113 for i in 0..nx {
114 if mask[i + j * nx + k * nx * ny] != 0 {
115 i_min = i_min.min(i);
116 i_max = i_max.max(i);
117 j_min = j_min.min(j);
118 j_max = j_max.max(j);
119 k_min = k_min.min(k);
120 k_max = k_max.max(k);
121 }
122 }
123 }
124 }
125
126 let i_min = i_min.saturating_sub(padding);
128 let j_min = j_min.saturating_sub(padding);
129 let k_min = k_min.saturating_sub(padding);
130 let i_max = (i_max + padding + 1).min(nx);
131 let j_max = (j_max + padding + 1).min(ny);
132 let k_max = (k_max + padding + 1).min(nz);
133
134 Self { i_min, i_max, j_min, j_max, k_min, k_max }
135 }
136
137 fn dims(&self) -> (usize, usize, usize) {
138 (self.i_max - self.i_min, self.j_max - self.j_min, self.k_max - self.k_min)
139 }
140
141 fn total(&self) -> usize {
142 let (bx, by, bz) = self.dims();
143 bx * by * bz
144 }
145}
146
147fn extract_subvolume<T: Copy + Default>(
149 full: &[T],
150 bbox: &BoundingBox,
151 nx: usize, ny: usize, _nz: usize,
152) -> Vec<T> {
153 let (bx, by, bz) = bbox.dims();
154 let mut sub = vec![T::default(); bx * by * bz];
155
156 for k in 0..bz {
157 for j in 0..by {
158 for i in 0..bx {
159 let full_idx = (bbox.i_min + i) + (bbox.j_min + j) * nx + (bbox.k_min + k) * nx * ny;
160 let sub_idx = i + j * bx + k * bx * by;
161 sub[sub_idx] = full[full_idx];
162 }
163 }
164 }
165 sub
166}
167
168fn insert_subvolume<T: Copy>(
170 full: &mut [T],
171 sub: &[T],
172 bbox: &BoundingBox,
173 nx: usize, ny: usize, _nz: usize,
174) {
175 let (bx, by, bz) = bbox.dims();
176
177 for k in 0..bz {
178 for j in 0..by {
179 for i in 0..bx {
180 let full_idx = (bbox.i_min + i) + (bbox.j_min + j) * nx + (bbox.k_min + k) * nx * ny;
181 let sub_idx = i + j * bx + k * bx * by;
182 full[full_idx] = sub[sub_idx];
183 }
184 }
185 }
186}
187
188pub fn compute_dipole_stencil(
190 res: (f32, f32, f32),
191 b0_dir: (f32, f32, f32),
192) -> [[[f32; 3]; 3]; 3] {
193 let (dx, dy, dz) = res;
194 let (bx, by, bz) = b0_dir;
195
196 let b_norm = (bx * bx + by * by + bz * bz).sqrt();
198 let bx = bx / b_norm;
199 let by = by / b_norm;
200 let bz = bz / b_norm;
201
202 let mut stencil = [[[0.0f32; 3]; 3]; 3];
203
204 let hx2 = 1.0 / (dx * dx);
205 let hy2 = 1.0 / (dy * dy);
206 let hz2 = 1.0 / (dz * dz);
207 let factor = 1.0 / 3.0;
208
209 stencil[1][1][1] = -2.0 * (hx2 + hy2 + hz2) * factor
211 + 2.0 * (bx * bx * hx2 + by * by * hy2 + bz * bz * hz2);
212
213 stencil[0][1][1] = hx2 * factor - bx * bx * hx2;
215 stencil[2][1][1] = hx2 * factor - bx * bx * hx2;
216
217 stencil[1][0][1] = hy2 * factor - by * by * hy2;
219 stencil[1][2][1] = hy2 * factor - by * by * hy2;
220
221 stencil[1][1][0] = hz2 * factor - bz * bz * hz2;
223 stencil[1][1][2] = hz2 * factor - bz * bz * hz2;
224
225 let hxy = 1.0 / (dx * dy);
227 let hxz = 1.0 / (dx * dz);
228 let hyz = 1.0 / (dy * dz);
229
230 let xy_factor = -bx * by * hxy;
231 stencil[0][0][1] = xy_factor;
232 stencil[2][2][1] = xy_factor;
233 stencil[0][2][1] = -xy_factor;
234 stencil[2][0][1] = -xy_factor;
235
236 let xz_factor = -bx * bz * hxz;
237 stencil[0][1][0] = xz_factor;
238 stencil[2][1][2] = xz_factor;
239 stencil[0][1][2] = -xz_factor;
240 stencil[2][1][0] = -xz_factor;
241
242 let yz_factor = -by * bz * hyz;
243 stencil[1][0][0] = yz_factor;
244 stencil[1][2][2] = yz_factor;
245 stencil[1][0][2] = -yz_factor;
246 stencil[1][2][0] = -yz_factor;
247
248 stencil
249}
250
251pub fn compute_oblique_stencil(
264 res: (f32, f32, f32),
265 b0_dir: (f32, f32, f32),
266) -> [[[f32; 3]; 3]; 3] {
267 let n: usize = 64;
268 let singularity_cutout = 4.0_f64;
269 let mid = n / 2; let n2 = n * n;
271 let n3 = n * n * n;
272
273 let (dx, dy, dz) = (res.0 as f64, res.1 as f64, res.2 as f64);
274
275 let b_norm = ((b0_dir.0 * b0_dir.0 + b0_dir.1 * b0_dir.1 + b0_dir.2 * b0_dir.2) as f64).sqrt();
277 let bdir = (b0_dir.0 as f64 / b_norm, b0_dir.1 as f64 / b_norm, b0_dir.2 as f64 / b_norm);
278
279 let mut d = vec![f64::NAN; n3];
281 let mut d_mask = vec![false; n3];
282
283 for k in 0..n {
284 for j in 0..n {
285 for i in 0..n {
286 let x = i as f64 - mid as f64;
287 let y = j as f64 - mid as f64;
288 let z = k as f64 - mid as f64;
289 let r = (x * x + y * y + z * z).sqrt();
290
291 let idx = i + j * n + k * n2;
292 if r < singularity_cutout {
293 } else {
295 let xz = (bdir.0 * x + bdir.1 * y + bdir.2 * z) / r;
296 let kappa = (3.0 * xz * xz - 1.0) / (4.0 * std::f64::consts::PI * r * r * r);
297 d[idx] = kappa;
298 d_mask[idx] = true;
299 }
300 }
301 }
302 }
303
304 let coord2_sq: Vec<f64> = (0..n)
306 .map(|k| {
307 let v = 2.0 * (std::f64::consts::PI * (k as f64 + 1.0) / (2.0 * (n as f64 + 1.0))).sin();
308 v * v
309 })
310 .collect();
311
312 let mut coord2_grid = vec![0.0_f64; n3];
314 for k in 0..n {
315 for j in 0..n {
316 for i in 0..n {
317 coord2_grid[i + j * n + k * n2] = coord2_sq[i] + coord2_sq[j] + coord2_sq[k];
318 }
319 }
320 }
321
322 let sin_table = dst_sin_table(n);
324 let idst_scale = (2.0 / (n as f64 + 1.0)).powi(3);
327
328 let mut stencil_positions: Vec<(i32, i32, i32)> = Vec::with_capacity(26);
330 for dk in -1..=1_i32 {
331 for dj in -1..=1_i32 {
332 for di in -1..=1_i32 {
333 if di == 0 && dj == 0 && dk == 0 {
334 continue;
335 }
336 stencil_positions.push((di, dj, dk));
337 }
338 }
339 }
340
341 let valid_indices: Vec<usize> = d_mask
343 .iter()
344 .enumerate()
345 .filter(|(_, &m)| m)
346 .map(|(idx, _)| idx)
347 .collect();
348
349 let num_pairs = 13;
353 let mut a_rows: Vec<Vec<f64>> = Vec::with_capacity(num_pairs);
354
355 for p in 0..num_pairs {
356 let (di, dj, dk) = stencil_positions[p];
357
358 let mut delta = vec![0.0_f64; n3];
360 let pi = (mid as i32 + di) as usize;
361 let pj = (mid as i32 + dj) as usize;
362 let pk = (mid as i32 + dk) as usize;
363 let mi = (mid as i32 - di) as usize;
364 let mj = (mid as i32 - dj) as usize;
365 let mk = (mid as i32 - dk) as usize;
366
367 delta[pi + pj * n + pk * n2] = 1.0;
368 delta[mi + mj * n + mk * n2] += 1.0; delta[mid + mid * n + mid * n2] = -2.0;
370
371 let mut fdelta = dst3d(&delta, n, &sin_table);
373
374 for idx in 0..n3 {
376 fdelta[idx] /= -coord2_grid[idx];
377 }
378
379 let vdelta = dst3d(&fdelta, n, &sin_table);
381
382 let row: Vec<f64> = valid_indices
384 .iter()
385 .map(|&idx| vdelta[idx] * idst_scale)
386 .collect();
387 a_rows.push(row);
388 }
389
390 let d_valid: Vec<f64> = valid_indices.iter().map(|&idx| d[idx]).collect();
392
393 let mut g = vec![vec![0.0_f64; num_pairs]; num_pairs];
396 for i in 0..num_pairs {
397 for j in 0..=i {
398 let dot: f64 = a_rows[i]
399 .iter()
400 .zip(a_rows[j].iter())
401 .map(|(&a, &b)| a * b)
402 .sum();
403 g[i][j] = dot;
404 g[j][i] = dot;
405 }
406 }
407
408 let mut h = vec![0.0_f64; num_pairs];
409 for i in 0..num_pairs {
410 h[i] = a_rows[i]
411 .iter()
412 .zip(d_valid.iter())
413 .map(|(&a, &b)| a * b)
414 .sum();
415 }
416
417 let y = solve_symmetric_pseudoinverse(&g, &h);
420
421 let mut result = [[[0.0f32; 3]; 3]; 3];
424 for p in 0..num_pairs {
425 let coeff = y[p] as f32;
426
427 let (di, dj, dk) = stencil_positions[p];
429 result[(di + 1) as usize][(dj + 1) as usize][(dk + 1) as usize] = coeff;
430
431 let (di2, dj2, dk2) = stencil_positions[25 - p];
433 result[(di2 + 1) as usize][(dj2 + 1) as usize][(dk2 + 1) as usize] = coeff;
434 }
435
436 for dk in -1..=1_i32 {
438 for dj in -1..=1_i32 {
439 for di in -1..=1_i32 {
440 if di == 0 && dj == 0 && dk == 0 {
441 continue;
442 }
443 let si = (di + 1) as usize;
444 let sj = (dj + 1) as usize;
445 let sk = (dk + 1) as usize;
446
447 let i2 = (di * di) as f64;
448 let j2 = (dj * dj) as f64;
449 let k2 = (dk * dk) as f64;
450 let weight = (i2 / (dx * dx) + j2 / (dy * dy) + k2 / (dz * dz)) / (i2 + j2 + k2);
451 result[si][sj][sk] *= weight as f32;
452 }
453 }
454 }
455
456 let mut total = 0.0f32;
458 for dk in 0..3 {
459 for dj in 0..3 {
460 for di in 0..3 {
461 if !(di == 1 && dj == 1 && dk == 1) {
462 total += result[di][dj][dk];
463 }
464 }
465 }
466 }
467 result[1][1][1] = -total;
468
469 result
470}
471
472fn dst_sin_table(n: usize) -> Vec<Vec<f64>> {
475 let scale = std::f64::consts::PI / (n as f64 + 1.0);
476 let mut table = vec![vec![0.0_f64; n]; n];
477 for j in 0..n {
478 for k in 0..n {
479 table[j][k] = ((j as f64 + 1.0) * (k as f64 + 1.0) * scale).sin();
480 }
481 }
482 table
483}
484
485fn dst3d(input: &[f64], n: usize, sin_table: &[Vec<f64>]) -> Vec<f64> {
487 let n2 = n * n;
488 let mut data = input.to_vec();
489 let mut buf_in = vec![0.0_f64; n];
490 let mut buf_out = vec![0.0_f64; n];
491
492 for k in 0..n {
494 for j in 0..n {
495 let base = j * n + k * n2;
496 buf_in.copy_from_slice(&data[base..base + n]);
497 dst1(&buf_in, sin_table, &mut buf_out);
498 data[base..base + n].copy_from_slice(&buf_out);
499 }
500 }
501
502 for k in 0..n {
504 for i in 0..n {
505 for j in 0..n {
506 buf_in[j] = data[i + j * n + k * n2];
507 }
508 dst1(&buf_in, sin_table, &mut buf_out);
509 for j in 0..n {
510 data[i + j * n + k * n2] = buf_out[j];
511 }
512 }
513 }
514
515 for j in 0..n {
517 for i in 0..n {
518 for k in 0..n {
519 buf_in[k] = data[i + j * n + k * n2];
520 }
521 dst1(&buf_in, sin_table, &mut buf_out);
522 for k in 0..n {
523 data[i + j * n + k * n2] = buf_out[k];
524 }
525 }
526 }
527
528 data
529}
530
531fn dst1(input: &[f64], sin_table: &[Vec<f64>], output: &mut [f64]) {
533 let n = input.len();
534 for k in 0..n {
535 let mut sum = 0.0_f64;
536 for j in 0..n {
537 sum += input[j] * sin_table[j][k];
538 }
539 output[k] = sum;
540 }
541}
542
543fn solve_symmetric_pseudoinverse(g: &[Vec<f64>], h: &[f64]) -> Vec<f64> {
549 let n = h.len();
550
551 let mut a: Vec<Vec<f64>> = g.to_vec();
553 let mut v = vec![vec![0.0_f64; n]; n];
554 for i in 0..n {
555 v[i][i] = 1.0;
556 }
557
558 let max_sweeps = 100;
560 let tol = 1e-15;
561
562 for _ in 0..max_sweeps {
563 let mut max_off = 0.0_f64;
565 for i in 0..n {
566 for j in (i + 1)..n {
567 max_off = max_off.max(a[i][j].abs());
568 }
569 }
570 if max_off < tol {
571 break;
572 }
573
574 for p in 0..n {
576 for q in (p + 1)..n {
577 if a[p][q].abs() < tol {
578 continue;
579 }
580
581 let app = a[p][p];
583 let aqq = a[q][q];
584 let apq = a[p][q];
585 let tau = (aqq - app) / (2.0 * apq);
586 let t = if tau >= 0.0 {
587 1.0 / (tau + (1.0 + tau * tau).sqrt())
588 } else {
589 -1.0 / (-tau + (1.0 + tau * tau).sqrt())
590 };
591 let c = 1.0 / (1.0 + t * t).sqrt();
592 let s = t * c;
593
594 for i in 0..n {
597 if i == p || i == q {
598 continue;
599 }
600 let aip = a[i][p];
601 let aiq = a[i][q];
602 a[i][p] = c * aip - s * aiq;
603 a[p][i] = a[i][p];
604 a[i][q] = s * aip + c * aiq;
605 a[q][i] = a[i][q];
606 }
607
608 a[p][p] = c * c * app - 2.0 * c * s * apq + s * s * aqq;
610 a[q][q] = s * s * app + 2.0 * c * s * apq + c * c * aqq;
611 a[p][q] = 0.0;
612 a[q][p] = 0.0;
613
614 for i in 0..n {
616 let vip = v[i][p];
617 let viq = v[i][q];
618 v[i][p] = c * vip - s * viq;
619 v[i][q] = s * vip + c * viq;
620 }
621 }
622 }
623 }
624
625 let eigenvalues: Vec<f64> = (0..n).map(|i| a[i][i]).collect();
627 let max_eigen = eigenvalues.iter().cloned().fold(0.0_f64, |a, b| a.max(b.abs()));
628 let threshold = 1e-10 * max_eigen;
629
630 let mut vt_h = vec![0.0_f64; n];
632 for i in 0..n {
633 for j in 0..n {
634 vt_h[i] += v[j][i] * h[j];
635 }
636 }
637
638 for i in 0..n {
639 if eigenvalues[i].abs() > threshold {
640 vt_h[i] /= eigenvalues[i];
641 } else {
642 vt_h[i] = 0.0;
643 }
644 }
645
646 let mut x = vec![0.0_f64; n];
647 for i in 0..n {
648 for j in 0..n {
649 x[i] += v[i][j] * vt_h[j];
650 }
651 }
652
653 x
654}
655
656fn apply_stencil(
659 output: &mut [f32],
660 input: &[f32],
661 stencil: &[[[f32; 3]; 3]; 3],
662 mask: &[u8],
663 nx: usize, ny: usize, nz: usize,
664) {
665 for k in 0..nz {
666 for j in 0..ny {
667 for i in 0..nx {
668 let idx = i + j * nx + k * nx * ny;
669
670 if mask[idx] == 0 {
671 output[idx] = 0.0;
672 continue;
673 }
674
675 if i == 0 || j == 0 || k == 0 || i + 1 >= nx || j + 1 >= ny || k + 1 >= nz {
678 output[idx] = 0.0;
679 continue;
680 }
681
682 let mut sum = 0.0f32;
683
684 for dk in 0..3i32 {
685 for dj in 0..3i32 {
686 for di in 0..3i32 {
687 let ni = (i as i32 + di - 1) as usize;
688 let nj = (j as i32 + dj - 1) as usize;
689 let nk = (k as i32 + dk - 1) as usize;
690
691 let nidx = ni + nj * nx + nk * nx * ny;
692 sum += stencil[di as usize][dj as usize][dk as usize] * input[nidx];
693 }
694 }
695 }
696
697 output[idx] = sum;
698 }
699 }
700 }
701}
702
703pub fn compute_phase_laplacian(
705 phase: &[f32],
706 mask: &[u8],
707 nx: usize, ny: usize, nz: usize,
708 vsx: f32, vsy: f32, vsz: f32,
709) -> Vec<f32> {
710 let n_total = nx * ny * nz;
711
712 let sin_phase: Vec<f32> = phase.iter().map(|&p| p.sin()).collect();
713 let cos_phase: Vec<f32> = phase.iter().map(|&p| p.cos()).collect();
714
715 let lap_sin = compute_laplacian(&sin_phase, nx, ny, nz, vsx, vsy, vsz);
716 let lap_cos = compute_laplacian(&cos_phase, nx, ny, nz, vsx, vsy, vsz);
717
718 let mut laplacian = vec![0.0f32; n_total];
719 for i in 0..n_total {
720 if mask[i] != 0 {
721 laplacian[i] = lap_sin[i] * cos_phase[i] - lap_cos[i] * sin_phase[i];
722 }
723 }
724
725 laplacian
726}
727
728fn compute_laplacian(
730 input: &[f32],
731 nx: usize, ny: usize, nz: usize,
732 vsx: f32, vsy: f32, vsz: f32,
733) -> Vec<f32> {
734 let n_total = nx * ny * nz;
735 let mut output = vec![0.0f32; n_total];
736
737 let hx2 = 1.0 / (vsx * vsx);
738 let hy2 = 1.0 / (vsy * vsy);
739 let hz2 = 1.0 / (vsz * vsz);
740 let center = -2.0 * (hx2 + hy2 + hz2);
741
742 for k in 0..nz {
743 let km1 = if k == 0 { 0 } else { k - 1 };
744 let kp1 = if k + 1 >= nz { nz - 1 } else { k + 1 };
745
746 for j in 0..ny {
747 let jm1 = if j == 0 { 0 } else { j - 1 };
748 let jp1 = if j + 1 >= ny { ny - 1 } else { j + 1 };
749
750 for i in 0..nx {
751 let im1 = if i == 0 { 0 } else { i - 1 };
752 let ip1 = if i + 1 >= nx { nx - 1 } else { i + 1 };
753
754 let idx = i + j * nx + k * nx * ny;
755
756 output[idx] = center * input[idx]
757 + hx2 * (input[im1 + j * nx + k * nx * ny] + input[ip1 + j * nx + k * nx * ny])
758 + hy2 * (input[i + jm1 * nx + k * nx * ny] + input[i + jp1 * nx + k * nx * ny])
759 + hz2 * (input[i + j * nx + km1 * nx * ny] + input[i + j * nx + kp1 * nx * ny]);
760 }
761 }
762 }
763
764 output
765}
766
767fn apply_laplacian_inplace(
769 output: &mut [f32],
770 input: &[f32],
771 mask: &[u8],
772 nx: usize, ny: usize, nz: usize,
773 vsx: f32, vsy: f32, vsz: f32,
774) {
775 let hx2 = 1.0 / (vsx * vsx);
776 let hy2 = 1.0 / (vsy * vsy);
777 let hz2 = 1.0 / (vsz * vsz);
778
779 for k in 0..nz {
780 let k_offset = k * nx * ny;
781
782 for j in 0..ny {
783 let j_offset = j * nx;
784
785 for i in 0..nx {
786 let idx = i + j_offset + k_offset;
787
788 if mask[idx] == 0 {
789 output[idx] = 0.0;
790 continue;
791 }
792
793 let a0 = input[idx];
794
795 let a_xm = if i > 0 { input[(i - 1) + j_offset + k_offset] } else { a0 };
797 let a_xp = if i + 1 < nx { input[(i + 1) + j_offset + k_offset] } else { a0 };
798 let a_ym = if j > 0 { input[i + (j - 1) * nx + k_offset] } else { a0 };
799 let a_yp = if j + 1 < ny { input[i + (j + 1) * nx + k_offset] } else { a0 };
800 let a_zm = if k > 0 { input[i + j_offset + (k - 1) * nx * ny] } else { a0 };
801 let a_zp = if k + 1 < nz { input[i + j_offset + (k + 1) * nx * ny] } else { a0 };
802
803 output[idx] = hx2 * (a_xm - 2.0 * a0 + a_xp)
805 + hy2 * (a_ym - 2.0 * a0 + a_yp)
806 + hz2 * (a_zm - 2.0 * a0 + a_zp);
807 }
808 }
809 }
810}
811
812pub fn erode_mask(mask: &[u8], nx: usize, ny: usize, nz: usize) -> Vec<u8> {
814 let n_total = nx * ny * nz;
815 let mut eroded = vec![0u8; n_total];
816
817 for k in 0..nz {
818 for j in 0..ny {
819 for i in 0..nx {
820 let idx = i + j * nx + k * nx * ny;
821
822 if mask[idx] == 0 {
823 continue;
824 }
825
826 if i == 0 || i == nx - 1 || j == 0 || j == ny - 1 || k == 0 || k == nz - 1 {
828 continue;
829 }
830
831 let all_neighbors = mask[idx - 1] != 0
833 && mask[idx + 1] != 0
834 && mask[idx - nx] != 0
835 && mask[idx + nx] != 0
836 && mask[idx - nx * ny] != 0
837 && mask[idx + nx * ny] != 0;
838
839 eroded[idx] = if all_neighbors { 1 } else { 0 };
840 }
841 }
842 }
843
844 eroded
845}
846
847#[inline]
849fn grad_norm_sq(res: (f32, f32, f32)) -> f32 {
850 let (dx, dy, dz) = res;
851 4.0 * (1.0 / (dx * dx) + 1.0 / (dy * dy) + 1.0 / (dz * dz))
852}
853
854fn compute_operator_norm_sqr(g: f32, g2: f32, w: f32) -> f32 {
863 let g4 = g2 * g2;
867 let g2w = g2 * w;
868 let g2_w2 = g2 + w * w;
869 let g2_1 = g2 + 1.0;
870
871 let mut v = [1.0f32, 1.0, 1.0];
873
874 for _ in 0..20 {
875 let y0 = g4 * v[0] + g2w * v[1];
877 let y1 = g2w * v[0] + g2_w2 * v[1] + g * v[2];
878 let y2 = g * v[1] + g2_1 * v[2];
879
880 let norm = (y0 * y0 + y1 * y1 + y2 * y2).sqrt();
882 if norm < 1e-10 {
883 break;
884 }
885
886 v[0] = y0 / norm;
888 v[1] = y1 / norm;
889 v[2] = y2 / norm;
890 }
891
892 let y0 = g4 * v[0] + g2w * v[1];
894 let y1 = g2w * v[0] + g2_w2 * v[1] + g * v[2];
895 let y2 = g * v[1] + g2_1 * v[2];
896
897 v[0] * y0 + v[1] * y1 + v[2] * y2
898}
899
900#[inline]
902fn norm3(x: f32, y: f32, z: f32) -> f32 {
903 (x * x + y * y + z * z).sqrt()
904}
905
906#[inline]
908fn frobenius_norm(sxx: f32, sxy: f32, sxz: f32, syy: f32, syz: f32, szz: f32) -> f32 {
909 (sxx * sxx + syy * syy + szz * szz + 2.0 * (sxy * sxy + sxz * sxz + syz * syz)).sqrt()
910}
911
912#[inline]
914fn project_linf3(px: &mut f32, py: &mut f32, pz: &mut f32, threshold: f32) {
915 let norm = norm3(*px, *py, *pz);
916 if norm > threshold {
917 let scale = threshold / norm;
918 *px *= scale;
919 *py *= scale;
920 *pz *= scale;
921 }
922}
923
924#[inline]
926fn project_linf6(
927 qxx: &mut f32, qxy: &mut f32, qxz: &mut f32,
928 qyy: &mut f32, qyz: &mut f32, qzz: &mut f32,
929 threshold: f32,
930) {
931 let norm = frobenius_norm(*qxx, *qxy, *qxz, *qyy, *qyz, *qzz);
932 if norm > threshold {
933 let scale = threshold / norm;
934 *qxx *= scale;
935 *qxy *= scale;
936 *qxz *= scale;
937 *qyy *= scale;
938 *qyz *= scale;
939 *qzz *= scale;
940 }
941}
942
943fn compute_relative_change(chi: &[f32], chi_prev: &[f32], mask: &[u8]) -> f32 {
945 let mut diff_sq = 0.0f32;
946 let mut norm_sq = 0.0f32;
947
948 for i in 0..chi.len() {
949 if mask[i] != 0 {
950 let d = chi[i] - chi_prev[i];
951 diff_sq += d * d;
952 norm_sq += chi[i] * chi[i];
953 }
954 }
955
956 if norm_sq > 1e-10 {
957 (diff_sq / norm_sq).sqrt()
958 } else {
959 1.0
960 }
961}
962
963struct TgvWorkspace {
965 chi: Vec<f32>,
967 chi_: Vec<f32>,
968 chi_prev: Vec<f32>, phi: Vec<f32>,
970 phi_: Vec<f32>,
971 wx: Vec<f32>,
972 wy: Vec<f32>,
973 wz: Vec<f32>,
974 wx_: Vec<f32>,
975 wy_: Vec<f32>,
976 wz_: Vec<f32>,
977
978 eta: Vec<f32>,
980 px: Vec<f32>,
981 py: Vec<f32>,
982 pz: Vec<f32>,
983 qxx: Vec<f32>,
984 qxy: Vec<f32>,
985 qxz: Vec<f32>,
986 qyy: Vec<f32>,
987 qyz: Vec<f32>,
988 qzz: Vec<f32>,
989
990 temp1: Vec<f32>,
992 temp2: Vec<f32>,
993 gx: Vec<f32>,
994 gy: Vec<f32>,
995 gz: Vec<f32>,
996
997 sxx: Vec<f32>,
999 sxy: Vec<f32>,
1000 sxz: Vec<f32>,
1001 syy: Vec<f32>,
1002 syz: Vec<f32>,
1003 szz: Vec<f32>,
1004
1005 divqx: Vec<f32>,
1007 divqy: Vec<f32>,
1008 divqz: Vec<f32>,
1009}
1010
1011impl TgvWorkspace {
1012 fn new(n: usize) -> Self {
1013 Self {
1014 chi: vec![0.0; n],
1015 chi_: vec![0.0; n],
1016 chi_prev: vec![0.0; n],
1017 phi: vec![0.0; n],
1018 phi_: vec![0.0; n],
1019 wx: vec![0.0; n],
1020 wy: vec![0.0; n],
1021 wz: vec![0.0; n],
1022 wx_: vec![0.0; n],
1023 wy_: vec![0.0; n],
1024 wz_: vec![0.0; n],
1025 eta: vec![0.0; n],
1026 px: vec![0.0; n],
1027 py: vec![0.0; n],
1028 pz: vec![0.0; n],
1029 qxx: vec![0.0; n],
1030 qxy: vec![0.0; n],
1031 qxz: vec![0.0; n],
1032 qyy: vec![0.0; n],
1033 qyz: vec![0.0; n],
1034 qzz: vec![0.0; n],
1035 temp1: vec![0.0; n],
1036 temp2: vec![0.0; n],
1037 gx: vec![0.0; n],
1038 gy: vec![0.0; n],
1039 gz: vec![0.0; n],
1040 sxx: vec![0.0; n],
1041 sxy: vec![0.0; n],
1042 sxz: vec![0.0; n],
1043 syy: vec![0.0; n],
1044 syz: vec![0.0; n],
1045 szz: vec![0.0; n],
1046 divqx: vec![0.0; n],
1047 divqy: vec![0.0; n],
1048 divqz: vec![0.0; n],
1049 }
1050 }
1051}
1052
1053pub fn tgv_qsm(
1055 phase: &[f32],
1056 mask: &[u8],
1057 nx: usize, ny: usize, nz: usize,
1058 vsx: f32, vsy: f32, vsz: f32,
1059 params: &TgvParams,
1060 b0_dir: (f32, f32, f32),
1061) -> Vec<f32> {
1062 tgv_qsm_with_progress(phase, mask, nx, ny, nz, vsx, vsy, vsz, params, b0_dir, |_, _| {})
1063}
1064
1065pub fn tgv_qsm_with_progress<F>(
1067 phase: &[f32],
1068 mask: &[u8],
1069 nx: usize, ny: usize, nz: usize,
1070 vsx: f32, vsy: f32, vsz: f32,
1071 params: &TgvParams,
1072 b0_dir: (f32, f32, f32),
1073 progress: F,
1074) -> Vec<f32>
1075where
1076 F: Fn(usize, usize),
1077{
1078 let n_total = nx * ny * nz;
1079 let res = (vsx, vsy, vsz);
1080
1081 let mut mask_eroded = mask.to_vec();
1083 for _ in 0..params.erosions {
1084 mask_eroded = erode_mask(&mask_eroded, nx, ny, nz);
1085 }
1086
1087 let mask0 = erode_mask(&mask_eroded, nx, ny, nz);
1089
1090 let bbox = BoundingBox::from_mask(&mask0, nx, ny, nz, 2);
1092 let (bx, by, bz) = bbox.dims();
1093 let b_total = bbox.total();
1094
1095 let phase_sub = extract_subvolume(phase, &bbox, nx, ny, nz);
1097 let mask0_sub = extract_subvolume(&mask0, &bbox, nx, ny, nz);
1098 let mask_eroded_sub = extract_subvolume(&mask_eroded, &bbox, nx, ny, nz);
1099
1100 let mut laplace_phi0 = compute_phase_laplacian(&phase_sub, &mask0_sub, bx, by, bz, vsx, vsy, vsz);
1102
1103 let (sum, count): (f32, usize) = laplace_phi0.iter().zip(mask0_sub.iter())
1105 .filter(|(_, &m)| m != 0)
1106 .fold((0.0, 0), |(s, c), (&v, _)| (s + v, c + 1));
1107 if count > 0 {
1108 let mean = sum / count as f32;
1109 for (v, &m) in laplace_phi0.iter_mut().zip(mask0_sub.iter()) {
1110 if m != 0 {
1111 *v -= mean;
1112 }
1113 }
1114 }
1115
1116 let stencil = compute_oblique_stencil(res, b0_dir);
1118
1119 let grad_norm_squared = grad_norm_sq(res);
1121 let grad_norm = grad_norm_squared.sqrt();
1122 let wave_norm: f32 = stencil.iter().flatten().flatten().map(|x| x.abs()).sum();
1123 let norm_sqr = compute_operator_norm_sqr(grad_norm, grad_norm_squared, wave_norm);
1124
1125 let tau = 1.0 / norm_sqr.sqrt();
1126 let sigma = tau;
1127
1128 let sigma_step = sigma * params.step_size;
1132 let tau_step = tau * params.step_size;
1133
1134 let alpha = (params.alpha0, params.alpha1);
1137
1138 let mut ws = TgvWorkspace::new(b_total);
1140
1141 let mut _converged = false;
1142 let mut final_iter = params.iterations;
1143
1144 for iter in 0..params.iterations {
1146 progress(iter, params.iterations);
1147
1148 if iter > 0 && iter % 100 == 0 {
1150 let rel_change = compute_relative_change(&ws.chi, &ws.chi_prev, &mask0_sub);
1151 if rel_change < params.tol {
1152 _converged = true;
1153 final_iter = iter;
1154 break;
1155 }
1156 ws.chi_prev.copy_from_slice(&ws.chi);
1158 }
1159
1160 apply_laplacian_inplace(&mut ws.temp1, &ws.phi_, &mask0_sub, bx, by, bz, vsx, vsy, vsz);
1164 apply_stencil(&mut ws.temp2, &ws.chi_, &stencil, &mask0_sub, bx, by, bz);
1165
1166 for i in 0..b_total {
1167 if mask0_sub[i] != 0 {
1168 ws.eta[i] += sigma * (-ws.temp1[i] + ws.temp2[i] - laplace_phi0[i]);
1169 }
1170 }
1171
1172 crate::utils::gradient::fgrad_inplace_f32(
1176 &mut ws.gx, &mut ws.gy, &mut ws.gz, &ws.chi_, bx, by, bz, vsx, vsy, vsz
1177 );
1178
1179 for i in 0..b_total {
1180 let in_mask0 = mask0_sub[i] != 0;
1181 let in_mask = mask_eroded_sub[i] != 0;
1182
1183 if in_mask0 || in_mask {
1184 let sigmaw0 = if in_mask0 { sigma_step } else { 0.0 };
1186 let sigmaw = if in_mask { sigma_step } else { 0.0 };
1187
1188 ws.px[i] += sigmaw0 * ws.gx[i] - sigmaw * ws.wx_[i];
1189 ws.py[i] += sigmaw0 * ws.gy[i] - sigmaw * ws.wy_[i];
1190 ws.pz[i] += sigmaw0 * ws.gz[i] - sigmaw * ws.wz_[i];
1191
1192 project_linf3(&mut ws.px[i], &mut ws.py[i], &mut ws.pz[i], alpha.1);
1193 }
1194 }
1195
1196 crate::utils::gradient::symgrad_inplace_f32(
1198 &mut ws.sxx, &mut ws.sxy, &mut ws.sxz, &mut ws.syy, &mut ws.syz, &mut ws.szz,
1199 &ws.wx_, &ws.wy_, &ws.wz_, bx, by, bz, vsx, vsy, vsz
1200 );
1201
1202 for i in 0..b_total {
1203 if mask0_sub[i] != 0 {
1204 ws.qxx[i] += sigma_step * ws.sxx[i];
1205 ws.qxy[i] += sigma_step * ws.sxy[i];
1206 ws.qxz[i] += sigma_step * ws.sxz[i];
1207 ws.qyy[i] += sigma_step * ws.syy[i];
1208 ws.qyz[i] += sigma_step * ws.syz[i];
1209 ws.qzz[i] += sigma_step * ws.szz[i];
1210
1211 project_linf6(
1212 &mut ws.qxx[i], &mut ws.qxy[i], &mut ws.qxz[i],
1213 &mut ws.qyy[i], &mut ws.qyz[i], &mut ws.qzz[i],
1214 alpha.0
1215 );
1216 }
1217 }
1218
1219 std::mem::swap(&mut ws.phi, &mut ws.phi_);
1221 std::mem::swap(&mut ws.chi, &mut ws.chi_);
1222 std::mem::swap(&mut ws.wx, &mut ws.wx_);
1223 std::mem::swap(&mut ws.wy, &mut ws.wy_);
1224 std::mem::swap(&mut ws.wz, &mut ws.wz_);
1225
1226 for i in 0..b_total {
1230 ws.temp1[i] = if mask0_sub[i] != 0 { ws.eta[i] } else { 0.0 };
1231 }
1232 apply_laplacian_inplace(&mut ws.temp2, &ws.temp1, &mask0_sub, bx, by, bz, vsx, vsy, vsz);
1233
1234 for i in 0..b_total {
1235 let denom = 1.0 + if mask_eroded_sub[i] != 0 { tau } else { 0.0 };
1236 ws.phi[i] = (ws.phi_[i] + tau * ws.temp2[i]) / denom;
1237 }
1238
1239 crate::utils::gradient::bdiv_masked_inplace_f32(
1241 &mut ws.temp1, &ws.px, &ws.py, &ws.pz, &mask0_sub, bx, by, bz, vsx, vsy, vsz
1242 );
1243
1244 for i in 0..b_total {
1245 ws.gx[i] = if mask0_sub[i] != 0 { ws.eta[i] } else { 0.0 };
1246 }
1247 apply_stencil(&mut ws.temp2, &ws.gx, &stencil, &mask0_sub, bx, by, bz);
1248
1249 for i in 0..b_total {
1250 ws.chi[i] = ws.chi_[i] + tau_step * (ws.temp1[i] - ws.temp2[i]);
1251 }
1252
1253 for i in 0..b_total {
1255 let m = if mask0_sub[i] != 0 { 1.0 } else { 0.0 };
1256 ws.sxx[i] = ws.qxx[i] * m;
1257 ws.sxy[i] = ws.qxy[i] * m;
1258 ws.sxz[i] = ws.qxz[i] * m;
1259 ws.syy[i] = ws.qyy[i] * m;
1260 ws.syz[i] = ws.qyz[i] * m;
1261 ws.szz[i] = ws.qzz[i] * m;
1262 }
1263
1264 crate::utils::gradient::symdiv_inplace_f32(
1265 &mut ws.divqx, &mut ws.divqy, &mut ws.divqz,
1266 &ws.sxx, &ws.sxy, &ws.sxz, &ws.syy, &ws.syz, &ws.szz,
1267 bx, by, bz, vsx, vsy, vsz
1268 );
1269
1270 for i in 0..b_total {
1272 ws.wx[i] = ws.wx_[i];
1273 ws.wy[i] = ws.wy_[i];
1274 ws.wz[i] = ws.wz_[i];
1275 if mask_eroded_sub[i] != 0 {
1276 ws.wx[i] += tau_step * (ws.px[i] + ws.divqx[i]);
1277 ws.wy[i] += tau_step * (ws.py[i] + ws.divqy[i]);
1278 ws.wz[i] += tau_step * (ws.pz[i] + ws.divqz[i]);
1279 }
1280 }
1281
1282 for i in 0..b_total {
1284 ws.phi_[i] = 2.0 * ws.phi[i] - ws.phi_[i];
1285 ws.chi_[i] = 2.0 * ws.chi[i] - ws.chi_[i];
1286 ws.wx_[i] = 2.0 * ws.wx[i] - ws.wx_[i];
1287 ws.wy_[i] = 2.0 * ws.wy[i] - ws.wy_[i];
1288 ws.wz_[i] = 2.0 * ws.wz[i] - ws.wz_[i];
1289 }
1290 }
1291
1292 progress(final_iter, params.iterations);
1293
1294 let gamma = 42.5781f32; let scale = 1.0 / (2.0 * PI * params.te * params.fieldstrength * gamma);
1297
1298 let mut result = vec![0.0f32; n_total];
1300
1301 let mut chi_scaled = vec![0.0f32; b_total];
1303 for i in 0..b_total {
1304 if mask_eroded_sub[i] != 0 {
1305 chi_scaled[i] = ws.chi[i] * scale;
1306 }
1307 }
1308
1309 insert_subvolume(&mut result, &chi_scaled, &bbox, nx, ny, nz);
1311
1312 result
1313}
1314
1315#[cfg(test)]
1316mod tests {
1317 use super::*;
1318
1319 #[test]
1320 fn test_dipole_stencil() {
1321 let stencil = compute_dipole_stencil((1.0, 1.0, 1.0), (0.0, 0.0, 1.0));
1322
1323 let mut sum = 0.0f32;
1324 for k in 0..3 {
1325 for j in 0..3 {
1326 for i in 0..3 {
1327 sum += stencil[i][j][k];
1328 }
1329 }
1330 }
1331 assert!(sum.abs() < 1e-6, "Stencil sum should be ~0, got {}", sum);
1332 }
1333
1334 #[test]
1335 fn test_phase_laplacian() {
1336 let nx = 4;
1337 let ny = 4;
1338 let nz = 4;
1339 let n = nx * ny * nz;
1340
1341 let phase = vec![1.0f32; n];
1342 let mask = vec![1u8; n];
1343
1344 let lap = compute_phase_laplacian(&phase, &mask, nx, ny, nz, 1.0, 1.0, 1.0);
1345
1346 let max_val = lap.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
1347 assert!(max_val < 1e-5, "Laplacian of constant should be ~0, got max {}", max_val);
1348 }
1349
1350 #[test]
1351 fn test_erode_mask() {
1352 let nx = 5;
1353 let ny = 5;
1354 let nz = 5;
1355
1356 let mask = vec![1u8; nx * ny * nz];
1357 let eroded = erode_mask(&mask, nx, ny, nz);
1358
1359 let center = 2 + 2 * nx + 2 * nx * ny;
1360 assert_eq!(eroded[center], 1);
1361 assert_eq!(eroded[0], 0);
1362 }
1363
1364 #[test]
1365 fn test_default_alpha() {
1366 let (a0, a1) = get_default_alpha(2);
1367 assert!((a0 - 0.002).abs() < 1e-6);
1368 assert!((a1 - 0.003).abs() < 1e-6);
1369 }
1370
1371 #[test]
1372 fn test_bounding_box() {
1373 let nx = 10;
1374 let ny = 10;
1375 let nz = 10;
1376 let mut mask = vec![0u8; nx * ny * nz];
1377
1378 for k in 3..7 {
1380 for j in 3..7 {
1381 for i in 3..7 {
1382 mask[i + j * nx + k * nx * ny] = 1;
1383 }
1384 }
1385 }
1386
1387 let bbox = BoundingBox::from_mask(&mask, nx, ny, nz, 1);
1388
1389 assert_eq!(bbox.i_min, 2);
1391 assert_eq!(bbox.i_max, 8);
1392 assert_eq!(bbox.j_min, 2);
1393 assert_eq!(bbox.j_max, 8);
1394 }
1395
1396 #[test]
1397 fn test_tgv_qsm_small() {
1398 let n = 12;
1399 let n_total = n * n * n;
1400 let center = 6.0_f32;
1401 let radius = 4.0_f32;
1402
1403 let mut mask = vec![0u8; n_total];
1405 for k in 0..n {
1406 for j in 0..n {
1407 for i in 0..n {
1408 let dx = i as f32 - center;
1409 let dy = j as f32 - center;
1410 let dz = k as f32 - center;
1411 if (dx * dx + dy * dy + dz * dz).sqrt() < radius {
1412 mask[i + j * n + k * n * n] = 1;
1413 }
1414 }
1415 }
1416 }
1417
1418 let mut phase = vec![0.0f32; n_total];
1420 for k in 0..n {
1421 for j in 0..n {
1422 for i in 0..n {
1423 let idx = i + j * n + k * n * n;
1424 if mask[idx] != 0 {
1425 phase[idx] = 0.1 * k as f32;
1426 }
1427 }
1428 }
1429 }
1430
1431 let params = TgvParams {
1432 iterations: 10,
1433 erosions: 1,
1434 ..TgvParams::default()
1435 };
1436
1437 let result = tgv_qsm(&phase, &mask, n, n, n, 1.0, 1.0, 1.0, ¶ms, (0.0, 0.0, 1.0));
1438
1439 assert_eq!(result.len(), n_total);
1440
1441 for &v in &result {
1443 assert!(v.is_finite(), "Result contains non-finite value: {}", v);
1444 }
1445
1446 let has_nonzero = result.iter().any(|&v| v.abs() > 1e-20);
1448 assert!(has_nonzero, "Result is entirely zero; expected non-zero values within mask");
1449 }
1450
1451 #[test]
1452 fn test_oblique_stencil() {
1453 let stencil = compute_oblique_stencil((1.0, 1.0, 1.0), (0.0, 0.0, 1.0));
1454
1455 let mut sum = 0.0f32;
1457 for k in 0..3 {
1458 for j in 0..3 {
1459 for i in 0..3 {
1460 sum += stencil[i][j][k];
1461 }
1462 }
1463 }
1464 assert!(
1465 sum.abs() < 1e-4,
1466 "Oblique stencil sum should be ~0, got {}",
1467 sum
1468 );
1469
1470 let mut off_sum = 0.0f32;
1472 for k in 0..3 {
1473 for j in 0..3 {
1474 for i in 0..3 {
1475 if !(i == 1 && j == 1 && k == 1) {
1476 off_sum += stencil[i][j][k];
1477 }
1478 }
1479 }
1480 }
1481 assert!(
1482 (stencil[1][1][1] + off_sum).abs() < 1e-6,
1483 "Center should be -sum(others): center={}, off_sum={}",
1484 stencil[1][1][1], off_sum
1485 );
1486 }
1487
1488 #[test]
1489 fn test_oblique_stencil_aniso() {
1490 let stencil = compute_oblique_stencil((1.0, 1.0, 2.0), (0.2, 0.3, 0.9));
1491
1492 let mut sum = 0.0f32;
1494 for k in 0..3 {
1495 for j in 0..3 {
1496 for i in 0..3 {
1497 sum += stencil[i][j][k];
1498 }
1499 }
1500 }
1501 assert!(
1502 sum.abs() < 1e-4,
1503 "Anisotropic oblique stencil sum should be ~0, got {}",
1504 sum
1505 );
1506 }
1507
1508 #[test]
1509 fn test_get_default_iterations() {
1510 let it = get_default_iterations((1.0, 1.0, 1.0), 1.0);
1512 assert!(it >= 1000, "Iterations should be >= 1000 for 1mm iso, got {}", it);
1513
1514 let it_large = get_default_iterations((2.0, 2.0, 2.0), 1.0);
1516 assert!(it_large >= 1000, "Iterations should still be >= 1000 for 2mm iso");
1517
1518 let it_small = get_default_iterations((0.5, 0.5, 0.5), 1.0);
1520 assert!(it_small > it, "Smaller voxels should need more iterations: {} vs {}", it_small, it);
1521
1522 let it_fast = get_default_iterations((1.0, 1.0, 1.0), 3.0);
1524 assert!(it_fast < it, "Higher step_size should give fewer iterations: {} vs {}", it_fast, it);
1525 }
1526
1527 #[test]
1528 fn test_compute_relative_change() {
1529 let chi = vec![1.0f32, 2.0, 3.0, 4.0];
1531 let chi_prev = vec![1.0f32, 2.0, 3.0, 4.0];
1532 let mask = vec![1u8, 1, 1, 1];
1533 let rc = compute_relative_change(&chi, &chi_prev, &mask);
1534 assert!(rc.abs() < 1e-10, "Identical arrays should give 0 change, got {}", rc);
1535
1536 let chi2 = vec![1.1f32, 2.0, 3.0, 4.0];
1538 let rc2 = compute_relative_change(&chi2, &chi_prev, &mask);
1539 assert!(rc2 > 0.0, "Different arrays should give positive change");
1540 let expected = (0.01f32 / 30.21).sqrt();
1542 assert!(
1543 (rc2 - expected).abs() < 1e-5,
1544 "Expected relative change ~{}, got {}",
1545 expected,
1546 rc2
1547 );
1548
1549 let mask_partial = vec![1u8, 0, 0, 0];
1551 let chi3 = vec![2.0f32, 999.0, 999.0, 999.0];
1552 let chi_prev3 = vec![1.0f32, 0.0, 0.0, 0.0];
1553 let rc3 = compute_relative_change(&chi3, &chi_prev3, &mask_partial);
1554 let expected3 = (1.0f32 / 4.0).sqrt();
1556 assert!(
1557 (rc3 - expected3).abs() < 1e-6,
1558 "Masked relative change expected {}, got {}",
1559 expected3,
1560 rc3
1561 );
1562
1563 let zeros = vec![0.0f32; 4];
1565 let rc4 = compute_relative_change(&zeros, &zeros, &mask);
1566 assert!(
1567 (rc4 - 1.0).abs() < 1e-6,
1568 "Zero norm should return 1.0, got {}",
1569 rc4
1570 );
1571 }
1572
1573 #[test]
1574 fn test_tgv_convergence() {
1575 let n = 12;
1577 let n_total = n * n * n;
1578 let center = 6.0_f32;
1579 let radius = 4.0_f32;
1580
1581 let mut mask = vec![0u8; n_total];
1582 for k in 0..n {
1583 for j in 0..n {
1584 for i in 0..n {
1585 let dx = i as f32 - center;
1586 let dy = j as f32 - center;
1587 let dz = k as f32 - center;
1588 if (dx * dx + dy * dy + dz * dz).sqrt() < radius {
1589 mask[i + j * n + k * n * n] = 1;
1590 }
1591 }
1592 }
1593 }
1594
1595 let phase = vec![0.0f32; n_total];
1596
1597 let params = TgvParams {
1598 iterations: 1000,
1599 erosions: 1,
1600 tol: 1.1, ..TgvParams::default()
1602 };
1603
1604 let progress_iters = std::cell::RefCell::new(Vec::new());
1605 let result = tgv_qsm_with_progress(
1606 &phase, &mask, n, n, n, 1.0, 1.0, 1.0, ¶ms, (0.0, 0.0, 1.0),
1607 |iter, _total| { progress_iters.borrow_mut().push(iter); }
1608 );
1609
1610 assert_eq!(result.len(), n_total);
1611
1612 let max_abs = result.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
1614 assert!(
1615 max_abs < 1e-3,
1616 "Zero-phase TGV result should be ~0, got max abs {}",
1617 max_abs
1618 );
1619
1620 let iters = progress_iters.borrow();
1623 let &last_iter = iters.last().unwrap();
1624 assert!(
1625 last_iter <= 100,
1626 "Expected early convergence by iter 100, but last progress was at iter {}",
1627 last_iter
1628 );
1629 }
1630}