Skip to main content

qsm_core/
fft.rs

1//! FFT wrapper for 3D transforms using rustfft
2//!
3//! Provides 3D FFT/IFFT operations compatible with NumPy's FFT conventions.
4//! Uses Fortran (column-major) order indexing to match NIfTI convention.
5
6use num_complex::{Complex32, Complex64};
7use rustfft::{Fft, FftPlanner, FftDirection};
8use std::f64::consts::PI;
9use std::sync::Arc;
10
11#[cfg(feature = "parallel")]
12use rayon::prelude::*;
13
14/// Wrapper to send a raw mutable pointer across threads.
15/// Stores as usize to avoid auto-trait issues with raw pointers.
16/// Safety: caller must guarantee non-overlapping access patterns.
17#[cfg(feature = "parallel")]
18#[derive(Clone, Copy)]
19struct SendPtr {
20    ptr: usize,
21    len: usize,
22}
23#[cfg(feature = "parallel")]
24unsafe impl Send for SendPtr {}
25#[cfg(feature = "parallel")]
26unsafe impl Sync for SendPtr {}
27
28#[cfg(feature = "parallel")]
29impl SendPtr {
30    fn new(data: &mut [Complex64]) -> Self {
31        Self { ptr: data.as_mut_ptr() as usize, len: data.len() }
32    }
33    unsafe fn as_slice(&self) -> &mut [Complex64] {
34        std::slice::from_raw_parts_mut(self.ptr as *mut Complex64, self.len)
35    }
36}
37
38/// FFT workspace that caches plans and scratch buffers for reuse
39pub struct Fft3dWorkspace {
40    nx: usize,
41    ny: usize,
42    nz: usize,
43    n_total: usize,
44    // Forward FFT plans
45    fft_x: Arc<dyn Fft<f64>>,
46    fft_y: Arc<dyn Fft<f64>>,
47    fft_z: Arc<dyn Fft<f64>>,
48    // Inverse FFT plans
49    ifft_x: Arc<dyn Fft<f64>>,
50    ifft_y: Arc<dyn Fft<f64>>,
51    ifft_z: Arc<dyn Fft<f64>>,
52    // Scratch buffers
53    scratch_x: Vec<Complex64>,
54    scratch_y: Vec<Complex64>,
55    scratch_z: Vec<Complex64>,
56    buffer_y: Vec<Complex64>,
57    buffer_z: Vec<Complex64>,
58}
59
60impl Fft3dWorkspace {
61    /// Create a new FFT workspace for the given dimensions
62    pub fn new(nx: usize, ny: usize, nz: usize) -> Self {
63        let mut planner = FftPlanner::new();
64
65        let fft_x = planner.plan_fft(nx, FftDirection::Forward);
66        let fft_y = planner.plan_fft(ny, FftDirection::Forward);
67        let fft_z = planner.plan_fft(nz, FftDirection::Forward);
68
69        let ifft_x = planner.plan_fft(nx, FftDirection::Inverse);
70        let ifft_y = planner.plan_fft(ny, FftDirection::Inverse);
71        let ifft_z = planner.plan_fft(nz, FftDirection::Inverse);
72
73        let scratch_x = vec![Complex64::new(0.0, 0.0); fft_x.get_inplace_scratch_len().max(ifft_x.get_inplace_scratch_len())];
74        let scratch_y = vec![Complex64::new(0.0, 0.0); fft_y.get_inplace_scratch_len().max(ifft_y.get_inplace_scratch_len())];
75        let scratch_z = vec![Complex64::new(0.0, 0.0); fft_z.get_inplace_scratch_len().max(ifft_z.get_inplace_scratch_len())];
76
77        Self {
78            nx, ny, nz,
79            n_total: nx * ny * nz,
80            fft_x, fft_y, fft_z,
81            ifft_x, ifft_y, ifft_z,
82            scratch_x, scratch_y, scratch_z,
83            buffer_y: vec![Complex64::new(0.0, 0.0); ny],
84            buffer_z: vec![Complex64::new(0.0, 0.0); nz],
85        }
86    }
87
88    /// In-place forward 3D FFT
89    pub fn fft3d(&mut self, data: &mut [Complex64]) {
90        let (nx, ny, nz) = (self.nx, self.ny, self.nz);
91
92        // Transform along x-axis (sequential — workspace reuses pre-allocated scratch)
93        for k in 0..nz {
94            for j in 0..ny {
95                let start = idx3d(0, j, k, nx, ny);
96                self.fft_x.process_with_scratch(&mut data[start..start + nx], &mut self.scratch_x);
97            }
98        }
99
100        // Transform along y-axis
101        for k in 0..nz {
102            for i in 0..nx {
103                for j in 0..ny {
104                    self.buffer_y[j] = data[idx3d(i, j, k, nx, ny)];
105                }
106                self.fft_y.process_with_scratch(&mut self.buffer_y, &mut self.scratch_y);
107                for j in 0..ny {
108                    data[idx3d(i, j, k, nx, ny)] = self.buffer_y[j];
109                }
110            }
111        }
112
113        // Transform along z-axis
114        for j in 0..ny {
115            for i in 0..nx {
116                for k in 0..nz {
117                    self.buffer_z[k] = data[idx3d(i, j, k, nx, ny)];
118                }
119                self.fft_z.process_with_scratch(&mut self.buffer_z, &mut self.scratch_z);
120                for k in 0..nz {
121                    data[idx3d(i, j, k, nx, ny)] = self.buffer_z[k];
122                }
123            }
124        }
125    }
126
127    /// In-place inverse 3D FFT (with normalization)
128    pub fn ifft3d(&mut self, data: &mut [Complex64]) {
129        let (nx, ny, nz) = (self.nx, self.ny, self.nz);
130        let n_total = self.n_total as f64;
131
132        // Transform along x-axis (sequential — workspace reuses pre-allocated scratch)
133        for k in 0..nz {
134            for j in 0..ny {
135                let start = idx3d(0, j, k, nx, ny);
136                self.ifft_x.process_with_scratch(&mut data[start..start + nx], &mut self.scratch_x);
137            }
138        }
139
140        // Transform along y-axis
141        for k in 0..nz {
142            for i in 0..nx {
143                for j in 0..ny { self.buffer_y[j] = data[idx3d(i, j, k, nx, ny)]; }
144                self.ifft_y.process_with_scratch(&mut self.buffer_y, &mut self.scratch_y);
145                for j in 0..ny { data[idx3d(i, j, k, nx, ny)] = self.buffer_y[j]; }
146            }
147        }
148
149        // Transform along z-axis
150        for j in 0..ny {
151            for i in 0..nx {
152                for k in 0..nz { self.buffer_z[k] = data[idx3d(i, j, k, nx, ny)]; }
153                self.ifft_z.process_with_scratch(&mut self.buffer_z, &mut self.scratch_z);
154                for k in 0..nz { data[idx3d(i, j, k, nx, ny)] = self.buffer_z[k]; }
155            }
156        }
157
158        // Normalize
159        for val in data.iter_mut() { *val /= n_total; }
160    }
161
162    /// Apply dipole convolution in-place: out = real(ifft(D * fft(x)))
163    /// Uses the provided complex buffer for the transform
164    #[inline]
165    pub fn apply_dipole_inplace(&mut self, x: &[f64], d_kernel: &[f64], out: &mut [f64], complex_buf: &mut [Complex64]) {
166        // Copy real to complex buffer
167        for (c, &r) in complex_buf.iter_mut().zip(x.iter()) {
168            *c = Complex64::new(r, 0.0);
169        }
170
171        self.fft3d(complex_buf);
172
173        // Multiply by kernel
174        for (c, &d) in complex_buf.iter_mut().zip(d_kernel.iter()) {
175            *c *= d;
176        }
177
178        self.ifft3d(complex_buf);
179
180        // Extract real part
181        for (o, c) in out.iter_mut().zip(complex_buf.iter()) {
182            *o = c.re;
183        }
184    }
185}
186
187/// Index into a 3D array stored in Fortran order (column-major)
188/// index = x + y*nx + z*nx*ny
189#[inline(always)]
190pub fn idx3d(i: usize, j: usize, k: usize, nx: usize, ny: usize) -> usize {
191    i + j * nx + k * nx * ny
192}
193
194// ============================================================================
195// F32 (Single Precision) FFT Workspace
196// ============================================================================
197
198/// FFT workspace using f32 for better WASM performance
199/// Single precision halves memory bandwidth and is faster on most hardware
200pub struct Fft3dWorkspaceF32 {
201    nx: usize,
202    ny: usize,
203    nz: usize,
204    n_total: usize,
205    // Forward FFT plans
206    fft_x: Arc<dyn Fft<f32>>,
207    fft_y: Arc<dyn Fft<f32>>,
208    fft_z: Arc<dyn Fft<f32>>,
209    // Inverse FFT plans
210    ifft_x: Arc<dyn Fft<f32>>,
211    ifft_y: Arc<dyn Fft<f32>>,
212    ifft_z: Arc<dyn Fft<f32>>,
213    // Scratch buffers
214    scratch_x: Vec<Complex32>,
215    scratch_y: Vec<Complex32>,
216    scratch_z: Vec<Complex32>,
217    buffer_y: Vec<Complex32>,
218    buffer_z: Vec<Complex32>,
219}
220
221impl Fft3dWorkspaceF32 {
222    /// Create a new f32 FFT workspace for the given dimensions
223    pub fn new(nx: usize, ny: usize, nz: usize) -> Self {
224        let mut planner = FftPlanner::<f32>::new();
225
226        let fft_x = planner.plan_fft(nx, FftDirection::Forward);
227        let fft_y = planner.plan_fft(ny, FftDirection::Forward);
228        let fft_z = planner.plan_fft(nz, FftDirection::Forward);
229
230        let ifft_x = planner.plan_fft(nx, FftDirection::Inverse);
231        let ifft_y = planner.plan_fft(ny, FftDirection::Inverse);
232        let ifft_z = planner.plan_fft(nz, FftDirection::Inverse);
233
234        let scratch_x = vec![Complex32::new(0.0, 0.0); fft_x.get_inplace_scratch_len().max(ifft_x.get_inplace_scratch_len())];
235        let scratch_y = vec![Complex32::new(0.0, 0.0); fft_y.get_inplace_scratch_len().max(ifft_y.get_inplace_scratch_len())];
236        let scratch_z = vec![Complex32::new(0.0, 0.0); fft_z.get_inplace_scratch_len().max(ifft_z.get_inplace_scratch_len())];
237
238        Self {
239            nx, ny, nz,
240            n_total: nx * ny * nz,
241            fft_x, fft_y, fft_z,
242            ifft_x, ifft_y, ifft_z,
243            scratch_x, scratch_y, scratch_z,
244            buffer_y: vec![Complex32::new(0.0, 0.0); ny],
245            buffer_z: vec![Complex32::new(0.0, 0.0); nz],
246        }
247    }
248
249    /// In-place forward 3D FFT
250    #[inline]
251    pub fn fft3d(&mut self, data: &mut [Complex32]) {
252        let (nx, ny, nz) = (self.nx, self.ny, self.nz);
253
254        // Transform along x-axis
255        for k in 0..nz {
256            for j in 0..ny {
257                let start = idx3d(0, j, k, nx, ny);
258                self.fft_x.process_with_scratch(&mut data[start..start + nx], &mut self.scratch_x);
259            }
260        }
261
262        // Transform along y-axis
263        for k in 0..nz {
264            for i in 0..nx {
265                for j in 0..ny {
266                    self.buffer_y[j] = data[idx3d(i, j, k, nx, ny)];
267                }
268                self.fft_y.process_with_scratch(&mut self.buffer_y, &mut self.scratch_y);
269                for j in 0..ny {
270                    data[idx3d(i, j, k, nx, ny)] = self.buffer_y[j];
271                }
272            }
273        }
274
275        // Transform along z-axis
276        for j in 0..ny {
277            for i in 0..nx {
278                for k in 0..nz {
279                    self.buffer_z[k] = data[idx3d(i, j, k, nx, ny)];
280                }
281                self.fft_z.process_with_scratch(&mut self.buffer_z, &mut self.scratch_z);
282                for k in 0..nz {
283                    data[idx3d(i, j, k, nx, ny)] = self.buffer_z[k];
284                }
285            }
286        }
287    }
288
289    /// In-place inverse 3D FFT (with normalization)
290    #[inline]
291    pub fn ifft3d(&mut self, data: &mut [Complex32]) {
292        let (nx, ny, nz) = (self.nx, self.ny, self.nz);
293        let n_total = self.n_total as f32;
294
295        // Transform along x-axis
296        for k in 0..nz {
297            for j in 0..ny {
298                let start = idx3d(0, j, k, nx, ny);
299                self.ifft_x.process_with_scratch(&mut data[start..start + nx], &mut self.scratch_x);
300            }
301        }
302
303        // Transform along y-axis
304        for k in 0..nz {
305            for i in 0..nx {
306                for j in 0..ny {
307                    self.buffer_y[j] = data[idx3d(i, j, k, nx, ny)];
308                }
309                self.ifft_y.process_with_scratch(&mut self.buffer_y, &mut self.scratch_y);
310                for j in 0..ny {
311                    data[idx3d(i, j, k, nx, ny)] = self.buffer_y[j];
312                }
313            }
314        }
315
316        // Transform along z-axis
317        for j in 0..ny {
318            for i in 0..nx {
319                for k in 0..nz {
320                    self.buffer_z[k] = data[idx3d(i, j, k, nx, ny)];
321                }
322                self.ifft_z.process_with_scratch(&mut self.buffer_z, &mut self.scratch_z);
323                for k in 0..nz {
324                    data[idx3d(i, j, k, nx, ny)] = self.buffer_z[k];
325                }
326            }
327        }
328
329        // Normalize
330        for val in data.iter_mut() {
331            *val /= n_total;
332        }
333    }
334
335    /// Apply dipole convolution in-place: out = real(ifft(D * fft(x)))
336    #[inline]
337    pub fn apply_dipole_inplace(&mut self, x: &[f32], d_kernel: &[f32], out: &mut [f32], complex_buf: &mut [Complex32]) {
338        // Copy real to complex buffer
339        for (c, &r) in complex_buf.iter_mut().zip(x.iter()) {
340            *c = Complex32::new(r, 0.0);
341        }
342
343        self.fft3d(complex_buf);
344
345        // Multiply by kernel
346        for (c, &d) in complex_buf.iter_mut().zip(d_kernel.iter()) {
347            *c *= d;
348        }
349
350        self.ifft3d(complex_buf);
351
352        // Extract real part
353        for (o, c) in out.iter_mut().zip(complex_buf.iter()) {
354            *o = c.re;
355        }
356    }
357}
358
359/// 3D FFT (in-place, complex-to-complex)
360///
361/// Transforms data in Fortran order with shape (nx, ny, nz).
362/// Matches numpy.fft.fftn behavior.
363pub fn fft3d(data: &mut [Complex64], nx: usize, ny: usize, nz: usize) {
364    let mut planner = FftPlanner::new();
365
366    // Transform along x-axis (contiguous rows of length nx)
367    let fft_x = planner.plan_fft(nx, FftDirection::Forward);
368    #[cfg(feature = "parallel")]
369    {
370        let scratch_len = fft_x.get_inplace_scratch_len();
371        data.par_chunks_mut(nx).for_each(|row| {
372            let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
373            fft_x.process_with_scratch(row, &mut scratch);
374        });
375    }
376    #[cfg(not(feature = "parallel"))]
377    {
378        let mut scratch_x = vec![Complex64::new(0.0, 0.0); fft_x.get_inplace_scratch_len()];
379        for chunk in data.chunks_mut(nx) {
380            fft_x.process_with_scratch(chunk, &mut scratch_x);
381        }
382    }
383
384    // Transform along y-axis (stride nx)
385    let fft_y = planner.plan_fft(ny, FftDirection::Forward);
386    #[cfg(feature = "parallel")]
387    {
388        let scratch_len = fft_y.get_inplace_scratch_len();
389        let nxy = nx * ny;
390        let pairs: Vec<(usize, usize)> = (0..nz).flat_map(|k| (0..nx).map(move |i| (k, i))).collect();
391        let data_send = SendPtr::new(data);
392        pairs.par_iter().for_each(|&(k, i)| {
393            let mut buffer = vec![Complex64::new(0.0, 0.0); ny];
394            let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
395            unsafe {
396                let slice = data_send.as_slice();
397                for j in 0..ny { buffer[j] = slice[i + j * nx + k * nxy]; }
398                fft_y.process_with_scratch(&mut buffer, &mut scratch);
399                for j in 0..ny { slice[i + j * nx + k * nxy] = buffer[j]; }
400            }
401        });
402    }
403    #[cfg(not(feature = "parallel"))]
404    {
405        let mut scratch_y = vec![Complex64::new(0.0, 0.0); fft_y.get_inplace_scratch_len()];
406        let mut buffer_y = vec![Complex64::new(0.0, 0.0); ny];
407        for k in 0..nz {
408            for i in 0..nx {
409                for j in 0..ny { buffer_y[j] = data[idx3d(i, j, k, nx, ny)]; }
410                fft_y.process_with_scratch(&mut buffer_y, &mut scratch_y);
411                for j in 0..ny { data[idx3d(i, j, k, nx, ny)] = buffer_y[j]; }
412            }
413        }
414    }
415
416    // Transform along z-axis (stride nx*ny)
417    let fft_z = planner.plan_fft(nz, FftDirection::Forward);
418    #[cfg(feature = "parallel")]
419    {
420        let scratch_len = fft_z.get_inplace_scratch_len();
421        let nxy = nx * ny;
422        let pairs: Vec<(usize, usize)> = (0..ny).flat_map(|j| (0..nx).map(move |i| (j, i))).collect();
423        let data_send = SendPtr::new(data);
424        pairs.par_iter().for_each(|&(j, i)| {
425            let mut buffer = vec![Complex64::new(0.0, 0.0); nz];
426            let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
427            unsafe {
428                let slice = data_send.as_slice();
429                for k in 0..nz { buffer[k] = slice[i + j * nx + k * nxy]; }
430                fft_z.process_with_scratch(&mut buffer, &mut scratch);
431                for k in 0..nz { slice[i + j * nx + k * nxy] = buffer[k]; }
432            }
433        });
434    }
435    #[cfg(not(feature = "parallel"))]
436    {
437        let mut scratch_z = vec![Complex64::new(0.0, 0.0); fft_z.get_inplace_scratch_len()];
438        let mut buffer_z = vec![Complex64::new(0.0, 0.0); nz];
439        for j in 0..ny {
440            for i in 0..nx {
441                for k in 0..nz { buffer_z[k] = data[idx3d(i, j, k, nx, ny)]; }
442                fft_z.process_with_scratch(&mut buffer_z, &mut scratch_z);
443                for k in 0..nz { data[idx3d(i, j, k, nx, ny)] = buffer_z[k]; }
444            }
445        }
446    }
447}
448
449/// 3D IFFT (in-place, complex-to-complex)
450///
451/// Transforms data in Fortran order with shape (nx, ny, nz).
452/// Matches numpy.fft.ifftn behavior (includes 1/N normalization).
453pub fn ifft3d(data: &mut [Complex64], nx: usize, ny: usize, nz: usize) {
454    let mut planner = FftPlanner::new();
455    let n_total = (nx * ny * nz) as f64;
456
457    // Transform along x-axis
458    let ifft_x = planner.plan_fft(nx, FftDirection::Inverse);
459    #[cfg(feature = "parallel")]
460    {
461        let scratch_len = ifft_x.get_inplace_scratch_len();
462        data.par_chunks_mut(nx).for_each(|row| {
463            let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
464            ifft_x.process_with_scratch(row, &mut scratch);
465        });
466    }
467    #[cfg(not(feature = "parallel"))]
468    {
469        let mut scratch_x = vec![Complex64::new(0.0, 0.0); ifft_x.get_inplace_scratch_len()];
470        for chunk in data.chunks_mut(nx) {
471            ifft_x.process_with_scratch(chunk, &mut scratch_x);
472        }
473    }
474
475    // Transform along y-axis
476    let ifft_y = planner.plan_fft(ny, FftDirection::Inverse);
477    #[cfg(feature = "parallel")]
478    {
479        let scratch_len = ifft_y.get_inplace_scratch_len();
480        let nxy = nx * ny;
481        let pairs: Vec<(usize, usize)> = (0..nz).flat_map(|k| (0..nx).map(move |i| (k, i))).collect();
482        let data_send = SendPtr::new(data);
483        pairs.par_iter().for_each(|&(k, i)| {
484            let mut buffer = vec![Complex64::new(0.0, 0.0); ny];
485            let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
486            unsafe {
487                let slice = data_send.as_slice();
488                for j in 0..ny { buffer[j] = slice[i + j * nx + k * nxy]; }
489                ifft_y.process_with_scratch(&mut buffer, &mut scratch);
490                for j in 0..ny { slice[i + j * nx + k * nxy] = buffer[j]; }
491            }
492        });
493    }
494    #[cfg(not(feature = "parallel"))]
495    {
496        let mut scratch_y = vec![Complex64::new(0.0, 0.0); ifft_y.get_inplace_scratch_len()];
497        let mut buffer_y = vec![Complex64::new(0.0, 0.0); ny];
498        for k in 0..nz {
499            for i in 0..nx {
500                for j in 0..ny { buffer_y[j] = data[idx3d(i, j, k, nx, ny)]; }
501                ifft_y.process_with_scratch(&mut buffer_y, &mut scratch_y);
502                for j in 0..ny { data[idx3d(i, j, k, nx, ny)] = buffer_y[j]; }
503            }
504        }
505    }
506
507    // Transform along z-axis
508    let ifft_z = planner.plan_fft(nz, FftDirection::Inverse);
509    #[cfg(feature = "parallel")]
510    {
511        let scratch_len = ifft_z.get_inplace_scratch_len();
512        let nxy = nx * ny;
513        let pairs: Vec<(usize, usize)> = (0..ny).flat_map(|j| (0..nx).map(move |i| (j, i))).collect();
514        let data_send = SendPtr::new(data);
515        pairs.par_iter().for_each(|&(j, i)| {
516            let mut buffer = vec![Complex64::new(0.0, 0.0); nz];
517            let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
518            unsafe {
519                let slice = data_send.as_slice();
520                for k in 0..nz { buffer[k] = slice[i + j * nx + k * nxy]; }
521                ifft_z.process_with_scratch(&mut buffer, &mut scratch);
522                for k in 0..nz { slice[i + j * nx + k * nxy] = buffer[k]; }
523            }
524        });
525    }
526    #[cfg(not(feature = "parallel"))]
527    {
528        let mut scratch_z = vec![Complex64::new(0.0, 0.0); ifft_z.get_inplace_scratch_len()];
529        let mut buffer_z = vec![Complex64::new(0.0, 0.0); nz];
530        for j in 0..ny {
531            for i in 0..nx {
532                for k in 0..nz { buffer_z[k] = data[idx3d(i, j, k, nx, ny)]; }
533                ifft_z.process_with_scratch(&mut buffer_z, &mut scratch_z);
534                for k in 0..nz { data[idx3d(i, j, k, nx, ny)] = buffer_z[k]; }
535            }
536        }
537    }
538
539    // Normalize
540    let n_total_f = n_total;
541    #[cfg(feature = "parallel")]
542    data.par_iter_mut().for_each(|val| { *val /= n_total_f; });
543    #[cfg(not(feature = "parallel"))]
544    for val in data.iter_mut() { *val /= n_total_f; }
545}
546
547/// 3D FFT of real data (real-to-complex)
548///
549/// Returns complex array. Output shape is (nx, ny, nz) for simplicity
550/// (not the half-spectrum like numpy's rfft).
551pub fn fft3d_real(data: &[f64], nx: usize, ny: usize, nz: usize) -> Vec<Complex64> {
552    let mut complex_data: Vec<Complex64> = data.iter()
553        .map(|&x| Complex64::new(x, 0.0))
554        .collect();
555    fft3d(&mut complex_data, nx, ny, nz);
556    complex_data
557}
558
559/// 3D IFFT returning real part (complex-to-real)
560///
561/// Takes complex array, returns real array (imaginary parts discarded).
562pub fn ifft3d_real(data: &[Complex64], nx: usize, ny: usize, nz: usize) -> Vec<f64> {
563    let mut complex_data = data.to_vec();
564    ifft3d(&mut complex_data, nx, ny, nz);
565    complex_data.iter().map(|c| c.re).collect()
566}
567
568/// Generate FFT frequency values for a given dimension
569/// Matches numpy.fft.fftfreq(n, d)
570pub fn fftfreq(n: usize, d: f64) -> Vec<f64> {
571    let mut freq = vec![0.0; n];
572    let val = 1.0 / (n as f64 * d);
573
574    if n % 2 == 0 {
575        // Even: [0, 1, ..., n/2-1, -n/2, ..., -1]
576        for i in 0..n / 2 {
577            freq[i] = (i as f64) * val;
578        }
579        for i in n / 2..n {
580            freq[i] = ((i as i64) - (n as i64)) as f64 * val;
581        }
582    } else {
583        // Odd: [0, 1, ..., (n-1)/2, -(n-1)/2, ..., -1]
584        for i in 0..=(n - 1) / 2 {
585            freq[i] = (i as f64) * val;
586        }
587        for i in (n + 1) / 2..n {
588            freq[i] = ((i as i64) - (n as i64)) as f64 * val;
589        }
590    }
591    freq
592}
593
594/// Generate FFT frequency values (f32 version for WASM performance)
595/// Matches numpy.fft.fftfreq(n, d)
596pub fn fftfreq_f32(n: usize, d: f32) -> Vec<f32> {
597    let mut freq = vec![0.0f32; n];
598    let val = 1.0f32 / (n as f32 * d);
599
600    if n % 2 == 0 {
601        // Even: [0, 1, ..., n/2-1, -n/2, ..., -1]
602        for i in 0..n / 2 {
603            freq[i] = (i as f32) * val;
604        }
605        for i in n / 2..n {
606            freq[i] = ((i as i64) - (n as i64)) as f32 * val;
607        }
608    } else {
609        // Odd: [0, 1, ..., (n-1)/2, -(n-1)/2, ..., -1]
610        for i in 0..=(n - 1) / 2 {
611            freq[i] = (i as f32) * val;
612        }
613        for i in (n + 1) / 2..n {
614            freq[i] = ((i as i64) - (n as i64)) as f32 * val;
615        }
616    }
617    freq
618}
619
620/// 3D FFT shift: swap quadrants so zero-frequency is at center
621///
622/// Returns a new array with the zero-frequency component shifted to the center.
623/// Matches numpy.fft.fftshift behavior for 3D data in Fortran order.
624pub fn fftshift(data: &[f64], nx: usize, ny: usize, nz: usize) -> Vec<f64> {
625    let n_total = nx * ny * nz;
626    let mut out = vec![0.0; n_total];
627
628    let hx = nx / 2;
629    let hy = ny / 2;
630    let hz = nz / 2;
631
632    for k in 0..nz {
633        for j in 0..ny {
634            for i in 0..nx {
635                let si = (i + hx) % nx;
636                let sj = (j + hy) % ny;
637                let sk = (k + hz) % nz;
638                out[idx3d(si, sj, sk, nx, ny)] = data[idx3d(i, j, k, nx, ny)];
639            }
640        }
641    }
642
643    out
644}
645
646/// 3D inverse FFT shift: undo fftshift
647///
648/// Returns a new array with the zero-frequency component shifted back to the corner.
649/// Matches numpy.fft.ifftshift behavior for 3D data in Fortran order.
650pub fn ifftshift(data: &[f64], nx: usize, ny: usize, nz: usize) -> Vec<f64> {
651    let n_total = nx * ny * nz;
652    let mut out = vec![0.0; n_total];
653
654    let hx = (nx + 1) / 2;
655    let hy = (ny + 1) / 2;
656    let hz = (nz + 1) / 2;
657
658    for k in 0..nz {
659        for j in 0..ny {
660            for i in 0..nx {
661                let si = (i + hx) % nx;
662                let sj = (j + hy) % ny;
663                let sk = (k + hz) % nz;
664                out[idx3d(si, sj, sk, nx, ny)] = data[idx3d(i, j, k, nx, ny)];
665            }
666        }
667    }
668
669    out
670}
671
672/// 3D FFT shift in-place: swap quadrants so zero-frequency is at center
673///
674/// Modifies the input array in place. Only works correctly for even-sized dimensions.
675pub fn fftshift_inplace(data: &mut [f64], nx: usize, ny: usize, nz: usize) {
676    let hx = nx / 2;
677    let hy = ny / 2;
678    let hz = nz / 2;
679
680    // For even dimensions, fftshift is its own inverse and can be done
681    // by swapping pairs of elements
682    for k in 0..nz {
683        for j in 0..ny {
684            for i in 0..nx {
685                let si = (i + hx) % nx;
686                let sj = (j + hy) % ny;
687                let sk = (k + hz) % nz;
688
689                let idx_src = idx3d(i, j, k, nx, ny);
690                let idx_dst = idx3d(si, sj, sk, nx, ny);
691
692                // Only swap once (when src < dst)
693                if idx_src < idx_dst {
694                    data.swap(idx_src, idx_dst);
695                }
696            }
697        }
698    }
699}
700
701/// Wrap angle to [-π, π]
702#[inline]
703pub fn wrap_angle(angle: f64) -> f64 {
704    let mut a = angle % (2.0 * PI);
705    if a > PI {
706        a -= 2.0 * PI;
707    } else if a < -PI {
708        a += 2.0 * PI;
709    }
710    a
711}
712
713#[cfg(test)]
714mod tests {
715    use super::*;
716
717    #[test]
718    fn test_fft_ifft_roundtrip() {
719        let nx = 4;
720        let ny = 4;
721        let nz = 4;
722
723        // Create test data
724        let original: Vec<f64> = (0..nx * ny * nz).map(|i| i as f64).collect();
725
726        // FFT then IFFT
727        let mut data: Vec<Complex64> = original.iter()
728            .map(|&x| Complex64::new(x, 0.0))
729            .collect();
730
731        fft3d(&mut data, nx, ny, nz);
732        ifft3d(&mut data, nx, ny, nz);
733
734        // Check roundtrip
735        for (i, (&orig, result)) in original.iter().zip(data.iter()).enumerate() {
736            assert!(
737                (result.re - orig).abs() < 1e-10,
738                "Mismatch at index {}: expected {}, got {}",
739                i, orig, result.re
740            );
741            assert!(
742                result.im.abs() < 1e-10,
743                "Imaginary part not zero at index {}: {}",
744                i, result.im
745            );
746        }
747    }
748
749    #[test]
750    fn test_fftfreq() {
751        // Test even n=4
752        let freq = fftfreq(4, 1.0);
753        assert!((freq[0] - 0.0).abs() < 1e-10);
754        assert!((freq[1] - 0.25).abs() < 1e-10);
755        assert!((freq[2] - (-0.5)).abs() < 1e-10);
756        assert!((freq[3] - (-0.25)).abs() < 1e-10);
757
758        // Test odd n=5
759        let freq = fftfreq(5, 1.0);
760        assert!((freq[0] - 0.0).abs() < 1e-10);
761        assert!((freq[1] - 0.2).abs() < 1e-10);
762        assert!((freq[2] - 0.4).abs() < 1e-10);
763        assert!((freq[3] - (-0.4)).abs() < 1e-10);
764        assert!((freq[4] - (-0.2)).abs() < 1e-10);
765    }
766
767    #[test]
768    fn test_fft_f32_roundtrip() {
769        let nx = 4;
770        let ny = 4;
771        let nz = 4;
772
773        let original: Vec<f32> = (0..nx * ny * nz).map(|i| i as f32).collect();
774
775        let mut data: Vec<Complex32> = original.iter()
776            .map(|&x| Complex32::new(x, 0.0))
777            .collect();
778
779        let mut ws = Fft3dWorkspaceF32::new(nx, ny, nz);
780        ws.fft3d(&mut data);
781        ws.ifft3d(&mut data);
782
783        for (i, (&orig, result)) in original.iter().zip(data.iter()).enumerate() {
784            assert!(
785                (result.re - orig).abs() < 1e-4,
786                "f32 roundtrip mismatch at index {}: expected {}, got {}",
787                i, orig, result.re
788            );
789            assert!(
790                result.im.abs() < 1e-4,
791                "f32 imaginary part not zero at index {}: {}",
792                i, result.im
793            );
794        }
795    }
796
797    #[test]
798    fn test_fftshift_even() {
799        // 4x4x4 array with sequential values
800        let nx = 4;
801        let ny = 4;
802        let nz = 4;
803        let n = nx * ny * nz;
804
805        let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
806        let shifted = fftshift(&data, nx, ny, nz);
807
808        // After fftshift, element at (0,0,0) should move to (2,2,2)
809        // Original index for (0,0,0) = 0
810        // Shifted position (2,2,2) -> index = 2 + 2*4 + 2*4*4 = 2 + 8 + 32 = 42
811        assert!(
812            (shifted[idx3d(2, 2, 2, nx, ny)] - data[idx3d(0, 0, 0, nx, ny)]).abs() < 1e-12,
813            "fftshift: element at (0,0,0) should move to (2,2,2)"
814        );
815
816        // Element at (1,1,1) should move to (3,3,3)
817        assert!(
818            (shifted[idx3d(3, 3, 3, nx, ny)] - data[idx3d(1, 1, 1, nx, ny)]).abs() < 1e-12,
819            "fftshift: element at (1,1,1) should move to (3,3,3)"
820        );
821
822        // Size should be preserved
823        assert_eq!(shifted.len(), n);
824    }
825
826    #[test]
827    fn test_ifftshift_roundtrip() {
828        let nx = 4;
829        let ny = 4;
830        let nz = 4;
831        let n = nx * ny * nz;
832
833        let data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1).collect();
834
835        // fftshift then ifftshift should be identity (for even dimensions)
836        let shifted = fftshift(&data, nx, ny, nz);
837        let unshifted = ifftshift(&shifted, nx, ny, nz);
838
839        for i in 0..n {
840            assert!(
841                (unshifted[i] - data[i]).abs() < 1e-12,
842                "ifftshift(fftshift(x)) != x at index {}: expected {}, got {}",
843                i, data[i], unshifted[i]
844            );
845        }
846    }
847
848    #[test]
849    fn test_fftshift_inplace() {
850        let nx = 4;
851        let ny = 4;
852        let nz = 4;
853        let n = nx * ny * nz;
854
855        let original: Vec<f64> = (0..n).map(|i| i as f64).collect();
856
857        // Compare in-place version with out-of-place version
858        let shifted_copy = fftshift(&original, nx, ny, nz);
859
860        let mut data = original.clone();
861        fftshift_inplace(&mut data, nx, ny, nz);
862
863        for i in 0..n {
864            assert!(
865                (data[i] - shifted_copy[i]).abs() < 1e-12,
866                "fftshift_inplace mismatch at index {}: expected {}, got {}",
867                i, shifted_copy[i], data[i]
868            );
869        }
870    }
871
872    /// Verify parallel workspace FFT matches sequential.
873    #[cfg(feature = "parallel")]
874    #[test]
875    fn test_fft3d_workspace_parallel_matches_sequential() {
876        let n = 16;
877        let input: Vec<Complex64> = (0..n*n*n)
878            .map(|i| Complex64::new((i as f64 * 0.3).sin(), (i as f64 * 0.7).cos()))
879            .collect();
880
881        // Sequential (1 thread)
882        let pool_1 = rayon::ThreadPoolBuilder::new().num_threads(1).build().unwrap();
883        let result_seq = pool_1.install(|| {
884            let mut ws = super::Fft3dWorkspace::new(n, n, n);
885            let mut data = input.clone();
886            ws.fft3d(&mut data);
887            data
888        });
889
890        // Parallel (default threads)
891        let result_par = {
892            let mut ws = super::Fft3dWorkspace::new(n, n, n);
893            let mut data = input.clone();
894            ws.fft3d(&mut data);
895            data
896        };
897
898        for (i, (s, p)) in result_seq.iter().zip(result_par.iter()).enumerate() {
899            assert!(
900                (s - p).norm() < 1e-10,
901                "FFT mismatch at {}: seq={} par={}", i, s, p
902            );
903        }
904    }
905
906    /// Verify parallel workspace IFFT matches sequential.
907    #[cfg(feature = "parallel")]
908    #[test]
909    fn test_ifft3d_workspace_parallel_matches_sequential() {
910        let n = 16;
911        let input: Vec<Complex64> = (0..n*n*n)
912            .map(|i| Complex64::new((i as f64 * 0.3).sin(), (i as f64 * 0.7).cos()))
913            .collect();
914
915        let pool_1 = rayon::ThreadPoolBuilder::new().num_threads(1).build().unwrap();
916        let result_seq = pool_1.install(|| {
917            let mut ws = super::Fft3dWorkspace::new(n, n, n);
918            let mut data = input.clone();
919            ws.ifft3d(&mut data);
920            data
921        });
922
923        let result_par = {
924            let mut ws = super::Fft3dWorkspace::new(n, n, n);
925            let mut data = input.clone();
926            ws.ifft3d(&mut data);
927            data
928        };
929
930        for (i, (s, p)) in result_seq.iter().zip(result_par.iter()).enumerate() {
931            assert!(
932                (s - p).norm() < 1e-10,
933                "IFFT mismatch at {}: seq={} par={}", i, s, p
934            );
935        }
936    }
937
938    /// Verify FFT roundtrip (FFT then IFFT = identity) with parallel.
939    #[cfg(feature = "parallel")]
940    #[test]
941    fn test_fft_ifft_roundtrip_parallel() {
942        let n = 16;
943        let original: Vec<Complex64> = (0..n*n*n)
944            .map(|i| Complex64::new((i as f64 * 0.3).sin(), 0.0))
945            .collect();
946
947        let mut ws = super::Fft3dWorkspace::new(n, n, n);
948        let mut data = original.clone();
949        ws.fft3d(&mut data);
950        ws.ifft3d(&mut data);
951
952        for (i, (orig, round)) in original.iter().zip(data.iter()).enumerate() {
953            assert!(
954                (orig - round).norm() < 1e-10,
955                "Roundtrip mismatch at {}: orig={} round={}", i, orig, round
956            );
957        }
958    }
959}