1use num_complex::Complex64;
15use crate::fft::{fft3d, ifft3d};
16use crate::kernels::dipole::dipole_kernel;
17
18#[derive(Clone, Debug)]
20pub struct TkdParams {
21 pub threshold: f64,
23}
24
25impl Default for TkdParams {
26 fn default() -> Self {
27 Self { threshold: 0.15 }
28 }
29}
30
31pub fn tkd(
47 local_field: &[f64],
48 mask: &[u8],
49 nx: usize, ny: usize, nz: usize,
50 vsx: f64, vsy: f64, vsz: f64,
51 bdir: (f64, f64, f64),
52 threshold: f64,
53) -> Vec<f64> {
54 let n_total = nx * ny * nz;
55
56 let d = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
58
59 let inv_threshold = 1.0 / threshold;
62 let inv_d: Vec<f64> = d.iter().map(|&dval| {
63 if dval.abs() <= threshold {
64 if dval >= 0.0 { inv_threshold } else { -inv_threshold }
66 } else {
67 1.0 / dval
69 }
70 }).collect();
71
72 let mut field_complex: Vec<Complex64> = local_field.iter()
74 .map(|&x| Complex64::new(x, 0.0))
75 .collect();
76
77 fft3d(&mut field_complex, nx, ny, nz);
79
80 for i in 0..n_total {
82 field_complex[i] *= inv_d[i];
83 }
84
85 ifft3d(&mut field_complex, nx, ny, nz);
87
88 let mut chi: Vec<f64> = field_complex.iter()
90 .map(|c| c.re)
91 .collect();
92
93 for i in 0..n_total {
95 if mask[i] == 0 {
96 chi[i] = 0.0;
97 }
98 }
99
100 chi
101}
102
103pub fn tkd_default(
105 local_field: &[f64],
106 mask: &[u8],
107 nx: usize, ny: usize, nz: usize,
108 vsx: f64, vsy: f64, vsz: f64,
109) -> Vec<f64> {
110 let p = TkdParams::default();
111 tkd(local_field, mask, nx, ny, nz, vsx, vsy, vsz, (0.0, 0.0, 1.0), p.threshold)
112}
113
114pub fn tsvd(
119 local_field: &[f64],
120 mask: &[u8],
121 nx: usize, ny: usize, nz: usize,
122 vsx: f64, vsy: f64, vsz: f64,
123 bdir: (f64, f64, f64),
124 threshold: f64,
125) -> Vec<f64> {
126 let n_total = nx * ny * nz;
127
128 let d = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
130
131 let inv_d: Vec<f64> = d.iter().map(|&dval| {
134 if dval.abs() <= threshold {
135 0.0 } else {
137 1.0 / dval
138 }
139 }).collect();
140
141 let mut field_complex: Vec<Complex64> = local_field.iter()
143 .map(|&x| Complex64::new(x, 0.0))
144 .collect();
145
146 fft3d(&mut field_complex, nx, ny, nz);
148
149 for i in 0..n_total {
151 field_complex[i] *= inv_d[i];
152 }
153
154 ifft3d(&mut field_complex, nx, ny, nz);
156
157 let mut chi: Vec<f64> = field_complex.iter()
159 .map(|c| c.re)
160 .collect();
161
162 for i in 0..n_total {
164 if mask[i] == 0 {
165 chi[i] = 0.0;
166 }
167 }
168
169 chi
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_tkd_zero_field() {
178 let n = 8;
180 let field = vec![0.0; n * n * n];
181 let mask = vec![1u8; n * n * n];
182
183 let chi = tkd_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
184
185 for val in chi.iter() {
186 assert!(val.abs() < 1e-10, "Zero field should give zero chi");
187 }
188 }
189
190 #[test]
191 fn test_tkd_mask() {
192 let n = 8;
194 let field = vec![1.0; n * n * n];
195 let mut mask = vec![1u8; n * n * n];
196
197 mask[0] = 0;
199 mask[1] = 0;
200
201 let chi = tkd_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
202
203 assert_eq!(chi[0], 0.0, "Outside mask should be 0");
204 assert_eq!(chi[1], 0.0, "Outside mask should be 0");
205 }
206
207 #[test]
208 fn test_tkd_finite() {
209 let n = 8;
211 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.01).collect();
212 let mask = vec![1u8; n * n * n];
213
214 let chi = tkd_default(&field, &mask, n, n, n, 1.0, 1.0, 1.0);
215
216 for (i, val) in chi.iter().enumerate() {
217 assert!(val.is_finite(), "Chi should be finite at index {}", i);
218 }
219 }
220
221 #[test]
222 fn test_tsvd_vs_tkd() {
223 let n = 16;
225 let field: Vec<f64> = (0..n*n*n).map(|i| ((i as f64) * 0.1).sin()).collect();
226 let mask = vec![1u8; n * n * n];
227
228 let chi_tkd = tkd(&field, &mask, n, n, n, 1.0, 1.0, 1.0, (0.0, 0.0, 1.0), 0.15);
229 let chi_tsvd = tsvd(&field, &mask, n, n, n, 1.0, 1.0, 1.0, (0.0, 0.0, 1.0), 0.15);
230
231 let diff: f64 = chi_tkd.iter().zip(chi_tsvd.iter())
233 .map(|(a, b)| (a - b).abs())
234 .sum();
235
236 assert!(diff > 1e-10, "TKD and TSVD should give different results");
237 }
238}