1use num_complex::Complex32;
32use crate::fft::Fft3dWorkspaceF32;
33use crate::kernels::dipole::dipole_kernel_f32;
34use crate::kernels::smv::smv_kernel_f32;
35use crate::utils::simd_ops::{
36 dot_product_f32, norm_squared_f32, axpy_f32, xpby_f32,
37 apply_gradient_weights_f32, compute_p_weights_f32, combine_terms_f32, negate_f32,
38};
39#[derive(Clone, Debug)]
46pub struct MediParams {
47 pub lambda: f64,
49 pub merit: bool,
51 pub smv: bool,
53 pub smv_radius: f64,
55 pub data_weighting: i32,
57 pub percentage: f64,
59 pub cg_tol: f64,
61 pub cg_max_iter: usize,
63 pub max_iter: usize,
65 pub tol: f64,
67}
68
69impl Default for MediParams {
70 fn default() -> Self {
71 Self {
72 lambda: 7.5e-5,
73 merit: false,
74 smv: true,
75 smv_radius: 5.0,
76 data_weighting: 1,
77 percentage: 0.3,
78 cg_tol: 0.01,
79 cg_max_iter: 10,
80 max_iter: 30,
81 tol: 0.1,
82 }
83 }
84}
85
86pub struct MediWorkspace {
89 pub n_total: usize,
90 pub nx: usize,
91 pub ny: usize,
92 pub nz: usize,
93 pub vsx: f32,
94 pub vsy: f32,
95 pub vsz: f32,
96
97 pub fft_ws: Fft3dWorkspaceF32,
99
100 pub gx: Vec<f32>,
102 pub gy: Vec<f32>,
103 pub gz: Vec<f32>,
104
105 pub reg_x: Vec<f32>,
107 pub reg_y: Vec<f32>,
108 pub reg_z: Vec<f32>,
109
110 pub div_buf: Vec<f32>,
112
113 pub complex_buf: Vec<Complex32>,
115 pub complex_buf2: Vec<Complex32>,
116
117 pub dipole_buf: Vec<f32>,
119
120 pub cg_r: Vec<f32>,
122 pub cg_p: Vec<f32>,
123 pub cg_ap: Vec<f32>,
124}
125
126impl MediWorkspace {
127 pub fn new(nx: usize, ny: usize, nz: usize, vsx: f32, vsy: f32, vsz: f32) -> Self {
129 let n_total = nx * ny * nz;
130
131 Self {
132 n_total,
133 nx, ny, nz,
134 vsx, vsy, vsz,
135 fft_ws: Fft3dWorkspaceF32::new(nx, ny, nz),
136 gx: vec![0.0; n_total],
137 gy: vec![0.0; n_total],
138 gz: vec![0.0; n_total],
139 reg_x: vec![0.0; n_total],
140 reg_y: vec![0.0; n_total],
141 reg_z: vec![0.0; n_total],
142 div_buf: vec![0.0; n_total],
143 complex_buf: vec![Complex32::new(0.0, 0.0); n_total],
144 complex_buf2: vec![Complex32::new(0.0, 0.0); n_total],
145 dipole_buf: vec![0.0; n_total],
146 cg_r: vec![0.0; n_total],
147 cg_p: vec![0.0; n_total],
148 cg_ap: vec![0.0; n_total],
149 }
150 }
151}
152
153#[inline]
155fn apply_dipole_conv(
156 fft_ws: &mut Fft3dWorkspaceF32,
157 x: &[f32],
158 d_kernel: &[f32],
159 out: &mut [f32],
160 complex_buf: &mut [Complex32],
161) {
162 fft_ws.apply_dipole_inplace(x, d_kernel, out, complex_buf);
163}
164
165struct MediOpBuffers<'a> {
167 gx: &'a mut [f32],
168 gy: &'a mut [f32],
169 gz: &'a mut [f32],
170 reg_x: &'a mut [f32],
171 reg_y: &'a mut [f32],
172 reg_z: &'a mut [f32],
173 div_buf: &'a mut [f32],
174 dipole_buf: &'a mut [f32],
175 complex_buf: &'a mut [Complex32],
176 complex_buf2: &'a mut [Complex32],
177}
178
179#[inline]
184fn apply_medi_operator_core(
185 fft_ws: &mut Fft3dWorkspaceF32,
186 bufs: &mut MediOpBuffers,
187 n: usize,
188 nx: usize, ny: usize, nz: usize,
189 vsx: f32, vsy: f32, vsz: f32,
190 dx: &[f32],
191 w: &[Complex32],
192 d_kernel: &[f32],
193 mx: &[f32], my: &[f32], mz: &[f32], vr: &[f32],
197 lambda: f32,
198 out: &mut [f32],
199) {
200 fgrad_periodic_inplace_f32(bufs.gx, bufs.gy, bufs.gz, dx, nx, ny, nz, vsx, vsy, vsz);
202
203 apply_gradient_weights_f32(
206 bufs.reg_x, bufs.reg_y, bufs.reg_z,
207 mx, my, mz, vr,
208 bufs.gx, bufs.gy, bufs.gz,
209 );
210
211 bdiv_periodic_inplace_f32(bufs.div_buf, bufs.reg_x, bufs.reg_y, bufs.reg_z, nx, ny, nz, vsx, vsy, vsz);
213
214 apply_dipole_conv(fft_ws, dx, d_kernel, bufs.dipole_buf, bufs.complex_buf);
216
217 for i in 0..n {
219 let w_mag_sq = w[i].norm_sqr();
220 bufs.complex_buf2[i] = Complex32::new(bufs.dipole_buf[i] * w_mag_sq, 0.0);
221 }
222
223 fft_ws.fft3d(bufs.complex_buf2);
225 for i in 0..n {
226 bufs.complex_buf2[i] *= d_kernel[i];
227 }
228 fft_ws.ifft3d(bufs.complex_buf2);
229
230 for i in 0..n {
233 bufs.dipole_buf[i] = bufs.complex_buf2[i].re;
234 }
235 combine_terms_f32(out, bufs.div_buf, bufs.dipole_buf, lambda);
236}
237
238#[inline]
244fn cg_solve_medi<F>(
245 ws: &mut MediWorkspace,
246 w: &[Complex32],
247 d_kernel: &[f32],
248 mx: &[f32], my: &[f32], mz: &[f32], vr: &[f32],
252 lambda: f32,
253 b: &[f32],
254 x: &mut [f32],
255 tol: f32,
256 max_iter: usize,
257 mut progress_callback: F,
258) where
259 F: FnMut(usize, usize),
260{
261 let n = ws.n_total;
262 let (nx, ny, nz) = (ws.nx, ws.ny, ws.nz);
263 let (vsx, vsy, vsz) = (ws.vsx, ws.vsy, ws.vsz);
264
265 x.fill(0.0);
267
268 ws.cg_r.copy_from_slice(b);
270
271 ws.cg_p.copy_from_slice(&ws.cg_r);
273
274 let mut rsold: f32 = norm_squared_f32(&ws.cg_r);
276
277 let b_norm: f32 = norm_squared_f32(b).sqrt();
279 if b_norm < 1e-10 {
280 return; }
282
283 let mut p_copy = vec![0.0f32; n];
285
286 for cg_iter in 0..max_iter {
287 progress_callback(cg_iter + 1, max_iter);
289
290 p_copy.copy_from_slice(&ws.cg_p);
292
293 {
295 let mut bufs = MediOpBuffers {
296 gx: &mut ws.gx,
297 gy: &mut ws.gy,
298 gz: &mut ws.gz,
299 reg_x: &mut ws.reg_x,
300 reg_y: &mut ws.reg_y,
301 reg_z: &mut ws.reg_z,
302 div_buf: &mut ws.div_buf,
303 dipole_buf: &mut ws.dipole_buf,
304 complex_buf: &mut ws.complex_buf,
305 complex_buf2: &mut ws.complex_buf2,
306 };
307 apply_medi_operator_core(
308 &mut ws.fft_ws, &mut bufs, n, nx, ny, nz, vsx, vsy, vsz,
309 &p_copy, w, d_kernel, mx, my, mz, vr, lambda, &mut ws.cg_ap
310 );
311 }
312
313 let pap: f32 = dot_product_f32(&ws.cg_p, &ws.cg_ap);
315
316 if pap.abs() < 1e-15 {
317 break;
318 }
319
320 let alpha = rsold / pap;
321
322 axpy_f32(x, alpha, &ws.cg_p);
324
325 axpy_f32(&mut ws.cg_r, -alpha, &ws.cg_ap);
327
328 let rsnew: f32 = norm_squared_f32(&ws.cg_r);
330 let residual = rsnew.sqrt();
331
332 if residual < tol * b_norm {
334 break;
335 }
336
337 let beta = rsnew / rsold;
338
339 xpby_f32(&mut ws.cg_p, &ws.cg_r, beta);
341
342 rsold = rsnew;
343 }
344}
345
346#[allow(clippy::too_many_arguments)]
370pub fn medi_l1(
371 local_field: &[f64],
372 n_std: &[f64],
373 magnitude: &[f64],
374 mask: &[u8],
375 nx: usize, ny: usize, nz: usize,
376 vsx: f64, vsy: f64, vsz: f64,
377 lambda: f64,
378 bdir: (f64, f64, f64),
379 merit: bool,
380 smv: bool,
381 smv_radius: f64,
382 data_weighting: i32,
383 percentage: f64,
384 cg_tol: f64,
385 cg_max_iter: usize,
386 max_iter: usize,
387 tol: f64,
388) -> Vec<f64> {
389 let n_total = nx * ny * nz;
390
391 let vsx_f32 = vsx as f32;
393 let vsy_f32 = vsy as f32;
394 let vsz_f32 = vsz as f32;
395 let lambda_f32 = lambda as f32;
396 let bdir_f32 = (bdir.0 as f32, bdir.1 as f32, bdir.2 as f32);
397 let smv_radius_f32 = smv_radius as f32;
398 let percentage_f32 = percentage as f32;
399 let cg_tol_f32 = cg_tol as f32;
400 let tol_f32 = tol as f32;
401
402 let local_field_f32: Vec<f32> = local_field.iter().map(|&v| v as f32).collect();
404 let n_std_f32: Vec<f32> = n_std.iter().map(|&v| v as f32).collect();
405 let magnitude_f32: Vec<f32> = magnitude.iter().map(|&v| v as f32).collect();
406
407 let mut ws = MediWorkspace::new(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
409
410 let mut rdf: Vec<f32> = local_field_f32.clone();
412 let mut work_mask: Vec<u8> = mask.to_vec();
413 let mut tempn: Vec<f32> = n_std_f32.clone();
414
415 for i in 0..n_total {
417 if mask[i] == 0 {
418 tempn[i] = 0.0;
419 }
420 }
421
422 let mut d_kernel = dipole_kernel_f32(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, bdir_f32);
424
425 let sphere_k = if smv {
427 let sk = smv_kernel_f32(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, smv_radius_f32);
428
429 let mut sk_fft: Vec<Complex32> = sk.iter()
431 .map(|&v| Complex32::new(v, 0.0))
432 .collect();
433 ws.fft_ws.fft3d(&mut sk_fft);
434
435 let mask_f32: Vec<f32> = work_mask.iter().map(|&m| m as f32).collect();
437 let smv_mask = apply_smv_kernel_ws(&mask_f32, &sk_fft, &mut ws);
438 for i in 0..n_total {
439 work_mask[i] = if smv_mask[i] > 0.999 { 1 } else { 0 };
440 }
441
442 for i in 0..n_total {
444 d_kernel[i] *= 1.0 - sk[i];
445 }
446
447 let smv_rdf = apply_smv_kernel_ws(&rdf, &sk_fft, &mut ws);
449 for i in 0..n_total {
450 rdf[i] -= smv_rdf[i];
451 if work_mask[i] == 0 {
452 rdf[i] = 0.0;
453 }
454 }
455
456 let tempn_sq: Vec<f32> = tempn.iter().map(|&t| t * t).collect();
458 let smv_tempn_sq = apply_smv_kernel_ws(&tempn_sq, &sk_fft, &mut ws);
459 for i in 0..n_total {
460 tempn[i] = (smv_tempn_sq[i] + tempn_sq[i]).sqrt();
461 }
462
463 Some(sk_fft)
464 } else {
465 None
466 };
467
468 let mut m = dataterm_mask_f32(data_weighting, &tempn, &work_mask);
470
471 let mut b0: Vec<Complex32> = rdf.iter()
473 .zip(m.iter())
474 .map(|(&f, &mi)| {
475 let phase = Complex32::new(0.0, f);
476 mi * phase.exp()
477 })
478 .collect();
479
480 let (w_gx, w_gy, w_gz) = gradient_mask_f32(&magnitude_f32, &work_mask, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, percentage_f32);
483
484 let w_gx = if w_gx.iter().any(|&v| v != 0.0) { w_gx } else { magnitude_f32.clone() };
486 let w_gy = if w_gy.iter().any(|&v| v != 0.0) { w_gy } else { magnitude_f32.clone() };
487 let w_gz = if w_gz.iter().any(|&v| v != 0.0) { w_gz } else { magnitude_f32.clone() };
488
489 let mut chi = vec![0.0f32; n_total];
491 let mut dx = vec![0.0f32; n_total]; let mut rhs = vec![0.0f32; n_total]; let mut vr = vec![0.0f32; n_total]; let mut w: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); n_total]; let mut chi_prev = vec![0.0f32; n_total]; let mut badpoint = vec![0.0f32; n_total];
497 let mut n_std_work: Vec<f32> = n_std_f32.clone();
498
499 let beta = 1.49e-8_f32;
503
504 for _iter in 0..max_iter {
506 chi_prev.copy_from_slice(&chi);
508
509 fgrad_periodic_inplace_f32(
514 &mut ws.gx, &mut ws.gy, &mut ws.gz,
515 &chi, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32,
516 );
517
518 compute_p_weights_f32(&mut vr, &w_gx, &w_gy, &w_gz, &ws.gx, &ws.gy, &ws.gz, beta);
519
520 apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
522 for i in 0..n_total {
523 let phase = Complex32::new(0.0, ws.dipole_buf[i]);
524 w[i] = m[i] * phase.exp();
525 }
526
527 compute_rhs_inplace(&chi, &w, &b0, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda_f32, &mut rhs, &mut ws);
529
530 negate_f32(&mut rhs);
532
533 cg_solve_medi(&mut ws, &w, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda_f32, &rhs, &mut dx, cg_tol_f32, cg_max_iter, |_, _| {});
535
536 axpy_f32(&mut chi, 1.0, &dx);
538
539 let norm_dx_sq = norm_squared_f32(&dx);
541 let norm_chi_sq = norm_squared_f32(&chi_prev);
542 let rel_change = norm_dx_sq.sqrt() / (norm_chi_sq.sqrt() + 1e-6);
543
544 if merit {
546 apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
548 let mut wres: Vec<Complex32> = ws.dipole_buf.iter()
549 .zip(m.iter())
550 .zip(b0.iter())
551 .map(|((&dc, &mi), &b0i)| {
552 let phase = Complex32::new(0.0, dc);
553 mi * phase.exp() - b0i
554 })
555 .collect();
556
557 let mask_count = work_mask.iter().filter(|&&m| m != 0).count() as f32;
559 if mask_count > 0.0 {
560 let mean_wres: Complex32 = wres.iter()
561 .zip(work_mask.iter())
562 .filter(|(_, &m)| m != 0)
563 .map(|(w, _)| w)
564 .sum::<Complex32>() / mask_count;
565
566 for i in 0..n_total {
567 if work_mask[i] != 0 {
568 wres[i] -= mean_wres;
569 }
570 }
571 }
572
573 let abs_wres: Vec<f32> = wres.iter()
575 .zip(work_mask.iter())
576 .filter(|(_, &m)| m != 0)
577 .map(|(w, _)| w.norm())
578 .collect();
579
580 if !abs_wres.is_empty() {
581 let mean_abs: f32 = abs_wres.iter().sum::<f32>() / abs_wres.len() as f32;
582 let var: f32 = abs_wres.iter()
583 .map(|&v| (v - mean_abs).powi(2))
584 .sum::<f32>() / abs_wres.len() as f32;
585 let factor = var.sqrt() * 6.0;
586
587 if factor > 1e-10 {
588 let mut wres_norm: Vec<f32> = wres.iter()
590 .map(|w| w.norm() / factor)
591 .collect();
592
593 for v in wres_norm.iter_mut() {
595 if *v < 1.0 {
596 *v = 1.0;
597 }
598 }
599
600 for i in 0..n_total {
602 if wres_norm[i] > 1.0 {
603 badpoint[i] = 1.0;
604 }
605 if work_mask[i] != 0 {
606 n_std_work[i] *= wres_norm[i].powi(2);
607 }
608 }
609
610 tempn = n_std_work.clone();
612 if let Some(ref sk_fft) = sphere_k {
613 let tempn_sq: Vec<f32> = tempn.iter().map(|&t| t * t).collect();
614 let smv_tempn_sq = apply_smv_kernel_ws(&tempn_sq, sk_fft, &mut ws);
615 for i in 0..n_total {
616 tempn[i] = (smv_tempn_sq[i] + tempn_sq[i]).sqrt();
617 }
618 }
619
620 m = dataterm_mask_f32(data_weighting, &tempn, &work_mask);
622 b0 = rdf.iter()
623 .zip(m.iter())
624 .map(|(&f, &mi)| {
625 let phase = Complex32::new(0.0, f);
626 mi * phase.exp()
627 })
628 .collect();
629 }
630 }
631 }
632
633 if rel_change < tol_f32 {
634 break;
635 }
636 }
637
638 let _ = badpoint;
640
641 chi.iter()
643 .zip(mask.iter())
644 .map(|(&c, &m)| if m == 0 { 0.0 } else { c as f64 })
645 .collect()
646}
647
648fn apply_smv_kernel_ws(
650 x: &[f32],
651 sk_fft: &[Complex32],
652 ws: &mut MediWorkspace,
653) -> Vec<f32> {
654 let n_total = ws.n_total;
655
656 for (c, &r) in ws.complex_buf.iter_mut().zip(x.iter()) {
658 *c = Complex32::new(r, 0.0);
659 }
660
661 ws.fft_ws.fft3d(&mut ws.complex_buf);
662
663 for i in 0..n_total {
664 ws.complex_buf[i] *= sk_fft[i];
665 }
666
667 ws.fft_ws.ifft3d(&mut ws.complex_buf);
668
669 ws.complex_buf.iter().map(|c| c.re).collect()
670}
671
672fn compute_rhs_inplace(
676 chi: &[f32],
677 w: &[Complex32],
678 b0: &[Complex32],
679 d_kernel: &[f32],
680 mx: &[f32], my: &[f32], mz: &[f32], vr: &[f32],
684 lambda: f32,
685 rhs: &mut [f32],
686 ws: &mut MediWorkspace,
687) {
688 let n = ws.n_total;
689
690 fgrad_periodic_inplace_f32(
695 &mut ws.gx, &mut ws.gy, &mut ws.gz,
696 chi, ws.nx, ws.ny, ws.nz,
697 ws.vsx, ws.vsy, ws.vsz,
698 );
699
700 apply_gradient_weights_f32(
702 &mut ws.reg_x, &mut ws.reg_y, &mut ws.reg_z,
703 mx, my, mz, vr,
704 &ws.gx, &ws.gy, &ws.gz,
705 );
706
707 bdiv_periodic_inplace_f32(
708 &mut ws.div_buf,
709 &ws.reg_x, &ws.reg_y, &ws.reg_z,
710 ws.nx, ws.ny, ws.nz,
711 ws.vsx, ws.vsy, ws.vsz,
712 );
713
714 for i in 0..n {
717 let diff = w[i] - b0[i];
718 let conj_w = w[i].conj();
719 let neg_i = Complex32::new(0.0, -1.0);
720 ws.complex_buf2[i] = conj_w * neg_i * diff;
721 }
722
723 ws.fft_ws.fft3d(&mut ws.complex_buf2);
725
726 for i in 0..n {
727 ws.complex_buf2[i] *= d_kernel[i];
728 }
729
730 ws.fft_ws.ifft3d(&mut ws.complex_buf2);
731
732 for i in 0..n {
734 ws.dipole_buf[i] = ws.complex_buf2[i].re;
735 }
736
737 combine_terms_f32(rhs, &ws.div_buf, &ws.dipole_buf, lambda);
739}
740
741fn dataterm_mask_f32(mode: i32, n_std: &[f32], mask: &[u8]) -> Vec<f32> {
748 let n = n_std.len();
749
750 if mode == 0 {
751 mask.iter().map(|&m| if m != 0 { 1.0 } else { 0.0 }).collect()
753 } else {
754 let mut w: Vec<f32> = n_std.iter()
756 .zip(mask.iter())
757 .map(|(&n, &m)| {
758 if m != 0 && n > 1e-10 {
759 1.0 / n
760 } else {
761 0.0
762 }
763 })
764 .collect();
765
766 let mask_count = mask.iter().filter(|&&m| m != 0).count() as f32;
768 if mask_count > 0.0 {
769 let sum: f32 = w.iter()
770 .zip(mask.iter())
771 .filter(|(_, &m)| m != 0)
772 .map(|(&wi, _)| wi)
773 .sum();
774 let mean = sum / mask_count;
775
776 if mean > 1e-10 {
777 for i in 0..n {
779 w[i] /= mean;
780 }
781 }
782 }
783
784 for i in 0..n {
786 if mask[i] == 0 {
787 w[i] = 0.0;
788 }
789 }
790
791 w
792 }
793}
794
795pub(crate) fn gradient_mask_f32(
811 magnitude: &[f32],
812 mask: &[u8],
813 nx: usize, ny: usize, nz: usize,
814 vsx: f32, vsy: f32, vsz: f32,
815 percentage: f32,
816) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
817 let n_total = nx * ny * nz;
818
819 let mag_max = magnitude.iter()
821 .zip(mask.iter())
822 .filter(|(_, &m)| m != 0)
823 .map(|(&v, _)| v.abs())
824 .fold(0.0_f32, f32::max);
825
826 let mag_normalized: Vec<f32> = magnitude.iter()
827 .zip(mask.iter())
828 .map(|(&m, &msk)| {
829 if msk != 0 && mag_max > 1e-10 {
830 m / mag_max
831 } else {
832 0.0
833 }
834 })
835 .collect();
836
837 let (gx, gy, gz) = fgrad_linext_f32(&mag_normalized, nx, ny, nz, vsx, vsy, vsz);
839
840 let abs_gx: Vec<f32> = gx.iter().map(|&v| v.abs()).collect();
842 let abs_gy: Vec<f32> = gy.iter().map(|&v| v.abs()).collect();
843 let abs_gz: Vec<f32> = gz.iter().map(|&v| v.abs()).collect();
844
845 let mut all_grads: Vec<f32> = Vec::with_capacity(3 * n_total);
847 for i in 0..n_total {
848 if mask[i] != 0 {
849 all_grads.push(abs_gx[i]);
850 all_grads.push(abs_gy[i]);
851 all_grads.push(abs_gz[i]);
852 }
853 }
854
855 if all_grads.is_empty() {
856 return (vec![1.0; n_total], vec![1.0; n_total], vec![1.0; n_total]);
857 }
858
859 all_grads.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
862 let percentile_idx = ((1.0 - percentage) * (all_grads.len() - 1) as f32) as usize;
863 let threshold = all_grads[percentile_idx.min(all_grads.len() - 1)];
864
865 let mx: Vec<f32> = abs_gx.iter()
868 .zip(mask.iter())
869 .map(|(&g, &m)| if m != 0 && g < threshold { 1.0 } else { 0.0 })
870 .collect();
871
872 let my: Vec<f32> = abs_gy.iter()
873 .zip(mask.iter())
874 .map(|(&g, &m)| if m != 0 && g < threshold { 1.0 } else { 0.0 })
875 .collect();
876
877 let mz: Vec<f32> = abs_gz.iter()
878 .zip(mask.iter())
879 .map(|(&g, &m)| if m != 0 && g < threshold { 1.0 } else { 0.0 })
880 .collect();
881
882 (mx, my, mz)
883}
884
885pub(crate) fn fgrad_linext_f32(
888 x: &[f32],
889 nx: usize, ny: usize, nz: usize,
890 vsx: f32, vsy: f32, vsz: f32,
891) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
892 let n_total = nx * ny * nz;
893 let mut gx = vec![0.0f32; n_total];
894 let mut gy = vec![0.0f32; n_total];
895 let mut gz = vec![0.0f32; n_total];
896 fgrad_linext_inplace_f32(&mut gx, &mut gy, &mut gz, x, nx, ny, nz, vsx, vsy, vsz);
897 (gx, gy, gz)
898}
899
900#[inline]
903pub(crate) fn fgrad_linext_inplace_f32(
904 gx: &mut [f32], gy: &mut [f32], gz: &mut [f32],
905 x: &[f32],
906 nx: usize, ny: usize, nz: usize,
907 vsx: f32, vsy: f32, vsz: f32,
908) {
909 let hx = 1.0 / vsx;
910 let hy = 1.0 / vsy;
911 let hz = 1.0 / vsz;
912
913 for k in 0..nz {
914 let k_offset = k * nx * ny;
915
916 for j in 0..ny {
917 let j_offset = j * nx;
918
919 for i in 0..nx {
920 let idx = i + j_offset + k_offset;
921 let x_val = x[idx];
922
923 if i + 1 < nx {
926 gx[idx] = (x[idx + 1] - x_val) * hx;
927 } else if i > 0 {
928 gx[idx] = gx[idx - 1];
930 } else {
931 gx[idx] = 0.0;
932 }
933
934 if j + 1 < ny {
935 gy[idx] = (x[i + (j + 1) * nx + k_offset] - x_val) * hy;
936 } else if j > 0 {
937 gy[idx] = gy[i + (j - 1) * nx + k_offset];
938 } else {
939 gy[idx] = 0.0;
940 }
941
942 if k + 1 < nz {
943 gz[idx] = (x[i + j_offset + (k + 1) * nx * ny] - x_val) * hz;
944 } else if k > 0 {
945 gz[idx] = gz[i + j_offset + (k - 1) * nx * ny];
946 } else {
947 gz[idx] = 0.0;
948 }
949 }
950 }
951 }
952}
953
954
955#[inline]
959pub(crate) fn fgrad_periodic_inplace_f32(
960 gx: &mut [f32], gy: &mut [f32], gz: &mut [f32],
961 x: &[f32],
962 nx: usize, ny: usize, nz: usize,
963 vsx: f32, vsy: f32, vsz: f32,
964) {
965 let hx = 1.0 / vsx;
966 let hy = 1.0 / vsy;
967 let hz = 1.0 / vsz;
968 let nxny = nx * ny;
969
970 for k in 0..nz {
971 let k_offset = k * nxny;
972
973 for j in 0..ny {
974 let j_offset = j * nx;
975
976 for i in 0..nx {
977 let idx = i + j_offset + k_offset;
978 let x_val = x[idx];
979
980 let x_next = if i + 1 < nx { x[idx + 1] } else { x[j_offset + k_offset] };
982 gx[idx] = (x_next - x_val) * hx;
983
984 let y_next = if j + 1 < ny { x[i + (j + 1) * nx + k_offset] } else { x[i + k_offset] };
986 gy[idx] = (y_next - x_val) * hy;
987
988 let z_next = if k + 1 < nz { x[i + j_offset + (k + 1) * nxny] } else { x[i + j_offset] };
990 gz[idx] = (z_next - x_val) * hz;
991 }
992 }
993 }
994}
995
996#[inline]
1000pub(crate) fn bdiv_periodic_inplace_f32(
1001 div: &mut [f32],
1002 gx: &[f32], gy: &[f32], gz: &[f32],
1003 nx: usize, ny: usize, nz: usize,
1004 vsx: f32, vsy: f32, vsz: f32,
1005) {
1006 let hx = -1.0 / vsx; let hy = -1.0 / vsy;
1008 let hz = -1.0 / vsz;
1009 let nxny = nx * ny;
1010
1011 for k in 0..nz {
1012 let k_offset = k * nxny;
1013
1014 for j in 0..ny {
1015 let j_offset = j * nx;
1016
1017 for i in 0..nx {
1018 let idx = i + j_offset + k_offset;
1019
1020 let gx_prev = if i > 0 { gx[idx - 1] } else { gx[(nx - 1) + j_offset + k_offset] };
1022 let gx_term = (gx[idx] - gx_prev) * hx;
1023
1024 let gy_prev = if j > 0 { gy[i + (j - 1) * nx + k_offset] } else { gy[i + (ny - 1) * nx + k_offset] };
1026 let gy_term = (gy[idx] - gy_prev) * hy;
1027
1028 let gz_prev = if k > 0 { gz[i + j_offset + (k - 1) * nxny] } else { gz[i + j_offset + (nz - 1) * nxny] };
1030 let gz_term = (gz[idx] - gz_prev) * hz;
1031
1032 div[idx] = gx_term + gy_term + gz_term;
1033 }
1034 }
1035 }
1036}
1037
1038#[allow(clippy::too_many_arguments)]
1042pub fn medi_l1_with_progress<F>(
1043 local_field: &[f64],
1044 n_std: &[f64],
1045 magnitude: &[f64],
1046 mask: &[u8],
1047 nx: usize, ny: usize, nz: usize,
1048 vsx: f64, vsy: f64, vsz: f64,
1049 lambda: f64,
1050 bdir: (f64, f64, f64),
1051 merit: bool,
1052 smv: bool,
1053 smv_radius: f64,
1054 data_weighting: i32,
1055 percentage: f64,
1056 cg_tol: f64,
1057 cg_max_iter: usize,
1058 max_iter: usize,
1059 tol: f64,
1060 mut progress_callback: F,
1061) -> Vec<f64>
1062where
1063 F: FnMut(usize, usize),
1064{
1065 let n_total = nx * ny * nz;
1066
1067 let vsx_f32 = vsx as f32;
1069 let vsy_f32 = vsy as f32;
1070 let vsz_f32 = vsz as f32;
1071 let lambda_f32 = lambda as f32;
1072 let bdir_f32 = (bdir.0 as f32, bdir.1 as f32, bdir.2 as f32);
1073 let smv_radius_f32 = smv_radius as f32;
1074 let percentage_f32 = percentage as f32;
1075 let cg_tol_f32 = cg_tol as f32;
1076 let tol_f32 = tol as f32;
1077
1078 let local_field_f32: Vec<f32> = local_field.iter().map(|&v| v as f32).collect();
1080 let n_std_f32: Vec<f32> = n_std.iter().map(|&v| v as f32).collect();
1081 let magnitude_f32: Vec<f32> = magnitude.iter().map(|&v| v as f32).collect();
1082
1083 let mut ws = MediWorkspace::new(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
1085
1086 let mut rdf: Vec<f32> = local_field_f32.clone();
1088 let mut work_mask: Vec<u8> = mask.to_vec();
1089 let mut tempn: Vec<f32> = n_std_f32.clone();
1090
1091 for i in 0..n_total {
1093 if mask[i] == 0 {
1094 tempn[i] = 0.0;
1095 }
1096 }
1097
1098 let mut d_kernel = dipole_kernel_f32(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, bdir_f32);
1100
1101 let sphere_k = if smv {
1103 let sk = smv_kernel_f32(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, smv_radius_f32);
1104
1105 let mut sk_fft: Vec<Complex32> = sk.iter()
1107 .map(|&v| Complex32::new(v, 0.0))
1108 .collect();
1109 ws.fft_ws.fft3d(&mut sk_fft);
1110
1111 let mask_f32: Vec<f32> = work_mask.iter().map(|&m| m as f32).collect();
1113 let smv_mask = apply_smv_kernel_ws(&mask_f32, &sk_fft, &mut ws);
1114 for i in 0..n_total {
1115 work_mask[i] = if smv_mask[i] > 0.999 { 1 } else { 0 };
1116 }
1117
1118 for i in 0..n_total {
1120 d_kernel[i] *= 1.0 - sk[i];
1121 }
1122
1123 let smv_rdf = apply_smv_kernel_ws(&rdf, &sk_fft, &mut ws);
1125 for i in 0..n_total {
1126 rdf[i] -= smv_rdf[i];
1127 if work_mask[i] == 0 {
1128 rdf[i] = 0.0;
1129 }
1130 }
1131
1132 let tempn_sq: Vec<f32> = tempn.iter().map(|&t| t * t).collect();
1134 let smv_tempn_sq = apply_smv_kernel_ws(&tempn_sq, &sk_fft, &mut ws);
1135 for i in 0..n_total {
1136 tempn[i] = (smv_tempn_sq[i] + tempn_sq[i]).sqrt();
1137 }
1138
1139 Some(sk_fft)
1140 } else {
1141 None
1142 };
1143
1144 let mut m = dataterm_mask_f32(data_weighting, &tempn, &work_mask);
1146
1147 let mut b0: Vec<Complex32> = rdf.iter()
1149 .zip(m.iter())
1150 .map(|(&f, &mi)| {
1151 let phase = Complex32::new(0.0, f);
1152 mi * phase.exp()
1153 })
1154 .collect();
1155
1156 let (w_gx, w_gy, w_gz) = gradient_mask_f32(&magnitude_f32, &work_mask, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, percentage_f32);
1159
1160 let w_gx = if w_gx.iter().any(|&v| v != 0.0) { w_gx } else { magnitude_f32.clone() };
1162 let w_gy = if w_gy.iter().any(|&v| v != 0.0) { w_gy } else { magnitude_f32.clone() };
1163 let w_gz = if w_gz.iter().any(|&v| v != 0.0) { w_gz } else { magnitude_f32.clone() };
1164
1165 let mut chi = vec![0.0f32; n_total];
1167 let mut dx = vec![0.0f32; n_total];
1168 let mut rhs = vec![0.0f32; n_total];
1169 let mut vr = vec![0.0f32; n_total];
1170 let mut w: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); n_total];
1171 let mut chi_prev = vec![0.0f32; n_total];
1172 let mut badpoint = vec![0.0f32; n_total];
1173 let mut n_std_work: Vec<f32> = n_std_f32.clone();
1174
1175 let beta = 1.49e-8_f32;
1179
1180 let total_steps = max_iter * cg_max_iter;
1182
1183 for iter in 0..max_iter {
1185 chi_prev.copy_from_slice(&chi);
1186
1187 fgrad_periodic_inplace_f32(
1191 &mut ws.gx, &mut ws.gy, &mut ws.gz,
1192 &chi, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32,
1193 );
1194
1195 compute_p_weights_f32(&mut vr, &w_gx, &w_gy, &w_gz, &ws.gx, &ws.gy, &ws.gz, beta);
1196
1197 apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
1199 for i in 0..n_total {
1200 let phase = Complex32::new(0.0, ws.dipole_buf[i]);
1201 w[i] = m[i] * phase.exp();
1202 }
1203
1204 compute_rhs_inplace(&chi, &w, &b0, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda_f32, &mut rhs, &mut ws);
1206
1207 negate_f32(&mut rhs);
1209
1210 let gn_iter = iter;
1213 cg_solve_medi(
1214 &mut ws, &w, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda_f32, &rhs, &mut dx, cg_tol_f32, cg_max_iter,
1215 |cg_iter, cg_total| {
1216 let current = gn_iter * cg_total + cg_iter;
1217 progress_callback(current, total_steps);
1218 }
1219 );
1220
1221 axpy_f32(&mut chi, 1.0, &dx);
1223
1224 let norm_dx_sq = norm_squared_f32(&dx);
1226 let norm_chi_sq = norm_squared_f32(&chi_prev);
1227 let rel_change = norm_dx_sq.sqrt() / (norm_chi_sq.sqrt() + 1e-6);
1228
1229 if merit {
1231 apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
1233 let mut wres: Vec<Complex32> = ws.dipole_buf.iter()
1234 .zip(m.iter())
1235 .zip(b0.iter())
1236 .map(|((&dc, &mi), &b0i)| {
1237 let phase = Complex32::new(0.0, dc);
1238 mi * phase.exp() - b0i
1239 })
1240 .collect();
1241
1242 let mask_count = work_mask.iter().filter(|&&m| m != 0).count() as f32;
1244 if mask_count > 0.0 {
1245 let mean_wres: Complex32 = wres.iter()
1246 .zip(work_mask.iter())
1247 .filter(|(_, &m)| m != 0)
1248 .map(|(w, _)| w)
1249 .sum::<Complex32>() / mask_count;
1250
1251 for i in 0..n_total {
1252 if work_mask[i] != 0 {
1253 wres[i] -= mean_wres;
1254 }
1255 }
1256 }
1257
1258 let abs_wres: Vec<f32> = wres.iter()
1260 .zip(work_mask.iter())
1261 .filter(|(_, &m)| m != 0)
1262 .map(|(w, _)| w.norm())
1263 .collect();
1264
1265 if !abs_wres.is_empty() {
1266 let mean_abs: f32 = abs_wres.iter().sum::<f32>() / abs_wres.len() as f32;
1267 let var: f32 = abs_wres.iter()
1268 .map(|&v| (v - mean_abs).powi(2))
1269 .sum::<f32>() / abs_wres.len() as f32;
1270 let factor = var.sqrt() * 6.0;
1271
1272 if factor > 1e-10 {
1273 let mut wres_norm: Vec<f32> = wres.iter()
1275 .map(|w| w.norm() / factor)
1276 .collect();
1277
1278 for v in wres_norm.iter_mut() {
1280 if *v < 1.0 {
1281 *v = 1.0;
1282 }
1283 }
1284
1285 for i in 0..n_total {
1287 if wres_norm[i] > 1.0 {
1288 badpoint[i] = 1.0;
1289 }
1290 if work_mask[i] != 0 {
1291 n_std_work[i] *= wres_norm[i].powi(2);
1292 }
1293 }
1294
1295 tempn = n_std_work.clone();
1297 if let Some(ref sk_fft) = sphere_k {
1298 let tempn_sq: Vec<f32> = tempn.iter().map(|&t| t * t).collect();
1299 let smv_tempn_sq = apply_smv_kernel_ws(&tempn_sq, sk_fft, &mut ws);
1300 for i in 0..n_total {
1301 tempn[i] = (smv_tempn_sq[i] + tempn_sq[i]).sqrt();
1302 }
1303 }
1304
1305 m = dataterm_mask_f32(data_weighting, &tempn, &work_mask);
1307 b0 = rdf.iter()
1308 .zip(m.iter())
1309 .map(|(&f, &mi)| {
1310 let phase = Complex32::new(0.0, f);
1311 mi * phase.exp()
1312 })
1313 .collect();
1314 }
1315 }
1316 }
1317
1318 if rel_change < tol_f32 {
1319 progress_callback(total_steps, total_steps);
1321 break;
1322 }
1323 }
1324
1325 let _ = badpoint;
1327
1328 chi.iter()
1330 .zip(mask.iter())
1331 .map(|(&c, &m)| if m == 0 { 0.0 } else { c as f64 })
1332 .collect()
1333}
1334
1335pub fn medi_l1_default(
1337 local_field: &[f64],
1338 mask: &[u8],
1339 magnitude: &[f64],
1340 nx: usize, ny: usize, nz: usize,
1341 vsx: f64, vsy: f64, vsz: f64,
1342) -> Vec<f64> {
1343 let n_std = vec![1.0; local_field.len()];
1345
1346 medi_l1(
1347 local_field,
1348 &n_std,
1349 magnitude,
1350 mask,
1351 nx, ny, nz,
1352 vsx, vsy, vsz,
1353 7.5e-5, (0.0, 0.0, 1.0), false, false, 5.0, 1, 0.3, 0.01, 10, 30, 0.1, )
1365}
1366
1367#[cfg(test)]
1368mod tests {
1369 use super::*;
1370
1371 #[test]
1372 fn test_dataterm_mask_uniform() {
1373 let n_std = vec![1.0f32; 27];
1374 let mask = vec![1u8; 27];
1375
1376 let w = dataterm_mask_f32(0, &n_std, &mask);
1377
1378 for &wi in w.iter() {
1379 assert!((wi - 1.0).abs() < 1e-10);
1380 }
1381 }
1382
1383 #[test]
1384 fn test_dataterm_mask_snr() {
1385 let n_std = vec![2.0f32; 27];
1386 let mask = vec![1u8; 27];
1387
1388 let w = dataterm_mask_f32(1, &n_std, &mask);
1389
1390 let mean: f32 = w.iter().sum::<f32>() / 27.0;
1392 assert!((mean - 1.0).abs() < 1e-5);
1393 }
1394
1395 #[test]
1396 fn test_gradient_mask_constant() {
1397 let mag = vec![1.0f32; 8 * 8 * 8];
1399 let mask = vec![1u8; 8 * 8 * 8];
1400
1401 let (mx, my, mz) = gradient_mask_f32(&mag, &mask, 8, 8, 8, 1.0, 1.0, 1.0, 0.3);
1402
1403 for i in 0..(8 * 8 * 8) {
1405 assert!(mx[i] == 0.0 || mx[i] == 1.0, "mx should be binary, got {}", mx[i]);
1406 assert!(my[i] == 0.0 || my[i] == 1.0, "my should be binary, got {}", my[i]);
1407 assert!(mz[i] == 0.0 || mz[i] == 1.0, "mz should be binary, got {}", mz[i]);
1408 }
1409 }
1410
1411 #[test]
1412 fn test_medi_zero_field() {
1413 let n = 8;
1414 let field = vec![0.0; n * n * n];
1415 let mask = vec![1u8; n * n * n];
1416 let mag = vec![1.0; n * n * n];
1417 let n_std = vec![1.0; n * n * n];
1418
1419 let chi = medi_l1(
1420 &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1421 1000.0, (0.0, 0.0, 1.0), false, false, 5.0, 1, 0.9, 0.1, 10, 3, 0.1
1422 );
1423
1424 for &val in chi.iter() {
1425 assert!(val.abs() < 1e-4, "Zero field should give near-zero chi, got {}", val);
1426 }
1427 }
1428
1429 #[test]
1430 fn test_medi_finite() {
1431 let n = 8;
1432 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
1433 let mask = vec![1u8; n * n * n];
1434 let mag = vec![1.0; n * n * n];
1435 let n_std = vec![1.0; n * n * n];
1436
1437 let chi = medi_l1(
1438 &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1439 1000.0, (0.0, 0.0, 1.0), false, false, 5.0, 1, 0.9, 0.1, 10, 3, 0.1
1440 );
1441
1442 for (i, &val) in chi.iter().enumerate() {
1443 assert!(val.is_finite(), "Chi should be finite at index {}", i);
1444 }
1445 }
1446
1447 #[test]
1448 fn test_medi_with_smv() {
1449 let n = 8;
1450 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
1451 let mask = vec![1u8; n * n * n];
1452 let mag = vec![1.0; n * n * n];
1453 let n_std = vec![1.0; n * n * n];
1454
1455 let chi = medi_l1(
1457 &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1458 1000.0, (0.0, 0.0, 1.0), false, true, 2.0, 1, 0.9, 0.1, 10, 3, 0.1
1459 );
1460
1461 for (i, &val) in chi.iter().enumerate() {
1462 assert!(val.is_finite(), "Chi with SMV should be finite at index {}", i);
1463 }
1464 }
1465
1466 #[test]
1467 fn test_medi_mask() {
1468 let n = 8;
1469 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
1470 let mut mask = vec![1u8; n * n * n];
1471 let mag = vec![1.0; n * n * n];
1472 let n_std = vec![1.0; n * n * n];
1473 mask[0] = 0;
1474 mask[10] = 0;
1475
1476 let chi = medi_l1(
1477 &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1478 1000.0, (0.0, 0.0, 1.0), false, false, 5.0, 1, 0.9, 0.1, 10, 3, 0.1
1479 );
1480
1481 assert_eq!(chi[0], 0.0, "Masked voxel should be zero");
1482 assert_eq!(chi[10], 0.0, "Masked voxel should be zero");
1483 }
1484
1485 #[test]
1489 #[ignore]
1490 fn test_medi_debug() {
1491 let data_path = "/home/ashley/OUT/2bgRemoved.nii";
1492 if !std::path::Path::new(data_path).exists() {
1493 eprintln!("Skipping: {} not found", data_path);
1494 return;
1495 }
1496
1497 let outdir = "/home/ashley/OUT/debug";
1498 std::fs::create_dir_all(outdir).ok();
1499
1500 let bytes = std::fs::read(data_path).unwrap();
1502 let nifti_data = crate::nifti_io::load_nifti(&bytes).unwrap();
1503 let (nx, ny, nz) = nifti_data.dims;
1504 let (vsx, vsy, vsz) = nifti_data.voxel_size;
1505
1506 let n_total = nx * ny * nz;
1507 let vsx_f32 = vsx as f32;
1508 let vsy_f32 = vsy as f32;
1509 let vsz_f32 = vsz as f32;
1510
1511 eprintln!("Data: {}x{}x{}, voxel: {}x{}x{}", nx, ny, nz, vsx, vsy, vsz);
1512
1513 let local_field: Vec<f32> = nifti_data.data.iter().map(|&v| v as f32).collect();
1515
1516 let mask: Vec<u8> = local_field.iter()
1518 .map(|&v| if v.abs() > 1e-10 { 1 } else { 0 })
1519 .collect();
1520 let mask_count: usize = mask.iter().filter(|&&m| m != 0).count();
1521 eprintln!("Mask voxels: {} / {}", mask_count, n_total);
1522
1523 save_f32_raw(&local_field, &format!("{}/f_rust.raw", outdir));
1525 let mask_f32: Vec<f32> = mask.iter().map(|&m| m as f32).collect();
1526 save_f32_raw(&mask_f32, &format!("{}/mask_rust.raw", outdir));
1527
1528 let lambda: f32 = 7.5e-5;
1530 let beta: f32 = 1.49e-8;
1531 let bdir = (0.0f32, 0.0f32, 1.0f32);
1532 let cg_tol: f32 = 0.1; let cg_max_iter: usize = 10;
1534
1535 let d_kernel = crate::kernels::dipole::dipole_kernel_f32(
1537 nx, ny, nz, vsx_f32, vsy_f32, vsz_f32, bdir,
1538 );
1539 save_f32_raw(&d_kernel, &format!("{}/D_rust.raw", outdir));
1540 eprintln!("D: min={} max={} D[0]={}", fmin(&d_kernel), fmax(&d_kernel), d_kernel[0]);
1541
1542 let mut ws = MediWorkspace::new(nx, ny, nz, vsx_f32, vsy_f32, vsz_f32);
1544
1545 let m: Vec<f32> = mask.iter().map(|&m| if m != 0 { 1.0 } else { 0.0 }).collect();
1547
1548 let b0: Vec<Complex32> = local_field.iter().zip(m.iter())
1550 .map(|(&f, &mi)| {
1551 let phase = Complex32::new(0.0, f);
1552 mi * phase.exp()
1553 })
1554 .collect();
1555
1556 let w_gx: Vec<f32> = m.clone();
1558 let w_gy: Vec<f32> = m.clone();
1559 let w_gz: Vec<f32> = m.clone();
1560
1561 eprintln!("\n=== Iteration 1 ===");
1563 let mut chi = vec![0.0f32; n_total];
1564 let mut dx = vec![0.0f32; n_total];
1565 let mut rhs = vec![0.0f32; n_total];
1566 let mut vr = vec![0.0f32; n_total];
1567 let mut w: Vec<Complex32> = vec![Complex32::new(0.0, 0.0); n_total];
1568
1569 fgrad_periodic_inplace_f32(
1571 &mut ws.gx, &mut ws.gy, &mut ws.gz,
1572 &chi, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32,
1573 );
1574 compute_p_weights_f32(&mut vr, &w_gx, &w_gy, &w_gz, &ws.gx, &ws.gy, &ws.gz, beta);
1575 save_f32_raw(&vr, &format!("{}/P1_rust.raw", outdir));
1576 eprintln!("P1: min={} max={} mean={}", fmin(&vr), fmax(&vr), fmean(&vr));
1577
1578 apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
1580 for i in 0..n_total {
1581 let phase = Complex32::new(0.0, ws.dipole_buf[i]);
1582 w[i] = m[i] * phase.exp();
1583 }
1584
1585 compute_rhs_inplace(
1587 &chi, &w, &b0, &d_kernel,
1588 &w_gx, &w_gy, &w_gz, &vr, lambda,
1589 &mut rhs, &mut ws,
1590 );
1591 save_f32_raw(&rhs, &format!("{}/rhs1_rust.raw", outdir));
1592 eprintln!("RHS1: min={} max={} norm={}", fmin(&rhs), fmax(&rhs), fnorm(&rhs));
1593
1594 negate_f32(&mut rhs);
1596
1597 let mut cg_residuals: Vec<f32> = Vec::new();
1599 {
1600 let n = ws.n_total;
1602 let (nx, ny, nz) = (ws.nx, ws.ny, ws.nz);
1603 let (vsx, vsy, vsz) = (ws.vsx, ws.vsy, ws.vsz);
1604
1605 dx.fill(0.0);
1606 ws.cg_r.copy_from_slice(&rhs);
1607 ws.cg_p.copy_from_slice(&ws.cg_r);
1608 let mut rsold: f32 = norm_squared_f32(&ws.cg_r);
1609 let b_norm: f32 = norm_squared_f32(&rhs).sqrt();
1610
1611 let mut p_copy = vec![0.0f32; n];
1612 let mut prev_residual = rsold.sqrt();
1613
1614 for cg_iter in 0..cg_max_iter {
1615 let residual_before = rsold.sqrt();
1616 cg_residuals.push(residual_before);
1617
1618 p_copy.copy_from_slice(&ws.cg_p);
1619 {
1620 let mut bufs = MediOpBuffers {
1621 gx: &mut ws.gx, gy: &mut ws.gy, gz: &mut ws.gz,
1622 reg_x: &mut ws.reg_x, reg_y: &mut ws.reg_y, reg_z: &mut ws.reg_z,
1623 div_buf: &mut ws.div_buf, dipole_buf: &mut ws.dipole_buf,
1624 complex_buf: &mut ws.complex_buf, complex_buf2: &mut ws.complex_buf2,
1625 };
1626 apply_medi_operator_core(
1627 &mut ws.fft_ws, &mut bufs, n, nx, ny, nz, vsx, vsy, vsz,
1628 &p_copy, &w, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda, &mut ws.cg_ap,
1629 );
1630 }
1631
1632 let pap: f32 = dot_product_f32(&ws.cg_p, &ws.cg_ap);
1633 if pap.abs() < 1e-15 { break; }
1634 let alpha = rsold / pap;
1635
1636 axpy_f32(&mut dx, alpha, &ws.cg_p);
1637 axpy_f32(&mut ws.cg_r, -alpha, &ws.cg_ap);
1638
1639 let rsnew: f32 = norm_squared_f32(&ws.cg_r);
1640 let residual = rsnew.sqrt();
1641
1642 eprintln!(" CG iter {}: res={:.6e}, alpha={:.6e}, pap={:.6e}",
1643 cg_iter + 1, residual, alpha, pap);
1644
1645 if residual < cg_tol * b_norm { break; }
1646
1647 let beta_cg = rsnew / rsold;
1650 xpby_f32(&mut ws.cg_p, &ws.cg_r, beta_cg);
1651 rsold = rsnew;
1652 prev_residual = residual;
1653 }
1654 }
1655 save_f32_raw(&dx, &format!("{}/dx1_rust.raw", outdir));
1656 eprintln!("dx1: min={} max={} norm={}", fmin(&dx), fmax(&dx), fnorm(&dx));
1657
1658 axpy_f32(&mut chi, 1.0, &dx);
1660 save_f32_raw(&chi, &format!("{}/chi1_rust.raw", outdir));
1661 eprintln!("chi1: min={} max={} norm={}", fmin(&chi), fmax(&chi), fnorm(&chi));
1662
1663 eprintln!("\n=== Iteration 2 ===");
1665
1666 fgrad_periodic_inplace_f32(
1668 &mut ws.gx, &mut ws.gy, &mut ws.gz,
1669 &chi, nx, ny, nz, vsx_f32, vsy_f32, vsz_f32,
1670 );
1671 compute_p_weights_f32(&mut vr, &w_gx, &w_gy, &w_gz, &ws.gx, &ws.gy, &ws.gz, beta);
1672 save_f32_raw(&vr, &format!("{}/P2_rust.raw", outdir));
1673 eprintln!("P2: min={} max={} mean={}", fmin(&vr), fmax(&vr), fmean(&vr));
1674
1675 apply_dipole_conv(&mut ws.fft_ws, &chi, &d_kernel, &mut ws.dipole_buf, &mut ws.complex_buf);
1677 for i in 0..n_total {
1678 let phase = Complex32::new(0.0, ws.dipole_buf[i]);
1679 w[i] = m[i] * phase.exp();
1680 }
1681
1682 compute_rhs_inplace(
1684 &chi, &w, &b0, &d_kernel,
1685 &w_gx, &w_gy, &w_gz, &vr, lambda,
1686 &mut rhs, &mut ws,
1687 );
1688 save_f32_raw(&rhs, &format!("{}/rhs2_rust.raw", outdir));
1689 eprintln!("RHS2: min={} max={} norm={}", fmin(&rhs), fmax(&rhs), fnorm(&rhs));
1690
1691 negate_f32(&mut rhs);
1692
1693 {
1695 let n = ws.n_total;
1696 let (nx, ny, nz) = (ws.nx, ws.ny, ws.nz);
1697 let (vsx, vsy, vsz) = (ws.vsx, ws.vsy, ws.vsz);
1698
1699 dx.fill(0.0);
1700 ws.cg_r.copy_from_slice(&rhs);
1701 ws.cg_p.copy_from_slice(&ws.cg_r);
1702 let mut rsold: f32 = norm_squared_f32(&ws.cg_r);
1703 let b_norm: f32 = norm_squared_f32(&rhs).sqrt();
1704 let mut p_copy = vec![0.0f32; n];
1705
1706 for cg_iter in 0..cg_max_iter {
1707 p_copy.copy_from_slice(&ws.cg_p);
1708 {
1709 let mut bufs = MediOpBuffers {
1710 gx: &mut ws.gx, gy: &mut ws.gy, gz: &mut ws.gz,
1711 reg_x: &mut ws.reg_x, reg_y: &mut ws.reg_y, reg_z: &mut ws.reg_z,
1712 div_buf: &mut ws.div_buf, dipole_buf: &mut ws.dipole_buf,
1713 complex_buf: &mut ws.complex_buf, complex_buf2: &mut ws.complex_buf2,
1714 };
1715 apply_medi_operator_core(
1716 &mut ws.fft_ws, &mut bufs, n, nx, ny, nz, vsx, vsy, vsz,
1717 &p_copy, &w, &d_kernel, &w_gx, &w_gy, &w_gz, &vr, lambda, &mut ws.cg_ap,
1718 );
1719 }
1720 let pap: f32 = dot_product_f32(&ws.cg_p, &ws.cg_ap);
1721 if pap.abs() < 1e-15 { break; }
1722 let alpha = rsold / pap;
1723 axpy_f32(&mut dx, alpha, &ws.cg_p);
1724 axpy_f32(&mut ws.cg_r, -alpha, &ws.cg_ap);
1725 let rsnew: f32 = norm_squared_f32(&ws.cg_r);
1726 let residual = rsnew.sqrt();
1727 eprintln!(" CG iter {}: res={:.6e}", cg_iter + 1, residual);
1728 if residual < cg_tol * b_norm { break; }
1729 let beta_cg = rsnew / rsold;
1730 xpby_f32(&mut ws.cg_p, &ws.cg_r, beta_cg);
1731 rsold = rsnew;
1732 }
1733 }
1734 save_f32_raw(&dx, &format!("{}/dx2_rust.raw", outdir));
1735 eprintln!("dx2: min={} max={} norm={}", fmin(&dx), fmax(&dx), fnorm(&dx));
1736
1737 axpy_f32(&mut chi, 1.0, &dx);
1738 save_f32_raw(&chi, &format!("{}/chi2_rust.raw", outdir));
1739 eprintln!("chi2: min={} max={} norm={}", fmin(&chi), fmax(&chi), fnorm(&chi));
1740
1741 eprintln!("\nDone. Intermediates saved to {}", outdir);
1742 }
1743
1744 fn save_f32_raw(data: &[f32], path: &str) {
1745 use std::io::Write;
1746 let mut file = std::fs::File::create(path).unwrap();
1747 for &val in data {
1748 file.write_all(&val.to_le_bytes()).unwrap();
1749 }
1750 }
1751
1752 fn fmin(data: &[f32]) -> f32 { data.iter().cloned().fold(f32::MAX, f32::min) }
1753 fn fmax(data: &[f32]) -> f32 { data.iter().cloned().fold(f32::MIN, f32::max) }
1754 fn fmean(data: &[f32]) -> f32 { data.iter().sum::<f32>() / data.len() as f32 }
1755 fn fnorm(data: &[f32]) -> f32 { data.iter().map(|&v| v * v).sum::<f32>().sqrt() }
1756
1757 #[test]
1758 fn test_medi_l1_small() {
1759 let n = 8;
1761 let n_total = n * n * n;
1762
1763 let mut field = vec![0.0f64; n_total];
1765 let center = n / 2;
1766 for z in 0..n {
1767 for y in 0..n {
1768 for x in 0..n {
1769 let idx = x + y * n + z * n * n;
1770 let dx = (x as f64) - (center as f64);
1771 let dy = (y as f64) - (center as f64);
1772 let dz = (z as f64) - (center as f64);
1773 let r2 = dx*dx + dy*dy + dz*dz;
1774 if r2 > 1.0 {
1775 let r = r2.sqrt();
1777 let cos_theta = dz / r;
1778 field[idx] = (3.0 * cos_theta * cos_theta - 1.0) / (r * r * r) * 0.01;
1779 }
1780 }
1781 }
1782 }
1783
1784 let mask = vec![1u8; n_total];
1785 let mag = vec![1.0f64; n_total];
1786 let n_std = vec![1.0f64; n_total];
1787
1788 let chi = medi_l1(
1790 &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1791 1e-4, (0.0, 0.0, 1.0), false, false, 5.0, 1, 0.3, 0.01, 10, 5, 0.1, );
1803
1804 assert_eq!(chi.len(), n_total);
1806 for (i, &val) in chi.iter().enumerate() {
1807 assert!(val.is_finite(), "MEDI L1 chi should be finite at index {}", i);
1808 }
1809
1810 let chi_norm: f64 = chi.iter().map(|&v| v * v).sum::<f64>().sqrt();
1812 assert!(chi_norm > 1e-10, "MEDI L1 should produce non-zero susceptibility for dipole field, got norm={}", chi_norm);
1813 }
1814
1815 #[test]
1816 fn test_medi_weight_types() {
1817 let n = 8;
1818 let n_total = n * n * n;
1819
1820 let field: Vec<f64> = (0..n_total).map(|i| (i as f64) * 0.001).collect();
1821 let mask = vec![1u8; n_total];
1822 let mag = vec![1.0f64; n_total];
1823 let n_std = vec![1.0f64; n_total];
1824
1825 let chi_uniform = medi_l1(
1827 &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1828 1000.0, (0.0, 0.0, 1.0), false, false, 5.0,
1829 0, 0.9, 0.1, 10, 3, 0.1
1831 );
1832
1833 let n_std_varying: Vec<f64> = (0..n_total).map(|i| 0.5 + (i as f64) * 0.01).collect();
1835 let chi_snr = medi_l1(
1836 &field, &n_std_varying, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1837 1000.0, (0.0, 0.0, 1.0), false, false, 5.0,
1838 1, 0.9, 0.1, 10, 3, 0.1
1840 );
1841
1842 for (i, &val) in chi_uniform.iter().enumerate() {
1844 assert!(val.is_finite(), "Uniform weighting chi should be finite at {}", i);
1845 }
1846 for (i, &val) in chi_snr.iter().enumerate() {
1847 assert!(val.is_finite(), "SNR weighting chi should be finite at {}", i);
1848 }
1849
1850 let diff_norm: f64 = chi_uniform.iter()
1852 .zip(chi_snr.iter())
1853 .map(|(&a, &b)| (a - b).powi(2))
1854 .sum::<f64>()
1855 .sqrt();
1856 assert!(diff_norm.is_finite(), "Difference between weight modes should be finite");
1858 }
1859
1860 #[test]
1861 fn test_medi_l1_with_progress_small() {
1862 let n = 8;
1863 let n_total = n * n * n;
1864
1865 let field: Vec<f64> = (0..n_total).map(|i| (i as f64) * 0.001).collect();
1866 let mask = vec![1u8; n_total];
1867 let mag = vec![1.0f64; n_total];
1868 let n_std = vec![1.0f64; n_total];
1869
1870 let mut progress_calls = 0usize;
1871 let chi = medi_l1_with_progress(
1872 &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1873 1000.0, (0.0, 0.0, 1.0), false, false, 5.0,
1874 1, 0.9, 0.1, 10, 3, 0.1,
1875 |_iter, _max| { progress_calls += 1; },
1876 );
1877
1878 assert_eq!(chi.len(), n_total);
1879 for &val in &chi {
1880 assert!(val.is_finite(), "medi_l1_with_progress output should be finite");
1881 }
1882 assert!(progress_calls > 0, "progress callback should be called at least once");
1883 }
1884
1885 #[test]
1886 fn test_medi_l1_with_merit() {
1887 let n = 8;
1888 let n_total = n * n * n;
1889
1890 let mut field = vec![0.0f64; n_total];
1891 let center = n / 2;
1892 for z in 0..n {
1893 for y in 0..n {
1894 for x in 0..n {
1895 let idx = x + y * n + z * n * n;
1896 let dx = (x as f64) - (center as f64);
1897 let dy = (y as f64) - (center as f64);
1898 let dz = (z as f64) - (center as f64);
1899 let r2 = dx * dx + dy * dy + dz * dz;
1900 if r2 > 1.0 {
1901 let r = r2.sqrt();
1902 let cos_theta = dz / r;
1903 field[idx] = (3.0 * cos_theta * cos_theta - 1.0) / (r * r * r) * 0.01;
1904 }
1905 }
1906 }
1907 }
1908
1909 let mask = vec![1u8; n_total];
1910 let mag = vec![1.0f64; n_total];
1911 let n_std = vec![1.0f64; n_total];
1912
1913 let chi = medi_l1(
1915 &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1916 1e-4, (0.0, 0.0, 1.0),
1917 true, false, 5.0, 1, 0.3, 0.01, 10, 5, 0.1,
1919 );
1920
1921 assert_eq!(chi.len(), n_total);
1922 for &val in &chi {
1923 assert!(val.is_finite(), "MEDI with merit should produce finite results");
1924 }
1925 }
1926
1927 #[test]
1928 fn test_medi_l1_with_smv() {
1929 let n = 8;
1930 let n_total = n * n * n;
1931
1932 let field: Vec<f64> = (0..n_total).map(|i| (i as f64) * 0.001).collect();
1933 let mask = vec![1u8; n_total];
1934 let mag = vec![1.0f64; n_total];
1935 let n_std = vec![1.0f64; n_total];
1936
1937 let chi = medi_l1(
1939 &field, &n_std, &mag, &mask, n, n, n, 1.0, 1.0, 1.0,
1940 1000.0, (0.0, 0.0, 1.0),
1941 false,
1942 true, 3.0, 1, 0.3, 0.01, 10, 3, 0.1,
1945 );
1946
1947 assert_eq!(chi.len(), n_total);
1948 for &val in &chi {
1949 assert!(val.is_finite(), "MEDI with SMV should produce finite results");
1950 }
1951 }
1952}