1use num_complex::Complex64;
17use crate::fft::Fft3dWorkspace;
18use crate::kernels::dipole::dipole_kernel;
19use crate::kernels::laplacian::laplacian_kernel;
20use crate::utils::gradient::{bdiv_inplace, fgrad_inplace};
21
22#[cfg(feature = "parallel")]
23use crate::par::*;
24
25#[inline]
28fn shrink(x: f64, threshold: f64) -> f64 {
29 if x > threshold {
30 x - threshold
31 } else if x < -threshold {
32 x + threshold
33 } else {
34 0.0
35 }
36}
37
38#[derive(Clone, Debug)]
40pub struct TvParams {
41 pub lambda: f64,
43 pub rho: f64,
45 pub tol: f64,
47 pub max_iter: usize,
49}
50
51impl Default for TvParams {
52 fn default() -> Self {
53 Self {
54 lambda: 2e-4,
55 rho: 2e-2,
56 tol: 1e-3,
57 max_iter: 250,
58 }
59 }
60}
61
62pub fn tv_admm(
78 local_field: &[f64],
79 mask: &[u8],
80 nx: usize, ny: usize, nz: usize,
81 vsx: f64, vsy: f64, vsz: f64,
82 bdir: (f64, f64, f64),
83 lambda: f64,
84 rho: f64,
85 tol: f64,
86 max_iter: usize,
87) -> Vec<f64> {
88 tv_admm_with_progress(
89 local_field, mask, nx, ny, nz, vsx, vsy, vsz,
90 bdir, lambda, rho, tol, max_iter,
91 |_, _| {} )
93}
94
95pub fn tv_admm_with_progress<F>(
105 local_field: &[f64],
106 mask: &[u8],
107 nx: usize, ny: usize, nz: usize,
108 vsx: f64, vsy: f64, vsz: f64,
109 bdir: (f64, f64, f64),
110 lambda: f64,
111 rho: f64,
112 tol: f64,
113 max_iter: usize,
114 mut progress_callback: F,
115) -> Vec<f64>
116where
117 F: FnMut(usize, usize),
118{
119 let n_total = nx * ny * nz;
120
121 let mut fft_ws = Fft3dWorkspace::new(nx, ny, nz);
127
128 let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, bdir);
130
131 let l_kernel = laplacian_kernel(nx, ny, nz, vsx, vsy, vsz, true);
133
134 let mut l_complex: Vec<Complex64> = l_kernel.iter()
136 .map(|&x| Complex64::new(x, 0.0))
137 .collect();
138 fft_ws.fft3d(&mut l_complex);
139
140 let mut inv_a: Vec<f64> = vec![0.0; n_total];
142 for i in 0..n_total {
143 let a = d_kernel[i] * d_kernel[i] + rho * l_complex[i].re;
144 inv_a[i] = if a.abs() > 1e-20 { 1.0 / a } else { 0.0 };
145 }
146
147 let f_hat = &mut l_complex; for i in 0..n_total {
150 f_hat[i] = Complex64::new(local_field[i], 0.0);
151 }
152 fft_ws.fft3d(f_hat);
153
154 for i in 0..n_total {
156 f_hat[i] = f_hat[i] * d_kernel[i] * inv_a[i];
157 }
158
159 let mut x = vec![0.0; n_total];
165 let mut x_prev = vec![0.0; n_total];
166
167 let mut ux = vec![0.0; n_total];
169 let mut uy = vec![0.0; n_total];
170 let mut uz = vec![0.0; n_total];
171
172 let mut gx = vec![0.0; n_total];
174 let mut gy = vec![0.0; n_total];
175 let mut gz = vec![0.0; n_total];
176
177 let mut div_d = vec![0.0; n_total];
179
180 let mut work_complex = vec![Complex64::new(0.0, 0.0); n_total];
182
183 let lambda_over_rho = lambda / rho;
184
185 for iter in 0..max_iter {
189 progress_callback(iter + 1, max_iter);
191
192 std::mem::swap(&mut x, &mut x_prev);
194
195 bdiv_inplace(&mut div_d, &gx, &gy, &gz, nx, ny, nz, vsx, vsy, vsz);
205
206 for i in 0..n_total {
208 work_complex[i] = Complex64::new(div_d[i], 0.0);
209 }
210 fft_ws.fft3d(&mut work_complex);
211
212 for i in 0..n_total {
216 work_complex[i] = f_hat[i] - rho * work_complex[i] * inv_a[i];
217 }
218
219 fft_ws.ifft3d(&mut work_complex);
221 for i in 0..n_total {
222 x[i] = work_complex[i].re;
223 }
224
225 let mut norm_diff_sq = 0.0;
229 let mut norm_x_sq = 0.0;
230 for i in 0..n_total {
231 let diff = x[i] - x_prev[i];
232 norm_diff_sq += diff * diff;
233 norm_x_sq += x[i] * x[i];
234 }
235
236 let rel_change = norm_diff_sq.sqrt() / (norm_x_sq.sqrt() + 1e-20);
237 if rel_change < tol {
238 progress_callback(iter + 1, iter + 1);
239 break;
240 }
241
242 fgrad_inplace(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, vsy, vsz);
248
249 for i in 0..n_total {
251 let vx = gx[i] + ux[i];
252 let vy = gy[i] + uy[i];
253 let vz = gz[i] + uz[i];
254
255 let zx_i = shrink(vx, lambda_over_rho);
256 let zy_i = shrink(vy, lambda_over_rho);
257 let zz_i = shrink(vz, lambda_over_rho);
258
259 ux[i] = vx - zx_i;
260 uy[i] = vy - zy_i;
261 uz[i] = vz - zz_i;
262
263 gx[i] = 2.0 * zx_i - vx;
264 gy[i] = 2.0 * zy_i - vy;
265 gz[i] = 2.0 * zz_i - vz;
266 }
267 }
268
269 for i in 0..n_total {
271 if mask[i] == 0 { x[i] = 0.0; }
272 }
273
274 x
275}
276
277pub fn tv_admm_default(
279 local_field: &[f64],
280 mask: &[u8],
281 nx: usize, ny: usize, nz: usize,
282 vsx: f64, vsy: f64, vsz: f64,
283) -> Vec<f64> {
284 let p = TvParams::default();
285 tv_admm(
286 local_field, mask, nx, ny, nz, vsx, vsy, vsz,
287 (0.0, 0.0, 1.0),
288 p.lambda, p.rho, p.tol, p.max_iter,
289 )
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::utils::gradient::fgrad;
296
297 #[test]
298 fn test_shrink() {
299 assert!((shrink(1.0, 0.5) - 0.5).abs() < 1e-10);
300 assert!((shrink(-1.0, 0.5) - (-0.5)).abs() < 1e-10);
301 assert!((shrink(0.3, 0.5) - 0.0).abs() < 1e-10);
302 assert!((shrink(-0.3, 0.5) - 0.0).abs() < 1e-10);
303 }
304
305 #[test]
306 fn test_tv_admm_zero_field() {
307 let n = 8;
309 let field = vec![0.0; n * n * n];
310 let mask = vec![1u8; n * n * n];
311
312 let chi = tv_admm(
313 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
314 (0.0, 0.0, 1.0), 1e-3, 0.1, 1e-2, 10
315 );
316
317 for &val in chi.iter() {
318 assert!(val.abs() < 1e-8, "Zero field should give zero chi, got {}", val);
319 }
320 }
321
322 #[test]
323 fn test_tv_admm_finite() {
324 let n = 8;
326 let field: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.001).collect();
327 let mask = vec![1u8; n * n * n];
328
329 let chi = tv_admm(
330 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
331 (0.0, 0.0, 1.0), 1e-3, 0.1, 1e-2, 10
332 );
333
334 for (i, &val) in chi.iter().enumerate() {
335 assert!(val.is_finite(), "Chi should be finite at index {}", i);
336 }
337 }
338
339 #[test]
340 fn test_tv_admm_smoother_than_tkd() {
341 let n = 8;
343 let mut field = vec![0.0; n * n * n];
345 for i in 0..n*n*n {
346 field[i] = if i % 2 == 0 { 0.01 } else { -0.01 }; }
348 let mask = vec![1u8; n * n * n];
349
350 let chi_tv = tv_admm(
351 &field, &mask, n, n, n, 1.0, 1.0, 1.0,
352 (0.0, 0.0, 1.0), 1e-2, 1.0, 1e-2, 50 );
354
355 let (gx, gy, gz) = fgrad(&chi_tv, n, n, n, 1.0, 1.0, 1.0);
357 let tv: f64 = gx.iter().chain(gy.iter()).chain(gz.iter())
358 .map(|&g| g.abs())
359 .sum();
360
361 assert!(tv.is_finite(), "TV should be finite");
364 }
365
366 #[cfg(feature = "parallel")]
368 #[test]
369 fn test_tv_parallel_matches_sequential() {
370 let n = 16;
371 let field: Vec<f64> = (0..n*n*n).map(|i| ((i as f64) * 0.7).sin() * 0.01).collect();
372 let mask = vec![1u8; n * n * n];
373
374 let pool_1 = rayon::ThreadPoolBuilder::new().num_threads(1).build().unwrap();
376 let chi_seq = pool_1.install(|| {
377 tv_admm(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
378 (0.0, 0.0, 1.0), 1e-3, 0.1, 1e-3, 50)
379 });
380
381 let chi_par = tv_admm(&field, &mask, n, n, n, 1.0, 1.0, 1.0,
383 (0.0, 0.0, 1.0), 1e-3, 0.1, 1e-3, 50);
384
385 for (i, (s, p)) in chi_seq.iter().zip(chi_par.iter()).enumerate() {
387 assert!(
388 (s - p).abs() < 1e-10,
389 "TV mismatch at voxel {}: seq={} par={} diff={}",
390 i, s, p, (s - p).abs()
391 );
392 }
393 }
394}