Skip to main content

qsm_core/bgremove/
lbv.rs

1//! Laplacian Boundary Value (LBV) background field removal
2//!
3//! LBV removes background fields by solving the Laplace equation inside the mask
4//! with Dirichlet boundary conditions from the total field at the mask boundary.
5//!
6//! The method exploits that background fields satisfy ∇²b = 0 inside the ROI.
7//!
8//! Reference:
9//! Zhou, D., Liu, T., Spincemaille, P., Wang, Y. (2014).
10//! "Background field removal by solving the Laplacian boundary value problem."
11//! NMR in Biomedicine, 27(3):312-319. https://doi.org/10.1002/nbm.3064
12//!
13//! Reference implementation: https://github.com/kamesy/QSM.jl
14
15/// LBV background field removal
16///
17/// Solves ∇²b = 0 inside mask with b = f on boundary to find background field,
18/// then computes local field as l = f - b.
19///
20/// # Arguments
21/// * `field` - Total field (nx * ny * nz)
22/// * `mask` - Binary mask (nx * ny * nz), 1 = brain, 0 = background
23/// * `nx`, `ny`, `nz` - Array dimensions
24/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm
25/// * `tol` - Convergence tolerance for iterative solver
26/// * `max_iter` - Maximum iterations
27///
28/// LBV algorithm parameters
29#[derive(Clone, Debug)]
30pub struct LbvParams {
31    /// Convergence tolerance
32    pub tol: f64,
33}
34
35impl Default for LbvParams {
36    fn default() -> Self {
37        Self { tol: 1e-6 }
38    }
39}
40
41/// # Returns
42/// Tuple of (local_field, eroded_mask)
43pub fn lbv(
44    field: &[f64],
45    mask: &[u8],
46    nx: usize, ny: usize, nz: usize,
47    vsx: f64, vsy: f64, vsz: f64,
48    tol: f64,
49    max_iter: usize,
50) -> (Vec<f64>, Vec<u8>) {
51    let n_total = nx * ny * nz;
52
53    // Compute inverse squared voxel sizes for Laplacian
54    let dx2_inv = 1.0 / (vsx * vsx);
55    let dy2_inv = 1.0 / (vsy * vsy);
56    let dz2_inv = 1.0 / (vsz * vsz);
57    let diag = -2.0 * (dx2_inv + dy2_inv + dz2_inv);
58
59    // Find interior and boundary voxels
60    // Interior: mask=1 and all 6 neighbors have mask=1
61    // Boundary: mask=1 and at least one neighbor has mask=0
62    let mut interior = vec![false; n_total];
63    let mut boundary = vec![false; n_total];
64
65    for z in 1..(nz - 1) {
66        for y in 1..(ny - 1) {
67            for x in 1..(nx - 1) {
68                let idx = x + y * nx + z * nx * ny;
69                if mask[idx] == 0 {
70                    continue;
71                }
72
73                // Check 6-connected neighbors
74                let neighbors = [
75                    idx.wrapping_sub(1),      // x-1
76                    idx + 1,                  // x+1
77                    idx.wrapping_sub(nx),     // y-1
78                    idx + nx,                 // y+1
79                    idx.wrapping_sub(nx * ny), // z-1
80                    idx + nx * ny,            // z+1
81                ];
82
83                let all_inside = neighbors.iter().all(|&n| n < n_total && mask[n] != 0);
84
85                if all_inside {
86                    interior[idx] = true;
87                } else {
88                    boundary[idx] = true;
89                }
90            }
91        }
92    }
93
94    // Edge voxels are boundary by definition
95    for z in 0..nz {
96        for y in 0..ny {
97            for x in 0..nx {
98                if z == 0 || z == nz - 1 || y == 0 || y == ny - 1 || x == 0 || x == nx - 1 {
99                    let idx = x + y * nx + z * nx * ny;
100                    if mask[idx] != 0 {
101                        boundary[idx] = true;
102                        interior[idx] = false;
103                    }
104                }
105            }
106        }
107    }
108
109    // Initialize background field with total field
110    // Background field = total field on boundary, solve for interior
111    let mut bg_field = field.to_vec();
112
113    // Compute field scale for relative convergence criterion
114    // This makes convergence independent of field units (Hz vs ppm)
115    // Matches QSM.jl's rtol behavior
116    let field_scale = field.iter()
117        .map(|&v| v.abs())
118        .fold(0.0f64, f64::max)
119        .max(1.0); // floor at 1.0 to avoid division issues with zero fields
120    let scaled_tol = tol * field_scale;
121
122    // Solve ∇²b = 0 on interior voxels using Gauss-Seidel with over-relaxation
123    // The boundary values are fixed (Dirichlet BC)
124    let omega = 1.5; // Over-relaxation parameter
125
126    for _iter in 0..max_iter {
127        let mut max_change = 0.0f64;
128
129        for z in 1..(nz - 1) {
130            for y in 1..(ny - 1) {
131                for x in 1..(nx - 1) {
132                    let idx = x + y * nx + z * nx * ny;
133
134                    if !interior[idx] {
135                        continue;
136                    }
137
138                    // Compute Laplacian stencil weighted sum
139                    let sum = dx2_inv * (bg_field[idx - 1] + bg_field[idx + 1])
140                            + dy2_inv * (bg_field[idx - nx] + bg_field[idx + nx])
141                            + dz2_inv * (bg_field[idx - nx * ny] + bg_field[idx + nx * ny]);
142
143                    // Gauss-Seidel update: solve diag * b_new = -sum
144                    let new_val = -sum / diag;
145
146                    // SOR update
147                    let old_val = bg_field[idx];
148                    let updated = old_val + omega * (new_val - old_val);
149
150                    max_change = max_change.max((updated - old_val).abs());
151                    bg_field[idx] = updated;
152                }
153            }
154        }
155
156        // Check convergence using relative tolerance (scale-independent)
157        if max_change < scaled_tol {
158            break;
159        }
160    }
161
162    // Compute local field = total field - background field
163    let mut local_field = vec![0.0; n_total];
164    let mut eroded_mask = vec![0u8; n_total];
165
166    for i in 0..n_total {
167        if interior[i] {
168            local_field[i] = field[i] - bg_field[i];
169            eroded_mask[i] = 1;
170        }
171    }
172
173    (local_field, eroded_mask)
174}
175
176/// LBV with progress callback
177pub fn lbv_with_progress<F>(
178    field: &[f64],
179    mask: &[u8],
180    nx: usize, ny: usize, nz: usize,
181    vsx: f64, vsy: f64, vsz: f64,
182    tol: f64,
183    max_iter: usize,
184    mut progress_callback: F,
185) -> (Vec<f64>, Vec<u8>)
186where
187    F: FnMut(usize, usize),
188{
189    let n_total = nx * ny * nz;
190
191    let dx2_inv = 1.0 / (vsx * vsx);
192    let dy2_inv = 1.0 / (vsy * vsy);
193    let dz2_inv = 1.0 / (vsz * vsz);
194    let diag = -2.0 * (dx2_inv + dy2_inv + dz2_inv);
195
196    let mut interior = vec![false; n_total];
197    let mut boundary = vec![false; n_total];
198
199    for z in 1..(nz - 1) {
200        for y in 1..(ny - 1) {
201            for x in 1..(nx - 1) {
202                let idx = x + y * nx + z * nx * ny;
203                if mask[idx] == 0 {
204                    continue;
205                }
206
207                let neighbors = [
208                    idx.wrapping_sub(1),
209                    idx + 1,
210                    idx.wrapping_sub(nx),
211                    idx + nx,
212                    idx.wrapping_sub(nx * ny),
213                    idx + nx * ny,
214                ];
215
216                let all_inside = neighbors.iter().all(|&n| n < n_total && mask[n] != 0);
217
218                if all_inside {
219                    interior[idx] = true;
220                } else {
221                    boundary[idx] = true;
222                }
223            }
224        }
225    }
226
227    for z in 0..nz {
228        for y in 0..ny {
229            for x in 0..nx {
230                if z == 0 || z == nz - 1 || y == 0 || y == ny - 1 || x == 0 || x == nx - 1 {
231                    let idx = x + y * nx + z * nx * ny;
232                    if mask[idx] != 0 {
233                        boundary[idx] = true;
234                        interior[idx] = false;
235                    }
236                }
237            }
238        }
239    }
240
241    let mut bg_field = field.to_vec();
242
243    // Compute field scale for relative convergence criterion
244    let field_scale = field.iter()
245        .map(|&v| v.abs())
246        .fold(0.0f64, f64::max)
247        .max(1.0);
248    let scaled_tol = tol * field_scale;
249
250    let omega = 1.5;
251
252    for iter in 0..max_iter {
253        if iter % 10 == 0 {
254            progress_callback(iter, max_iter);
255        }
256
257        let mut max_change = 0.0f64;
258
259        for z in 1..(nz - 1) {
260            for y in 1..(ny - 1) {
261                for x in 1..(nx - 1) {
262                    let idx = x + y * nx + z * nx * ny;
263
264                    if !interior[idx] {
265                        continue;
266                    }
267
268                    let sum = dx2_inv * (bg_field[idx - 1] + bg_field[idx + 1])
269                            + dy2_inv * (bg_field[idx - nx] + bg_field[idx + nx])
270                            + dz2_inv * (bg_field[idx - nx * ny] + bg_field[idx + nx * ny]);
271
272                    let new_val = -sum / diag;
273                    let old_val = bg_field[idx];
274                    let updated = old_val + omega * (new_val - old_val);
275
276                    max_change = max_change.max((updated - old_val).abs());
277                    bg_field[idx] = updated;
278                }
279            }
280        }
281
282        if max_change < scaled_tol {
283            progress_callback(iter + 1, iter + 1);
284            break;
285        }
286    }
287
288    progress_callback(max_iter, max_iter);
289
290    let mut local_field = vec![0.0; n_total];
291    let mut eroded_mask = vec![0u8; n_total];
292
293    for i in 0..n_total {
294        if interior[i] {
295            local_field[i] = field[i] - bg_field[i];
296            eroded_mask[i] = 1;
297        }
298    }
299
300    (local_field, eroded_mask)
301}
302
303/// LBV with default parameters
304///
305/// Note: QSM.jl uses multigrid-preconditioned CG which converges in ~max(dims)
306/// iterations. Our Gauss-Seidel SOR needs more iterations, so we use 3*max(dims)
307/// with a floor of 500.
308pub fn lbv_default(
309    field: &[f64],
310    mask: &[u8],
311    nx: usize, ny: usize, nz: usize,
312    vsx: f64, vsy: f64, vsz: f64,
313) -> (Vec<f64>, Vec<u8>) {
314    let max_iter = (3 * nx.max(ny).max(nz)).max(500);
315    lbv(field, mask, nx, ny, nz, vsx, vsy, vsz, 1e-6, max_iter)
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_lbv_zero_field() {
324        let n = 8;
325        let field = vec![0.0; n * n * n];
326        let mask = vec![1u8; n * n * n];
327
328        let (local, _) = lbv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 1e-6, 100);
329
330        for &val in local.iter() {
331            assert!(val.abs() < 1e-10, "Zero field should give zero local field");
332        }
333    }
334
335    #[test]
336    fn test_lbv_finite() {
337        let n = 16;
338        let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
339
340        // Create spherical mask
341        let mut mask = vec![0u8; n * n * n];
342        let center = n / 2;
343        let radius = n / 3;
344
345        for z in 0..n {
346            for y in 0..n {
347                for x in 0..n {
348                    let dx = (x as i32) - (center as i32);
349                    let dy = (y as i32) - (center as i32);
350                    let dz = (z as i32) - (center as i32);
351                    if dx*dx + dy*dy + dz*dz <= (radius * radius) as i32 {
352                        mask[x + y * n + z * n * n] = 1;
353                    }
354                }
355            }
356        }
357
358        let (local, eroded_mask) = lbv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 1e-5, 100);
359
360        for (i, &val) in local.iter().enumerate() {
361            assert!(val.is_finite(), "Local field should be finite at index {}", i);
362        }
363
364        // Eroded mask should be smaller than original
365        let eroded_count: usize = eroded_mask.iter().map(|&x| x as usize).sum();
366        let mask_count: usize = mask.iter().map(|&x| x as usize).sum();
367        assert!(eroded_count <= mask_count, "Eroded mask should be <= original mask");
368        assert!(eroded_count > 0, "Eroded mask should not be empty for reasonable-sized input");
369    }
370
371    #[test]
372    fn test_lbv_harmonic_removal() {
373        // Create a harmonic background field (satisfies ∇²b = 0)
374        // and verify LBV removes it
375        let n = 16;
376        let mut field = vec![0.0; n * n * n];
377
378        // Add linear field (which is harmonic)
379        for z in 0..n {
380            for y in 0..n {
381                for x in 0..n {
382                    let idx = x + y * n + z * n * n;
383                    field[idx] = (z as f64) * 0.1; // Linear in z
384                }
385            }
386        }
387
388        // Create spherical mask
389        let mut mask = vec![0u8; n * n * n];
390        let center = n / 2;
391        let radius = n / 3;
392
393        for z in 0..n {
394            for y in 0..n {
395                for x in 0..n {
396                    let dx = (x as i32) - (center as i32);
397                    let dy = (y as i32) - (center as i32);
398                    let dz = (z as i32) - (center as i32);
399                    if dx*dx + dy*dy + dz*dz <= (radius * radius) as i32 {
400                        mask[x + y * n + z * n * n] = 1;
401                    }
402                }
403            }
404        }
405
406        let (local, eroded_mask) = lbv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 1e-6, 500);
407
408        // Local field should be close to zero for interior voxels
409        // since the input is purely harmonic
410        let mut max_local: f64 = 0.0;
411        for i in 0..n*n*n {
412            if eroded_mask[i] != 0 {
413                max_local = max_local.max(local[i].abs());
414            }
415        }
416
417        // Allow some tolerance due to discrete Laplacian
418        assert!(max_local < 0.5, "Harmonic field should be mostly removed, got max {}", max_local);
419    }
420
421    #[test]
422    fn test_lbv_nonuniform_voxels() {
423        let n = 16;
424
425        // Linear field
426        let mut field = vec![0.0; n * n * n];
427        for z in 0..n {
428            for y in 0..n {
429                for x in 0..n {
430                    let idx = x + y * n + z * n * n;
431                    field[idx] = (z as f64) * 0.1;
432                }
433            }
434        }
435
436        // Spherical mask
437        let mut mask = vec![0u8; n * n * n];
438        let center = n / 2;
439        let radius = n / 3;
440        for z in 0..n {
441            for y in 0..n {
442                for x in 0..n {
443                    let dx = (x as i32) - (center as i32);
444                    let dy = (y as i32) - (center as i32);
445                    let dz = (z as i32) - (center as i32);
446                    if dx*dx + dy*dy + dz*dz <= (radius * radius) as i32 {
447                        mask[x + y * n + z * n * n] = 1;
448                    }
449                }
450            }
451        }
452
453        // Use anisotropic voxel sizes
454        let (local, eroded_mask) = lbv(
455            &field, &mask, n, n, n, 0.5, 1.0, 2.0, 1e-5, 200
456        );
457
458        // All values should be finite
459        for (i, &val) in local.iter().enumerate() {
460            assert!(val.is_finite(), "LBV nonuniform voxels: finite at index {}", i);
461        }
462
463        // Eroded mask should have some voxels
464        let eroded_count: usize = eroded_mask.iter().map(|&x| x as usize).sum();
465        assert!(eroded_count > 0, "LBV nonuniform: eroded mask should not be empty");
466    }
467
468    #[test]
469    fn test_lbv_tolerance() {
470        let n = 16;
471
472        // Linear field (harmonic)
473        let mut field = vec![0.0; n * n * n];
474        for z in 0..n {
475            for y in 0..n {
476                for x in 0..n {
477                    let idx = x + y * n + z * n * n;
478                    field[idx] = (z as f64) * 0.1;
479                }
480            }
481        }
482
483        // Spherical mask
484        let mut mask = vec![0u8; n * n * n];
485        let center = n / 2;
486        let radius = n / 3;
487        for z in 0..n {
488            for y in 0..n {
489                for x in 0..n {
490                    let dx = (x as i32) - (center as i32);
491                    let dy = (y as i32) - (center as i32);
492                    let dz = (z as i32) - (center as i32);
493                    if dx*dx + dy*dy + dz*dz <= (radius * radius) as i32 {
494                        mask[x + y * n + z * n * n] = 1;
495                    }
496                }
497            }
498        }
499
500        // Tight tolerance should produce better harmonic removal
501        let (local_tight, _) = lbv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 1e-8, 1000);
502
503        // Loose tolerance
504        let (local_loose, _) = lbv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 1e-2, 50);
505
506        // Compute max residual for each
507        let max_tight: f64 = local_tight.iter()
508            .zip(mask.iter())
509            .filter(|(_, &m)| m != 0)
510            .map(|(&v, _)| v.abs())
511            .fold(0.0f64, f64::max);
512
513        let max_loose: f64 = local_loose.iter()
514            .zip(mask.iter())
515            .filter(|(_, &m)| m != 0)
516            .map(|(&v, _)| v.abs())
517            .fold(0.0f64, f64::max);
518
519        // Tight tolerance should give at least as good results as loose
520        assert!(
521            max_tight <= max_loose + 1e-6,
522            "Tight tolerance max={} should be <= loose tolerance max={}",
523            max_tight, max_loose
524        );
525    }
526
527    #[test]
528    fn test_lbv_with_progress() {
529        let n = 16;
530        let mut field = vec![0.0; n * n * n];
531        for z in 0..n {
532            for y in 0..n {
533                for x in 0..n {
534                    field[x + y * n + z * n * n] = (z as f64) * 0.1;
535                }
536            }
537        }
538
539        let mut mask = vec![0u8; n * n * n];
540        let center = n / 2;
541        let radius = n / 3;
542        for z in 0..n {
543            for y in 0..n {
544                for x in 0..n {
545                    let dx = (x as i32) - (center as i32);
546                    let dy = (y as i32) - (center as i32);
547                    let dz = (z as i32) - (center as i32);
548                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
549                        mask[x + y * n + z * n * n] = 1;
550                    }
551                }
552            }
553        }
554
555        let mut progress_calls = Vec::new();
556        let (local, eroded) = lbv_with_progress(
557            &field, &mask, n, n, n, 1.0, 1.0, 1.0, 1e-5, 200,
558            |iter, max| { progress_calls.push((iter, max)); }
559        );
560
561        assert_eq!(local.len(), n * n * n);
562        assert!(!progress_calls.is_empty(), "Progress callback should be called");
563        for &val in &local {
564            assert!(val.is_finite());
565        }
566        let eroded_count: usize = eroded.iter().map(|&x| x as usize).sum();
567        assert!(eroded_count > 0);
568    }
569
570    #[test]
571    fn test_lbv_default_wrapper() {
572        let n = 16;
573        let mut field = vec![0.0; n * n * n];
574        for z in 0..n {
575            for y in 0..n {
576                for x in 0..n {
577                    field[x + y * n + z * n * n] = (z as f64) * 0.1;
578                }
579            }
580        }
581
582        let mut mask = vec![0u8; n * n * n];
583        let center = n / 2;
584        let radius = n / 3;
585        for z in 0..n {
586            for y in 0..n {
587                for x in 0..n {
588                    let dx = (x as i32) - (center as i32);
589                    let dy = (y as i32) - (center as i32);
590                    let dz = (z as i32) - (center as i32);
591                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
592                        mask[x + y * n + z * n * n] = 1;
593                    }
594                }
595            }
596        }
597
598        let (local, eroded) = lbv_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
599        assert_eq!(local.len(), n * n * n);
600        for &val in &local {
601            assert!(val.is_finite());
602        }
603        let eroded_count: usize = eroded.iter().map(|&x| x as usize).sum();
604        assert!(eroded_count > 0);
605    }
606
607    #[test]
608    fn test_lbv_non_harmonic_field() {
609        // A non-harmonic field should not be fully removed
610        let n = 16;
611        let mut field = vec![0.0; n * n * n];
612        // Quadratic field: x^2 - not harmonic in 3D (Laplacian = 2)
613        for z in 0..n {
614            for y in 0..n {
615                for x in 0..n {
616                    let xf = (x as f64) / (n as f64);
617                    field[x + y * n + z * n * n] = xf * xf;
618                }
619            }
620        }
621
622        let mut mask = vec![0u8; n * n * n];
623        let center = n / 2;
624        let radius = n / 3;
625        for z in 0..n {
626            for y in 0..n {
627                for x in 0..n {
628                    let dx = (x as i32) - (center as i32);
629                    let dy = (y as i32) - (center as i32);
630                    let dz = (z as i32) - (center as i32);
631                    if dx * dx + dy * dy + dz * dz <= (radius * radius) as i32 {
632                        mask[x + y * n + z * n * n] = 1;
633                    }
634                }
635            }
636        }
637
638        let (local, eroded_mask) = lbv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 1e-6, 500);
639
640        // Local field should have some non-zero values (non-harmonic part remains)
641        let mut has_nonzero = false;
642        for i in 0..n * n * n {
643            assert!(local[i].is_finite());
644            if eroded_mask[i] != 0 && local[i].abs() > 1e-10 {
645                has_nonzero = true;
646            }
647        }
648        assert!(has_nonzero, "Non-harmonic field should leave non-zero local field");
649    }
650
651    #[test]
652    fn test_lbv_small_mask() {
653        // Test with a very small mask (just a few voxels) to exercise boundary logic
654        let n = 8;
655        let field: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.01).collect();
656
657        // Small central mask
658        let mut mask = vec![0u8; n * n * n];
659        for z in 2..6 {
660            for y in 2..6 {
661                for x in 2..6 {
662                    mask[x + y * n + z * n * n] = 1;
663                }
664            }
665        }
666
667        let (local, eroded) = lbv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 1e-5, 100);
668
669        for &val in &local {
670            assert!(val.is_finite());
671        }
672
673        // Eroded mask should be smaller than original
674        let eroded_count: usize = eroded.iter().map(|&x| x as usize).sum();
675        let mask_count: usize = mask.iter().map(|&x| x as usize).sum();
676        assert!(eroded_count <= mask_count);
677    }
678
679    #[test]
680    fn test_lbv_edge_mask_voxels() {
681        // Mask that includes edge voxels of the volume
682        let n = 8;
683        let field = vec![1.0; n * n * n];
684        let mask = vec![1u8; n * n * n]; // Full mask including edges
685
686        let (local, eroded) = lbv(&field, &mask, n, n, n, 1.0, 1.0, 1.0, 1e-5, 100);
687
688        for &val in &local {
689            assert!(val.is_finite());
690        }
691
692        // Edge voxels should be boundary, not interior, so they should not be in eroded mask
693        // Check corners
694        assert_eq!(eroded[0], 0, "Corner should not be in eroded mask");
695        assert_eq!(eroded[n - 1], 0, "Corner should not be in eroded mask");
696
697        let eroded_count: usize = eroded.iter().map(|&x| x as usize).sum();
698        assert!(eroded_count > 0, "Should have some interior voxels");
699    }
700}