1use num_complex::Complex64;
15use crate::fft::Fft3dWorkspace;
16use crate::kernels::dipole::dipole_kernel;
17use crate::kernels::laplacian::laplacian_kernel;
18use crate::utils::gradient::{fgrad_inplace, bdiv_inplace};
19
20#[cfg(feature = "parallel")]
21use crate::par::*;
22
23#[inline]
25fn shrink(x: f64, threshold: f64) -> f64 {
26 if x > threshold {
27 x - threshold
28 } else if x < -threshold {
29 x + threshold
30 } else {
31 0.0
32 }
33}
34
35#[derive(Clone, Debug)]
37pub struct RtsParams {
38 pub delta: f64,
40 pub mu: f64,
42 pub rho: f64,
44 pub tol: f64,
46 pub max_iter: usize,
48 pub lsmr_iter: usize,
50}
51
52impl Default for RtsParams {
53 fn default() -> Self {
54 Self {
55 delta: 0.15,
56 mu: 1e5,
57 rho: 10.0,
58 tol: 1e-2,
59 max_iter: 20,
60 lsmr_iter: 4,
61 }
62 }
63}
64
65pub fn rts(
83 local_field: &[f64],
84 mask: &[u8],
85 nx: usize, ny: usize, nz: usize,
86 vsx: f64, vsy: f64, vsz: f64,
87 bdir: (f64, f64, f64),
88 delta: f64,
89 mu: f64,
90 rho: f64,
91 tol: f64,
92 max_iter: usize,
93 lsmr_iter: usize,
94) -> Vec<f64> {
95 rts_with_progress(
96 local_field, mask, nx, ny, nz, vsx, vsy, vsz,
97 bdir, delta, mu, rho, tol, max_iter, lsmr_iter,
98 |_, _| {} )
100}
101
102pub fn rts_with_progress<F>(
112 local_field: &[f64],
113 mask: &[u8],
114 nx: usize, ny: usize, nz: usize,
115 vsx: f64, vsy: f64, vsz: f64,
116 bdir: (f64, f64, f64),
117 delta: f64,
118 mu: f64,
119 rho: f64,
120 tol: f64,
121 max_iter: usize,
122 lsmr_iter: usize,
123 mut progress_callback: F,
124) -> Vec<f64>
125where
126 F: FnMut(usize, usize),
127{
128 let n_total = nx * ny * nz;
129
130 let mut fft_ws = Fft3dWorkspace::new(nx, ny, nz);
136
137 let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
139
140 let l_kernel = laplacian_kernel(nx, ny, nz, vsx, vsy, vsz, true);
142
143 let mut work_complex: Vec<Complex64> = l_kernel.iter()
145 .map(|&x| Complex64::new(x, 0.0))
146 .collect();
147 fft_ws.fft3d(&mut work_complex);
148
149 let mut m_mask: Vec<f64> = vec![0.0; n_total];
151 let mut inv_a: Vec<f64> = vec![0.0; n_total];
152
153 for i in 0..n_total {
154 let l_fft_i = work_complex[i].re;
155 if d_kernel[i].abs() > delta {
156 m_mask[i] = mu;
157 }
158 let a = m_mask[i] + rho * l_fft_i;
159 if a.abs() > 1e-20 {
160 inv_a[i] = rho / a;
161 }
162 }
163
164 for i in 0..n_total {
170 work_complex[i] = Complex64::new(local_field[i], 0.0);
171 }
172 fft_ws.fft3d(&mut work_complex);
173
174 let field_fft: Vec<Complex64> = work_complex.clone();
176
177 for i in 0..n_total {
180 let d = d_kernel[i];
181 if d.abs() > delta {
182 work_complex[i] = field_fft[i] * d / (d * d + 1e-6);
183 } else {
184 work_complex[i] = Complex64::new(0.0, 0.0);
185 }
186 }
187
188 let mut residual = vec![Complex64::new(0.0, 0.0); n_total];
191 for _ in 0..lsmr_iter {
192 for i in 0..n_total {
194 residual[i] = field_fft[i] - work_complex[i] * d_kernel[i];
195 }
196
197 for i in 0..n_total {
199 let d = d_kernel[i];
200 if d.abs() > delta {
201 work_complex[i] += residual[i] * d / (d * d + 1e-6);
202 }
203 }
204 }
205
206 fft_ws.ifft3d(&mut work_complex);
208
209 let mut x = vec![0.0; n_total];
211 for i in 0..n_total {
212 x[i] = if mask[i] != 0 { work_complex[i].re } else { 0.0 };
213 }
214
215 for i in 0..n_total {
221 work_complex[i] = Complex64::new(x[i], 0.0);
222 }
223 fft_ws.fft3d(&mut work_complex);
224
225 let mut f_hat: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); n_total];
226 for i in 0..n_total {
227 if m_mask[i].abs() > 1e-20 && inv_a[i].abs() > 1e-20 {
228 f_hat[i] = work_complex[i] * (m_mask[i] / rho) * inv_a[i];
229 }
230 }
231
232 let mut x_prev = vec![0.0; n_total];
237
238 let mut ux = vec![0.0; n_total];
240 let mut uy = vec![0.0; n_total];
241 let mut uz = vec![0.0; n_total];
242
243 let mut gx = vec![0.0; n_total];
245 let mut gy = vec![0.0; n_total];
246 let mut gz = vec![0.0; n_total];
247
248 let mut div_v = vec![0.0; n_total];
250
251 let inv_rho = 1.0 / rho;
252
253 for iter in 0..max_iter {
258 progress_callback(iter + 1, max_iter);
259
260 std::mem::swap(&mut x, &mut x_prev);
262
263 bdiv_inplace(&mut div_v, &gx, &gy, &gz, nx, ny, nz, vsx, vsy, vsz);
269
270 for i in 0..n_total {
272 work_complex[i] = Complex64::new(div_v[i], 0.0);
273 }
274 fft_ws.fft3d(&mut work_complex);
275
276 for i in 0..n_total {
280 work_complex[i] = f_hat[i] - work_complex[i] * inv_a[i];
281 }
282
283 fft_ws.ifft3d(&mut work_complex);
285 for i in 0..n_total {
286 x[i] = work_complex[i].re;
287 }
288
289 let mut norm_diff_sq = 0.0;
293 let mut norm_x_sq = 0.0;
294 for i in 0..n_total {
295 let diff = x[i] - x_prev[i];
296 norm_diff_sq += diff * diff;
297 norm_x_sq += x[i] * x[i];
298 }
299
300 let rel_change = norm_diff_sq.sqrt() / (norm_x_sq.sqrt() + 1e-20);
301 if rel_change < tol {
302 progress_callback(iter + 1, iter + 1);
303 break;
304 }
305
306 fgrad_inplace(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, vsy, vsz);
311
312 for i in 0..n_total {
314 let vx = gx[i] + ux[i];
315 let vy = gy[i] + uy[i];
316 let vz = gz[i] + uz[i];
317
318 let zx_i = shrink(vx, inv_rho);
319 let zy_i = shrink(vy, inv_rho);
320 let zz_i = shrink(vz, inv_rho);
321
322 ux[i] = vx - zx_i;
323 uy[i] = vy - zy_i;
324 uz[i] = vz - zz_i;
325
326 gx[i] = 2.0 * zx_i - vx;
327 gy[i] = 2.0 * zy_i - vy;
328 gz[i] = 2.0 * zz_i - vz;
329 }
330 }
331
332 for i in 0..n_total {
334 if mask[i] == 0 { x[i] = 0.0; }
335 }
336
337 x
338}
339
340pub fn rts_default(
342 local_field: &[f64],
343 mask: &[u8],
344 nx: usize, ny: usize, nz: usize,
345 vsx: f64, vsy: f64, vsz: f64,
346) -> Vec<f64> {
347 let p = RtsParams::default();
348 rts(
349 local_field, mask, nx, ny, nz, vsx, vsy, vsz,
350 (0.0, 0.0, 1.0),
351 p.delta, p.mu, p.rho, p.tol, p.max_iter, p.lsmr_iter,
352 )
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_rts_zero_field() {
361 let n = 8;
362 let field = vec![0.0; n * n * n];
363 let mask = vec![1u8; n * n * n];
364
365 let chi = rts(
366 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
367 (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-2, 5, 2
368 );
369
370 for &val in chi.iter() {
371 assert!(val.abs() < 1e-6, "Zero field should give near-zero chi");
372 }
373 }
374
375 #[test]
376 fn test_rts_finite() {
377 let n = 8;
378 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
379 let mask = vec![1u8; n * n * n];
380
381 let chi = rts(
382 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
383 (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-2, 5, 2
384 );
385
386 for (i, &val) in chi.iter().enumerate() {
387 assert!(val.is_finite(), "Chi should be finite at index {}", i);
388 }
389 }
390
391 #[test]
392 fn test_rts_mask() {
393 let n = 8;
394 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
395 let mut mask = vec![1u8; n * n * n];
396 mask[0] = 0;
398 mask[10] = 0;
399
400 let chi = rts(
401 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
402 (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-2, 5, 2
403 );
404
405 assert_eq!(chi[0], 0.0, "Masked voxel should be zero");
406 assert_eq!(chi[10], 0.0, "Masked voxel should be zero");
407 }
408
409 #[cfg(feature = "parallel")]
411 #[test]
412 fn test_rts_parallel_matches_sequential() {
413 let n = 16;
414 let field: Vec<f64> = (0..n*n*n).map(|i| ((i as f64) * 0.7).sin() * 0.01).collect();
415 let mask = vec![1u8; n * n * n];
416
417 let pool_1 = rayon::ThreadPoolBuilder::new().num_threads(1).build().unwrap();
419 let chi_seq = pool_1.install(|| {
420 rts(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
421 (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-4, 20, 4)
422 });
423
424 let chi_par = rts(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
426 (0.0, 0.0, 1.0), 0.15, 1e5, 10.0, 1e-4, 20, 4);
427
428 for (i, (s, p)) in chi_seq.iter().zip(chi_par.iter()).enumerate() {
430 assert!(
431 (s - p).abs() < 1e-10,
432 "RTS mismatch at voxel {}: seq={} par={} diff={}",
433 i, s, p, (s - p).abs()
434 );
435 }
436 }
437}