1use num_complex::{Complex32, Complex64};
7use rustfft::{Fft, FftPlanner, FftDirection};
8use std::f64::consts::PI;
9use std::sync::Arc;
10
11#[cfg(feature = "parallel")]
12use rayon::prelude::*;
13
14#[cfg(feature = "parallel")]
18#[derive(Clone, Copy)]
19struct SendPtr {
20 ptr: usize,
21 len: usize,
22}
23#[cfg(feature = "parallel")]
24unsafe impl Send for SendPtr {}
25#[cfg(feature = "parallel")]
26unsafe impl Sync for SendPtr {}
27
28#[cfg(feature = "parallel")]
29impl SendPtr {
30 fn new(data: &mut [Complex64]) -> Self {
31 Self { ptr: data.as_mut_ptr() as usize, len: data.len() }
32 }
33 unsafe fn as_slice(&self) -> &mut [Complex64] {
34 std::slice::from_raw_parts_mut(self.ptr as *mut Complex64, self.len)
35 }
36}
37
38pub struct Fft3dWorkspace {
40 nx: usize,
41 ny: usize,
42 nz: usize,
43 n_total: usize,
44 fft_x: Arc<dyn Fft<f64>>,
46 fft_y: Arc<dyn Fft<f64>>,
47 fft_z: Arc<dyn Fft<f64>>,
48 ifft_x: Arc<dyn Fft<f64>>,
50 ifft_y: Arc<dyn Fft<f64>>,
51 ifft_z: Arc<dyn Fft<f64>>,
52 scratch_x: Vec<Complex64>,
54 scratch_y: Vec<Complex64>,
55 scratch_z: Vec<Complex64>,
56 buffer_y: Vec<Complex64>,
57 buffer_z: Vec<Complex64>,
58}
59
60impl Fft3dWorkspace {
61 pub fn new(nx: usize, ny: usize, nz: usize) -> Self {
63 let mut planner = FftPlanner::new();
64
65 let fft_x = planner.plan_fft(nx, FftDirection::Forward);
66 let fft_y = planner.plan_fft(ny, FftDirection::Forward);
67 let fft_z = planner.plan_fft(nz, FftDirection::Forward);
68
69 let ifft_x = planner.plan_fft(nx, FftDirection::Inverse);
70 let ifft_y = planner.plan_fft(ny, FftDirection::Inverse);
71 let ifft_z = planner.plan_fft(nz, FftDirection::Inverse);
72
73 let scratch_x = vec![Complex64::new(0.0, 0.0); fft_x.get_inplace_scratch_len().max(ifft_x.get_inplace_scratch_len())];
74 let scratch_y = vec![Complex64::new(0.0, 0.0); fft_y.get_inplace_scratch_len().max(ifft_y.get_inplace_scratch_len())];
75 let scratch_z = vec![Complex64::new(0.0, 0.0); fft_z.get_inplace_scratch_len().max(ifft_z.get_inplace_scratch_len())];
76
77 Self {
78 nx, ny, nz,
79 n_total: nx * ny * nz,
80 fft_x, fft_y, fft_z,
81 ifft_x, ifft_y, ifft_z,
82 scratch_x, scratch_y, scratch_z,
83 buffer_y: vec![Complex64::new(0.0, 0.0); ny],
84 buffer_z: vec![Complex64::new(0.0, 0.0); nz],
85 }
86 }
87
88 pub fn fft3d(&mut self, data: &mut [Complex64]) {
90 let (nx, ny, nz) = (self.nx, self.ny, self.nz);
91
92 for k in 0..nz {
94 for j in 0..ny {
95 let start = idx3d(0, j, k, nx, ny);
96 self.fft_x.process_with_scratch(&mut data[start..start + nx], &mut self.scratch_x);
97 }
98 }
99
100 for k in 0..nz {
102 for i in 0..nx {
103 for j in 0..ny {
104 self.buffer_y[j] = data[idx3d(i, j, k, nx, ny)];
105 }
106 self.fft_y.process_with_scratch(&mut self.buffer_y, &mut self.scratch_y);
107 for j in 0..ny {
108 data[idx3d(i, j, k, nx, ny)] = self.buffer_y[j];
109 }
110 }
111 }
112
113 for j in 0..ny {
115 for i in 0..nx {
116 for k in 0..nz {
117 self.buffer_z[k] = data[idx3d(i, j, k, nx, ny)];
118 }
119 self.fft_z.process_with_scratch(&mut self.buffer_z, &mut self.scratch_z);
120 for k in 0..nz {
121 data[idx3d(i, j, k, nx, ny)] = self.buffer_z[k];
122 }
123 }
124 }
125 }
126
127 pub fn ifft3d(&mut self, data: &mut [Complex64]) {
129 let (nx, ny, nz) = (self.nx, self.ny, self.nz);
130 let n_total = self.n_total as f64;
131
132 for k in 0..nz {
134 for j in 0..ny {
135 let start = idx3d(0, j, k, nx, ny);
136 self.ifft_x.process_with_scratch(&mut data[start..start + nx], &mut self.scratch_x);
137 }
138 }
139
140 for k in 0..nz {
142 for i in 0..nx {
143 for j in 0..ny { self.buffer_y[j] = data[idx3d(i, j, k, nx, ny)]; }
144 self.ifft_y.process_with_scratch(&mut self.buffer_y, &mut self.scratch_y);
145 for j in 0..ny { data[idx3d(i, j, k, nx, ny)] = self.buffer_y[j]; }
146 }
147 }
148
149 for j in 0..ny {
151 for i in 0..nx {
152 for k in 0..nz { self.buffer_z[k] = data[idx3d(i, j, k, nx, ny)]; }
153 self.ifft_z.process_with_scratch(&mut self.buffer_z, &mut self.scratch_z);
154 for k in 0..nz { data[idx3d(i, j, k, nx, ny)] = self.buffer_z[k]; }
155 }
156 }
157
158 for val in data.iter_mut() { *val /= n_total; }
160 }
161
162 #[inline]
165 pub fn apply_dipole_inplace(&mut self, x: &[f64], d_kernel: &[f64], out: &mut [f64], complex_buf: &mut [Complex64]) {
166 for (c, &r) in complex_buf.iter_mut().zip(x.iter()) {
168 *c = Complex64::new(r, 0.0);
169 }
170
171 self.fft3d(complex_buf);
172
173 for (c, &d) in complex_buf.iter_mut().zip(d_kernel.iter()) {
175 *c *= d;
176 }
177
178 self.ifft3d(complex_buf);
179
180 for (o, c) in out.iter_mut().zip(complex_buf.iter()) {
182 *o = c.re;
183 }
184 }
185}
186
187#[inline(always)]
190pub fn idx3d(i: usize, j: usize, k: usize, nx: usize, ny: usize) -> usize {
191 i + j * nx + k * nx * ny
192}
193
194pub struct Fft3dWorkspaceF32 {
201 nx: usize,
202 ny: usize,
203 nz: usize,
204 n_total: usize,
205 fft_x: Arc<dyn Fft<f32>>,
207 fft_y: Arc<dyn Fft<f32>>,
208 fft_z: Arc<dyn Fft<f32>>,
209 ifft_x: Arc<dyn Fft<f32>>,
211 ifft_y: Arc<dyn Fft<f32>>,
212 ifft_z: Arc<dyn Fft<f32>>,
213 scratch_x: Vec<Complex32>,
215 scratch_y: Vec<Complex32>,
216 scratch_z: Vec<Complex32>,
217 buffer_y: Vec<Complex32>,
218 buffer_z: Vec<Complex32>,
219}
220
221impl Fft3dWorkspaceF32 {
222 pub fn new(nx: usize, ny: usize, nz: usize) -> Self {
224 let mut planner = FftPlanner::<f32>::new();
225
226 let fft_x = planner.plan_fft(nx, FftDirection::Forward);
227 let fft_y = planner.plan_fft(ny, FftDirection::Forward);
228 let fft_z = planner.plan_fft(nz, FftDirection::Forward);
229
230 let ifft_x = planner.plan_fft(nx, FftDirection::Inverse);
231 let ifft_y = planner.plan_fft(ny, FftDirection::Inverse);
232 let ifft_z = planner.plan_fft(nz, FftDirection::Inverse);
233
234 let scratch_x = vec![Complex32::new(0.0, 0.0); fft_x.get_inplace_scratch_len().max(ifft_x.get_inplace_scratch_len())];
235 let scratch_y = vec![Complex32::new(0.0, 0.0); fft_y.get_inplace_scratch_len().max(ifft_y.get_inplace_scratch_len())];
236 let scratch_z = vec![Complex32::new(0.0, 0.0); fft_z.get_inplace_scratch_len().max(ifft_z.get_inplace_scratch_len())];
237
238 Self {
239 nx, ny, nz,
240 n_total: nx * ny * nz,
241 fft_x, fft_y, fft_z,
242 ifft_x, ifft_y, ifft_z,
243 scratch_x, scratch_y, scratch_z,
244 buffer_y: vec![Complex32::new(0.0, 0.0); ny],
245 buffer_z: vec![Complex32::new(0.0, 0.0); nz],
246 }
247 }
248
249 #[inline]
251 pub fn fft3d(&mut self, data: &mut [Complex32]) {
252 let (nx, ny, nz) = (self.nx, self.ny, self.nz);
253
254 for k in 0..nz {
256 for j in 0..ny {
257 let start = idx3d(0, j, k, nx, ny);
258 self.fft_x.process_with_scratch(&mut data[start..start + nx], &mut self.scratch_x);
259 }
260 }
261
262 for k in 0..nz {
264 for i in 0..nx {
265 for j in 0..ny {
266 self.buffer_y[j] = data[idx3d(i, j, k, nx, ny)];
267 }
268 self.fft_y.process_with_scratch(&mut self.buffer_y, &mut self.scratch_y);
269 for j in 0..ny {
270 data[idx3d(i, j, k, nx, ny)] = self.buffer_y[j];
271 }
272 }
273 }
274
275 for j in 0..ny {
277 for i in 0..nx {
278 for k in 0..nz {
279 self.buffer_z[k] = data[idx3d(i, j, k, nx, ny)];
280 }
281 self.fft_z.process_with_scratch(&mut self.buffer_z, &mut self.scratch_z);
282 for k in 0..nz {
283 data[idx3d(i, j, k, nx, ny)] = self.buffer_z[k];
284 }
285 }
286 }
287 }
288
289 #[inline]
291 pub fn ifft3d(&mut self, data: &mut [Complex32]) {
292 let (nx, ny, nz) = (self.nx, self.ny, self.nz);
293 let n_total = self.n_total as f32;
294
295 for k in 0..nz {
297 for j in 0..ny {
298 let start = idx3d(0, j, k, nx, ny);
299 self.ifft_x.process_with_scratch(&mut data[start..start + nx], &mut self.scratch_x);
300 }
301 }
302
303 for k in 0..nz {
305 for i in 0..nx {
306 for j in 0..ny {
307 self.buffer_y[j] = data[idx3d(i, j, k, nx, ny)];
308 }
309 self.ifft_y.process_with_scratch(&mut self.buffer_y, &mut self.scratch_y);
310 for j in 0..ny {
311 data[idx3d(i, j, k, nx, ny)] = self.buffer_y[j];
312 }
313 }
314 }
315
316 for j in 0..ny {
318 for i in 0..nx {
319 for k in 0..nz {
320 self.buffer_z[k] = data[idx3d(i, j, k, nx, ny)];
321 }
322 self.ifft_z.process_with_scratch(&mut self.buffer_z, &mut self.scratch_z);
323 for k in 0..nz {
324 data[idx3d(i, j, k, nx, ny)] = self.buffer_z[k];
325 }
326 }
327 }
328
329 for val in data.iter_mut() {
331 *val /= n_total;
332 }
333 }
334
335 #[inline]
337 pub fn apply_dipole_inplace(&mut self, x: &[f32], d_kernel: &[f32], out: &mut [f32], complex_buf: &mut [Complex32]) {
338 for (c, &r) in complex_buf.iter_mut().zip(x.iter()) {
340 *c = Complex32::new(r, 0.0);
341 }
342
343 self.fft3d(complex_buf);
344
345 for (c, &d) in complex_buf.iter_mut().zip(d_kernel.iter()) {
347 *c *= d;
348 }
349
350 self.ifft3d(complex_buf);
351
352 for (o, c) in out.iter_mut().zip(complex_buf.iter()) {
354 *o = c.re;
355 }
356 }
357}
358
359pub fn fft3d(data: &mut [Complex64], nx: usize, ny: usize, nz: usize) {
364 let mut planner = FftPlanner::new();
365
366 let fft_x = planner.plan_fft(nx, FftDirection::Forward);
368 #[cfg(feature = "parallel")]
369 {
370 let scratch_len = fft_x.get_inplace_scratch_len();
371 data.par_chunks_mut(nx).for_each(|row| {
372 let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
373 fft_x.process_with_scratch(row, &mut scratch);
374 });
375 }
376 #[cfg(not(feature = "parallel"))]
377 {
378 let mut scratch_x = vec![Complex64::new(0.0, 0.0); fft_x.get_inplace_scratch_len()];
379 for chunk in data.chunks_mut(nx) {
380 fft_x.process_with_scratch(chunk, &mut scratch_x);
381 }
382 }
383
384 let fft_y = planner.plan_fft(ny, FftDirection::Forward);
386 #[cfg(feature = "parallel")]
387 {
388 let scratch_len = fft_y.get_inplace_scratch_len();
389 let nxy = nx * ny;
390 let pairs: Vec<(usize, usize)> = (0..nz).flat_map(|k| (0..nx).map(move |i| (k, i))).collect();
391 let data_send = SendPtr::new(data);
392 pairs.par_iter().for_each(|&(k, i)| {
393 let mut buffer = vec![Complex64::new(0.0, 0.0); ny];
394 let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
395 unsafe {
396 let slice = data_send.as_slice();
397 for j in 0..ny { buffer[j] = slice[i + j * nx + k * nxy]; }
398 fft_y.process_with_scratch(&mut buffer, &mut scratch);
399 for j in 0..ny { slice[i + j * nx + k * nxy] = buffer[j]; }
400 }
401 });
402 }
403 #[cfg(not(feature = "parallel"))]
404 {
405 let mut scratch_y = vec![Complex64::new(0.0, 0.0); fft_y.get_inplace_scratch_len()];
406 let mut buffer_y = vec![Complex64::new(0.0, 0.0); ny];
407 for k in 0..nz {
408 for i in 0..nx {
409 for j in 0..ny { buffer_y[j] = data[idx3d(i, j, k, nx, ny)]; }
410 fft_y.process_with_scratch(&mut buffer_y, &mut scratch_y);
411 for j in 0..ny { data[idx3d(i, j, k, nx, ny)] = buffer_y[j]; }
412 }
413 }
414 }
415
416 let fft_z = planner.plan_fft(nz, FftDirection::Forward);
418 #[cfg(feature = "parallel")]
419 {
420 let scratch_len = fft_z.get_inplace_scratch_len();
421 let nxy = nx * ny;
422 let pairs: Vec<(usize, usize)> = (0..ny).flat_map(|j| (0..nx).map(move |i| (j, i))).collect();
423 let data_send = SendPtr::new(data);
424 pairs.par_iter().for_each(|&(j, i)| {
425 let mut buffer = vec![Complex64::new(0.0, 0.0); nz];
426 let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
427 unsafe {
428 let slice = data_send.as_slice();
429 for k in 0..nz { buffer[k] = slice[i + j * nx + k * nxy]; }
430 fft_z.process_with_scratch(&mut buffer, &mut scratch);
431 for k in 0..nz { slice[i + j * nx + k * nxy] = buffer[k]; }
432 }
433 });
434 }
435 #[cfg(not(feature = "parallel"))]
436 {
437 let mut scratch_z = vec![Complex64::new(0.0, 0.0); fft_z.get_inplace_scratch_len()];
438 let mut buffer_z = vec![Complex64::new(0.0, 0.0); nz];
439 for j in 0..ny {
440 for i in 0..nx {
441 for k in 0..nz { buffer_z[k] = data[idx3d(i, j, k, nx, ny)]; }
442 fft_z.process_with_scratch(&mut buffer_z, &mut scratch_z);
443 for k in 0..nz { data[idx3d(i, j, k, nx, ny)] = buffer_z[k]; }
444 }
445 }
446 }
447}
448
449pub fn ifft3d(data: &mut [Complex64], nx: usize, ny: usize, nz: usize) {
454 let mut planner = FftPlanner::new();
455 let n_total = (nx * ny * nz) as f64;
456
457 let ifft_x = planner.plan_fft(nx, FftDirection::Inverse);
459 #[cfg(feature = "parallel")]
460 {
461 let scratch_len = ifft_x.get_inplace_scratch_len();
462 data.par_chunks_mut(nx).for_each(|row| {
463 let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
464 ifft_x.process_with_scratch(row, &mut scratch);
465 });
466 }
467 #[cfg(not(feature = "parallel"))]
468 {
469 let mut scratch_x = vec![Complex64::new(0.0, 0.0); ifft_x.get_inplace_scratch_len()];
470 for chunk in data.chunks_mut(nx) {
471 ifft_x.process_with_scratch(chunk, &mut scratch_x);
472 }
473 }
474
475 let ifft_y = planner.plan_fft(ny, FftDirection::Inverse);
477 #[cfg(feature = "parallel")]
478 {
479 let scratch_len = ifft_y.get_inplace_scratch_len();
480 let nxy = nx * ny;
481 let pairs: Vec<(usize, usize)> = (0..nz).flat_map(|k| (0..nx).map(move |i| (k, i))).collect();
482 let data_send = SendPtr::new(data);
483 pairs.par_iter().for_each(|&(k, i)| {
484 let mut buffer = vec![Complex64::new(0.0, 0.0); ny];
485 let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
486 unsafe {
487 let slice = data_send.as_slice();
488 for j in 0..ny { buffer[j] = slice[i + j * nx + k * nxy]; }
489 ifft_y.process_with_scratch(&mut buffer, &mut scratch);
490 for j in 0..ny { slice[i + j * nx + k * nxy] = buffer[j]; }
491 }
492 });
493 }
494 #[cfg(not(feature = "parallel"))]
495 {
496 let mut scratch_y = vec![Complex64::new(0.0, 0.0); ifft_y.get_inplace_scratch_len()];
497 let mut buffer_y = vec![Complex64::new(0.0, 0.0); ny];
498 for k in 0..nz {
499 for i in 0..nx {
500 for j in 0..ny { buffer_y[j] = data[idx3d(i, j, k, nx, ny)]; }
501 ifft_y.process_with_scratch(&mut buffer_y, &mut scratch_y);
502 for j in 0..ny { data[idx3d(i, j, k, nx, ny)] = buffer_y[j]; }
503 }
504 }
505 }
506
507 let ifft_z = planner.plan_fft(nz, FftDirection::Inverse);
509 #[cfg(feature = "parallel")]
510 {
511 let scratch_len = ifft_z.get_inplace_scratch_len();
512 let nxy = nx * ny;
513 let pairs: Vec<(usize, usize)> = (0..ny).flat_map(|j| (0..nx).map(move |i| (j, i))).collect();
514 let data_send = SendPtr::new(data);
515 pairs.par_iter().for_each(|&(j, i)| {
516 let mut buffer = vec![Complex64::new(0.0, 0.0); nz];
517 let mut scratch = vec![Complex64::new(0.0, 0.0); scratch_len];
518 unsafe {
519 let slice = data_send.as_slice();
520 for k in 0..nz { buffer[k] = slice[i + j * nx + k * nxy]; }
521 ifft_z.process_with_scratch(&mut buffer, &mut scratch);
522 for k in 0..nz { slice[i + j * nx + k * nxy] = buffer[k]; }
523 }
524 });
525 }
526 #[cfg(not(feature = "parallel"))]
527 {
528 let mut scratch_z = vec![Complex64::new(0.0, 0.0); ifft_z.get_inplace_scratch_len()];
529 let mut buffer_z = vec![Complex64::new(0.0, 0.0); nz];
530 for j in 0..ny {
531 for i in 0..nx {
532 for k in 0..nz { buffer_z[k] = data[idx3d(i, j, k, nx, ny)]; }
533 ifft_z.process_with_scratch(&mut buffer_z, &mut scratch_z);
534 for k in 0..nz { data[idx3d(i, j, k, nx, ny)] = buffer_z[k]; }
535 }
536 }
537 }
538
539 let n_total_f = n_total;
541 #[cfg(feature = "parallel")]
542 data.par_iter_mut().for_each(|val| { *val /= n_total_f; });
543 #[cfg(not(feature = "parallel"))]
544 for val in data.iter_mut() { *val /= n_total_f; }
545}
546
547pub fn fft3d_real(data: &[f64], nx: usize, ny: usize, nz: usize) -> Vec<Complex64> {
552 let mut complex_data: Vec<Complex64> = data.iter()
553 .map(|&x| Complex64::new(x, 0.0))
554 .collect();
555 fft3d(&mut complex_data, nx, ny, nz);
556 complex_data
557}
558
559pub fn ifft3d_real(data: &[Complex64], nx: usize, ny: usize, nz: usize) -> Vec<f64> {
563 let mut complex_data = data.to_vec();
564 ifft3d(&mut complex_data, nx, ny, nz);
565 complex_data.iter().map(|c| c.re).collect()
566}
567
568pub fn fftfreq(n: usize, d: f64) -> Vec<f64> {
571 let mut freq = vec![0.0; n];
572 let val = 1.0 / (n as f64 * d);
573
574 if n % 2 == 0 {
575 for i in 0..n / 2 {
577 freq[i] = (i as f64) * val;
578 }
579 for i in n / 2..n {
580 freq[i] = ((i as i64) - (n as i64)) as f64 * val;
581 }
582 } else {
583 for i in 0..=(n - 1) / 2 {
585 freq[i] = (i as f64) * val;
586 }
587 for i in (n + 1) / 2..n {
588 freq[i] = ((i as i64) - (n as i64)) as f64 * val;
589 }
590 }
591 freq
592}
593
594pub fn fftfreq_f32(n: usize, d: f32) -> Vec<f32> {
597 let mut freq = vec![0.0f32; n];
598 let val = 1.0f32 / (n as f32 * d);
599
600 if n % 2 == 0 {
601 for i in 0..n / 2 {
603 freq[i] = (i as f32) * val;
604 }
605 for i in n / 2..n {
606 freq[i] = ((i as i64) - (n as i64)) as f32 * val;
607 }
608 } else {
609 for i in 0..=(n - 1) / 2 {
611 freq[i] = (i as f32) * val;
612 }
613 for i in (n + 1) / 2..n {
614 freq[i] = ((i as i64) - (n as i64)) as f32 * val;
615 }
616 }
617 freq
618}
619
620pub fn fftshift(data: &[f64], nx: usize, ny: usize, nz: usize) -> Vec<f64> {
625 let n_total = nx * ny * nz;
626 let mut out = vec![0.0; n_total];
627
628 let hx = nx / 2;
629 let hy = ny / 2;
630 let hz = nz / 2;
631
632 for k in 0..nz {
633 for j in 0..ny {
634 for i in 0..nx {
635 let si = (i + hx) % nx;
636 let sj = (j + hy) % ny;
637 let sk = (k + hz) % nz;
638 out[idx3d(si, sj, sk, nx, ny)] = data[idx3d(i, j, k, nx, ny)];
639 }
640 }
641 }
642
643 out
644}
645
646pub fn ifftshift(data: &[f64], nx: usize, ny: usize, nz: usize) -> Vec<f64> {
651 let n_total = nx * ny * nz;
652 let mut out = vec![0.0; n_total];
653
654 let hx = (nx + 1) / 2;
655 let hy = (ny + 1) / 2;
656 let hz = (nz + 1) / 2;
657
658 for k in 0..nz {
659 for j in 0..ny {
660 for i in 0..nx {
661 let si = (i + hx) % nx;
662 let sj = (j + hy) % ny;
663 let sk = (k + hz) % nz;
664 out[idx3d(si, sj, sk, nx, ny)] = data[idx3d(i, j, k, nx, ny)];
665 }
666 }
667 }
668
669 out
670}
671
672pub fn fftshift_inplace(data: &mut [f64], nx: usize, ny: usize, nz: usize) {
676 let hx = nx / 2;
677 let hy = ny / 2;
678 let hz = nz / 2;
679
680 for k in 0..nz {
683 for j in 0..ny {
684 for i in 0..nx {
685 let si = (i + hx) % nx;
686 let sj = (j + hy) % ny;
687 let sk = (k + hz) % nz;
688
689 let idx_src = idx3d(i, j, k, nx, ny);
690 let idx_dst = idx3d(si, sj, sk, nx, ny);
691
692 if idx_src < idx_dst {
694 data.swap(idx_src, idx_dst);
695 }
696 }
697 }
698 }
699}
700
701#[inline]
703pub fn wrap_angle(angle: f64) -> f64 {
704 let mut a = angle % (2.0 * PI);
705 if a > PI {
706 a -= 2.0 * PI;
707 } else if a < -PI {
708 a += 2.0 * PI;
709 }
710 a
711}
712
713#[cfg(test)]
714mod tests {
715 use super::*;
716
717 #[test]
718 fn test_fft_ifft_roundtrip() {
719 let nx = 4;
720 let ny = 4;
721 let nz = 4;
722
723 let original: Vec<f64> = (0..nx * ny * nz).map(|i| i as f64).collect();
725
726 let mut data: Vec<Complex64> = original.iter()
728 .map(|&x| Complex64::new(x, 0.0))
729 .collect();
730
731 fft3d(&mut data, nx, ny, nz);
732 ifft3d(&mut data, nx, ny, nz);
733
734 for (i, (&orig, result)) in original.iter().zip(data.iter()).enumerate() {
736 assert!(
737 (result.re - orig).abs() < 1e-10,
738 "Mismatch at index {}: expected {}, got {}",
739 i, orig, result.re
740 );
741 assert!(
742 result.im.abs() < 1e-10,
743 "Imaginary part not zero at index {}: {}",
744 i, result.im
745 );
746 }
747 }
748
749 #[test]
750 fn test_fftfreq() {
751 let freq = fftfreq(4, 1.0);
753 assert!((freq[0] - 0.0).abs() < 1e-10);
754 assert!((freq[1] - 0.25).abs() < 1e-10);
755 assert!((freq[2] - (-0.5)).abs() < 1e-10);
756 assert!((freq[3] - (-0.25)).abs() < 1e-10);
757
758 let freq = fftfreq(5, 1.0);
760 assert!((freq[0] - 0.0).abs() < 1e-10);
761 assert!((freq[1] - 0.2).abs() < 1e-10);
762 assert!((freq[2] - 0.4).abs() < 1e-10);
763 assert!((freq[3] - (-0.4)).abs() < 1e-10);
764 assert!((freq[4] - (-0.2)).abs() < 1e-10);
765 }
766
767 #[test]
768 fn test_fft_f32_roundtrip() {
769 let nx = 4;
770 let ny = 4;
771 let nz = 4;
772
773 let original: Vec<f32> = (0..nx * ny * nz).map(|i| i as f32).collect();
774
775 let mut data: Vec<Complex32> = original.iter()
776 .map(|&x| Complex32::new(x, 0.0))
777 .collect();
778
779 let mut ws = Fft3dWorkspaceF32::new(nx, ny, nz);
780 ws.fft3d(&mut data);
781 ws.ifft3d(&mut data);
782
783 for (i, (&orig, result)) in original.iter().zip(data.iter()).enumerate() {
784 assert!(
785 (result.re - orig).abs() < 1e-4,
786 "f32 roundtrip mismatch at index {}: expected {}, got {}",
787 i, orig, result.re
788 );
789 assert!(
790 result.im.abs() < 1e-4,
791 "f32 imaginary part not zero at index {}: {}",
792 i, result.im
793 );
794 }
795 }
796
797 #[test]
798 fn test_fftshift_even() {
799 let nx = 4;
801 let ny = 4;
802 let nz = 4;
803 let n = nx * ny * nz;
804
805 let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
806 let shifted = fftshift(&data, nx, ny, nz);
807
808 assert!(
812 (shifted[idx3d(2, 2, 2, nx, ny)] - data[idx3d(0, 0, 0, nx, ny)]).abs() < 1e-12,
813 "fftshift: element at (0,0,0) should move to (2,2,2)"
814 );
815
816 assert!(
818 (shifted[idx3d(3, 3, 3, nx, ny)] - data[idx3d(1, 1, 1, nx, ny)]).abs() < 1e-12,
819 "fftshift: element at (1,1,1) should move to (3,3,3)"
820 );
821
822 assert_eq!(shifted.len(), n);
824 }
825
826 #[test]
827 fn test_ifftshift_roundtrip() {
828 let nx = 4;
829 let ny = 4;
830 let nz = 4;
831 let n = nx * ny * nz;
832
833 let data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1).collect();
834
835 let shifted = fftshift(&data, nx, ny, nz);
837 let unshifted = ifftshift(&shifted, nx, ny, nz);
838
839 for i in 0..n {
840 assert!(
841 (unshifted[i] - data[i]).abs() < 1e-12,
842 "ifftshift(fftshift(x)) != x at index {}: expected {}, got {}",
843 i, data[i], unshifted[i]
844 );
845 }
846 }
847
848 #[test]
849 fn test_fftshift_inplace() {
850 let nx = 4;
851 let ny = 4;
852 let nz = 4;
853 let n = nx * ny * nz;
854
855 let original: Vec<f64> = (0..n).map(|i| i as f64).collect();
856
857 let shifted_copy = fftshift(&original, nx, ny, nz);
859
860 let mut data = original.clone();
861 fftshift_inplace(&mut data, nx, ny, nz);
862
863 for i in 0..n {
864 assert!(
865 (data[i] - shifted_copy[i]).abs() < 1e-12,
866 "fftshift_inplace mismatch at index {}: expected {}, got {}",
867 i, shifted_copy[i], data[i]
868 );
869 }
870 }
871
872 #[cfg(feature = "parallel")]
874 #[test]
875 fn test_fft3d_workspace_parallel_matches_sequential() {
876 let n = 16;
877 let input: Vec<Complex64> = (0..n*n*n)
878 .map(|i| Complex64::new((i as f64 * 0.3).sin(), (i as f64 * 0.7).cos()))
879 .collect();
880
881 let pool_1 = rayon::ThreadPoolBuilder::new().num_threads(1).build().unwrap();
883 let result_seq = pool_1.install(|| {
884 let mut ws = super::Fft3dWorkspace::new(n, n, n);
885 let mut data = input.clone();
886 ws.fft3d(&mut data);
887 data
888 });
889
890 let result_par = {
892 let mut ws = super::Fft3dWorkspace::new(n, n, n);
893 let mut data = input.clone();
894 ws.fft3d(&mut data);
895 data
896 };
897
898 for (i, (s, p)) in result_seq.iter().zip(result_par.iter()).enumerate() {
899 assert!(
900 (s - p).norm() < 1e-10,
901 "FFT mismatch at {}: seq={} par={}", i, s, p
902 );
903 }
904 }
905
906 #[cfg(feature = "parallel")]
908 #[test]
909 fn test_ifft3d_workspace_parallel_matches_sequential() {
910 let n = 16;
911 let input: Vec<Complex64> = (0..n*n*n)
912 .map(|i| Complex64::new((i as f64 * 0.3).sin(), (i as f64 * 0.7).cos()))
913 .collect();
914
915 let pool_1 = rayon::ThreadPoolBuilder::new().num_threads(1).build().unwrap();
916 let result_seq = pool_1.install(|| {
917 let mut ws = super::Fft3dWorkspace::new(n, n, n);
918 let mut data = input.clone();
919 ws.ifft3d(&mut data);
920 data
921 });
922
923 let result_par = {
924 let mut ws = super::Fft3dWorkspace::new(n, n, n);
925 let mut data = input.clone();
926 ws.ifft3d(&mut data);
927 data
928 };
929
930 for (i, (s, p)) in result_seq.iter().zip(result_par.iter()).enumerate() {
931 assert!(
932 (s - p).norm() < 1e-10,
933 "IFFT mismatch at {}: seq={} par={}", i, s, p
934 );
935 }
936 }
937
938 #[cfg(feature = "parallel")]
940 #[test]
941 fn test_fft_ifft_roundtrip_parallel() {
942 let n = 16;
943 let original: Vec<Complex64> = (0..n*n*n)
944 .map(|i| Complex64::new((i as f64 * 0.3).sin(), 0.0))
945 .collect();
946
947 let mut ws = super::Fft3dWorkspace::new(n, n, n);
948 let mut data = original.clone();
949 ws.fft3d(&mut data);
950 ws.ifft3d(&mut data);
951
952 for (i, (orig, round)) in original.iter().zip(data.iter()).enumerate() {
953 assert!(
954 (orig - round).norm() < 1e-10,
955 "Roundtrip mismatch at {}: orig={} round={}", i, orig, round
956 );
957 }
958 }
959}