1use num_complex::Complex64;
19use crate::fft::Fft3dWorkspace;
20use crate::kernels::dipole::dipole_kernel;
21use crate::kernels::laplacian::laplacian_kernel;
22use crate::utils::gradient::{bdiv_inplace, fgrad_inplace};
23
24#[inline]
26fn weighted_shrink(x: f64, threshold: f64, weight: f64) -> f64 {
27 let t = threshold * weight;
28 if x > t {
29 x - t
30 } else if x < -t {
31 x + t
32 } else {
33 0.0
34 }
35}
36
37#[derive(Clone, Debug)]
39pub struct NltvParams {
40 pub lambda: f64,
42 pub mu: f64,
44 pub tol: f64,
46 pub max_iter: usize,
48 pub newton_iter: usize,
50}
51
52impl Default for NltvParams {
53 fn default() -> Self {
54 Self {
55 lambda: 1e-3,
56 mu: 1.0,
57 tol: 1e-3,
58 max_iter: 250,
59 newton_iter: 10,
60 }
61 }
62}
63
64pub fn nltv(
81 local_field: &[f64],
82 mask: &[u8],
83 nx: usize, ny: usize, nz: usize,
84 vsx: f64, vsy: f64, vsz: f64,
85 bdir: (f64, f64, f64),
86 lambda: f64,
87 mu: f64,
88 tol: f64,
89 max_iter: usize,
90 newton_iter: usize,
91) -> Vec<f64> {
92 nltv_with_progress(
93 local_field, mask, nx, ny, nz, vsx, vsy, vsz,
94 bdir, lambda, mu, tol, max_iter, newton_iter,
95 |_, _| {} )
97}
98
99pub fn nltv_with_progress<F>(
101 local_field: &[f64],
102 mask: &[u8],
103 nx: usize, ny: usize, nz: usize,
104 vsx: f64, vsy: f64, vsz: f64,
105 bdir: (f64, f64, f64),
106 lambda: f64,
107 mu: f64,
108 tol: f64,
109 max_iter: usize,
110 newton_iter: usize,
111 mut progress_callback: F,
112) -> Vec<f64>
113where
114 F: FnMut(usize, usize),
115{
116 let n_total = nx * ny * nz;
117 let eps = 1e-6; let mut fft_ws = Fft3dWorkspace::new(nx, ny, nz);
125
126 let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
127 let l_kernel = laplacian_kernel(nx, ny, nz, vsx, vsy, vsz, true);
128
129 let mut l_complex: Vec<Complex64> = l_kernel.iter()
131 .map(|&x| Complex64::new(x, 0.0))
132 .collect();
133 fft_ws.fft3d(&mut l_complex);
134
135 let rho = 100.0 * lambda;
137
138 let mut inv_a: Vec<f64> = vec![0.0; n_total];
140 for i in 0..n_total {
141 let a = d_kernel[i] * d_kernel[i] + rho * l_complex[i].re;
142 inv_a[i] = if a.abs() > 1e-20 { 1.0 / a } else { 0.0 };
143 }
144
145 let f_hat = &mut l_complex;
147 for i in 0..n_total {
148 f_hat[i] = Complex64::new(local_field[i], 0.0);
149 }
150 fft_ws.fft3d(f_hat);
151 for i in 0..n_total {
152 f_hat[i] = f_hat[i] * d_kernel[i] * inv_a[i];
153 }
154
155 let mut x = vec![0.0; n_total];
160 let mut x_prev = vec![0.0; n_total];
161
162 let mut ux = vec![0.0; n_total];
164 let mut uy = vec![0.0; n_total];
165 let mut uz = vec![0.0; n_total];
166
167 let mut gx = vec![0.0; n_total];
169 let mut gy = vec![0.0; n_total];
170 let mut gz = vec![0.0; n_total];
171
172 let mut div_d = vec![0.0; n_total];
174
175 let mut work_complex = vec![Complex64::new(0.0, 0.0); n_total];
177
178 let mut weights = vec![1.0; n_total];
180
181 let total_iter = max_iter * newton_iter;
182 let mut current_iter = 0;
183
184 for _newton in 0..newton_iter {
188 let lambda_over_rho = lambda / rho;
189
190 for _iter in 0..max_iter {
194 current_iter += 1;
195 progress_callback(current_iter, total_iter);
196
197 std::mem::swap(&mut x, &mut x_prev);
199
200 bdiv_inplace(&mut div_d, &gx, &gy, &gz, nx, ny, nz, vsx, vsy, vsz);
204
205 for i in 0..n_total {
206 work_complex[i] = Complex64::new(div_d[i], 0.0);
207 }
208 fft_ws.fft3d(&mut work_complex);
209
210 for i in 0..n_total {
214 work_complex[i] = f_hat[i] - rho * work_complex[i] * inv_a[i];
215 }
216
217 fft_ws.ifft3d(&mut work_complex);
218 for i in 0..n_total {
219 x[i] = work_complex[i].re;
220 }
221
222 let mut norm_diff_sq = 0.0;
226 let mut norm_x_sq = 0.0;
227 for i in 0..n_total {
228 let diff = x[i] - x_prev[i];
229 norm_diff_sq += diff * diff;
230 norm_x_sq += x[i] * x[i];
231 }
232
233 let rel_change = norm_diff_sq.sqrt() / (norm_x_sq.sqrt() + 1e-20);
234 if rel_change < tol {
235 break;
236 }
237
238 fgrad_inplace(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, vsy, vsz);
242
243 for i in 0..n_total {
244 let grad_x = gx[i];
245 let grad_y = gy[i];
246 let grad_z = gz[i];
247
248 let vx = grad_x + ux[i];
249 let vy = grad_y + uy[i];
250 let vz = grad_z + uz[i];
251
252 let zx_i = weighted_shrink(vx, lambda_over_rho, weights[i]);
254 let zy_i = weighted_shrink(vy, lambda_over_rho, weights[i]);
255 let zz_i = weighted_shrink(vz, lambda_over_rho, weights[i]);
256
257 ux[i] = vx - zx_i;
259 uy[i] = vy - zy_i;
260 uz[i] = vz - zz_i;
261
262 gx[i] = 2.0 * zx_i - vx;
264 gy[i] = 2.0 * zy_i - vy;
265 gz[i] = 2.0 * zz_i - vz;
266 }
267 }
268
269 fgrad_inplace(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, vsy, vsz);
273
274 for i in 0..n_total {
275 let grad_mag = (gx[i] * gx[i] + gy[i] * gy[i] + gz[i] * gz[i]).sqrt();
277
278 weights[i] = 1.0 / (grad_mag + mu * eps);
281 }
282
283 let max_weight: f64 = weights.iter().cloned().fold(0.0, f64::max);
285 if max_weight > 1.0 {
286 for w in weights.iter_mut() {
287 *w /= max_weight;
288 }
289 }
290 }
291
292 for i in 0..n_total {
294 if mask[i] == 0 {
295 x[i] = 0.0;
296 }
297 }
298
299 x
300}
301
302pub fn nltv_default(
304 local_field: &[f64],
305 mask: &[u8],
306 nx: usize, ny: usize, nz: usize,
307 vsx: f64, vsy: f64, vsz: f64,
308) -> Vec<f64> {
309 nltv(
310 local_field, mask, nx, ny, nz, vsx, vsy, vsz,
311 (0.0, 0.0, 1.0), 1e-3, 1.0, 1e-3, 250, 10 )
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_nltv_zero_field() {
326 let n = 8;
327 let field = vec![0.0; n * n * n];
328 let mask = vec![1u8; n * n * n];
329
330 let chi = nltv(
331 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
332 (0.0, 0.0, 1.0), 1e-3, 1.0, 1e-2, 10, 2
333 );
334
335 for &val in chi.iter() {
336 assert!(val.abs() < 1e-8, "Zero field should give zero chi");
337 }
338 }
339
340 #[test]
341 fn test_nltv_finite() {
342 let n = 8;
343 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
344 let mask = vec![1u8; n * n * n];
345
346 let chi = nltv(
347 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
348 (0.0, 0.0, 1.0), 1e-3, 1.0, 1e-2, 10, 2
349 );
350
351 for (i, &val) in chi.iter().enumerate() {
352 assert!(val.is_finite(), "Chi should be finite at index {}", i);
353 }
354 }
355
356 #[test]
357 fn test_weighted_shrink() {
358 assert!((weighted_shrink(1.0, 0.5, 1.0) - 0.5).abs() < 1e-10);
360 assert!((weighted_shrink(-1.0, 0.5, 1.0) - (-0.5)).abs() < 1e-10);
361 assert!((weighted_shrink(0.3, 0.5, 1.0) - 0.0).abs() < 1e-10);
362
363 assert!((weighted_shrink(1.0, 0.5, 0.5) - 0.75).abs() < 1e-10);
365 assert!((weighted_shrink(0.3, 0.5, 0.5) - 0.05).abs() < 1e-10);
366 }
367}