Skip to main content

qsm_core/unwrap/
romeo.rs

1//! ROMEO weight calculation for phase unwrapping
2//!
3//! Calculates edge weights for region-growing phase unwrapping based on:
4//! - Phase coherence
5//! - Phase gradient coherence (multi-echo)
6//! - Magnitude coherence
7//! - Magnitude weights
8//!
9//! Reference:
10//! Dymerska, B., Eckstein, K., Bachrata, B., Siow, B., Trattnig, S., Shmueli, K.,
11//! Robinson, S.D. (2021). "Phase unwrapping with a rapid opensource minimum spanning
12//! tree algorithm (ROMEO)." Magnetic Resonance in Medicine, 85(4):2294-2308.
13//! https://doi.org/10.1002/mrm.28563
14//!
15//! Reference implementation: https://github.com/korbinian90/MriResearchTools.jl
16
17/// Parameters for ROMEO weight calculation.
18#[derive(Clone, Debug)]
19pub struct RomeoParams {
20    /// Use phase gradient coherence weights (multi-echo)
21    pub phase_gradient_coherence: bool,
22    /// Use magnitude coherence weights
23    pub mag_coherence: bool,
24    /// Use magnitude weighting
25    pub mag_weight: bool,
26}
27
28impl Default for RomeoParams {
29    fn default() -> Self {
30        Self {
31            phase_gradient_coherence: true,
32            mag_coherence: true,
33            mag_weight: true,
34        }
35    }
36}
37
38use std::f64::consts::PI;
39
40const TWO_PI: f64 = 2.0 * PI;
41
42/// Wrap angle to [-π, π]
43#[inline]
44fn wrap_angle(angle: f64) -> f64 {
45    let mut a = angle % TWO_PI;
46    if a > PI {
47        a -= TWO_PI;
48    } else if a < -PI {
49        a += TWO_PI;
50    }
51    a
52}
53
54/// Index into a 3D array in Fortran order (column-major, matches NIfTI)
55#[inline(always)]
56fn idx3d(i: usize, j: usize, k: usize, nx: usize, ny: usize) -> usize {
57    i + j * nx + k * nx * ny
58}
59
60/// Calculate ROMEO edge weights for phase unwrapping
61///
62/// Returns weights array with shape (3, nx, ny, nz) for 3 directions (x, y, z).
63/// Weights are normalized to 0-255 for use with the bucket priority queue.
64///
65/// # Arguments
66/// * `phase` - Wrapped phase data (nx * ny * nz), first echo
67/// * `mag` - Magnitude data (nx * ny * nz), optional (pass empty slice if none)
68/// * `phase2` - Second echo phase for gradient coherence (optional)
69/// * `te1`, `te2` - Echo times for gradient coherence scaling
70/// * `mask` - Binary mask (nx * ny * nz), 1 = process
71/// * `nx`, `ny`, `nz` - Array dimensions
72///
73/// # Returns
74/// Weights array of size 3 * nx * ny * nz in C order [dim][i][j][k]
75pub fn calculate_weights_romeo(
76    phase: &[f64],
77    mag: &[f64],
78    phase2: Option<&[f64]>,
79    te1: f64,
80    te2: f64,
81    mask: &[u8],
82    nx: usize, ny: usize, nz: usize,
83) -> Vec<u8> {
84    // Default: all weight components enabled
85    calculate_weights_romeo_configurable(
86        phase, mag, phase2, te1, te2, mask, nx, ny, nz,
87        true, true, true  // use_phase_gradient_coherence, use_mag_coherence, use_mag_weight
88    )
89}
90
91/// Calculate ROMEO edge weights with configurable weight components
92///
93/// # Arguments
94/// * `phase` - Wrapped phase data (nx * ny * nz), first echo
95/// * `mag` - Magnitude data (nx * ny * nz), optional (pass empty slice if none)
96/// * `phase2` - Second echo phase for gradient coherence (optional)
97/// * `te1`, `te2` - Echo times for gradient coherence scaling
98/// * `mask` - Binary mask (nx * ny * nz), 1 = process
99/// * `nx`, `ny`, `nz` - Array dimensions
100/// * `use_phase_gradient_coherence` - Include phase gradient coherence (multi-echo temporal)
101/// * `use_mag_coherence` - Include magnitude coherence (min/max similarity)
102/// * `use_mag_weight` - Include magnitude weight (penalize low signal)
103///
104/// # Returns
105/// Weights array of size 3 * nx * ny * nz in C order [dim][i][j][k]
106pub fn calculate_weights_romeo_configurable(
107    phase: &[f64],
108    mag: &[f64],
109    phase2: Option<&[f64]>,
110    te1: f64,
111    te2: f64,
112    mask: &[u8],
113    nx: usize, ny: usize, nz: usize,
114    use_phase_gradient_coherence: bool,
115    use_mag_coherence: bool,
116    use_mag_weight: bool,
117) -> Vec<u8> {
118    let n_total = nx * ny * nz;
119    let mut weights = vec![0u8; 3 * n_total];
120
121    let has_mag = !mag.is_empty();
122    let has_phase2 = phase2.is_some();
123    let te_ratio = if te2.abs() > 1e-10 { te1 / te2 } else { 1.0 };
124
125    // Get max magnitude for normalization
126    let max_mag = if has_mag && use_mag_weight {
127        mag.iter().cloned().fold(0.0_f64, f64::max)
128    } else {
129        1.0
130    };
131    let half_max_mag = 0.5 * max_mag + 1e-12;
132
133    // Process each direction
134    for dim in 0..3_usize {
135        for i in 0..nx {
136            for j in 0..ny {
137                for k in 0..nz {
138                    // Get neighbor based on dimension
139                    let (ni, nj, nk) = match dim {
140                        0 => (i + 1, j, k),  // x direction
141                        1 => (i, j + 1, k),  // y direction
142                        _ => (i, j, k + 1),  // z direction
143                    };
144
145                    // Skip if neighbor is out of bounds
146                    if ni >= nx || nj >= ny || nk >= nz {
147                        continue;
148                    }
149
150                    let idx = idx3d(i, j, k, nx, ny);
151                    let idx_n = idx3d(ni, nj, nk, nx, ny);
152
153                    // Skip if either voxel is outside mask
154                    if mask[idx] == 0 || mask[idx_n] == 0 {
155                        continue;
156                    }
157
158                    // Phase difference
159                    let p1_diff = phase[idx_n] - phase[idx];
160
161                    // 1. Phase coherence: 1 - |wrap(diff)| / π (always on)
162                    let pc = 1.0 - wrap_angle(p1_diff).abs() / PI;
163
164                    // 2. Phase gradient coherence (optional, multi-echo only)
165                    let pgc = if use_phase_gradient_coherence && has_phase2 {
166                        let phase2_data = phase2.unwrap();
167                        let p2_diff = phase2_data[idx_n] - phase2_data[idx];
168                        let wrapped_p1 = wrap_angle(p1_diff);
169                        let wrapped_p2 = wrap_angle(p2_diff);
170                        (1.0 - (wrapped_p1 - wrapped_p2 * te_ratio).abs()).max(0.0)
171                    } else {
172                        1.0
173                    };
174
175                    // 3. Magnitude coherence: (min/max)² (optional)
176                    let mc = if use_mag_coherence && has_mag {
177                        let m1 = mag[idx];
178                        let m2 = mag[idx_n];
179                        let mag_min = m1.min(m2);
180                        let mag_max = m1.max(m2);
181                        if mag_max > 1e-12 {
182                            (mag_min / mag_max).powi(2)
183                        } else {
184                            0.0
185                        }
186                    } else {
187                        1.0
188                    };
189
190                    // 4. Magnitude weights: 0.5 + 0.5 * min(1, mag / (0.5 * max_mag)) (optional)
191                    let (mw1, mw2) = if use_mag_weight && has_mag {
192                        let mw1 = 0.5 + 0.5 * (mag[idx] / half_max_mag).min(1.0);
193                        let mw2 = 0.5 + 0.5 * (mag[idx_n] / half_max_mag).min(1.0);
194                        (mw1, mw2)
195                    } else {
196                        (1.0, 1.0)
197                    };
198
199                    // Combined weight
200                    let weight = pc * pgc * mc * mw1 * mw2;
201
202                    // Convert to u8 and store
203                    let weight_u8 = (weight.clamp(0.0, 1.0) * 255.0) as u8;
204
205                    // Store at edge location (min coordinate)
206                    let edge_idx = dim * n_total + idx3d(i, j, k, nx, ny);
207                    weights[edge_idx] = weight_u8;
208                }
209            }
210        }
211    }
212
213    weights
214}
215
216/// Simplified weight calculation for single-echo data (no phase2)
217pub fn calculate_weights_single_echo(
218    phase: &[f64],
219    mag: &[f64],
220    mask: &[u8],
221    nx: usize, ny: usize, nz: usize,
222) -> Vec<u8> {
223    calculate_weights_romeo(phase, mag, None, 1.0, 1.0, mask, nx, ny, nz)
224}
225
226/// Calculate per-voxel quality map from ROMEO edge weights
227///
228/// Computes ROMEO edge weights and then aggregates them per-voxel by averaging
229/// the incident edge weights across all 6 neighboring directions (±x, ±y, ±z).
230/// This produces a quality map where high values indicate voxels with coherent
231/// phase and magnitude, suitable for thresholding into a brain mask.
232///
233/// Reference: MriResearchTools.jl `romeovoxelquality()` function
234///
235/// # Arguments
236/// * `phase` - Wrapped phase data (nx * ny * nz), first echo
237/// * `mag` - Magnitude data (nx * ny * nz), optional (pass empty slice if none)
238/// * `phase2` - Second echo phase for gradient coherence (optional)
239/// * `te1`, `te2` - Echo times for gradient coherence scaling
240/// * `mask` - Binary mask (nx * ny * nz), 1 = process
241/// * `nx`, `ny`, `nz` - Array dimensions
242///
243/// # Returns
244/// Quality map of size nx * ny * nz with values in range [0, 100]
245pub fn voxel_quality_romeo(
246    phase: &[f64],
247    mag: &[f64],
248    phase2: Option<&[f64]>,
249    te1: f64,
250    te2: f64,
251    mask: &[u8],
252    nx: usize, ny: usize, nz: usize,
253) -> Vec<f64> {
254    let n_total = nx * ny * nz;
255    let weights = calculate_weights_romeo(phase, mag, phase2, te1, te2, mask, nx, ny, nz);
256
257    let mut quality = vec![0.0_f64; n_total];
258
259    for i in 0..nx {
260        for j in 0..ny {
261            for k in 0..nz {
262                let idx = idx3d(i, j, k, nx, ny);
263                if mask[idx] == 0 {
264                    continue;
265                }
266
267                let mut sum = 0.0_f64;
268                let mut count = 0u32;
269
270                // +x edge: stored at (i, j, k) in dim=0
271                if i + 1 < nx && mask[idx3d(i + 1, j, k, nx, ny)] != 0 {
272                    sum += weights[0 * n_total + idx] as f64;
273                    count += 1;
274                }
275                // -x edge: stored at (i-1, j, k) in dim=0
276                if i > 0 && mask[idx3d(i - 1, j, k, nx, ny)] != 0 {
277                    sum += weights[0 * n_total + idx3d(i - 1, j, k, nx, ny)] as f64;
278                    count += 1;
279                }
280                // +y edge: stored at (i, j, k) in dim=1
281                if j + 1 < ny && mask[idx3d(i, j + 1, k, nx, ny)] != 0 {
282                    sum += weights[1 * n_total + idx] as f64;
283                    count += 1;
284                }
285                // -y edge: stored at (i, j-1, k) in dim=1
286                if j > 0 && mask[idx3d(i, j - 1, k, nx, ny)] != 0 {
287                    sum += weights[1 * n_total + idx3d(i, j - 1, k, nx, ny)] as f64;
288                    count += 1;
289                }
290                // +z edge: stored at (i, j, k) in dim=2
291                if k + 1 < nz && mask[idx3d(i, j, k + 1, nx, ny)] != 0 {
292                    sum += weights[2 * n_total + idx] as f64;
293                    count += 1;
294                }
295                // -z edge: stored at (i, j, k-1) in dim=2
296                if k > 0 && mask[idx3d(i, j, k - 1, nx, ny)] != 0 {
297                    sum += weights[2 * n_total + idx3d(i, j, k - 1, nx, ny)] as f64;
298                    count += 1;
299                }
300
301                if count > 0 {
302                    // Normalize from 0-255 to 0-1, then scale to 0-100
303                    quality[idx] = (sum / count as f64) / 255.0 * 100.0;
304                }
305            }
306        }
307    }
308
309    quality
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_wrap_angle() {
318        assert!((wrap_angle(0.0) - 0.0).abs() < 1e-10);
319        assert!((wrap_angle(PI) - PI).abs() < 1e-10);
320        assert!((wrap_angle(-PI) - (-PI)).abs() < 1e-10);
321        assert!((wrap_angle(TWO_PI) - 0.0).abs() < 1e-10);
322        assert!((wrap_angle(3.0 * PI) - PI).abs() < 1e-10);
323        assert!((wrap_angle(-3.0 * PI) - (-PI)).abs() < 1e-10);
324    }
325
326    #[test]
327    fn test_weights_constant_phase() {
328        // Constant phase should give high weights (phase coherence = 1)
329        let n = 4;
330        let phase = vec![0.0; n * n * n];
331        let mag = vec![1.0; n * n * n];
332        let mask = vec![1u8; n * n * n];
333
334        let weights = calculate_weights_single_echo(&phase, &mag, &mask, n, n, n);
335
336        // All interior weights should be 255 (constant phase, uniform magnitude)
337        let mut high_weight_count = 0;
338        for &w in weights.iter() {
339            if w == 255 {
340                high_weight_count += 1;
341            }
342        }
343        assert!(high_weight_count > 0, "Should have some high weights for constant phase");
344    }
345
346    #[test]
347    fn test_weights_wrapped_jump() {
348        // Phase with 2π jump should give low weights at jump location
349        let n = 4;
350        let mut phase = vec![0.0; n * n * n];
351
352        // Create a 2π jump in x direction at i=2
353        for i in 2..n {
354            for j in 0..n {
355                for k in 0..n {
356                    phase[idx3d(i, j, k, n, n)] = TWO_PI;
357                }
358            }
359        }
360
361        let mask = vec![1u8; n * n * n];
362        let _weights = calculate_weights_single_echo(&phase, &[], &mask, n, n, n);
363
364        // Weight at x=1 to x=2 edge should be low (wrapped difference = 0, but that's ok)
365        // Actually, for a 2π jump, the wrapped difference is 0, so coherence is 1
366        // This is correct - ROMEO uses wrapped differences, not raw differences
367    }
368
369    #[test]
370    fn test_weights_mask() {
371        // Weights should be 0 where mask is 0
372        let n = 4;
373        let phase = vec![0.5; n * n * n];
374        let mut mask = vec![1u8; n * n * n];
375
376        // Set some voxels outside mask
377        mask[0] = 0;
378        mask[1] = 0;
379
380        let weights = calculate_weights_single_echo(&phase, &[], &mask, n, n, n);
381
382        // Edges connected to masked-out voxels should be 0
383        // Weight at edge (0,0,0)-(1,0,0) should be 0 since idx 0 is masked out
384        assert_eq!(weights[0], 0);  // x-direction edge at (0,0,0)
385    }
386
387    #[test]
388    fn test_voxel_quality_constant_phase() {
389        // Constant phase + uniform magnitude → all quality values should be 100
390        let n = 4;
391        let phase = vec![0.0; n * n * n];
392        let mag = vec![1.0; n * n * n];
393        let mask = vec![1u8; n * n * n];
394
395        let quality = voxel_quality_romeo(&phase, &mag, None, 1.0, 1.0, &mask, n, n, n);
396
397        assert_eq!(quality.len(), n * n * n);
398
399        // Interior voxels (not on boundary) should have quality = 100
400        let interior_q = quality[idx3d(1, 1, 1, n, n)];
401        assert!((interior_q - 100.0).abs() < 1e-6,
402                "Interior voxel quality should be 100, got {}", interior_q);
403    }
404
405    #[test]
406    fn test_voxel_quality_masked() {
407        // Masked-out voxels should have quality = 0
408        let n = 4;
409        let phase = vec![0.0; n * n * n];
410        let mut mask = vec![1u8; n * n * n];
411        mask[idx3d(1, 1, 1, n, n)] = 0;
412
413        let quality = voxel_quality_romeo(&phase, &[], None, 1.0, 1.0, &mask, n, n, n);
414
415        assert_eq!(quality[idx3d(1, 1, 1, n, n)], 0.0,
416                   "Masked-out voxel should have quality 0");
417    }
418
419    #[test]
420    fn test_voxel_quality_range() {
421        // Quality values should be in range [0, 100]
422        let n = 6;
423        let phase: Vec<f64> = (0..n * n * n).map(|i| (i as f64) * 0.7).collect();
424        let mag: Vec<f64> = (0..n * n * n).map(|i| (i as f64) / (n * n * n) as f64).collect();
425        let mask = vec![1u8; n * n * n];
426
427        let quality = voxel_quality_romeo(&phase, &mag, None, 1.0, 1.0, &mask, n, n, n);
428
429        for &q in quality.iter() {
430            assert!(q >= 0.0 && q <= 100.0,
431                    "Quality should be in [0, 100], got {}", q);
432        }
433    }
434}