Skip to main content

qsm_core/
swi.rs

1//! Susceptibility Weighted Imaging (SWI)
2//!
3//! SWI enhances susceptibility contrast by combining magnitude and phase
4//! information. Phase is high-pass filtered, converted to a [0, 1] mask,
5//! and multiplied with magnitude.
6//!
7//! Reference:
8//! Eckstein, K., et al. (2021). "Computationally efficient combination of
9//! multi-channel phase data from multi-echo acquisitions (ASPIRE)."
10//! Magnetic Resonance in Medicine, 79:2996-3006.
11//! https://doi.org/10.1002/mrm.26963
12//!
13//! Reference implementation: https://github.com/korbinian90/CLEARSWI.jl
14
15use crate::utils::gaussian_smooth_3d;
16
17/// SWI algorithm parameters
18#[derive(Clone, Debug)]
19pub struct SwiParams {
20    /// High-pass filter sigma in voxels [x, y, z]
21    pub hp_sigma: [f64; 3],
22    /// Phase scaling type
23    pub scaling: PhaseScaling,
24    /// Phase scaling strength
25    pub strength: f64,
26    /// MIP window size in slices
27    pub mip_window: usize,
28}
29
30impl Default for SwiParams {
31    fn default() -> Self {
32        Self {
33            hp_sigma: [4.0, 4.0, 0.0],
34            scaling: PhaseScaling::Tanh,
35            strength: 4.0,
36            mip_window: 7,
37        }
38    }
39}
40
41/// Phase mask scaling type
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum PhaseScaling {
44    /// Sigmoid weighting: `(1 + tanh(1 - x/m)) / 2`
45    /// where `m = median(positive_phase) * 10 / strength`
46    Tanh,
47    /// Negate phase first, then apply Tanh
48    NegativeTanh,
49    /// Traditional SWI: positive phase suppressed, negative → 1
50    Positive,
51    /// Traditional SWI: negative phase suppressed, positive → 1
52    Negative,
53    /// Both positive and negative phase suppressed
54    Triangular,
55}
56
57/// High-pass filter by subtracting Gaussian-smoothed version
58///
59/// # Arguments
60/// * `data` - Input data (e.g. unwrapped phase)
61/// * `mask` - Binary mask (1 = inside, 0 = outside)
62/// * `nx`, `ny`, `nz` - Array dimensions
63/// * `sigma` - Gaussian sigma for each dimension in voxels (e.g. [4, 4, 0])
64///
65/// # Returns
66/// High-pass filtered data
67pub fn highpass_filter(
68    data: &[f64],
69    mask: &[u8],
70    nx: usize, ny: usize, nz: usize,
71    sigma: [f64; 3],
72) -> Vec<f64> {
73    let nbox = 4; // masked smoothing uses nbox=4 in MriResearchTools
74    let smoothed = gaussian_smooth_3d(data, sigma, Some(mask), None, nbox, nx, ny, nz);
75    let n_total = nx * ny * nz;
76    let mut result = vec![0.0; n_total];
77    for i in 0..n_total {
78        if mask[i] == 1 {
79            result[i] = data[i] - smoothed[i];
80        }
81    }
82    result
83}
84
85/// Create phase mask from filtered phase values
86///
87/// Converts phase to a [0, 1] weighting mask using the specified scaling.
88///
89/// # Arguments
90/// * `phase` - High-pass filtered phase
91/// * `mask` - Binary mask
92/// * `scaling` - Phase scaling type
93/// * `strength` - Scaling strength (higher = stronger phase contrast)
94///
95/// # Returns
96/// Phase mask with values in [0, 1]
97pub fn create_phase_mask(
98    phase: &[f64],
99    mask: &[u8],
100    scaling: PhaseScaling,
101    strength: f64,
102) -> Vec<f64> {
103    let n = phase.len();
104    let mut result = vec![0.0; n];
105
106    // Copy phase into result, zeroing outside mask
107    for i in 0..n {
108        if mask[i] == 1 {
109            result[i] = phase[i];
110        }
111    }
112
113    // Handle NegativeTanh by negating first
114    let effective_scaling = if scaling == PhaseScaling::NegativeTanh {
115        for v in result.iter_mut() {
116            *v = -*v;
117        }
118        PhaseScaling::Tanh
119    } else {
120        scaling
121    };
122
123    match effective_scaling {
124        PhaseScaling::Tanh => {
125            // m = median(positive phase in mask) * 10 / strength
126            let mut positives: Vec<f64> = (0..n)
127                .filter(|&i| mask[i] == 1 && result[i] > 0.0)
128                .map(|i| result[i])
129                .collect();
130
131            let m = if positives.is_empty() {
132                1.0
133            } else {
134                positives.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
135                let mid = positives.len() / 2;
136                let median = if positives.len().is_multiple_of(2) {
137                    (positives[mid - 1] + positives[mid]) / 2.0
138                } else {
139                    positives[mid]
140                };
141                median * 10.0 / strength
142            };
143
144            for v in result.iter_mut() {
145                *v = (1.0 + (1.0 - *v / m).tanh()) / 2.0;
146            }
147        }
148        PhaseScaling::Positive => {
149            // Positive phase: rescale to [1,0] then ^strength; negative → 1
150            let (min_pos, max_pos) = positive_range(&result, mask);
151            for i in 0..n {
152                if result[i] > 0.0 && mask[i] == 1 {
153                    result[i] = rescale(result[i], min_pos, max_pos, 1.0, 0.0).powf(strength);
154                } else {
155                    result[i] = 1.0;
156                }
157            }
158        }
159        PhaseScaling::Negative => {
160            // Negative phase: rescale to [0,1] then ^strength; positive → 1
161            let (min_neg, max_neg) = negative_range(&result, mask);
162            for i in 0..n {
163                if result[i] <= 0.0 && mask[i] == 1 {
164                    result[i] = rescale(result[i], min_neg, max_neg, 0.0, 1.0).powf(strength);
165                } else {
166                    result[i] = 1.0;
167                }
168            }
169        }
170        PhaseScaling::Triangular => {
171            // Both directions suppressed
172            let (min_pos, max_pos) = positive_range(&result, mask);
173            let (min_neg, max_neg) = negative_range(&result, mask);
174            for i in 0..n {
175                if mask[i] == 0 {
176                    result[i] = 0.0;
177                } else if result[i] > 0.0 {
178                    result[i] = rescale(result[i], min_pos, max_pos, 1.0, 0.0).powf(strength);
179                } else {
180                    result[i] = rescale(result[i], min_neg, max_neg, 0.0, 1.0).powf(strength);
181                }
182            }
183        }
184        PhaseScaling::NegativeTanh => unreachable!(),
185    }
186
187    // Clamp to [0, 1]
188    for v in &mut result {
189        if *v < 0.0 {
190            *v = 0.0;
191        }
192    }
193
194    // Zero outside mask
195    for i in 0..n {
196        if mask[i] == 0 {
197            result[i] = 0.0;
198        }
199    }
200
201    result
202}
203
204/// Calculate SWI from unwrapped phase and magnitude
205///
206/// Pipeline: high-pass filter phase → create phase mask → multiply with magnitude.
207///
208/// # Arguments
209/// * `phase` - Unwrapped phase (single echo or combined)
210/// * `magnitude` - Magnitude image (single echo or combined)
211/// * `mask` - Binary brain mask
212/// * `nx`, `ny`, `nz` - Array dimensions
213/// * `vsx`, `vsy`, `vsz` - Voxel sizes in mm (unused, reserved for consistency)
214/// * `hp_sigma` - High-pass filter sigma in voxels [x, y, z]
215/// * `scaling` - Phase scaling type
216/// * `strength` - Phase scaling strength
217///
218/// # Returns
219/// SWI image (magnitude × phase mask)
220#[allow(clippy::too_many_arguments)]
221pub fn calculate_swi(
222    phase: &[f64],
223    magnitude: &[f64],
224    mask: &[u8],
225    nx: usize, ny: usize, nz: usize,
226    _vsx: f64, _vsy: f64, _vsz: f64,
227    hp_sigma: [f64; 3],
228    scaling: PhaseScaling,
229    strength: f64,
230) -> Vec<f64> {
231    let n_total = nx * ny * nz;
232
233    // High-pass filter phase
234    let filtered = highpass_filter(phase, mask, nx, ny, nz, hp_sigma);
235
236    // Create phase mask
237    let phase_mask = create_phase_mask(&filtered, mask, scaling, strength);
238
239    // SWI = magnitude × phase_mask
240    let mut swi = vec![0.0; n_total];
241    for i in 0..n_total {
242        swi[i] = magnitude[i] * phase_mask[i];
243    }
244
245    swi
246}
247
248/// Calculate SWI with default parameters
249///
250/// Defaults: sigma=[4,4,0], Tanh scaling, strength=4
251#[allow(clippy::too_many_arguments)]
252pub fn calculate_swi_default(
253    phase: &[f64],
254    magnitude: &[f64],
255    mask: &[u8],
256    nx: usize, ny: usize, nz: usize,
257    vsx: f64, vsy: f64, vsz: f64,
258) -> Vec<f64> {
259    calculate_swi(
260        phase, magnitude, mask,
261        nx, ny, nz,
262        vsx, vsy, vsz,
263        [4.0, 4.0, 0.0],
264        PhaseScaling::Tanh,
265        4.0,
266    )
267}
268
269/// Minimum intensity projection along the z-axis
270///
271/// For each (x, y) position, takes the minimum value over a sliding window
272/// of `window` slices along z.
273///
274/// # Arguments
275/// * `data` - 3D volume (Fortran order)
276/// * `nx`, `ny`, `nz` - Array dimensions
277/// * `window` - Number of slices in the projection window
278///
279/// # Returns
280/// MIP volume with dimensions `nx × ny × (nz - window + 1)`.
281/// Returns empty vec if `window > nz`.
282pub fn create_mip(
283    data: &[f64],
284    nx: usize, ny: usize, nz: usize,
285    window: usize,
286) -> Vec<f64> {
287    if window > nz || window == 0 {
288        return vec![];
289    }
290
291    let nz_out = nz - window + 1;
292    let nxy = nx * ny;
293    let mut mip = vec![0.0; nxy * nz_out];
294
295    for k_out in 0..nz_out {
296        for j in 0..ny {
297            for i in 0..nx {
298                let idx_xy = i + j * nx;
299                let mut min_val = data[idx_xy + k_out * nxy];
300                for kw in 1..window {
301                    let val = data[idx_xy + (k_out + kw) * nxy];
302                    if val < min_val {
303                        min_val = val;
304                    }
305                }
306                mip[idx_xy + k_out * nxy] = min_val;
307            }
308        }
309    }
310
311    mip
312}
313
314/// Minimum intensity projection with default window of 7 slices
315pub fn create_mip_default(
316    data: &[f64],
317    nx: usize, ny: usize, nz: usize,
318) -> Vec<f64> {
319    create_mip(data, nx, ny, nz, 7)
320}
321
322/// Softplus magnitude scaling for enhanced contrast
323///
324/// Applies a shifted softplus function: `softplus(x) - softplus(0)` where
325/// `softplus(x) = (log(1 + exp(-|f*(x-offset)|)) + max(0, f*(x-offset))) / f`
326/// and `f = factor / offset`.
327///
328/// # Arguments
329/// * `magnitude` - Input magnitude data
330/// * `offset` - Softplus offset (controls transition point)
331/// * `factor` - Steepness factor (default 2.0)
332///
333/// # Returns
334/// Scaled magnitude
335pub fn softplus_scaling(
336    magnitude: &[f64],
337    offset: f64,
338    factor: f64,
339) -> Vec<f64> {
340    if offset.abs() < 1e-20 {
341        return magnitude.to_vec();
342    }
343
344    let f = factor / offset;
345
346    // softplus(0) for baseline subtraction
347    let arg0 = f * (0.0 - offset);
348    let sp0 = ((1.0 + (-arg0.abs()).exp()).ln() + arg0.max(0.0)) / f;
349
350    magnitude.iter().map(|&val| {
351        let arg = f * (val - offset);
352        let sp = ((1.0 + (-arg.abs()).exp()).ln() + arg.max(0.0)) / f;
353        sp - sp0
354    }).collect()
355}
356
357// ---- Helpers ----
358
359/// Get min/max of positive values within mask
360fn positive_range(data: &[f64], mask: &[u8]) -> (f64, f64) {
361    let mut min_val = f64::MAX;
362    let mut max_val = f64::MIN;
363    for i in 0..data.len() {
364        if mask[i] == 1 && data[i] > 0.0 {
365            if data[i] < min_val { min_val = data[i]; }
366            if data[i] > max_val { max_val = data[i]; }
367        }
368    }
369    if min_val > max_val {
370        (0.0, 1.0) // fallback
371    } else {
372        (min_val, max_val)
373    }
374}
375
376/// Get min/max of non-positive values within mask
377fn negative_range(data: &[f64], mask: &[u8]) -> (f64, f64) {
378    let mut min_val = f64::MAX;
379    let mut max_val = f64::MIN;
380    for i in 0..data.len() {
381        if mask[i] == 1 && data[i] <= 0.0 {
382            if data[i] < min_val { min_val = data[i]; }
383            if data[i] > max_val { max_val = data[i]; }
384        }
385    }
386    if min_val > max_val {
387        (-1.0, 0.0) // fallback
388    } else {
389        (min_val, max_val)
390    }
391}
392
393/// Linear rescale from [old_min, old_max] to [new_min, new_max]
394#[inline]
395fn rescale(val: f64, old_min: f64, old_max: f64, new_min: f64, new_max: f64) -> f64 {
396    let range = old_max - old_min;
397    if range.abs() < 1e-20 {
398        return (new_min + new_max) / 2.0;
399    }
400    let t = (val - old_min) / range;
401    // Clamp t to [0, 1] for robustness
402    let t = t.clamp(0.0, 1.0);
403    new_min + t * (new_max - new_min)
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_calculate_swi_zero_phase() {
412        let n = 8;
413        let nn = n * n * n;
414        let phase = vec![0.0; nn];
415        let magnitude = vec![1.0; nn];
416        let mask = vec![1u8; nn];
417
418        let swi = calculate_swi_default(&phase, &magnitude, &mask, n, n, n, 1.0, 1.0, 1.0);
419
420        // With zero phase, tanh mask gives (1 + tanh(1)) / 2 ≈ 0.88
421        for &v in &swi {
422            assert!(v.is_finite(), "SWI values should be finite");
423            assert!(v >= 0.0, "SWI values should be non-negative");
424        }
425    }
426
427    #[test]
428    fn test_calculate_swi_mask() {
429        let n = 8;
430        let nn = n * n * n;
431        let phase = vec![0.1; nn];
432        let magnitude = vec![1.0; nn];
433        let mut mask = vec![1u8; nn];
434        mask[0] = 0;
435        mask[1] = 0;
436
437        let swi = calculate_swi_default(&phase, &magnitude, &mask, n, n, n, 1.0, 1.0, 1.0);
438
439        assert_eq!(swi[0], 0.0, "Outside mask should be 0");
440        assert_eq!(swi[1], 0.0, "Outside mask should be 0");
441    }
442
443    #[test]
444    fn test_phase_mask_range() {
445        let n = 10;
446        let nn = n * n * n;
447        let phase: Vec<f64> = (0..nn).map(|i| (i as f64 * 0.01) - 5.0).collect();
448        let mask = vec![1u8; nn];
449
450        for scaling in &[
451            PhaseScaling::Tanh,
452            PhaseScaling::NegativeTanh,
453            PhaseScaling::Positive,
454            PhaseScaling::Negative,
455            PhaseScaling::Triangular,
456        ] {
457            let pm = create_phase_mask(&phase, &mask, *scaling, 4.0);
458            for (i, &v) in pm.iter().enumerate() {
459                assert!(v >= 0.0, "{:?}: value at {} = {} < 0", scaling, i, v);
460                assert!(v <= 1.0 + 1e-10, "{:?}: value at {} = {} > 1", scaling, i, v);
461            }
462        }
463    }
464
465    #[test]
466    fn test_highpass_filter_constant() {
467        // Constant input should give zero output (constant is its own smooth)
468        let n = 16;
469        let nn = n * n * n;
470        let data = vec![5.0; nn];
471        let mask = vec![1u8; nn];
472
473        let result = highpass_filter(&data, &mask, n, n, n, [2.0, 2.0, 0.0]);
474
475        for &v in &result {
476            assert!(v.abs() < 1.0, "High-pass of constant should be near zero, got {}", v);
477        }
478    }
479
480    #[test]
481    fn test_mip_basic() {
482        // 3x3x5 volume, mIP with window=3 → 3x3x3 output
483        let (nx, ny, nz) = (3, 3, 5);
484        let mut data = vec![10.0; nx * ny * nz];
485        // Place a low value at slice 2
486        let idx = 1 + 1 * nx + 2 * nx * ny; // (1,1,2)
487        data[idx] = 1.0;
488
489        let mip = create_mip(&data, nx, ny, nz, 3);
490        assert_eq!(mip.len(), nx * ny * 3);
491
492        // The minimum at (1,1) should appear in slices that include z=2
493        // Window starting at z=0: slices 0,1,2 → includes the 1.0
494        let mip_idx_0 = 1 + 1 * nx + 0 * nx * ny;
495        assert_eq!(mip[mip_idx_0], 1.0);
496        // Window starting at z=1: slices 1,2,3 → includes the 1.0
497        let mip_idx_1 = 1 + 1 * nx + 1 * nx * ny;
498        assert_eq!(mip[mip_idx_1], 1.0);
499        // Window starting at z=2: slices 2,3,4 → includes the 1.0
500        let mip_idx_2 = 1 + 1 * nx + 2 * nx * ny;
501        assert_eq!(mip[mip_idx_2], 1.0);
502    }
503
504    #[test]
505    fn test_mip_window_too_large() {
506        let mip = create_mip(&[1.0; 27], 3, 3, 3, 10);
507        assert!(mip.is_empty());
508    }
509
510    #[test]
511    fn test_softplus_scaling() {
512        let mag = vec![0.0, 0.5, 1.0, 2.0];
513        let result = softplus_scaling(&mag, 1.0, 2.0);
514
515        // softplus(0, offset=1, factor=2) should be 0 (baseline subtracted)
516        assert!(result[0].abs() < 1e-10, "softplus(0) should be ~0, got {}", result[0]);
517        // Values should increase monotonically
518        for i in 1..result.len() {
519            assert!(result[i] >= result[i - 1], "softplus should be monotonically increasing");
520        }
521    }
522
523    #[test]
524    fn test_rescale() {
525        assert!((rescale(0.0, 0.0, 10.0, 0.0, 1.0) - 0.0).abs() < 1e-10);
526        assert!((rescale(5.0, 0.0, 10.0, 0.0, 1.0) - 0.5).abs() < 1e-10);
527        assert!((rescale(10.0, 0.0, 10.0, 0.0, 1.0) - 1.0).abs() < 1e-10);
528        // Inverted rescale
529        assert!((rescale(0.0, 0.0, 10.0, 1.0, 0.0) - 1.0).abs() < 1e-10);
530        assert!((rescale(10.0, 0.0, 10.0, 1.0, 0.0) - 0.0).abs() < 1e-10);
531    }
532}