Skip to main content

qsm_core/utils/
padding.rs

1//! Padding utilities for FFT
2//!
3//! Functions to pad arrays to sizes efficient for FFT.
4
5/// Find next size that is efficient for FFT
6///
7/// FFT is most efficient when the size factors into small primes (2, 3, 5).
8/// This finds the smallest n >= size that satisfies this.
9pub fn next_fast_fft_size(size: usize) -> usize {
10    let mut n = size;
11    loop {
12        let mut m = n;
13        // Factor out 2, 3, 5
14        while m % 2 == 0 { m /= 2; }
15        while m % 3 == 0 { m /= 3; }
16        while m % 5 == 0 { m /= 5; }
17        if m == 1 {
18            return n;
19        }
20        n += 1;
21    }
22}
23
24/// Pad a 3D array to fast FFT size
25///
26/// # Arguments
27/// * `data` - Input array (nx * ny * nz)
28/// * `nx`, `ny`, `nz` - Original dimensions
29/// * `min_pad` - Minimum padding on each side (negative means no padding)
30///
31/// # Returns
32/// (padded_data, new_nx, new_ny, new_nz)
33pub fn pad_to_fast_fft(
34    data: &[f64],
35    nx: usize, ny: usize, nz: usize,
36    min_pad: (i32, i32, i32),
37) -> (Vec<f64>, usize, usize, usize) {
38    // Calculate new sizes
39    let new_nx = if min_pad.0 >= 0 {
40        next_fast_fft_size(nx + 2 * min_pad.0 as usize)
41    } else {
42        nx
43    };
44    let new_ny = if min_pad.1 >= 0 {
45        next_fast_fft_size(ny + 2 * min_pad.1 as usize)
46    } else {
47        ny
48    };
49    let new_nz = if min_pad.2 >= 0 {
50        next_fast_fft_size(nz + 2 * min_pad.2 as usize)
51    } else {
52        nz
53    };
54
55    // Create padded array (zero-filled)
56    let new_total = new_nx * new_ny * new_nz;
57    let mut padded = vec![0.0; new_total];
58
59    // Copy original data (Fortran order: index = i + j*nx + k*nx*ny)
60    for k in 0..nz {
61        for j in 0..ny {
62            for i in 0..nx {
63                let old_idx = i + j * nx + k * nx * ny;
64                let new_idx = i + j * new_nx + k * new_nx * new_ny;
65                padded[new_idx] = data[old_idx];
66            }
67        }
68    }
69
70    (padded, new_nx, new_ny, new_nz)
71}
72
73/// Extract original-sized region from padded array
74pub fn unpad(
75    padded: &[f64],
76    padded_nx: usize, padded_ny: usize, _padded_nz: usize,
77    orig_nx: usize, orig_ny: usize, orig_nz: usize,
78) -> Vec<f64> {
79    let orig_total = orig_nx * orig_ny * orig_nz;
80    let mut data = vec![0.0; orig_total];
81
82    // Fortran order: index = i + j*nx + k*nx*ny
83    for k in 0..orig_nz {
84        for j in 0..orig_ny {
85            for i in 0..orig_nx {
86                let padded_idx = i + j * padded_nx + k * padded_nx * padded_ny;
87                let orig_idx = i + j * orig_nx + k * orig_nx * orig_ny;
88                data[orig_idx] = padded[padded_idx];
89            }
90        }
91    }
92
93    data
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_fast_fft_sizes() {
102        // These should already be fast sizes
103        assert_eq!(next_fast_fft_size(2), 2);
104        assert_eq!(next_fast_fft_size(4), 4);
105        assert_eq!(next_fast_fft_size(8), 8);
106        assert_eq!(next_fast_fft_size(16), 16);
107        assert_eq!(next_fast_fft_size(64), 64);
108
109        // 7 is not a fast size, should round up to 8
110        assert_eq!(next_fast_fft_size(7), 8);
111
112        // 17 should round up to 18 (2 * 9 = 2 * 3^2)
113        assert_eq!(next_fast_fft_size(17), 18);
114    }
115
116    #[test]
117    fn test_pad_unpad_roundtrip() {
118        let nx = 5;
119        let ny = 6;
120        let nz = 7;
121        let data: Vec<f64> = (0..nx*ny*nz).map(|i| i as f64).collect();
122
123        let (padded, pnx, pny, pnz) = pad_to_fast_fft(&data, nx, ny, nz, (2, 2, 2));
124
125        // Padded size should be >= original + 2*padding
126        assert!(pnx >= nx + 4);
127        assert!(pny >= ny + 4);
128        assert!(pnz >= nz + 4);
129
130        let recovered = unpad(&padded, pnx, pny, pnz, nx, ny, nz);
131
132        // Should match original
133        for (orig, rec) in data.iter().zip(recovered.iter()) {
134            assert_eq!(*orig, *rec);
135        }
136    }
137}