Skip to main content

qsm_core/
region_grow.rs

1use crate::priority_queue::BucketQueue;
2use std::f64::consts::PI;
3
4/// Neighbor offsets: (dimension_for_weight, di, dj, dk)
5/// dimension 0 = x edges, 1 = y edges, 2 = z edges
6const NEIGHBOR_OFFSETS: [(usize, i32, i32, i32); 6] = [
7    (0, 1, 0, 0),
8    (0, -1, 0, 0),
9    (1, 0, 1, 0),
10    (1, 0, -1, 0),
11    (2, 0, 0, 1),
12    (2, 0, 0, -1),
13];
14
15const TWO_PI: f64 = 2.0 * PI;
16
17/// Convert 3D index to flat index (Fortran order / column-major, matches NIfTI)
18/// For array shape (nx, nx, ny), index [i,j,k] maps to: i + j*nx + k*nx*ny
19#[inline(always)]
20fn idx3d(i: usize, j: usize, k: usize, nx: usize, ny: usize) -> usize {
21    i + j * nx + k * nx * ny
22}
23
24/// Convert 4D index (for weights array) to flat index (Fortran order)
25/// weights layout: [dim][i][j][k] where dim is 0,1,2 for x,y,z
26/// For array shape (3, nx, nx, ny), index [dim,i,j,k] maps to: i + j*nx + k*nx*ny + dim*nx*ny*nz
27#[inline(always)]
28fn idx4d(dim: usize, i: usize, j: usize, k: usize, nx: usize, ny: usize, nz: usize) -> usize {
29    i + j * nx + k * nx * ny + dim * nx * ny * nz
30}
31
32/// Queue item: (target_i, target_j, target_k, ref_i, ref_j, ref_k)
33/// Stores both the target voxel to unwrap AND the reference voxel to use
34type QueueItem = (usize, usize, usize, usize, usize, usize);
35
36/// Region growing phase unwrapping (matches Python implementation exactly)
37///
38/// # Arguments
39/// * `phase` - Mutable slice of phase values (nx * ny * nz), will be modified in-place
40/// * `weights` - Weight values (3 * nx * ny * nz), layout [dim][x][y][z]
41/// * `mask` - Boolean mask (nx * ny * nz), 1 = process, 0 = skip
42/// * `nx`, `ny`, `nz` - Array dimensions
43/// * `seed_i`, `seed_j`, `seed_k` - Seed point coordinates
44///
45/// # Returns
46/// Number of voxels processed
47pub fn grow_region_unwrap(
48    phase: &mut [f64],
49    weights: &[u8],
50    mask: &mut [u8],  // Used as visited array (modified in-place)
51    nx: usize,
52    ny: usize,
53    nz: usize,
54    seed_i: usize,
55    seed_j: usize,
56    seed_k: usize,
57) -> usize {
58    let mut pq: BucketQueue<QueueItem> = BucketQueue::new(256);
59    let mut processed = 0usize;
60
61    // Mark seed as visited
62    let seed_idx = idx3d(seed_i, seed_j, seed_k, nx, ny);
63
64    // If mask[seed] is 0, seed is not in ROI - find alternative
65    if mask[seed_idx] == 0 {
66        return 0;
67    }
68
69    // Use mask as visited array: 2 = visited, 1 = in ROI but not visited, 0 = not in ROI
70    mask[seed_idx] = 2;
71    processed += 1;
72
73    // Add initial edges from seed (matching Python's queue item format)
74    for &(dim, di, dj, dk) in &NEIGHBOR_OFFSETS {
75        let ni = seed_i as i32 + di;
76        let nj = seed_j as i32 + dj;
77        let nk = seed_k as i32 + dk;
78
79        if ni >= 0 && ni < nx as i32 && nj >= 0 && nj < ny as i32 && nk >= 0 && nk < nz as i32 {
80            let ni = ni as usize;
81            let nj = nj as usize;
82            let nk = nk as usize;
83            let n_idx = idx3d(ni, nj, nk, nx, ny);
84
85            // Check if neighbor is in mask and not visited
86            if mask[n_idx] == 1 {
87                // Get weight at edge (min coordinates)
88                let ei = seed_i.min(ni);
89                let ej = seed_j.min(nj);
90                let ek = seed_k.min(nk);
91                let weight = weights[idx4d(dim, ei, ej, ek, nx, ny, nz)] as usize;
92
93                if weight > 0 {
94                    // Store (target, reference) coordinates like Python
95                    pq.push(weight, (ni, nj, nk, seed_i, seed_j, seed_k));
96                }
97            }
98        }
99    }
100
101    // Main region growing loop
102    while let Some((ni, nj, nk, oi, oj, ok)) = pq.pop() {
103        let n_idx = idx3d(ni, nj, nk, nx, ny);
104
105        // Skip if already visited
106        if mask[n_idx] != 1 {
107            continue;
108        }
109
110        // Unwrap using stored reference (exact Python match)
111        let new_val = phase[n_idx];
112        let old_val = phase[idx3d(oi, oj, ok, nx, ny)];
113
114        // Unwrap: new_val - 2π * round((new_val - old_val) / 2π)
115        let diff = new_val - old_val;
116        let n_wraps = (diff / TWO_PI).round();
117        phase[n_idx] = new_val - TWO_PI * n_wraps;
118
119        // Mark as visited
120        mask[n_idx] = 2;
121        processed += 1;
122
123        // Add new edges to unvisited neighbors
124        for &(dim, di, dj, dk) in &NEIGHBOR_OFFSETS {
125            let nni = ni as i32 + di;
126            let nnj = nj as i32 + dj;
127            let nnk = nk as i32 + dk;
128
129            if nni >= 0 && nni < nx as i32 && nnj >= 0 && nnj < ny as i32 && nnk >= 0 && nnk < nz as i32 {
130                let nni = nni as usize;
131                let nnj = nnj as usize;
132                let nnk = nnk as usize;
133                let nn_idx = idx3d(nni, nnj, nnk, nx, ny);
134
135                // Only add if in mask and not visited
136                if mask[nn_idx] == 1 {
137                    let ei = ni.min(nni);
138                    let ej = nj.min(nnj);
139                    let ek = nk.min(nnk);
140                    let weight = weights[idx4d(dim, ei, ej, ek, nx, ny, nz)] as usize;
141
142                    if weight > 0 {
143                        // Store current voxel as reference for neighbor
144                        pq.push(weight, (nni, nnj, nnk, ni, nj, nk));
145                    }
146                }
147            }
148        }
149    }
150
151    processed
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn test_simple_unwrap() {
160        // 3x3x3 test case
161        let nx = 3;
162        let ny = 3;
163        let nz = 3;
164
165        // Create wrapped phase with a 2π jump
166        let mut phase = vec![0.0f64; nx * ny * nz];
167        phase[idx3d(0, 0, 0, nx, ny)] = 0.0;
168        phase[idx3d(1, 0, 0, nx, ny)] = 0.1;
169        phase[idx3d(2, 0, 0, nx, ny)] = 0.2 - TWO_PI; // Wrapped value
170
171        // All weights = 255 (high quality)
172        let weights = vec![255u8; 3 * nx * ny * nz];
173
174        // All voxels in mask (1 = in ROI, not visited)
175        let mut mask = vec![1u8; nx * ny * nz];
176
177        // Unwrap from center
178        let processed = grow_region_unwrap(&mut phase, &weights, &mut mask, nx, ny, nz, 1, 1, 1);
179
180        assert!(processed > 0);
181
182        // Check that the wrapped value was unwrapped
183        let unwrapped_val = phase[idx3d(2, 0, 0, nx, ny)];
184        assert!((unwrapped_val - 0.2).abs() < 0.5, "Expected ~0.2, got {}", unwrapped_val);
185    }
186}