1#[derive(Clone, Debug)]
19pub struct IlsqrParams {
20 pub tol: f64,
22 pub max_iter: usize,
24}
25
26impl Default for IlsqrParams {
27 fn default() -> Self {
28 Self {
29 tol: 0.01,
30 max_iter: 50,
31 }
32 }
33}
34
35use std::cell::RefCell;
36use num_complex::Complex64;
37use crate::fft::Fft3dWorkspace;
38use crate::kernels::dipole::dipole_kernel;
39use crate::kernels::smv::smv_kernel;
40use crate::utils::gradient::{fgrad, bdiv};
41
42pub fn lsqr<F, G>(
61 apply_a: F,
62 apply_at: G,
63 b: &[f64],
64 tol: f64,
65 max_iter: usize,
66) -> Vec<f64>
67where
68 F: Fn(&[f64]) -> Vec<f64>,
69 G: Fn(&[f64]) -> Vec<f64>,
70{
71 let mut u = b.to_vec();
73 let mut beta = norm(&u);
74
75 if beta > 0.0 {
76 scale_inplace(&mut u, 1.0 / beta);
77 }
78
79 let mut v = apply_at(&u);
80 let n = v.len();
81 let mut alpha = norm(&v);
82
83 if alpha > 0.0 {
84 scale_inplace(&mut v, 1.0 / alpha);
85 }
86
87 let mut w = v.clone();
88 let mut x = vec![0.0; n];
89
90 let mut phi_bar = beta;
91 let mut rho_bar = alpha;
92
93 let bnorm = beta;
94
95 for _iter in 0..max_iter {
96 let mut u_new = apply_a(&v);
98 axpy(&mut u_new, -alpha, &u);
99 beta = norm(&u_new);
100
101 if beta > 0.0 {
102 scale_inplace(&mut u_new, 1.0 / beta);
103 }
104 u = u_new;
105
106 let mut v_new = apply_at(&u);
107 axpy(&mut v_new, -beta, &v);
108 alpha = norm(&v_new);
109
110 if alpha > 0.0 {
111 scale_inplace(&mut v_new, 1.0 / alpha);
112 }
113 v = v_new;
114
115 let rho = (rho_bar * rho_bar + beta * beta).sqrt();
117 let c = rho_bar / rho;
118 let s = beta / rho;
119 let theta = s * alpha;
120 rho_bar = -c * alpha;
121 let phi = c * phi_bar;
122 phi_bar = s * phi_bar;
123
124 let t1 = phi / rho;
126 let t2 = -theta / rho;
127
128 for i in 0..n {
129 x[i] += t1 * w[i];
130 w[i] = v[i] + t2 * w[i];
131 }
132
133 let rel_residual = phi_bar / (bnorm + 1e-20);
135
136 if rel_residual < tol {
137 break;
138 }
139 }
140
141 x
142}
143
144fn norm_complex(x: &[Complex64]) -> f64 {
150 x.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt()
151}
152
153fn scale_complex_inplace(x: &mut [Complex64], s: f64) {
155 for v in x.iter_mut() {
156 *v *= s;
157 }
158}
159
160fn axpy_complex(y: &mut [Complex64], a: f64, x: &[Complex64]) {
162 for (yi, xi) in y.iter_mut().zip(x.iter()) {
163 *yi += a * xi;
164 }
165}
166
167pub fn lsqr_complex<F, G>(
176 apply_a: F,
177 apply_ah: G,
178 b: &[Complex64],
179 tol: f64,
180 max_iter: usize,
181 verbose: bool,
182) -> Vec<Complex64>
183where
184 F: Fn(&[Complex64]) -> Vec<Complex64>,
185 G: Fn(&[Complex64]) -> Vec<Complex64>,
186{
187 let mut u = b.to_vec();
189 let mut beta = norm_complex(&u);
190
191 if beta > 0.0 {
192 scale_complex_inplace(&mut u, 1.0 / beta);
193 }
194
195 let mut v = apply_ah(&u);
197 let n = v.len();
198 let mut alpha = norm_complex(&v);
199
200 if alpha > 0.0 {
201 scale_complex_inplace(&mut v, 1.0 / alpha);
202 }
203
204 let mut w = v.clone();
205 let mut x = vec![Complex64::new(0.0, 0.0); n];
206
207 let mut phi_bar = beta;
208 let mut rho_bar = alpha;
209
210 let bnorm = beta;
211 let atol = tol;
212 let btol = tol;
213
214 let mut norm_a2 = alpha * alpha;
216
217 let mut xxnorm = 0.0;
220 let mut z_sol = 0.0;
221 let mut cs2 = -1.0;
222 let mut sn2 = 0.0;
223
224 if alpha * beta == 0.0 {
225 return x;
226 }
227
228 for _iter in 0..max_iter {
229 let mut u_new = apply_a(&v);
231 axpy_complex(&mut u_new, -alpha, &u);
232 beta = norm_complex(&u_new);
233
234 if beta > 0.0 {
235 scale_complex_inplace(&mut u_new, 1.0 / beta);
236 }
237 u = u_new;
238
239 let mut v_new = apply_ah(&u);
240 axpy_complex(&mut v_new, -beta, &v);
241 alpha = norm_complex(&v_new);
242
243 if alpha > 0.0 {
244 scale_complex_inplace(&mut v_new, 1.0 / alpha);
245 }
246 v = v_new;
247
248 let rho = (rho_bar * rho_bar + beta * beta).sqrt();
250 let c = rho_bar / rho;
251 let s = beta / rho;
252 let theta = s * alpha;
253 rho_bar = -c * alpha;
254 let phi = c * phi_bar;
255 phi_bar = s * phi_bar;
256
257 let delta = sn2 * rho;
259 let gambar = -cs2 * rho;
260 let rhs = phi - delta * z_sol;
261 let zbar = rhs / gambar;
262 let xnorm = (xxnorm + zbar * zbar).sqrt();
263 let gamma = (gambar * gambar + theta * theta).sqrt();
264 cs2 = gambar / gamma;
265 sn2 = theta / gamma;
266 z_sol = rhs / gamma;
267 xxnorm += z_sol * z_sol;
268
269 let t1 = phi / rho;
271 let t2 = -theta / rho;
272 for i in 0..n {
273 x[i] += t1 * w[i];
274 w[i] = v[i] + t2 * w[i];
275 }
276
277 let normr = phi_bar;
279 let norm_ar = alpha * (c * phi_bar).abs();
280
281 norm_a2 += beta * beta + alpha * alpha;
282 let norm_a = norm_a2.sqrt();
283
284 let test1 = normr / (bnorm + 1e-20);
286 let test2 = norm_ar / ((norm_a * normr) + 1e-20);
287 let rtol = btol + atol * norm_a * xnorm / (bnorm + 1e-20);
288
289 if verbose {
290 eprintln!(" LSQR iter {:>3}: ||r||/||b||={:.6e} ||A'r||/(||A||·||r||)={:.6e} rtol={:.6e}",
291 _iter + 1, test1, test2, rtol);
292 }
293
294 if test2 <= atol || test1 <= rtol {
295 if verbose {
296 eprintln!(" LSQR converged at iteration {} (test1={:.4e}, test2={:.4e})",
297 _iter + 1, test1, test2);
298 }
299 break;
300 }
301 }
302
303 x
304}
305
306pub fn lsmr<F, G>(
328 apply_a: F,
329 apply_at: G,
330 b: &[f64],
331 n: usize,
332 atol: f64,
333 btol: f64,
334 max_iter: usize,
335 _verbose: bool,
336) -> Vec<f64>
337where
338 F: Fn(&[f64]) -> Vec<f64>,
339 G: Fn(&[f64]) -> Vec<f64>,
340{
341 let mut u = b.to_vec();
347 let mut beta = norm(&u);
348
349 if beta > 0.0 {
350 scale_inplace(&mut u, 1.0 / beta);
351 }
352
353 let mut v = apply_at(&u);
354 let mut alpha = norm(&v);
355
356 if alpha > 0.0 {
357 scale_inplace(&mut v, 1.0 / alpha);
358 }
359
360 let mut alpha_bar = alpha;
362 let mut zeta_bar = alpha * beta;
363 let mut rho = 1.0;
364 let mut rho_bar = 1.0;
365 let mut c_bar = 1.0;
366 let mut s_bar = 0.0;
367
368 let mut h = v.clone();
369 let mut h_bar = vec![0.0; n];
370 let mut x = vec![0.0; n];
371
372 let normb = beta;
374 let mut betadd = beta;
375 let mut betad = 0.0;
376 let mut rhodold = 1.0;
377 let mut tautildeold = 0.0;
378 let mut thetatilde = 0.0;
379 let mut zeta = 0.0;
380 let d = 0.0;
381
382 let mut norm_a2 = alpha * alpha;
384 let mut maxrbar = 0.0f64;
385 let mut minrbar = 1e100f64;
386 let conlim = 1e8;
387 let ctol = if conlim > 0.0 { 1.0 / conlim } else { 0.0 };
388
389 if alpha * beta == 0.0 {
391 return x;
392 }
393
394 for _iter in 0..max_iter {
395 let mut u_new = apply_a(&v);
397 axpy(&mut u_new, -alpha, &u);
398 beta = norm(&u_new);
399
400 if beta > 0.0 {
401 scale_inplace(&mut u_new, 1.0 / beta);
402 }
403 u = u_new;
404
405 let mut v_new = apply_at(&u);
406 axpy(&mut v_new, -beta, &v);
407 alpha = norm(&v_new);
408
409 if alpha > 0.0 {
410 scale_inplace(&mut v_new, 1.0 / alpha);
411 }
412 v = v_new;
413
414 let rho_old = rho;
416 rho = (alpha_bar * alpha_bar + beta * beta).sqrt();
417 let c = alpha_bar / rho;
418 let s = beta / rho;
419 let theta_new = s * alpha;
420 alpha_bar = c * alpha;
421
422 let rho_bar_old = rho_bar;
424 let zeta_old = zeta;
425 let theta_bar = s_bar * rho;
426 let rho_temp = c_bar * rho;
427 rho_bar = (rho_temp * rho_temp + theta_new * theta_new).sqrt();
428 c_bar = rho_temp / rho_bar;
429 s_bar = theta_new / rho_bar;
430 zeta = c_bar * zeta_bar;
431 zeta_bar = -s_bar * zeta_bar;
432
433 for i in 0..n {
435 h_bar[i] = h[i] - (theta_bar * rho / (rho_old * rho_bar_old)) * h_bar[i];
436 x[i] += (zeta / (rho * rho_bar)) * h_bar[i];
437 h[i] = v[i] - (theta_new / rho) * h[i];
438 }
439
440 let betaacute = betadd; let betahat = c * betaacute;
445 betadd = -s * betaacute;
446
447 let thetatildeold = thetatilde;
448 let rhotildeold = (rhodold * rhodold + theta_bar * theta_bar).sqrt();
449 let ctildeold = rhodold / rhotildeold;
450 let stildeold = theta_bar / rhotildeold;
451 thetatilde = stildeold * rho_bar;
452 rhodold = ctildeold * rho_bar;
453 betad = -stildeold * betad + ctildeold * betahat;
454
455 tautildeold = (zeta_old - thetatildeold * tautildeold) / rhotildeold;
456 let taud = (zeta - thetatilde * tautildeold) / rhodold;
457 let normr = (d + (betad - taud).powi(2) + betadd * betadd).sqrt();
459
460 norm_a2 += beta * beta;
462 let norm_a = norm_a2.sqrt();
463 norm_a2 += alpha * alpha;
464
465 maxrbar = maxrbar.max(rho_bar_old);
467 if _iter > 0 {
468 minrbar = minrbar.min(rho_bar_old);
469 }
470 let cond_a = maxrbar.max(rho_temp) / minrbar.min(rho_temp);
471
472 let norm_ar = zeta_bar.abs();
474 let normx = norm(&x);
475
476 let test1 = normr / (normb + 1e-20);
477 let test2 = norm_ar / ((norm_a * normr) + 1e-20);
478 let test3 = 1.0 / (cond_a + 1e-20);
479 let rtol = btol + atol * norm_a * normx / (normb + 1e-20);
480
481 if _verbose {
482 eprintln!(" LSMR iter {:>3}: ||r||/||b||={:.6e} ||A'r||/(||A||·||r||)={:.6e} 1/condA={:.6e} rtol={:.6e}",
483 _iter + 1, test1, test2, test3, rtol);
484 }
485
486 if test3 <= ctol || test2 <= atol || test1 <= rtol {
489 if _verbose {
490 let reason = if test1 <= rtol { "test1 (residual)"
491 } else if test2 <= atol { "test2 (||A'r||)"
492 } else { "test3 (cond(A))" };
493 eprintln!(" LSMR converged at iteration {} via {}", _iter + 1, reason);
494 }
495 break;
496 }
497 }
498
499 x
500}
501
502fn compute_laplacian(
512 f: &[f64],
513 mask: &[u8],
514 nx: usize, ny: usize, nz: usize,
515 vsx: f64, vsy: f64, vsz: f64,
516) -> Vec<f64> {
517 let n_total = nx * ny * nz;
518 let mut lap = vec![0.0; n_total];
519
520 let hx = 1.0 / (vsx * vsx);
521 let hy = 1.0 / (vsy * vsy);
522 let hz = 1.0 / (vsz * vsz);
523
524 let nxny = nx * ny;
525
526 for k in 0..nz {
527 for j in 0..ny {
528 let jk_offset = j * nx + k * nxny;
529 for i in 0..nx {
530 let l = i + jk_offset;
531
532 if mask[l] == 0 {
533 continue;
534 }
535
536 lap[l] += hx * lap1_axis(f, mask, l, 1, nx, i, nx);
538
539 lap[l] += hy * lap1_axis(f, mask, l, nx, nxny, j * nx, nxny);
541
542 lap[l] += hz * lap1_axis(f, mask, l, nxny, n_total, k * nxny, n_total);
544 }
545 }
546 }
547
548 lap
549}
550
551#[inline]
567fn lap1_axis(
568 f: &[f64],
569 mask: &[u8],
570 l: usize,
571 a: usize,
572 n_axis: usize,
573 coord: usize,
574 n_total: usize,
575) -> f64 {
576 let n_end = n_axis - a; let n_interior = n_axis - 2 * a; let stencil = if coord.wrapping_sub(a) < n_interior {
583 2 * (mask[l + a] as u8) + (mask[l - a] as u8)
585 } else {
586 if coord == 0 { 2 } else if coord == n_end { 1 } else { 0 }
588 };
589
590 match stencil {
591 3 => {
592 f[l - a] - 2.0 * f[l] + f[l + a]
594 }
595 2 => {
596 lap1_forward(f, mask, l, a, n_axis, coord, n_total)
598 }
599 1 => {
600 lap1_backward(f, mask, l, a, n_axis, coord, n_total)
602 }
603 _ => 0.0, }
605}
606
607#[inline]
609fn lap1_forward(
610 f: &[f64],
611 mask: &[u8],
612 l: usize,
613 a: usize,
614 n_axis: usize,
615 coord: usize,
616 _n_total: usize,
617) -> f64 {
618 if coord + 3 * a < n_axis && mask[l + 2 * a] != 0 && mask[l + 3 * a] != 0 {
620 2.0 * f[l] - 5.0 * f[l + a] + 4.0 * f[l + 2 * a] - f[l + 3 * a]
621 }
622 else if coord + 2 * a < n_axis && mask[l + 2 * a] != 0 {
624 f[l] - 2.0 * f[l + a] + f[l + 2 * a]
625 }
626 else {
628 f[l + a] - f[l]
629 }
630}
631
632#[inline]
634fn lap1_backward(
635 f: &[f64],
636 mask: &[u8],
637 l: usize,
638 a: usize,
639 n_axis: usize,
640 coord: usize,
641 _n_total: usize,
642) -> f64 {
643 if coord.wrapping_sub(3 * a) < n_axis && mask[l - 3 * a] != 0 && mask[l - 2 * a] != 0 {
645 -f[l - 3 * a] + 4.0 * f[l - 2 * a] - 5.0 * f[l - a] + 2.0 * f[l]
646 }
647 else if coord.wrapping_sub(2 * a) < n_axis && mask[l - 2 * a] != 0 {
649 f[l - 2 * a] - 2.0 * f[l - a] + f[l]
650 }
651 else {
653 f[l - a] - f[l]
654 }
655}
656
657fn laplacian_weights_ilsqr(
661 f: &[f64],
662 mask: &[u8],
663 nx: usize, ny: usize, nz: usize,
664 vsx: f64, vsy: f64, vsz: f64,
665 pmin: f64,
666 pmax: f64,
667) -> Vec<f64> {
668 let n_total = nx * ny * nz;
669 let mut w = vec![0.0; n_total];
670
671 let lap = compute_laplacian(f, mask, nx, ny, nz, vsx, vsy, vsz);
673
674 let mut masked_lap: Vec<f64> = lap.iter()
676 .zip(mask.iter())
677 .filter(|(_, &m)| m > 0)
678 .map(|(&l, _)| l)
679 .collect();
680
681 if masked_lap.is_empty() {
682 return w;
683 }
684
685 masked_lap.sort_by(|a, b| a.partial_cmp(b).unwrap());
687
688 let thr_min = prctile(&masked_lap, pmin);
689 let thr_max = prctile(&masked_lap, pmax);
690
691 let range = thr_max - thr_min;
692
693 for i in 0..n_total {
695 if mask[i] == 0 {
696 continue;
697 }
698
699 let l = lap[i];
700
701 if l < thr_min {
702 w[i] = 1.0;
703 } else if l > thr_max {
704 w[i] = 0.0;
705 } else if range > 1e-10 {
706 w[i] = (thr_max - l) / range;
707 }
708 }
709
710 w
711}
712
713fn dipole_kspace_weights_ilsqr(
717 d: &[f64],
718 n_exp: f64,
719 pa: f64,
720 pb: f64,
721) -> Vec<f64> {
722 let len = d.len();
723 let mut w = vec![0.0; len];
724
725 for i in 0..len {
727 w[i] = d[i].abs().powf(n_exp);
728 }
729
730 let mut vals: Vec<f64> = w.to_vec();
732 vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
733
734 if vals.is_empty() {
735 return vec![0.0; len];
736 }
737
738 let ab_min = prctile(&vals, pa);
739 let ab_max = prctile(&vals, pb);
740
741 let range = ab_max - ab_min;
742
743 for i in 0..len {
745 if range > 1e-20 {
746 w[i] = (w[i] - ab_min) / range;
747 }
748 w[i] = w[i].max(0.0).min(1.0);
749 }
750
751 w
752}
753
754fn fgrad_masked(
760 f: &[f64],
761 mask: &[u8],
762 nx: usize, ny: usize, nz: usize,
763 vsx: f64, vsy: f64, vsz: f64,
764) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
765 let n_total = nx * ny * nz;
766 let mut dx = vec![0.0; n_total];
767 let mut dy = vec![0.0; n_total];
768 let mut dz = vec![0.0; n_total];
769
770 let hx = 1.0 / vsx;
771 let hy = 1.0 / vsy;
772 let hz = 1.0 / vsz;
773
774 let nxny = nx * ny;
775
776 for k in 0..nz {
777 for j in 0..ny {
778 let jk = j * nx + k * nxny;
779 for i in 0..nx {
780 let l = i + jk;
781 if mask[l] == 0 { continue; }
782
783 dx[l] = if i < nx - 1 && mask[l + 1] != 0 {
785 hx * (f[l + 1] - f[l])
786 } else if i > 0 && mask[l - 1] != 0 {
787 hx * (f[l] - f[l - 1])
788 } else {
789 0.0
790 };
791
792 dy[l] = if j < ny - 1 && mask[l + nx] != 0 {
794 hy * (f[l + nx] - f[l])
795 } else if j > 0 && mask[l - nx] != 0 {
796 hy * (f[l] - f[l - nx])
797 } else {
798 0.0
799 };
800
801 dz[l] = if k < nz - 1 && mask[l + nxny] != 0 {
803 hz * (f[l + nxny] - f[l])
804 } else if k > 0 && mask[l - nxny] != 0 {
805 hz * (f[l] - f[l - nxny])
806 } else {
807 0.0
808 };
809 }
810 }
811 }
812
813 (dx, dy, dz)
814}
815
816fn gradient_weights_ilsqr(
818 x: &[f64],
819 mask: &[u8],
820 nx: usize, ny: usize, nz: usize,
821 vsx: f64, vsy: f64, vsz: f64,
822 pmin: f64,
823 pmax: f64,
824) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
825 let (gx, gy, gz) = fgrad_masked(x, mask, nx, ny, nz, vsx, vsy, vsz);
827
828 let wx = gradient_weights_component(&gx, mask, pmin, pmax);
830 let wy = gradient_weights_component(&gy, mask, pmin, pmax);
831 let wz = gradient_weights_component(&gz, mask, pmin, pmax);
832
833 (wx, wy, wz)
834}
835
836fn gradient_weights_component(
837 g: &[f64],
838 mask: &[u8],
839 pmin: f64,
840 pmax: f64,
841) -> Vec<f64> {
842 let len = g.len();
843 let mut w = vec![0.0; len];
844
845 let mut masked_g: Vec<f64> = g.iter()
847 .zip(mask.iter())
848 .filter(|(_, &m)| m > 0)
849 .map(|(&v, _)| v)
850 .collect();
851
852 if masked_g.is_empty() {
853 return w;
854 }
855
856 masked_g.sort_by(|a, b| a.partial_cmp(b).unwrap());
857
858 let thr_min = prctile(&masked_g, pmin);
859 let thr_max = prctile(&masked_g, pmax);
860
861 let range = thr_max - thr_min;
862
863 for i in 0..len {
864 if mask[i] == 0 {
865 continue;
866 }
867
868 let v = g[i];
869
870 if v < thr_min {
871 w[i] = 1.0;
872 } else if v > thr_max {
873 w[i] = 0.0;
874 } else if range > 1e-10 {
875 w[i] = (thr_max - v) / range;
876 }
877
878 w[i] *= mask[i] as f64;
880 }
881
882 w
883}
884
885fn norm(x: &[f64]) -> f64 {
890 x.iter().map(|&v| v * v).sum::<f64>().sqrt()
891}
892
893fn scale_inplace(x: &mut [f64], s: f64) {
894 for v in x.iter_mut() {
895 *v *= s;
896 }
897}
898
899fn axpy(y: &mut [f64], a: f64, x: &[f64]) {
900 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
901 *yi += a * xi;
902 }
903}
904
905fn multiply_elementwise(a: &[f64], b: &[f64]) -> Vec<f64> {
906 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).collect()
907}
908
909fn sign_array(x: &[f64]) -> Vec<f64> {
910 x.iter().map(|&v| {
911 if v > 0.0 { 1.0 }
912 else if v < 0.0 { -1.0 }
913 else { 0.0 }
914 }).collect()
915}
916
917fn prctile(sorted: &[f64], p: f64) -> f64 {
921 let n = sorted.len();
922 if n == 0 { return 0.0; }
923 if n == 1 { return sorted[0]; }
924 let h = (p / 100.0) * (n - 1) as f64;
925 let lo = h.floor() as usize;
926 let hi = (lo + 1).min(n - 1);
927 let frac = h - lo as f64;
928 sorted[lo] + frac * (sorted[hi] - sorted[lo])
929}
930
931fn lsqr_step(
937 f: &[f64],
938 mask: &[u8],
939 d: &[f64],
940 nx: usize, ny: usize, nz: usize,
941 vsx: f64, vsy: f64, vsz: f64,
942 workspace: &mut Fft3dWorkspace,
943) -> Vec<f64> {
944
945 let pmin = 60.0;
947 let pmax = 99.9;
948 let tol_lsqr = 0.01;
949 let maxit_lsqr = 50;
950
951 let w = laplacian_weights_ilsqr(f, mask, nx, ny, nz, vsx, vsy, vsz, pmin, pmax);
953
954 let wf: Vec<Complex64> = w.iter().zip(f.iter())
956 .map(|(&wi, &fi)| Complex64::new(wi * fi, 0.0))
957 .collect();
958
959 let mut wf_fft = wf.clone();
960 workspace.fft3d(&mut wf_fft);
961
962 let b: Vec<Complex64> = wf_fft.iter().zip(d.iter())
964 .map(|(wfi, &di)| wfi * di)
965 .collect();
966
967 let lsqr_ws = RefCell::new(Fft3dWorkspace::new(nx, ny, nz));
971 let apply_a = |x: &[Complex64]| -> Vec<Complex64> {
972 let dx: Vec<Complex64> = x.iter().zip(d.iter())
974 .map(|(xi, &di)| xi * di)
975 .collect();
976
977 let mut dx_ifft = dx.clone();
979 let mut ws = lsqr_ws.borrow_mut();
980 ws.ifft3d(&mut dx_ifft);
981
982 let wdx: Vec<Complex64> = w.iter().zip(dx_ifft.iter())
984 .map(|(&wi, dxi)| Complex64::new(wi * dxi.re, 0.0))
985 .collect();
986
987 let mut wdx_fft = wdx.clone();
989 ws.fft3d(&mut wdx_fft);
990
991 wdx_fft.iter().zip(d.iter())
993 .map(|(wdxi, &di)| wdxi * di)
994 .collect()
995 };
996
997 let apply_ah = |x: &[Complex64]| -> Vec<Complex64> {
999 apply_a(x)
1000 };
1001
1002 let x_lsqr = lsqr_complex(apply_a, apply_ah, &b, tol_lsqr, maxit_lsqr, false);
1004
1005 let mut x_ifft = x_lsqr;
1007 workspace.ifft3d(&mut x_ifft);
1008
1009 x_ifft.iter().zip(mask.iter())
1011 .map(|(xi, &mi)| if mi > 0 { xi.re } else { 0.0 })
1012 .collect()
1013}
1014
1015fn fastqsm_step(
1021 f: &[f64],
1022 mask: &[u8],
1023 d: &[f64],
1024 nx: usize, ny: usize, nz: usize,
1025 vsx: f64, vsy: f64, vsz: f64,
1026 workspace: &mut Fft3dWorkspace,
1027) -> Vec<f64> {
1028 let n_total = nx * ny * nz;
1029
1030 let f_complex: Vec<Complex64> = f.iter()
1032 .map(|&v| Complex64::new(v, 0.0))
1033 .collect();
1034
1035 let mut f_fft = f_complex;
1036 workspace.fft3d(&mut f_fft);
1037
1038 let sign_d = sign_array(d);
1040 let x: Vec<Complex64> = f_fft.iter().zip(sign_d.iter())
1041 .map(|(fi, &si)| fi * si)
1042 .collect();
1043
1044 let pa = 1.0;
1046 let pb = 30.0;
1047 let n_exp = 0.001;
1048 let wfs = dipole_kspace_weights_ilsqr(d, n_exp, pa, pb);
1049
1050 let r_smv = 3.0;
1052 let h = smv_kernel(nx, ny, nz, vsx, vsy, vsz, r_smv);
1053
1054 let h_complex: Vec<Complex64> = h.iter()
1056 .map(|&v| Complex64::new(v, 0.0))
1057 .collect();
1058 let mut h_fft_complex = h_complex;
1059 workspace.fft3d(&mut h_fft_complex);
1060 let h_fft: Vec<f64> = h_fft_complex.iter().map(|c| c.re).collect();
1061
1062 let mut x_filtered: Vec<Complex64> = x.iter()
1065 .zip(wfs.iter())
1066 .zip(h_fft.iter())
1067 .map(|((xi, &wi), &hi)| {
1068 xi * wi + xi * hi * (1.0 - wi)
1069 })
1070 .collect();
1071
1072 workspace.ifft3d(&mut x_filtered);
1073
1074 for (xi, &mi) in x_filtered.iter_mut().zip(mask.iter()) {
1076 if mi == 0 {
1077 *xi = Complex64::new(0.0, 0.0);
1078 } else {
1079 *xi = Complex64::new(xi.re, 0.0);
1080 }
1081 }
1082
1083 workspace.fft3d(&mut x_filtered);
1084
1085 let mut x_filtered2: Vec<Complex64> = x_filtered.iter()
1087 .zip(wfs.iter())
1088 .zip(h_fft.iter())
1089 .map(|((xi, &wi), &hi)| {
1090 xi * wi + xi * hi * (1.0 - wi)
1091 })
1092 .collect();
1093
1094 workspace.ifft3d(&mut x_filtered2);
1095
1096 let x_fs: Vec<f64> = x_filtered2.iter().zip(mask.iter())
1097 .map(|(xi, &mi)| if mi > 0 { xi.re } else { 0.0 })
1098 .collect();
1099
1100 let t0 = 1.0 / 8.0;
1102 let mut inv_d = vec![0.0; n_total];
1103 for i in 0..n_total {
1104 if d[i].abs() < t0 {
1105 inv_d[i] = d[i].signum() / t0;
1106 } else {
1107 inv_d[i] = 1.0 / d[i];
1108 }
1109 }
1110
1111 let x_tkd_fft: Vec<Complex64> = f_fft.iter().zip(inv_d.iter())
1112 .map(|(fi, &idi)| fi * idi)
1113 .collect();
1114
1115 let mut x_tkd_complex = x_tkd_fft;
1116 workspace.ifft3d(&mut x_tkd_complex);
1117
1118 let x_tkd: Vec<f64> = x_tkd_complex.iter().zip(mask.iter())
1119 .map(|(xi, &mi)| if mi > 0 { xi.re } else { 0.0 })
1120 .collect();
1121
1122 let sum_xfs: f64 = x_fs.iter().map(|&v| v).sum();
1126 let sum_xtkd: f64 = x_tkd.iter().map(|&v| v).sum();
1127 let sum_xfs2: f64 = x_fs.iter().map(|&v| v * v).sum();
1128 let sum_xfs_xtkd: f64 = x_fs.iter().zip(x_tkd.iter())
1129 .map(|(&xf, &xt)| xf * xt)
1130 .sum();
1131
1132 let n_all: f64 = n_total as f64;
1133
1134 let det = sum_xfs2 * n_all - sum_xfs * sum_xfs;
1136
1137 let (a, b) = if det.abs() > 1e-20 {
1138 let a = (n_all * sum_xfs_xtkd - sum_xfs * sum_xtkd) / det;
1139 let b = (sum_xfs2 * sum_xtkd - sum_xfs * sum_xfs_xtkd) / det;
1140 (a, b)
1141 } else {
1142 (1.0, 0.0)
1143 };
1144
1145 x_fs.iter().zip(mask.iter())
1147 .map(|(&xf, &mi)| if mi > 0 { a * xf + b } else { 0.0 })
1148 .collect()
1149}
1150
1151fn susceptibility_artifacts_step(
1157 x0: &[f64],
1158 xfs: &[f64],
1159 mask: &[u8],
1160 d: &[f64],
1161 nx: usize, ny: usize, nz: usize,
1162 vsx: f64, vsy: f64, vsz: f64,
1163 tol: f64,
1164 maxit: usize,
1165 _workspace: &mut Fft3dWorkspace,
1166) -> Vec<f64> {
1167 let n_total = nx * ny * nz;
1168
1169 let pmin = 50.0;
1171 let pmax = 70.0;
1172 let (wx, wy, wz) = gradient_weights_ilsqr(xfs, mask, nx, ny, nz, vsx, vsy, vsz, pmin, pmax);
1173
1174 let thr = 0.1;
1176 let mic: Vec<f64> = d.iter().map(|&di| if di.abs() < thr { 1.0 } else { 0.0 }).collect();
1177
1178 let (dx, dy, dz) = fgrad(x0, nx, ny, nz, vsx, vsy, vsz);
1180
1181 let bx = multiply_elementwise(&wx, &dx);
1183 let by = multiply_elementwise(&wy, &dy);
1184 let bz = multiply_elementwise(&wz, &dz);
1185
1186 let mut b = Vec::with_capacity(3 * n_total);
1187 b.extend_from_slice(&bx);
1188 b.extend_from_slice(&by);
1189 b.extend_from_slice(&bz);
1190
1191 let lsmr_ws = RefCell::new(Fft3dWorkspace::new(nx, ny, nz));
1194 let apply_a = |x_in: &[f64]| -> Vec<f64> {
1195 let x_complex: Vec<Complex64> = x_in.iter()
1198 .map(|&v| Complex64::new(v, 0.0))
1199 .collect();
1200
1201 let mut x_fft = x_complex;
1202 let mut ws = lsmr_ws.borrow_mut();
1203 ws.fft3d(&mut x_fft);
1204
1205 let x_mic: Vec<Complex64> = x_fft.iter().zip(mic.iter())
1207 .map(|(xi, &mi)| xi * mi)
1208 .collect();
1209
1210 let mut x_ifft = x_mic;
1211 ws.ifft3d(&mut x_ifft);
1212
1213 let x_filtered: Vec<f64> = x_ifft.iter().map(|xi| xi.re).collect();
1214
1215 let (gx, gy, gz) = fgrad(&x_filtered, nx, ny, nz, vsx, vsy, vsz);
1217
1218 let mut result = Vec::with_capacity(3 * n_total);
1220 result.extend(wx.iter().zip(gx.iter()).map(|(&w, &g)| w * g));
1221 result.extend(wy.iter().zip(gy.iter()).map(|(&w, &g)| w * g));
1222 result.extend(wz.iter().zip(gz.iter()).map(|(&w, &g)| w * g));
1223
1224 result
1225 };
1226
1227 let apply_at = |y_in: &[f64]| -> Vec<f64> {
1229 let yx = &y_in[0..n_total];
1231 let yy = &y_in[n_total..2*n_total];
1232 let yz = &y_in[2*n_total..3*n_total];
1233
1234 let wyx: Vec<f64> = wx.iter().zip(yx.iter()).map(|(&w, &y)| w * y).collect();
1236 let wyy: Vec<f64> = wy.iter().zip(yy.iter()).map(|(&w, &y)| w * y).collect();
1237 let wyz: Vec<f64> = wz.iter().zip(yz.iter()).map(|(&w, &y)| w * y).collect();
1238
1239 let div = bdiv(&wyx, &wyy, &wyz, nx, ny, nz, vsx, vsy, vsz);
1243
1244 let div_complex: Vec<Complex64> = div.iter()
1246 .map(|&v| Complex64::new(-v, 0.0))
1247 .collect();
1248
1249 let mut div_fft = div_complex;
1250 let mut ws = lsmr_ws.borrow_mut();
1251 ws.fft3d(&mut div_fft);
1252
1253 let div_mic: Vec<Complex64> = div_fft.iter().zip(mic.iter())
1254 .map(|(di, &mi)| di * mi)
1255 .collect();
1256
1257 let mut div_ifft = div_mic;
1258 ws.ifft3d(&mut div_ifft);
1259
1260 div_ifft.iter().map(|di| di.re).collect()
1261 };
1262
1263 let xsa = lsmr(apply_a, apply_at, &b, n_total, tol, tol, maxit, false);
1265
1266 xsa.iter().zip(mask.iter())
1268 .map(|(&x, &m)| if m > 0 { x } else { 0.0 })
1269 .collect()
1270}
1271
1272pub fn ilsqr(
1290 field: &[f64],
1291 mask: &[u8],
1292 nx: usize, ny: usize, nz: usize,
1293 vsx: f64, vsy: f64, vsz: f64,
1294 bdir: (f64, f64, f64),
1295 tol: f64,
1296 maxit: usize,
1297) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
1298 let d = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
1300
1301 let mut workspace = Fft3dWorkspace::new(nx, ny, nz);
1303
1304 let xlsqr = lsqr_step(field, mask, &d, nx, ny, nz, vsx, vsy, vsz, &mut workspace);
1306
1307 let xfs = fastqsm_step(field, mask, &d, nx, ny, nz, vsx, vsy, vsz, &mut workspace);
1309
1310 let xsa = susceptibility_artifacts_step(
1312 &xlsqr, &xfs, mask, &d,
1313 nx, ny, nz, vsx, vsy, vsz,
1314 tol, maxit, &mut workspace
1315 );
1316
1317 let chi: Vec<f64> = xlsqr.iter().zip(xsa.iter()).zip(mask.iter())
1319 .map(|((&xl, &xs), &m)| if m > 0 { xl - xs } else { 0.0 })
1320 .collect();
1321
1322 (chi, xsa, xfs, xlsqr)
1323}
1324
1325pub fn ilsqr_simple(
1327 field: &[f64],
1328 mask: &[u8],
1329 nx: usize, ny: usize, nz: usize,
1330 vsx: f64, vsy: f64, vsz: f64,
1331 bdir: (f64, f64, f64),
1332 tol: f64,
1333 maxit: usize,
1334) -> Vec<f64> {
1335 let (chi, _, _, _) = ilsqr(field, mask, nx, ny, nz, vsx, vsy, vsz, bdir, tol, maxit);
1336 chi
1337}
1338
1339pub fn ilsqr_with_progress<F>(
1341 field: &[f64],
1342 mask: &[u8],
1343 nx: usize, ny: usize, nz: usize,
1344 vsx: f64, vsy: f64, vsz: f64,
1345 bdir: (f64, f64, f64),
1346 tol: f64,
1347 maxit: usize,
1348 mut progress_callback: F,
1349) -> Vec<f64>
1350where
1351 F: FnMut(usize, usize),
1352{
1353 let d = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
1355
1356 let mut workspace = Fft3dWorkspace::new(nx, ny, nz);
1358
1359 progress_callback(1, 4);
1360
1361 let xlsqr = lsqr_step(field, mask, &d, nx, ny, nz, vsx, vsy, vsz, &mut workspace);
1363
1364 progress_callback(2, 4);
1365
1366 let xfs = fastqsm_step(field, mask, &d, nx, ny, nz, vsx, vsy, vsz, &mut workspace);
1368
1369 progress_callback(3, 4);
1370
1371 let xsa = susceptibility_artifacts_step(
1373 &xlsqr, &xfs, mask, &d,
1374 nx, ny, nz, vsx, vsy, vsz,
1375 tol, maxit, &mut workspace
1376 );
1377
1378 progress_callback(4, 4);
1379
1380 xlsqr.iter().zip(xsa.iter()).zip(mask.iter())
1382 .map(|((&xl, &xs), &m)| if m > 0 { xl - xs } else { 0.0 })
1383 .collect()
1384}
1385
1386#[cfg(test)]
1387mod tests {
1388 use super::*;
1389
1390 #[test]
1391 fn test_lsqr_simple() {
1392 let n = 10;
1394 let diag: Vec<f64> = (1..=n).map(|i| i as f64).collect();
1395 let b: Vec<f64> = diag.iter().map(|&d| d * 2.0).collect(); let apply_a = |x: &[f64]| -> Vec<f64> {
1398 x.iter().zip(diag.iter()).map(|(&xi, &di)| xi * di).collect()
1399 };
1400
1401 let x = lsqr(apply_a, apply_a, &b, 1e-10, 100);
1402
1403 for (i, &xi) in x.iter().enumerate() {
1404 assert!((xi - 2.0).abs() < 1e-6, "x[{}] = {}, expected 2.0", i, xi);
1405 }
1406 }
1407
1408 #[test]
1409 fn test_norm() {
1410 let x = vec![3.0, 4.0];
1411 assert!((norm(&x) - 5.0).abs() < 1e-10);
1412 }
1413
1414 #[test]
1415 fn test_sign_array() {
1416 let x = vec![-2.0, 0.0, 3.0];
1417 let s = sign_array(&x);
1418 assert_eq!(s, vec![-1.0, 0.0, 1.0]);
1419 }
1420
1421 #[test]
1422 fn test_lsqr_complex_diagonal() {
1423 let diag = vec![1.0, 2.0, 3.0];
1426 let expected = vec![
1427 Complex64::new(1.0, 1.0),
1428 Complex64::new(2.0, 1.0),
1429 Complex64::new(3.0, 1.0),
1430 ];
1431 let b: Vec<Complex64> = expected.iter().zip(diag.iter())
1432 .map(|(&xi, &di)| xi * di)
1433 .collect();
1434
1435 let diag_a = diag.clone();
1436 let diag_ah = diag.clone();
1437 let apply_a = move |x: &[Complex64]| -> Vec<Complex64> {
1438 x.iter().zip(diag_a.iter()).map(|(&xi, &di)| xi * di).collect()
1439 };
1440 let apply_ah = move |x: &[Complex64]| -> Vec<Complex64> {
1441 x.iter().zip(diag_ah.iter()).map(|(&xi, &di)| xi * di).collect()
1442 };
1443
1444 let x = lsqr_complex(apply_a, apply_ah, &b, 1e-10, 100, false);
1445
1446 for (i, (xi, ei)) in x.iter().zip(expected.iter()).enumerate() {
1447 assert!((xi.re - ei.re).abs() < 1e-6,
1448 "x[{}].re = {}, expected {}", i, xi.re, ei.re);
1449 assert!((xi.im - ei.im).abs() < 1e-6,
1450 "x[{}].im = {}, expected {}", i, xi.im, ei.im);
1451 }
1452 }
1453
1454 #[test]
1455 fn test_lsmr_diagonal() {
1456 let b = vec![3.0, 5.0, 7.0];
1460
1461 let apply_a = |x: &[f64]| -> Vec<f64> { x.to_vec() };
1462 let apply_at = |x: &[f64]| -> Vec<f64> { x.to_vec() };
1463
1464 let x = lsmr(apply_a, apply_at, &b, 3, 1e-6, 1e-6, 200, false);
1465
1466 assert_eq!(x.len(), 3);
1468 for (i, &xi) in x.iter().enumerate() {
1469 assert!(xi.is_finite(), "x[{}] = {} is not finite", i, xi);
1470 }
1471
1472 let residual: f64 = x.iter().zip(b.iter())
1474 .map(|(&xi, &bi)| (xi - bi).powi(2))
1475 .sum::<f64>()
1476 .sqrt();
1477 let bnorm: f64 = b.iter().map(|&bi| bi * bi).sum::<f64>().sqrt();
1478 assert!(residual < bnorm,
1479 "residual {} should be less than ||b|| = {}", residual, bnorm);
1480 }
1481
1482 #[test]
1483 fn test_laplacian_weights() {
1484 let (nx, ny, nz) = (4, 4, 4);
1487 let n_total = nx * ny * nz;
1488 let mut mask = vec![0u8; n_total];
1489 let mut field = vec![0.0; n_total];
1490
1491 for k in 0..nz {
1493 for j in 0..ny {
1494 for i in 0..nx {
1495 let idx = i + j * nx + k * nx * ny;
1496 let ci = i as f64 - 1.5;
1497 let cj = j as f64 - 1.5;
1498 let ck = k as f64 - 1.5;
1499 let r2 = ci * ci + cj * cj + ck * ck;
1500 if r2 < 2.5 {
1501 mask[idx] = 1;
1502 field[idx] = 5.0; }
1504 }
1505 }
1506 }
1507
1508 let w = laplacian_weights_ilsqr(&field, &mask, nx, ny, nz, 1.0, 1.0, 1.0, 10.0, 90.0);
1509
1510 for (i, &wi) in w.iter().enumerate() {
1512 assert!(wi.is_finite(), "weight[{}] is not finite", i);
1513 assert!(wi >= 0.0 && wi <= 1.0, "weight[{}] = {} out of [0,1]", i, wi);
1514 }
1515
1516 for i in 0..n_total {
1518 if mask[i] == 0 {
1519 assert_eq!(w[i], 0.0, "weight outside mask should be 0 at index {}", i);
1520 }
1521 }
1522 }
1523
1524 #[test]
1525 fn test_dipole_kspace_weights() {
1526 let d = vec![0.0, 0.01, 0.1, 0.3, 0.5, 0.7, 1.0, -0.5, -1.0, 0.0];
1528
1529 let w = dipole_kspace_weights_ilsqr(&d, 1.0, 1.0, 90.0);
1530
1531 for (i, &wi) in w.iter().enumerate() {
1533 assert!(wi >= 0.0 && wi <= 1.0,
1534 "weight[{}] = {} out of [0,1]", i, wi);
1535 }
1536
1537 assert!(w[0] <= 1e-10, "weight at D=0 should be ~0, got {}", w[0]);
1539
1540 assert!(w[6] > 0.5, "weight at |D|=1.0 should be large, got {}", w[6]);
1543 assert!(w[8] > 0.5, "weight at |D|=1.0 should be large, got {}", w[8]);
1544 }
1545
1546 #[test]
1547 fn test_gradient_weights() {
1548 let (nx, ny, nz) = (4, 4, 4);
1550 let n_total = nx * ny * nz;
1551
1552 let mask = vec![1u8; n_total];
1554
1555 let mut field = vec![0.0; n_total];
1557 for k in 0..nz {
1558 for j in 0..ny {
1559 for i in 0..nx {
1560 let idx = i + j * nx + k * nx * ny;
1561 field[idx] = i as f64; }
1563 }
1564 }
1565
1566 let (wx, wy, wz) = gradient_weights_ilsqr(
1567 &field, &mask, nx, ny, nz, 1.0, 1.0, 1.0, 10.0, 90.0
1568 );
1569
1570 for i in 0..n_total {
1572 assert!(wx[i].is_finite(), "wx[{}] is not finite", i);
1573 assert!(wy[i].is_finite(), "wy[{}] is not finite", i);
1574 assert!(wz[i].is_finite(), "wz[{}] is not finite", i);
1575 assert!(wx[i] >= 0.0 && wx[i] <= 1.0, "wx[{}] = {} out of [0,1]", i, wx[i]);
1576 assert!(wy[i] >= 0.0 && wy[i] <= 1.0, "wy[{}] = {} out of [0,1]", i, wy[i]);
1577 assert!(wz[i] >= 0.0 && wz[i] <= 1.0, "wz[{}] = {} out of [0,1]", i, wz[i]);
1578 }
1579
1580 let wy_sum: f64 = wy.iter().sum();
1583 let wz_sum: f64 = wz.iter().sum();
1584 assert!(wy_sum.is_finite(), "wy sum is not finite");
1585 assert!(wz_sum.is_finite(), "wz sum is not finite");
1586 }
1587
1588 #[test]
1589 fn test_ilsqr_small() {
1590 let (nx, ny, nz) = (8, 8, 8);
1594 let n_total = nx * ny * nz;
1595 let vsx = 1.0;
1596 let vsy = 1.0;
1597 let vsz = 1.0;
1598 let bdir = (0.0, 0.0, 1.0);
1599
1600 let mut mask = vec![0u8; n_total];
1602 let cx = (nx as f64 - 1.0) / 2.0;
1603 let cy = (ny as f64 - 1.0) / 2.0;
1604 let cz = (nz as f64 - 1.0) / 2.0;
1605 let radius = 3.0;
1606
1607 for k in 0..nz {
1608 for j in 0..ny {
1609 for i in 0..nx {
1610 let idx = i + j * nx + k * nx * ny;
1611 let di = i as f64 - cx;
1612 let dj = j as f64 - cy;
1613 let dk = k as f64 - cz;
1614 if di * di + dj * dj + dk * dk < radius * radius {
1615 mask[idx] = 1;
1616 }
1617 }
1618 }
1619 }
1620
1621 let mut field = vec![0.0; n_total];
1624 for k in 0..nz {
1625 for j in 0..ny {
1626 for i in 0..nx {
1627 let idx = i + j * nx + k * nx * ny;
1628 if mask[idx] > 0 {
1629 let di = i as f64 - cx;
1630 let dj = j as f64 - cy;
1631 let dk = k as f64 - cz;
1632 field[idx] = 0.01 * (dk * dk - di * di - dj * dj)
1634 / (di * di + dj * dj + dk * dk + 1.0);
1635 }
1636 }
1637 }
1638 }
1639
1640 let tol = 0.1;
1641 let maxit = 5; let chi = ilsqr_simple(&field, &mask, nx, ny, nz, vsx, vsy, vsz, bdir, tol, maxit);
1644
1645 assert_eq!(chi.len(), n_total, "output size mismatch");
1647
1648 for (i, &v) in chi.iter().enumerate() {
1650 assert!(v.is_finite(), "chi[{}] = {} is not finite", i, v);
1651 }
1652
1653 for i in 0..n_total {
1655 if mask[i] == 0 {
1656 assert_eq!(chi[i], 0.0, "chi outside mask should be 0 at index {}", i);
1657 }
1658 }
1659
1660 let inside_sum: f64 = chi.iter().zip(mask.iter())
1662 .filter(|(_, &m)| m > 0)
1663 .map(|(&v, _)| v.abs())
1664 .sum();
1665 assert!(inside_sum > 0.0, "chi should not be all zeros inside the mask");
1666 }
1667}