Skip to main content

qsm_core/bgremove/
pdf.rs

1//! Projection onto Dipole Fields (PDF) background field removal
2//!
3//! Projects the field onto dipole fields generated by sources outside
4//! the brain mask, separating background and local fields.
5//!
6//! Reference:
7//! Liu, T., Khalidov, I., de Rochefort, L., Spincemaille, P., Liu, J., Tsiouris, A.J.,
8//! Wang, Y. (2011). "A novel background field removal method for MRI using projection
9//! onto dipole fields." NMR in Biomedicine, 24(9):1129-1136. https://doi.org/10.1002/nbm.1670
10//!
11//! Reference implementation: https://github.com/kamesy/QSM.jl
12
13use num_complex::Complex64;
14use crate::fft::{fft3d, ifft3d};
15use crate::kernels::dipole::dipole_kernel;
16
17#[cfg(feature = "parallel")]
18use crate::par::*;
19
20/// PDF background field removal
21///
22/// # Arguments
23/// * `field` - Total field (nx * ny * nz)
24/// * `mask` - Binary mask (nx * ny * nz), 1 = brain, 0 = background
25/// * `nx`, `ny`, `nz` - Array dimensions
26/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
27/// * `bdir` - B0 field direction
28/// * `tol` - Convergence tolerance for LSMR
29/// * `max_iter` - Maximum iterations for LSMR
30///
31/// PDF algorithm parameters
32#[derive(Clone, Debug)]
33pub struct PdfParams {
34    /// Convergence tolerance
35    pub tol: f64,
36}
37
38impl Default for PdfParams {
39    fn default() -> Self {
40        Self { tol: 1e-5 }
41    }
42}
43
44/// # Returns
45/// Local field with background removed
46pub fn pdf(
47    field: &[f64],
48    mask: &[u8],
49    nx: usize, ny: usize, nz: usize,
50    vsx: f64, vsy: f64, vsz: f64,
51    bdir: (f64, f64, f64),
52    tol: f64,
53    max_iter: usize,
54) -> Vec<f64> {
55    pdf_with_progress(field, mask, nx, ny, nz, vsx, vsy, vsz, bdir, tol, max_iter, |_, _| {})
56}
57
58
59
60/// Apply A = W * D * M_bg
61fn apply_a(
62    x: &[f64],
63    bg_mask: &[f64],
64    d_kernel: &[f64],
65    brain_mask: &[f64],
66    nx: usize, ny: usize, nz: usize,
67) -> Vec<f64> {
68    let n_total = nx * ny * nz;
69
70    // Apply background mask
71    let mut temp: Vec<Complex64> = x.iter()
72        .zip(bg_mask.iter())
73        .map(|(&xi, &m)| Complex64::new(xi * m, 0.0))
74        .collect();
75
76    // FFT
77    fft3d(&mut temp, nx, ny, nz);
78
79    // Apply dipole kernel
80    for i in 0..n_total {
81        temp[i] *= d_kernel[i];
82    }
83
84    // IFFT
85    ifft3d(&mut temp, nx, ny, nz);
86
87    // Apply brain mask weights
88    temp.iter()
89        .zip(brain_mask.iter())
90        .map(|(t, &w)| t.re * w)
91        .collect()
92}
93
94/// Apply A^T = M_bg * D * W
95fn apply_at(
96    u: &[f64],
97    brain_mask: &[f64],
98    d_kernel: &[f64],
99    bg_mask: &[f64],
100    nx: usize, ny: usize, nz: usize,
101) -> Vec<f64> {
102    let n_total = nx * ny * nz;
103
104    // Apply brain mask weights
105    let mut temp: Vec<Complex64> = u.iter()
106        .zip(brain_mask.iter())
107        .map(|(&ui, &w)| Complex64::new(ui * w, 0.0))
108        .collect();
109
110    // FFT
111    fft3d(&mut temp, nx, ny, nz);
112
113    // Apply dipole kernel (D is real and symmetric)
114    for i in 0..n_total {
115        temp[i] *= d_kernel[i];
116    }
117
118    // IFFT
119    ifft3d(&mut temp, nx, ny, nz);
120
121    // Apply background mask
122    temp.iter()
123        .zip(bg_mask.iter())
124        .map(|(t, &m)| t.re * m)
125        .collect()
126}
127
128/// Vector 2-norm
129fn vec_norm(v: &[f64]) -> f64 {
130    v.iter().map(|&x| x * x).sum::<f64>().sqrt()
131}
132
133/// PDF with progress callback
134///
135/// Same as `pdf` but calls `progress_callback(iteration, max_iter)` each iteration.
136pub fn pdf_with_progress<F>(
137    field: &[f64],
138    mask: &[u8],
139    nx: usize, ny: usize, nz: usize,
140    vsx: f64, vsy: f64, vsz: f64,
141    bdir: (f64, f64, f64),
142    tol: f64,
143    max_iter: usize,
144    mut progress_callback: F,
145) -> Vec<f64>
146where
147    F: FnMut(usize, usize),
148{
149    let n_total = nx * ny * nz;
150
151    // Generate dipole kernel
152    let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
153
154    // Create background mask (complement of brain mask)
155    let bg_mask: Vec<f64> = mask.iter()
156        .map(|&m| if m == 0 { 1.0 } else { 0.0 })
157        .collect();
158
159    // Brain mask as f64
160    let brain_mask: Vec<f64> = mask.iter()
161        .map(|&m| if m != 0 { 1.0 } else { 0.0 })
162        .collect();
163
164    // RHS: b = W * f where W is brain mask weights
165    let b: Vec<f64> = field.iter()
166        .zip(brain_mask.iter())
167        .map(|(&f, &w)| f * w)
168        .collect();
169
170    // Initialize solution
171    let mut x = vec![0.0; n_total];
172
173    let mut u = b.clone();
174    let mut beta = vec_norm(&u);
175
176    if beta < 1e-20 {
177        return vec![0.0; n_total];
178    }
179
180    for i in 0..n_total {
181        u[i] /= beta;
182    }
183
184    let mut v = apply_at(&u, &brain_mask, &d_kernel, &bg_mask, nx, ny, nz);
185    let mut alpha = vec_norm(&v);
186
187    if alpha < 1e-20 {
188        return vec![0.0; n_total];
189    }
190
191    for i in 0..n_total {
192        v[i] /= alpha;
193    }
194
195    // LSMR variables (following Fong & Saunders 2011 / QSM.jl implementation)
196    let norm_b = beta;
197    let mut w = v.clone();
198    let mut phi_bar = beta;
199    let mut rho_bar = alpha;
200
201    // Variables for ||r|| estimation (matching Julia's lsmr.jl)
202    let mut beta_dd = beta;
203    let mut beta_d = 0.0;
204    let mut rho_d_old = 1.0;
205    let mut tau_tilde_old = 0.0;
206    let mut theta_tilde = 0.0;
207    let mut zeta = 0.0;
208    let d_accum = 0.0; // beta_check^2 accumulator; always 0 when lambda=0
209
210    // Variables for ||A|| estimation
211    let mut norm_a2 = alpha * alpha;
212
213    // Variables for the QR factorization (needed for ||r|| recurrence)
214    let mut zeta_bar = alpha * beta;
215    let mut alpha_bar = alpha;
216    let mut c_bar = 1.0;
217    let mut s_bar = 0.0;
218    let mut rho_val = 1.0;
219    let mut rho_bar_lsmr = 1.0;
220
221    for iter in 0..max_iter {
222        // Report progress
223        progress_callback(iter + 1, max_iter);
224
225        // Bidiagonalization
226        let av = apply_a(&v, &bg_mask, &d_kernel, &brain_mask, nx, ny, nz);
227        for i in 0..n_total {
228            u[i] = av[i] - alpha * u[i];
229        }
230        beta = vec_norm(&u);
231
232        if beta < 1e-20 {
233            progress_callback(iter + 1, iter + 1);
234            break;
235        }
236
237        for i in 0..n_total {
238            u[i] /= beta;
239        }
240
241        let atu = apply_at(&u, &brain_mask, &d_kernel, &bg_mask, nx, ny, nz);
242        for i in 0..n_total {
243            v[i] = atu[i] - beta * v[i];
244        }
245        alpha = vec_norm(&v);
246
247        if alpha < 1e-20 {
248            progress_callback(iter + 1, iter + 1);
249            break;
250        }
251
252        for i in 0..n_total {
253            v[i] /= alpha;
254        }
255
256        // Construct and apply rotation Q_i (for solution update)
257        let rho = (rho_bar * rho_bar + beta * beta).sqrt();
258        let c = rho_bar / rho;
259        let s = beta / rho;
260        let theta = s * alpha;
261        rho_bar = -c * alpha;
262        let phi = c * phi_bar;
263        phi_bar = s * phi_bar;
264
265        // Update x and w
266        let phi_rho = phi / rho;
267        let theta_rho = theta / rho;
268
269        for i in 0..n_total {
270            x[i] += phi_rho * w[i];
271            w[i] = v[i] - theta_rho * w[i];
272        }
273
274        // LSMR QR factorization (matching Julia lsmr.jl, with lambda=0)
275        let _rho_old = rho_val;
276        rho_val = (alpha_bar * alpha_bar + beta * beta).sqrt();
277        let c_lsmr = alpha_bar / rho_val;
278        let s_lsmr = beta / rho_val;
279        let theta_new = s_lsmr * alpha;
280        alpha_bar = c_lsmr * alpha;
281
282        let _rho_bar_old = rho_bar_lsmr;
283        let zeta_old = zeta;
284        let theta_bar = s_bar * rho_val;
285        let rho_tmp = c_bar * rho_val;
286        rho_bar_lsmr = (rho_tmp * rho_tmp + theta_new * theta_new).sqrt();
287        c_bar = rho_tmp / rho_bar_lsmr;
288        s_bar = theta_new / rho_bar_lsmr;
289        zeta = c_bar * zeta_bar;
290        zeta_bar = -s_bar * zeta_bar;
291
292        // Estimate ||r|| (matching Julia lsmr.jl lines 264-287, with lambda=0)
293        let beta_hat = c_lsmr * beta_dd;
294        beta_dd = -s_lsmr * beta_dd;
295
296        let theta_tilde_old = theta_tilde;
297        let rho_tilde_old = (rho_d_old * rho_d_old + theta_bar * theta_bar).sqrt();
298        let c_tilde_old = rho_d_old / rho_tilde_old;
299        let s_tilde_old = theta_bar / rho_tilde_old;
300        theta_tilde = s_tilde_old * rho_bar_lsmr;
301        rho_d_old = c_tilde_old * rho_bar_lsmr;
302        beta_d = -s_tilde_old * beta_d + c_tilde_old * beta_hat;
303
304        tau_tilde_old = (zeta_old - theta_tilde_old * tau_tilde_old) / rho_tilde_old;
305        let tau_d = (zeta - theta_tilde * tau_tilde_old) / rho_d_old;
306
307        let norm_r = (d_accum + (beta_d - tau_d).powi(2) + beta_dd.powi(2)).sqrt();
308
309        // Estimate ||A||
310        norm_a2 += beta * beta;
311        let norm_a = norm_a2.sqrt();
312        norm_a2 += alpha * alpha;
313
314        // ||A'r|| estimate
315        let norm_ar = zeta_bar.abs();
316
317        // Convergence tests (matching Julia lsmr.jl lines 308-318)
318        let norm_x = vec_norm(&x);
319        let test1 = norm_r / norm_b;
320        let test2 = if norm_a * norm_r > 0.0 { norm_ar / (norm_a * norm_r) } else { 0.0 };
321        let eps_r = tol + tol * norm_a * norm_x / norm_b;
322
323        if test1 <= eps_r || test2 <= tol {
324            progress_callback(iter + 1, iter + 1);
325            break;
326        }
327    }
328
329    // Compute background field: b_field = D * (M_bg * x)
330    let mut bg_source: Vec<Complex64> = x.iter()
331        .zip(bg_mask.iter())
332        .map(|(&xi, &m)| Complex64::new(xi * m, 0.0))
333        .collect();
334
335    fft3d(&mut bg_source, nx, ny, nz);
336
337    for i in 0..n_total {
338        bg_source[i] *= d_kernel[i];
339    }
340
341    ifft3d(&mut bg_source, nx, ny, nz);
342
343    // Local field = total field - background field, masked
344    let mut local_field = vec![0.0; n_total];
345    for i in 0..n_total {
346        if mask[i] != 0 {
347            local_field[i] = field[i] - bg_source[i].re;
348        }
349    }
350
351    local_field
352}
353
354/// PDF with default parameters (adaptive max_iter matching QSM.jl: ceil(sqrt(nx*ny*nz)))
355pub fn pdf_default(
356    field: &[f64],
357    mask: &[u8],
358    nx: usize, ny: usize, nz: usize,
359    vsx: f64, vsy: f64, vsz: f64,
360) -> Vec<f64> {
361    let max_iter = ((nx * ny * nz) as f64).sqrt().ceil() as usize;
362    pdf(
363        field, mask, nx, ny, nz, vsx, vsy, vsz,
364        (0.0, 0.0, 1.0),  // bdir
365        1e-5,              // tol
366        max_iter           // adaptive max_iter
367    )
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn test_pdf_zero_field() {
376        let n = 8;
377        let field = vec![0.0; n * n * n];
378        let mask = vec![1u8; n * n * n];
379
380        let local = pdf(
381            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
382            (0.0, 0.0, 1.0), 1e-5, 10
383        );
384
385        for &val in local.iter() {
386            assert!(val.abs() < 1e-10, "Zero field should give zero local field");
387        }
388    }
389
390    #[test]
391    fn test_pdf_finite() {
392        let n = 8;
393        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
394
395        // Create a spherical mask in the center
396        let mut mask = vec![0u8; n * n * n];
397        let center = n / 2;
398        let radius = n / 4;
399
400        for i in 0..n {
401            for j in 0..n {
402                for k in 0..n {
403                    let di = (i as i32) - (center as i32);
404                    let dj = (j as i32) - (center as i32);
405                    let dk = (k as i32) - (center as i32);
406                    if di*di + dj*dj + dk*dk <= (radius * radius) as i32 {
407                        mask[i * n * n + j * n + k] = 1;
408                    }
409                }
410            }
411        }
412
413        let local = pdf(
414            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
415            (0.0, 0.0, 1.0), 1e-5, 20
416        );
417
418        for (i, &val) in local.iter().enumerate() {
419            assert!(val.is_finite(), "Local field should be finite at index {}", i);
420        }
421    }
422
423    #[test]
424    fn test_pdf_mask() {
425        let n = 8;
426        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
427        let mut mask = vec![1u8; n * n * n];
428        mask[0] = 0;
429        mask[10] = 0;
430
431        let local = pdf(
432            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
433            (0.0, 0.0, 1.0), 1e-5, 10
434        );
435
436        assert_eq!(local[0], 0.0, "Masked voxel should be zero");
437        assert_eq!(local[10], 0.0, "Masked voxel should be zero");
438    }
439
440    #[test]
441    fn test_pdf_nonuniform_voxels() {
442        let n = 8;
443        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
444
445        // Create a spherical mask in the center
446        let mut mask = vec![0u8; n * n * n];
447        let center = n / 2;
448        let radius = n / 4;
449
450        for z in 0..n {
451            for y in 0..n {
452                for x in 0..n {
453                    let dx = (x as i32) - (center as i32);
454                    let dy = (y as i32) - (center as i32);
455                    let dz = (z as i32) - (center as i32);
456                    if dx*dx + dy*dy + dz*dz <= (radius * radius) as i32 {
457                        mask[x + y * n + z * n * n] = 1;
458                    }
459                }
460            }
461        }
462
463        // Anisotropic voxel sizes
464        let local = pdf(
465            &field, &mask, n, n, n, 0.5, 1.0, 2.0,
466            (0.0, 0.0, 1.0), 1e-5, 20
467        );
468
469        for (i, &val) in local.iter().enumerate() {
470            assert!(val.is_finite(), "PDF with nonuniform voxels should be finite at index {}", i);
471        }
472
473        // Masked voxels should be zero
474        for i in 0..n*n*n {
475            if mask[i] == 0 {
476                assert_eq!(local[i], 0.0, "Outside mask should be zero");
477            }
478        }
479    }
480
481    #[test]
482    fn test_pdf_varying_field() {
483        let n = 8;
484
485        // Create a spatially varying field (quadratic in z)
486        let mut field = vec![0.0; n * n * n];
487        for z in 0..n {
488            for y in 0..n {
489                for x in 0..n {
490                    let idx = x + y * n + z * n * n;
491                    let zf = (z as f64) / (n as f64);
492                    field[idx] = zf * zf * 0.5;
493                }
494            }
495        }
496
497        // Spherical mask
498        let mut mask = vec![0u8; n * n * n];
499        let center = n / 2;
500        let radius = n / 4;
501        for z in 0..n {
502            for y in 0..n {
503                for x in 0..n {
504                    let dx = (x as i32) - (center as i32);
505                    let dy = (y as i32) - (center as i32);
506                    let dz = (z as i32) - (center as i32);
507                    if dx*dx + dy*dy + dz*dz <= (radius * radius) as i32 {
508                        mask[x + y * n + z * n * n] = 1;
509                    }
510                }
511            }
512        }
513
514        let local = pdf(
515            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
516            (0.0, 0.0, 1.0), 1e-5, 30
517        );
518
519        // All values should be finite
520        for (i, &val) in local.iter().enumerate() {
521            assert!(val.is_finite(), "Varying field should produce finite results at index {}", i);
522        }
523
524        // The local field inside the mask should differ from the total field
525        // (some background was removed)
526        let mut any_changed = false;
527        for i in 0..n*n*n {
528            if mask[i] != 0 && (local[i] - field[i]).abs() > 1e-10 {
529                any_changed = true;
530                break;
531            }
532        }
533        assert!(any_changed, "PDF should modify the field inside the mask for a varying field");
534    }
535
536    #[test]
537    fn test_pdf_larger_volume() {
538        // Use 16x16x16 to exercise more of the LSMR loop
539        let n = 16;
540
541        // Create a dipole-like field pattern
542        let mut field = vec![0.0; n * n * n];
543        let center = n / 2;
544        for z in 0..n {
545            for y in 0..n {
546                for x in 0..n {
547                    let dx = (x as f64) - (center as f64);
548                    let dy = (y as f64) - (center as f64);
549                    let dz = (z as f64) - (center as f64);
550                    let r2 = dx * dx + dy * dy + dz * dz;
551                    if r2 > 0.5 {
552                        // Dipole-like field: (3*cos^2(theta) - 1) / r^3
553                        let r = r2.sqrt();
554                        let cos_theta = dz / r;
555                        field[x + y * n + z * n * n] = (3.0 * cos_theta * cos_theta - 1.0) / (r * r * r) * 0.01;
556                    }
557                }
558            }
559        }
560
561        // Spherical mask
562        let mut mask = vec![0u8; n * n * n];
563        let radius = n / 3;
564        for z in 0..n {
565            for y in 0..n {
566                for x in 0..n {
567                    let dx = (x as i32) - (center as i32);
568                    let dy = (y as i32) - (center as i32);
569                    let dz = (z as i32) - (center as i32);
570                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
571                        mask[x + y * n + z * n * n] = 1;
572                    }
573                }
574            }
575        }
576
577        let local = pdf(
578            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
579            (0.0, 0.0, 1.0), 1e-4, 50
580        );
581
582        assert_eq!(local.len(), n * n * n);
583        for (i, &val) in local.iter().enumerate() {
584            assert!(val.is_finite(), "PDF larger volume: finite at index {}", i);
585        }
586
587        // Masked-out voxels must be zero
588        for i in 0..n * n * n {
589            if mask[i] == 0 {
590                assert_eq!(local[i], 0.0);
591            }
592        }
593    }
594
595    #[test]
596    fn test_pdf_more_iterations() {
597        // Test with more LSMR iterations to exercise convergence check
598        let n = 8;
599        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
600
601        let mut mask = vec![0u8; n * n * n];
602        let center = n / 2;
603        let radius = n / 4;
604        for z in 0..n {
605            for y in 0..n {
606                for x in 0..n {
607                    let dx = (x as i32) - (center as i32);
608                    let dy = (y as i32) - (center as i32);
609                    let dz = (z as i32) - (center as i32);
610                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
611                        mask[x + y * n + z * n * n] = 1;
612                    }
613                }
614            }
615        }
616
617        // With many iterations the result should converge
618        let local_many = pdf(
619            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
620            (0.0, 0.0, 1.0), 1e-8, 100
621        );
622
623        let local_few = pdf(
624            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
625            (0.0, 0.0, 1.0), 1e-8, 5
626        );
627
628        // Both should be finite
629        for &val in local_many.iter().chain(local_few.iter()) {
630            assert!(val.is_finite());
631        }
632    }
633
634    #[test]
635    fn test_pdf_different_bdir() {
636        // Test with non-standard B0 direction
637        let n = 8;
638        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
639
640        let mut mask = vec![0u8; n * n * n];
641        let center = n / 2;
642        let radius = n / 4;
643        for z in 0..n {
644            for y in 0..n {
645                for x in 0..n {
646                    let dx = (x as i32) - (center as i32);
647                    let dy = (y as i32) - (center as i32);
648                    let dz = (z as i32) - (center as i32);
649                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
650                        mask[x + y * n + z * n * n] = 1;
651                    }
652                }
653            }
654        }
655
656        // Tilted B0 direction
657        let local = pdf(
658            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
659            (0.1, 0.2, 0.97), 1e-5, 20
660        );
661
662        for &val in &local {
663            assert!(val.is_finite());
664        }
665    }
666
667    #[test]
668    fn test_pdf_with_progress() {
669        let n = 8;
670        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
671
672        let mut mask = vec![0u8; n * n * n];
673        let center = n / 2;
674        let radius = n / 4;
675        for z in 0..n {
676            for y in 0..n {
677                for x in 0..n {
678                    let dx = (x as i32) - (center as i32);
679                    let dy = (y as i32) - (center as i32);
680                    let dz = (z as i32) - (center as i32);
681                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
682                        mask[x + y * n + z * n * n] = 1;
683                    }
684                }
685            }
686        }
687
688        let mut progress_calls = Vec::new();
689        let local = pdf_with_progress(
690            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
691            (0.0, 0.0, 1.0), 1e-5, 20,
692            |iter, max| { progress_calls.push((iter, max)); }
693        );
694
695        assert_eq!(local.len(), n * n * n);
696        assert!(!progress_calls.is_empty(), "Progress should be called at least once");
697        for &val in &local {
698            assert!(val.is_finite());
699        }
700    }
701
702    #[test]
703    fn test_pdf_default_wrapper() {
704        let n = 8;
705        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
706
707        let mut mask = vec![0u8; n * n * n];
708        let center = n / 2;
709        let radius = n / 4;
710        for z in 0..n {
711            for y in 0..n {
712                for x in 0..n {
713                    let dx = (x as i32) - (center as i32);
714                    let dy = (y as i32) - (center as i32);
715                    let dz = (z as i32) - (center as i32);
716                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
717                        mask[x + y * n + z * n * n] = 1;
718                    }
719                }
720            }
721        }
722
723        let local = pdf_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
724        assert_eq!(local.len(), n * n * n);
725        for &val in &local {
726            assert!(val.is_finite());
727        }
728    }
729
730    #[test]
731    fn test_pdf_all_mask() {
732        // All voxels masked (no background) - should still work
733        let n = 8;
734        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.001).collect();
735        let mask = vec![1u8; n * n * n];
736
737        let local = pdf(
738            &field, &mask, n, n, n, 1.0, 1.0, 1.0,
739            (0.0, 0.0, 1.0), 1e-5, 10
740        );
741
742        // With all voxels as "brain", the background mask is empty
743        // so there's nothing to project onto => local should approximate field
744        for &val in &local {
745            assert!(val.is_finite());
746        }
747    }
748}