1use num_complex::Complex64;
14use crate::fft::{fft3d, ifft3d};
15use crate::kernels::dipole::dipole_kernel;
16
17#[cfg(feature = "parallel")]
18use crate::par::*;
19
20#[derive(Clone, Debug)]
33pub struct PdfParams {
34 pub tol: f64,
36}
37
38impl Default for PdfParams {
39 fn default() -> Self {
40 Self { tol: 1e-5 }
41 }
42}
43
44pub fn pdf(
47 field: &[f64],
48 mask: &[u8],
49 nx: usize, ny: usize, nz: usize,
50 vsx: f64, vsy: f64, vsz: f64,
51 bdir: (f64, f64, f64),
52 tol: f64,
53 max_iter: usize,
54) -> Vec<f64> {
55 pdf_with_progress(field, mask, nx, ny, nz, vsx, vsy, vsz, bdir, tol, max_iter, |_, _| {})
56}
57
58
59
60fn apply_a(
62 x: &[f64],
63 bg_mask: &[f64],
64 d_kernel: &[f64],
65 brain_mask: &[f64],
66 nx: usize, ny: usize, nz: usize,
67) -> Vec<f64> {
68 let n_total = nx * ny * nz;
69
70 let mut temp: Vec<Complex64> = x.iter()
72 .zip(bg_mask.iter())
73 .map(|(&xi, &m)| Complex64::new(xi * m, 0.0))
74 .collect();
75
76 fft3d(&mut temp, nx, ny, nz);
78
79 for i in 0..n_total {
81 temp[i] *= d_kernel[i];
82 }
83
84 ifft3d(&mut temp, nx, ny, nz);
86
87 temp.iter()
89 .zip(brain_mask.iter())
90 .map(|(t, &w)| t.re * w)
91 .collect()
92}
93
94fn apply_at(
96 u: &[f64],
97 brain_mask: &[f64],
98 d_kernel: &[f64],
99 bg_mask: &[f64],
100 nx: usize, ny: usize, nz: usize,
101) -> Vec<f64> {
102 let n_total = nx * ny * nz;
103
104 let mut temp: Vec<Complex64> = u.iter()
106 .zip(brain_mask.iter())
107 .map(|(&ui, &w)| Complex64::new(ui * w, 0.0))
108 .collect();
109
110 fft3d(&mut temp, nx, ny, nz);
112
113 for i in 0..n_total {
115 temp[i] *= d_kernel[i];
116 }
117
118 ifft3d(&mut temp, nx, ny, nz);
120
121 temp.iter()
123 .zip(bg_mask.iter())
124 .map(|(t, &m)| t.re * m)
125 .collect()
126}
127
128fn vec_norm(v: &[f64]) -> f64 {
130 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
131}
132
133pub fn pdf_with_progress<F>(
137 field: &[f64],
138 mask: &[u8],
139 nx: usize, ny: usize, nz: usize,
140 vsx: f64, vsy: f64, vsz: f64,
141 bdir: (f64, f64, f64),
142 tol: f64,
143 max_iter: usize,
144 mut progress_callback: F,
145) -> Vec<f64>
146where
147 F: FnMut(usize, usize),
148{
149 let n_total = nx * ny * nz;
150
151 let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
153
154 let bg_mask: Vec<f64> = mask.iter()
156 .map(|&m| if m == 0 { 1.0 } else { 0.0 })
157 .collect();
158
159 let brain_mask: Vec<f64> = mask.iter()
161 .map(|&m| if m != 0 { 1.0 } else { 0.0 })
162 .collect();
163
164 let b: Vec<f64> = field.iter()
166 .zip(brain_mask.iter())
167 .map(|(&f, &w)| f * w)
168 .collect();
169
170 let mut x = vec![0.0; n_total];
172
173 let mut u = b.clone();
174 let mut beta = vec_norm(&u);
175
176 if beta < 1e-20 {
177 return vec![0.0; n_total];
178 }
179
180 for i in 0..n_total {
181 u[i] /= beta;
182 }
183
184 let mut v = apply_at(&u, &brain_mask, &d_kernel, &bg_mask, nx, ny, nz);
185 let mut alpha = vec_norm(&v);
186
187 if alpha < 1e-20 {
188 return vec![0.0; n_total];
189 }
190
191 for i in 0..n_total {
192 v[i] /= alpha;
193 }
194
195 let norm_b = beta;
197 let mut w = v.clone();
198 let mut phi_bar = beta;
199 let mut rho_bar = alpha;
200
201 let mut beta_dd = beta;
203 let mut beta_d = 0.0;
204 let mut rho_d_old = 1.0;
205 let mut tau_tilde_old = 0.0;
206 let mut theta_tilde = 0.0;
207 let mut zeta = 0.0;
208 let d_accum = 0.0; let mut norm_a2 = alpha * alpha;
212
213 let mut zeta_bar = alpha * beta;
215 let mut alpha_bar = alpha;
216 let mut c_bar = 1.0;
217 let mut s_bar = 0.0;
218 let mut rho_val = 1.0;
219 let mut rho_bar_lsmr = 1.0;
220
221 for iter in 0..max_iter {
222 progress_callback(iter + 1, max_iter);
224
225 let av = apply_a(&v, &bg_mask, &d_kernel, &brain_mask, nx, ny, nz);
227 for i in 0..n_total {
228 u[i] = av[i] - alpha * u[i];
229 }
230 beta = vec_norm(&u);
231
232 if beta < 1e-20 {
233 progress_callback(iter + 1, iter + 1);
234 break;
235 }
236
237 for i in 0..n_total {
238 u[i] /= beta;
239 }
240
241 let atu = apply_at(&u, &brain_mask, &d_kernel, &bg_mask, nx, ny, nz);
242 for i in 0..n_total {
243 v[i] = atu[i] - beta * v[i];
244 }
245 alpha = vec_norm(&v);
246
247 if alpha < 1e-20 {
248 progress_callback(iter + 1, iter + 1);
249 break;
250 }
251
252 for i in 0..n_total {
253 v[i] /= alpha;
254 }
255
256 let rho = (rho_bar * rho_bar + beta * beta).sqrt();
258 let c = rho_bar / rho;
259 let s = beta / rho;
260 let theta = s * alpha;
261 rho_bar = -c * alpha;
262 let phi = c * phi_bar;
263 phi_bar = s * phi_bar;
264
265 let phi_rho = phi / rho;
267 let theta_rho = theta / rho;
268
269 for i in 0..n_total {
270 x[i] += phi_rho * w[i];
271 w[i] = v[i] - theta_rho * w[i];
272 }
273
274 let _rho_old = rho_val;
276 rho_val = (alpha_bar * alpha_bar + beta * beta).sqrt();
277 let c_lsmr = alpha_bar / rho_val;
278 let s_lsmr = beta / rho_val;
279 let theta_new = s_lsmr * alpha;
280 alpha_bar = c_lsmr * alpha;
281
282 let _rho_bar_old = rho_bar_lsmr;
283 let zeta_old = zeta;
284 let theta_bar = s_bar * rho_val;
285 let rho_tmp = c_bar * rho_val;
286 rho_bar_lsmr = (rho_tmp * rho_tmp + theta_new * theta_new).sqrt();
287 c_bar = rho_tmp / rho_bar_lsmr;
288 s_bar = theta_new / rho_bar_lsmr;
289 zeta = c_bar * zeta_bar;
290 zeta_bar = -s_bar * zeta_bar;
291
292 let beta_hat = c_lsmr * beta_dd;
294 beta_dd = -s_lsmr * beta_dd;
295
296 let theta_tilde_old = theta_tilde;
297 let rho_tilde_old = (rho_d_old * rho_d_old + theta_bar * theta_bar).sqrt();
298 let c_tilde_old = rho_d_old / rho_tilde_old;
299 let s_tilde_old = theta_bar / rho_tilde_old;
300 theta_tilde = s_tilde_old * rho_bar_lsmr;
301 rho_d_old = c_tilde_old * rho_bar_lsmr;
302 beta_d = -s_tilde_old * beta_d + c_tilde_old * beta_hat;
303
304 tau_tilde_old = (zeta_old - theta_tilde_old * tau_tilde_old) / rho_tilde_old;
305 let tau_d = (zeta - theta_tilde * tau_tilde_old) / rho_d_old;
306
307 let norm_r = (d_accum + (beta_d - tau_d).powi(2) + beta_dd.powi(2)).sqrt();
308
309 norm_a2 += beta * beta;
311 let norm_a = norm_a2.sqrt();
312 norm_a2 += alpha * alpha;
313
314 let norm_ar = zeta_bar.abs();
316
317 let norm_x = vec_norm(&x);
319 let test1 = norm_r / norm_b;
320 let test2 = if norm_a * norm_r > 0.0 { norm_ar / (norm_a * norm_r) } else { 0.0 };
321 let eps_r = tol + tol * norm_a * norm_x / norm_b;
322
323 if test1 <= eps_r || test2 <= tol {
324 progress_callback(iter + 1, iter + 1);
325 break;
326 }
327 }
328
329 let mut bg_source: Vec<Complex64> = x.iter()
331 .zip(bg_mask.iter())
332 .map(|(&xi, &m)| Complex64::new(xi * m, 0.0))
333 .collect();
334
335 fft3d(&mut bg_source, nx, ny, nz);
336
337 for i in 0..n_total {
338 bg_source[i] *= d_kernel[i];
339 }
340
341 ifft3d(&mut bg_source, nx, ny, nz);
342
343 let mut local_field = vec![0.0; n_total];
345 for i in 0..n_total {
346 if mask[i] != 0 {
347 local_field[i] = field[i] - bg_source[i].re;
348 }
349 }
350
351 local_field
352}
353
354pub fn pdf_default(
356 field: &[f64],
357 mask: &[u8],
358 nx: usize, ny: usize, nz: usize,
359 vsx: f64, vsy: f64, vsz: f64,
360) -> Vec<f64> {
361 let max_iter = ((nx * ny * nz) as f64).sqrt().ceil() as usize;
362 pdf(
363 field, mask, nx, ny, nz, vsx, vsy, vsz,
364 (0.0, 0.0, 1.0), 1e-5, max_iter )
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn test_pdf_zero_field() {
376 let n = 8;
377 let field = vec![0.0; n * n * n];
378 let mask = vec![1u8; n * n * n];
379
380 let local = pdf(
381 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
382 (0.0, 0.0, 1.0), 1e-5, 10
383 );
384
385 for &val in local.iter() {
386 assert!(val.abs() < 1e-10, "Zero field should give zero local field");
387 }
388 }
389
390 #[test]
391 fn test_pdf_finite() {
392 let n = 8;
393 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
394
395 let mut mask = vec![0u8; n * n * n];
397 let center = n / 2;
398 let radius = n / 4;
399
400 for i in 0..n {
401 for j in 0..n {
402 for k in 0..n {
403 let di = (i as i32) - (center as i32);
404 let dj = (j as i32) - (center as i32);
405 let dk = (k as i32) - (center as i32);
406 if di*di + dj*dj + dk*dk <= (radius * radius) as i32 {
407 mask[i * n * n + j * n + k] = 1;
408 }
409 }
410 }
411 }
412
413 let local = pdf(
414 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
415 (0.0, 0.0, 1.0), 1e-5, 20
416 );
417
418 for (i, &val) in local.iter().enumerate() {
419 assert!(val.is_finite(), "Local field should be finite at index {}", i);
420 }
421 }
422
423 #[test]
424 fn test_pdf_mask() {
425 let n = 8;
426 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
427 let mut mask = vec![1u8; n * n * n];
428 mask[0] = 0;
429 mask[10] = 0;
430
431 let local = pdf(
432 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
433 (0.0, 0.0, 1.0), 1e-5, 10
434 );
435
436 assert_eq!(local[0], 0.0, "Masked voxel should be zero");
437 assert_eq!(local[10], 0.0, "Masked voxel should be zero");
438 }
439
440 #[test]
441 fn test_pdf_nonuniform_voxels() {
442 let n = 8;
443 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
444
445 let mut mask = vec![0u8; n * n * n];
447 let center = n / 2;
448 let radius = n / 4;
449
450 for z in 0..n {
451 for y in 0..n {
452 for x in 0..n {
453 let dx = (x as i32) - (center as i32);
454 let dy = (y as i32) - (center as i32);
455 let dz = (z as i32) - (center as i32);
456 if dx*dx + dy*dy + dz*dz <= (radius * radius) as i32 {
457 mask[x + y * n + z * n * n] = 1;
458 }
459 }
460 }
461 }
462
463 let local = pdf(
465 &field, &mask, n, n, n, 0.5, 1.0, 2.0,
466 (0.0, 0.0, 1.0), 1e-5, 20
467 );
468
469 for (i, &val) in local.iter().enumerate() {
470 assert!(val.is_finite(), "PDF with nonuniform voxels should be finite at index {}", i);
471 }
472
473 for i in 0..n*n*n {
475 if mask[i] == 0 {
476 assert_eq!(local[i], 0.0, "Outside mask should be zero");
477 }
478 }
479 }
480
481 #[test]
482 fn test_pdf_varying_field() {
483 let n = 8;
484
485 let mut field = vec![0.0; n * n * n];
487 for z in 0..n {
488 for y in 0..n {
489 for x in 0..n {
490 let idx = x + y * n + z * n * n;
491 let zf = (z as f64) / (n as f64);
492 field[idx] = zf * zf * 0.5;
493 }
494 }
495 }
496
497 let mut mask = vec![0u8; n * n * n];
499 let center = n / 2;
500 let radius = n / 4;
501 for z in 0..n {
502 for y in 0..n {
503 for x in 0..n {
504 let dx = (x as i32) - (center as i32);
505 let dy = (y as i32) - (center as i32);
506 let dz = (z as i32) - (center as i32);
507 if dx*dx + dy*dy + dz*dz <= (radius * radius) as i32 {
508 mask[x + y * n + z * n * n] = 1;
509 }
510 }
511 }
512 }
513
514 let local = pdf(
515 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
516 (0.0, 0.0, 1.0), 1e-5, 30
517 );
518
519 for (i, &val) in local.iter().enumerate() {
521 assert!(val.is_finite(), "Varying field should produce finite results at index {}", i);
522 }
523
524 let mut any_changed = false;
527 for i in 0..n*n*n {
528 if mask[i] != 0 && (local[i] - field[i]).abs() > 1e-10 {
529 any_changed = true;
530 break;
531 }
532 }
533 assert!(any_changed, "PDF should modify the field inside the mask for a varying field");
534 }
535
536 #[test]
537 fn test_pdf_larger_volume() {
538 let n = 16;
540
541 let mut field = vec![0.0; n * n * n];
543 let center = n / 2;
544 for z in 0..n {
545 for y in 0..n {
546 for x in 0..n {
547 let dx = (x as f64) - (center as f64);
548 let dy = (y as f64) - (center as f64);
549 let dz = (z as f64) - (center as f64);
550 let r2 = dx * dx + dy * dy + dz * dz;
551 if r2 > 0.5 {
552 let r = r2.sqrt();
554 let cos_theta = dz / r;
555 field[x + y * n + z * n * n] = (3.0 * cos_theta * cos_theta - 1.0) / (r * r * r) * 0.01;
556 }
557 }
558 }
559 }
560
561 let mut mask = vec![0u8; n * n * n];
563 let radius = n / 3;
564 for z in 0..n {
565 for y in 0..n {
566 for x in 0..n {
567 let dx = (x as i32) - (center as i32);
568 let dy = (y as i32) - (center as i32);
569 let dz = (z as i32) - (center as i32);
570 if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
571 mask[x + y * n + z * n * n] = 1;
572 }
573 }
574 }
575 }
576
577 let local = pdf(
578 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
579 (0.0, 0.0, 1.0), 1e-4, 50
580 );
581
582 assert_eq!(local.len(), n * n * n);
583 for (i, &val) in local.iter().enumerate() {
584 assert!(val.is_finite(), "PDF larger volume: finite at index {}", i);
585 }
586
587 for i in 0..n * n * n {
589 if mask[i] == 0 {
590 assert_eq!(local[i], 0.0);
591 }
592 }
593 }
594
595 #[test]
596 fn test_pdf_more_iterations() {
597 let n = 8;
599 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
600
601 let mut mask = vec![0u8; n * n * n];
602 let center = n / 2;
603 let radius = n / 4;
604 for z in 0..n {
605 for y in 0..n {
606 for x in 0..n {
607 let dx = (x as i32) - (center as i32);
608 let dy = (y as i32) - (center as i32);
609 let dz = (z as i32) - (center as i32);
610 if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
611 mask[x + y * n + z * n * n] = 1;
612 }
613 }
614 }
615 }
616
617 let local_many = pdf(
619 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
620 (0.0, 0.0, 1.0), 1e-8, 100
621 );
622
623 let local_few = pdf(
624 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
625 (0.0, 0.0, 1.0), 1e-8, 5
626 );
627
628 for &val in local_many.iter().chain(local_few.iter()) {
630 assert!(val.is_finite());
631 }
632 }
633
634 #[test]
635 fn test_pdf_different_bdir() {
636 let n = 8;
638 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
639
640 let mut mask = vec![0u8; n * n * n];
641 let center = n / 2;
642 let radius = n / 4;
643 for z in 0..n {
644 for y in 0..n {
645 for x in 0..n {
646 let dx = (x as i32) - (center as i32);
647 let dy = (y as i32) - (center as i32);
648 let dz = (z as i32) - (center as i32);
649 if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
650 mask[x + y * n + z * n * n] = 1;
651 }
652 }
653 }
654 }
655
656 let local = pdf(
658 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
659 (0.1, 0.2, 0.97), 1e-5, 20
660 );
661
662 for &val in &local {
663 assert!(val.is_finite());
664 }
665 }
666
667 #[test]
668 fn test_pdf_with_progress() {
669 let n = 8;
670 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
671
672 let mut mask = vec![0u8; n * n * n];
673 let center = n / 2;
674 let radius = n / 4;
675 for z in 0..n {
676 for y in 0..n {
677 for x in 0..n {
678 let dx = (x as i32) - (center as i32);
679 let dy = (y as i32) - (center as i32);
680 let dz = (z as i32) - (center as i32);
681 if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
682 mask[x + y * n + z * n * n] = 1;
683 }
684 }
685 }
686 }
687
688 let mut progress_calls = Vec::new();
689 let local = pdf_with_progress(
690 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
691 (0.0, 0.0, 1.0), 1e-5, 20,
692 |iter, max| { progress_calls.push((iter, max)); }
693 );
694
695 assert_eq!(local.len(), n * n * n);
696 assert!(!progress_calls.is_empty(), "Progress should be called at least once");
697 for &val in &local {
698 assert!(val.is_finite());
699 }
700 }
701
702 #[test]
703 fn test_pdf_default_wrapper() {
704 let n = 8;
705 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
706
707 let mut mask = vec![0u8; n * n * n];
708 let center = n / 2;
709 let radius = n / 4;
710 for z in 0..n {
711 for y in 0..n {
712 for x in 0..n {
713 let dx = (x as i32) - (center as i32);
714 let dy = (y as i32) - (center as i32);
715 let dz = (z as i32) - (center as i32);
716 if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
717 mask[x + y * n + z * n * n] = 1;
718 }
719 }
720 }
721 }
722
723 let local = pdf_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
724 assert_eq!(local.len(), n * n * n);
725 for &val in &local {
726 assert!(val.is_finite());
727 }
728 }
729
730 #[test]
731 fn test_pdf_all_mask() {
732 let n = 8;
734 let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
735 let mask = vec![1u8; n * n * n];
736
737 let local = pdf(
738 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
739 (0.0, 0.0, 1.0), 1e-5, 10
740 );
741
742 for &val in &local {
745 assert!(val.is_finite());
746 }
747 }
748}