1use crate::utils::gaussian_smooth_3d;
16
17#[derive(Clone, Debug)]
19pub struct SwiParams {
20 pub hp_sigma: [f64; 3],
22 pub scaling: PhaseScaling,
24 pub strength: f64,
26 pub mip_window: usize,
28}
29
30impl Default for SwiParams {
31 fn default() -> Self {
32 Self {
33 hp_sigma: [4.0, 4.0, 0.0],
34 scaling: PhaseScaling::Tanh,
35 strength: 4.0,
36 mip_window: 7,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum PhaseScaling {
44 Tanh,
47 NegativeTanh,
49 Positive,
51 Negative,
53 Triangular,
55}
56
57pub fn highpass_filter(
68 data: &[f64],
69 mask: &[u8],
70 nx: usize, ny: usize, nz: usize,
71 sigma: [f64; 3],
72) -> Vec<f64> {
73 let nbox = 4; let smoothed = gaussian_smooth_3d(data, sigma, Some(mask), None, nbox, nx, ny, nz);
75 let n_total = nx * ny * nz;
76 let mut result = vec![0.0; n_total];
77 for i in 0..n_total {
78 if mask[i] == 1 {
79 result[i] = data[i] - smoothed[i];
80 }
81 }
82 result
83}
84
85pub fn create_phase_mask(
98 phase: &[f64],
99 mask: &[u8],
100 scaling: PhaseScaling,
101 strength: f64,
102) -> Vec<f64> {
103 let n = phase.len();
104 let mut result = vec![0.0; n];
105
106 for i in 0..n {
108 if mask[i] == 1 {
109 result[i] = phase[i];
110 }
111 }
112
113 let effective_scaling = if scaling == PhaseScaling::NegativeTanh {
115 for v in result.iter_mut() {
116 *v = -*v;
117 }
118 PhaseScaling::Tanh
119 } else {
120 scaling
121 };
122
123 match effective_scaling {
124 PhaseScaling::Tanh => {
125 let mut positives: Vec<f64> = (0..n)
127 .filter(|&i| mask[i] == 1 && result[i] > 0.0)
128 .map(|i| result[i])
129 .collect();
130
131 let m = if positives.is_empty() {
132 1.0
133 } else {
134 positives.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
135 let mid = positives.len() / 2;
136 let median = if positives.len().is_multiple_of(2) {
137 (positives[mid - 1] + positives[mid]) / 2.0
138 } else {
139 positives[mid]
140 };
141 median * 10.0 / strength
142 };
143
144 for v in result.iter_mut() {
145 *v = (1.0 + (1.0 - *v / m).tanh()) / 2.0;
146 }
147 }
148 PhaseScaling::Positive => {
149 let (min_pos, max_pos) = positive_range(&result, mask);
151 for i in 0..n {
152 if result[i] > 0.0 && mask[i] == 1 {
153 result[i] = rescale(result[i], min_pos, max_pos, 1.0, 0.0).powf(strength);
154 } else {
155 result[i] = 1.0;
156 }
157 }
158 }
159 PhaseScaling::Negative => {
160 let (min_neg, max_neg) = negative_range(&result, mask);
162 for i in 0..n {
163 if result[i] <= 0.0 && mask[i] == 1 {
164 result[i] = rescale(result[i], min_neg, max_neg, 0.0, 1.0).powf(strength);
165 } else {
166 result[i] = 1.0;
167 }
168 }
169 }
170 PhaseScaling::Triangular => {
171 let (min_pos, max_pos) = positive_range(&result, mask);
173 let (min_neg, max_neg) = negative_range(&result, mask);
174 for i in 0..n {
175 if mask[i] == 0 {
176 result[i] = 0.0;
177 } else if result[i] > 0.0 {
178 result[i] = rescale(result[i], min_pos, max_pos, 1.0, 0.0).powf(strength);
179 } else {
180 result[i] = rescale(result[i], min_neg, max_neg, 0.0, 1.0).powf(strength);
181 }
182 }
183 }
184 PhaseScaling::NegativeTanh => unreachable!(),
185 }
186
187 for v in &mut result {
189 if *v < 0.0 {
190 *v = 0.0;
191 }
192 }
193
194 for i in 0..n {
196 if mask[i] == 0 {
197 result[i] = 0.0;
198 }
199 }
200
201 result
202}
203
204#[allow(clippy::too_many_arguments)]
221pub fn calculate_swi(
222 phase: &[f64],
223 magnitude: &[f64],
224 mask: &[u8],
225 nx: usize, ny: usize, nz: usize,
226 _vsx: f64, _vsy: f64, _vsz: f64,
227 hp_sigma: [f64; 3],
228 scaling: PhaseScaling,
229 strength: f64,
230) -> Vec<f64> {
231 let n_total = nx * ny * nz;
232
233 let filtered = highpass_filter(phase, mask, nx, ny, nz, hp_sigma);
235
236 let phase_mask = create_phase_mask(&filtered, mask, scaling, strength);
238
239 let mut swi = vec![0.0; n_total];
241 for i in 0..n_total {
242 swi[i] = magnitude[i] * phase_mask[i];
243 }
244
245 swi
246}
247
248#[allow(clippy::too_many_arguments)]
252pub fn calculate_swi_default(
253 phase: &[f64],
254 magnitude: &[f64],
255 mask: &[u8],
256 nx: usize, ny: usize, nz: usize,
257 vsx: f64, vsy: f64, vsz: f64,
258) -> Vec<f64> {
259 calculate_swi(
260 phase, magnitude, mask,
261 nx, ny, nz,
262 vsx, vsy, vsz,
263 [4.0, 4.0, 0.0],
264 PhaseScaling::Tanh,
265 4.0,
266 )
267}
268
269pub fn create_mip(
283 data: &[f64],
284 nx: usize, ny: usize, nz: usize,
285 window: usize,
286) -> Vec<f64> {
287 if window > nz || window == 0 {
288 return vec![];
289 }
290
291 let nz_out = nz - window + 1;
292 let nxy = nx * ny;
293 let mut mip = vec![0.0; nxy * nz_out];
294
295 for k_out in 0..nz_out {
296 for j in 0..ny {
297 for i in 0..nx {
298 let idx_xy = i + j * nx;
299 let mut min_val = data[idx_xy + k_out * nxy];
300 for kw in 1..window {
301 let val = data[idx_xy + (k_out + kw) * nxy];
302 if val < min_val {
303 min_val = val;
304 }
305 }
306 mip[idx_xy + k_out * nxy] = min_val;
307 }
308 }
309 }
310
311 mip
312}
313
314pub fn create_mip_default(
316 data: &[f64],
317 nx: usize, ny: usize, nz: usize,
318) -> Vec<f64> {
319 create_mip(data, nx, ny, nz, 7)
320}
321
322pub fn softplus_scaling(
336 magnitude: &[f64],
337 offset: f64,
338 factor: f64,
339) -> Vec<f64> {
340 if offset.abs() < 1e-20 {
341 return magnitude.to_vec();
342 }
343
344 let f = factor / offset;
345
346 let arg0 = f * (0.0 - offset);
348 let sp0 = ((1.0 + (-arg0.abs()).exp()).ln() + arg0.max(0.0)) / f;
349
350 magnitude.iter().map(|&val| {
351 let arg = f * (val - offset);
352 let sp = ((1.0 + (-arg.abs()).exp()).ln() + arg.max(0.0)) / f;
353 sp - sp0
354 }).collect()
355}
356
357fn positive_range(data: &[f64], mask: &[u8]) -> (f64, f64) {
361 let mut min_val = f64::MAX;
362 let mut max_val = f64::MIN;
363 for i in 0..data.len() {
364 if mask[i] == 1 && data[i] > 0.0 {
365 if data[i] < min_val { min_val = data[i]; }
366 if data[i] > max_val { max_val = data[i]; }
367 }
368 }
369 if min_val > max_val {
370 (0.0, 1.0) } else {
372 (min_val, max_val)
373 }
374}
375
376fn negative_range(data: &[f64], mask: &[u8]) -> (f64, f64) {
378 let mut min_val = f64::MAX;
379 let mut max_val = f64::MIN;
380 for i in 0..data.len() {
381 if mask[i] == 1 && data[i] <= 0.0 {
382 if data[i] < min_val { min_val = data[i]; }
383 if data[i] > max_val { max_val = data[i]; }
384 }
385 }
386 if min_val > max_val {
387 (-1.0, 0.0) } else {
389 (min_val, max_val)
390 }
391}
392
393#[inline]
395fn rescale(val: f64, old_min: f64, old_max: f64, new_min: f64, new_max: f64) -> f64 {
396 let range = old_max - old_min;
397 if range.abs() < 1e-20 {
398 return (new_min + new_max) / 2.0;
399 }
400 let t = (val - old_min) / range;
401 let t = t.clamp(0.0, 1.0);
403 new_min + t * (new_max - new_min)
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_calculate_swi_zero_phase() {
412 let n = 8;
413 let nn = n * n * n;
414 let phase = vec![0.0; nn];
415 let magnitude = vec![1.0; nn];
416 let mask = vec![1u8; nn];
417
418 let swi = calculate_swi_default(&phase, &magnitude, &mask, n, n, n, 1.0, 1.0, 1.0);
419
420 for &v in &swi {
422 assert!(v.is_finite(), "SWI values should be finite");
423 assert!(v >= 0.0, "SWI values should be non-negative");
424 }
425 }
426
427 #[test]
428 fn test_calculate_swi_mask() {
429 let n = 8;
430 let nn = n * n * n;
431 let phase = vec![0.1; nn];
432 let magnitude = vec![1.0; nn];
433 let mut mask = vec![1u8; nn];
434 mask[0] = 0;
435 mask[1] = 0;
436
437 let swi = calculate_swi_default(&phase, &magnitude, &mask, n, n, n, 1.0, 1.0, 1.0);
438
439 assert_eq!(swi[0], 0.0, "Outside mask should be 0");
440 assert_eq!(swi[1], 0.0, "Outside mask should be 0");
441 }
442
443 #[test]
444 fn test_phase_mask_range() {
445 let n = 10;
446 let nn = n * n * n;
447 let phase: Vec<f64> = (0..nn).map(|i| (i as f64 * 0.01) - 5.0).collect();
448 let mask = vec![1u8; nn];
449
450 for scaling in &[
451 PhaseScaling::Tanh,
452 PhaseScaling::NegativeTanh,
453 PhaseScaling::Positive,
454 PhaseScaling::Negative,
455 PhaseScaling::Triangular,
456 ] {
457 let pm = create_phase_mask(&phase, &mask, *scaling, 4.0);
458 for (i, &v) in pm.iter().enumerate() {
459 assert!(v >= 0.0, "{:?}: value at {} = {} < 0", scaling, i, v);
460 assert!(v <= 1.0 + 1e-10, "{:?}: value at {} = {} > 1", scaling, i, v);
461 }
462 }
463 }
464
465 #[test]
466 fn test_highpass_filter_constant() {
467 let n = 16;
469 let nn = n * n * n;
470 let data = vec![5.0; nn];
471 let mask = vec![1u8; nn];
472
473 let result = highpass_filter(&data, &mask, n, n, n, [2.0, 2.0, 0.0]);
474
475 for &v in &result {
476 assert!(v.abs() < 1.0, "High-pass of constant should be near zero, got {}", v);
477 }
478 }
479
480 #[test]
481 fn test_mip_basic() {
482 let (nx, ny, nz) = (3, 3, 5);
484 let mut data = vec![10.0; nx * ny * nz];
485 let idx = 1 + 1 * nx + 2 * nx * ny; data[idx] = 1.0;
488
489 let mip = create_mip(&data, nx, ny, nz, 3);
490 assert_eq!(mip.len(), nx * ny * 3);
491
492 let mip_idx_0 = 1 + 1 * nx + 0 * nx * ny;
495 assert_eq!(mip[mip_idx_0], 1.0);
496 let mip_idx_1 = 1 + 1 * nx + 1 * nx * ny;
498 assert_eq!(mip[mip_idx_1], 1.0);
499 let mip_idx_2 = 1 + 1 * nx + 2 * nx * ny;
501 assert_eq!(mip[mip_idx_2], 1.0);
502 }
503
504 #[test]
505 fn test_mip_window_too_large() {
506 let mip = create_mip(&[1.0; 27], 3, 3, 3, 10);
507 assert!(mip.is_empty());
508 }
509
510 #[test]
511 fn test_softplus_scaling() {
512 let mag = vec![0.0, 0.5, 1.0, 2.0];
513 let result = softplus_scaling(&mag, 1.0, 2.0);
514
515 assert!(result[0].abs() < 1e-10, "softplus(0) should be ~0, got {}", result[0]);
517 for i in 1..result.len() {
519 assert!(result[i] >= result[i - 1], "softplus should be monotonically increasing");
520 }
521 }
522
523 #[test]
524 fn test_rescale() {
525 assert!((rescale(0.0, 0.0, 10.0, 0.0, 1.0) - 0.0).abs() < 1e-10);
526 assert!((rescale(5.0, 0.0, 10.0, 0.0, 1.0) - 0.5).abs() < 1e-10);
527 assert!((rescale(10.0, 0.0, 10.0, 0.0, 1.0) - 1.0).abs() < 1e-10);
528 assert!((rescale(0.0, 0.0, 10.0, 1.0, 0.0) - 1.0).abs() < 1e-10);
530 assert!((rescale(10.0, 0.0, 10.0, 1.0, 0.0) - 0.0).abs() < 1e-10);
531 }
532}