Skip to main content

qsm_core/utils/
gradient.rs

1//! Gradient operators for QSM
2//!
3//! Forward difference gradient and backward divergence operators
4//! used in TV regularization and other algorithms.
5
6/// Forward difference gradient operator (in-place)
7///
8/// Computes forward differences along each axis with periodic boundary conditions.
9/// Writes results directly into pre-allocated output buffers.
10///
11/// # Arguments
12/// * `gx`, `gy`, `gz` - Output gradient components (must be pre-allocated to nx*ny*nz)
13/// * `x` - Input array (nx * ny * nz)
14/// * `nx`, `ny`, `nz` - Array dimensions
15/// * `vsx`, `vsy`, `vsz` - Voxel sizes
16#[inline]
17pub fn fgrad_inplace(
18    gx: &mut [f64], gy: &mut [f64], gz: &mut [f64],
19    x: &[f64],
20    nx: usize, ny: usize, nz: usize,
21    vsx: f64, vsy: f64, vsz: f64,
22) {
23    let hx = 1.0 / vsx;
24    let hy = 1.0 / vsy;
25    let hz = 1.0 / vsz;
26
27    // Fortran order: index = i + j*nx + k*nx*ny
28    for k in 0..nz {
29        let kp1 = if k + 1 < nz { k + 1 } else { 0 };
30        let k_offset = k * nx * ny;
31        let kp1_offset = kp1 * nx * ny;
32
33        for j in 0..ny {
34            let jp1 = if j + 1 < ny { j + 1 } else { 0 };
35            let j_offset = j * nx;
36            let jp1_offset = jp1 * nx;
37
38            for i in 0..nx {
39                let ip1 = if i + 1 < nx { i + 1 } else { 0 };
40
41                let idx = i + j_offset + k_offset;
42                let idx_xp = ip1 + j_offset + k_offset;
43                let idx_yp = i + jp1_offset + k_offset;
44                let idx_zp = i + j_offset + kp1_offset;
45
46                let x_val = x[idx];
47                gx[idx] = (x[idx_xp] - x_val) * hx;
48                gy[idx] = (x[idx_yp] - x_val) * hy;
49                gz[idx] = (x[idx_zp] - x_val) * hz;
50            }
51        }
52    }
53}
54
55/// Backward divergence operator (in-place)
56///
57/// Computes backward divergence with periodic boundary conditions.
58/// Writes result directly into pre-allocated output buffer.
59///
60/// # Arguments
61/// * `div` - Output divergence (must be pre-allocated to nx*ny*nz)
62/// * `gx`, `gy`, `gz` - Gradient components
63/// * `nx`, `ny`, `nz` - Array dimensions
64/// * `vsx`, `vsy`, `vsz` - Voxel sizes
65#[inline]
66pub fn bdiv_inplace(
67    div: &mut [f64],
68    gx: &[f64], gy: &[f64], gz: &[f64],
69    nx: usize, ny: usize, nz: usize,
70    vsx: f64, vsy: f64, vsz: f64,
71) {
72    let hx = 1.0 / vsx;
73    let hy = 1.0 / vsy;
74    let hz = 1.0 / vsz;
75
76    // Fortran order: index = i + j*nx + k*nx*ny
77    for k in 0..nz {
78        let km1 = if k == 0 { nz - 1 } else { k - 1 };
79        let k_offset = k * nx * ny;
80        let km1_offset = km1 * nx * ny;
81
82        for j in 0..ny {
83            let jm1 = if j == 0 { ny - 1 } else { j - 1 };
84            let j_offset = j * nx;
85            let jm1_offset = jm1 * nx;
86
87            for i in 0..nx {
88                let im1 = if i == 0 { nx - 1 } else { i - 1 };
89
90                let idx = i + j_offset + k_offset;
91                let idx_xm = im1 + j_offset + k_offset;
92                let idx_ym = i + jm1_offset + k_offset;
93                let idx_zm = i + j_offset + km1_offset;
94
95                // Negative divergence (adjoint of forward gradient)
96                div[idx] = (gx[idx] - gx[idx_xm]) * hx
97                         + (gy[idx] - gy[idx_ym]) * hy
98                         + (gz[idx] - gz[idx_zm]) * hz;
99            }
100        }
101    }
102}
103
104/// Forward difference gradient operator
105///
106/// Computes forward differences along each axis with periodic boundary conditions.
107///
108/// # Arguments
109/// * `x` - Input array (nx * ny * nz)
110/// * `nx`, `ny`, `nz` - Array dimensions
111/// * `vsx`, `vsy`, `vsz` - Voxel sizes
112///
113/// # Returns
114/// Tuple of (gx, gy, gz) gradient components
115pub fn fgrad(
116    x: &[f64],
117    nx: usize, ny: usize, nz: usize,
118    vsx: f64, vsy: f64, vsz: f64,
119) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
120    let n_total = nx * ny * nz;
121    let mut gx = vec![0.0; n_total];
122    let mut gy = vec![0.0; n_total];
123    let mut gz = vec![0.0; n_total];
124
125    let hx = 1.0 / vsx;
126    let hy = 1.0 / vsy;
127    let hz = 1.0 / vsz;
128
129    // Fortran order: index = i + j*nx + k*nx*ny
130    for k in 0..nz {
131        let kp1 = (k + 1) % nz;
132        for j in 0..ny {
133            let jp1 = (j + 1) % ny;
134            for i in 0..nx {
135                let ip1 = (i + 1) % nx;  // Periodic BC
136
137                let idx = i + j * nx + k * nx * ny;
138                let idx_xp = ip1 + j * nx + k * nx * ny;
139                let idx_yp = i + jp1 * nx + k * nx * ny;
140                let idx_zp = i + j * nx + kp1 * nx * ny;
141
142                gx[idx] = (x[idx_xp] - x[idx]) * hx;
143                gy[idx] = (x[idx_yp] - x[idx]) * hy;
144                gz[idx] = (x[idx_zp] - x[idx]) * hz;
145            }
146        }
147    }
148
149    (gx, gy, gz)
150}
151
152/// Backward divergence operator (negative adjoint of forward gradient)
153///
154/// Computes backward divergence with periodic boundary conditions.
155///
156/// # Arguments
157/// * `gx`, `gy`, `gz` - Gradient components
158/// * `nx`, `ny`, `nz` - Array dimensions
159/// * `vsx`, `vsy`, `vsz` - Voxel sizes
160///
161/// # Returns
162/// Divergence (negative)
163pub fn bdiv(
164    gx: &[f64], gy: &[f64], gz: &[f64],
165    nx: usize, ny: usize, nz: usize,
166    vsx: f64, vsy: f64, vsz: f64,
167) -> Vec<f64> {
168    let n_total = nx * ny * nz;
169    let mut div = vec![0.0; n_total];
170
171    let hx = 1.0 / vsx;
172    let hy = 1.0 / vsy;
173    let hz = 1.0 / vsz;
174
175    // Fortran order: index = i + j*nx + k*nx*ny
176    for k in 0..nz {
177        let km1 = if k == 0 { nz - 1 } else { k - 1 };
178        for j in 0..ny {
179            let jm1 = if j == 0 { ny - 1 } else { j - 1 };
180            for i in 0..nx {
181                let im1 = if i == 0 { nx - 1 } else { i - 1 };  // Periodic BC
182
183                let idx = i + j * nx + k * nx * ny;
184                let idx_xm = im1 + j * nx + k * nx * ny;
185                let idx_ym = i + jm1 * nx + k * nx * ny;
186                let idx_zm = i + j * nx + km1 * nx * ny;
187
188                // Negative divergence (adjoint of forward gradient)
189                div[idx] = (gx[idx] - gx[idx_xm]) * hx
190                         + (gy[idx] - gy[idx_ym]) * hy
191                         + (gz[idx] - gz[idx_zm]) * hz;
192            }
193        }
194    }
195
196    div
197}
198
199/// Compute gradient magnitude squared: |∇x|² = gx² + gy² + gz²
200pub fn grad_magnitude_squared(
201    gx: &[f64], gy: &[f64], gz: &[f64]
202) -> Vec<f64> {
203    gx.iter().zip(gy.iter()).zip(gz.iter())
204        .map(|((&gxi, &gyi), &gzi)| gxi * gxi + gyi * gyi + gzi * gzi)
205        .collect()
206}
207
208// ============================================================================
209// F32 (Single Precision) Gradient Functions
210// ============================================================================
211
212/// Forward difference gradient operator (in-place, f32)
213#[inline]
214pub fn fgrad_inplace_f32(
215    gx: &mut [f32], gy: &mut [f32], gz: &mut [f32],
216    x: &[f32],
217    nx: usize, ny: usize, nz: usize,
218    vsx: f32, vsy: f32, vsz: f32,
219) {
220    let hx = 1.0 / vsx;
221    let hy = 1.0 / vsy;
222    let hz = 1.0 / vsz;
223
224    for k in 0..nz {
225        let k_offset = k * nx * ny;
226
227        for j in 0..ny {
228            let j_offset = j * nx;
229
230            for i in 0..nx {
231                let idx = i + j_offset + k_offset;
232                let x_val = x[idx];
233
234                // Forward difference with zero boundary (matching Julia)
235                gx[idx] = if i + 1 < nx {
236                    (x[idx + 1] - x_val) * hx
237                } else {
238                    0.0
239                };
240
241                gy[idx] = if j + 1 < ny {
242                    (x[i + (j + 1) * nx + k_offset] - x_val) * hy
243                } else {
244                    0.0
245                };
246
247                gz[idx] = if k + 1 < nz {
248                    (x[i + j_offset + (k + 1) * nx * ny] - x_val) * hz
249                } else {
250                    0.0
251                };
252            }
253        }
254    }
255}
256
257/// Backward divergence operator (in-place, f32)
258/// Uses zero boundary conditions (matching Julia)
259#[inline]
260pub fn bdiv_inplace_f32(
261    div: &mut [f32],
262    gx: &[f32], gy: &[f32], gz: &[f32],
263    nx: usize, ny: usize, nz: usize,
264    vsx: f32, vsy: f32, vsz: f32,
265) {
266    let hx = 1.0 / vsx;
267    let hy = 1.0 / vsy;
268    let hz = 1.0 / vsz;
269
270    for k in 0..nz {
271        let k_offset = k * nx * ny;
272
273        for j in 0..ny {
274            let j_offset = j * nx;
275
276            for i in 0..nx {
277                let idx = i + j_offset + k_offset;
278
279                // Zero at boundary (matching Julia)
280                let gx_xm = if i > 0 { gx[(i - 1) + j_offset + k_offset] } else { 0.0 };
281                let gy_ym = if j > 0 { gy[i + (j - 1) * nx + k_offset] } else { 0.0 };
282                let gz_zm = if k > 0 { gz[i + j_offset + (k - 1) * nx * ny] } else { 0.0 };
283
284                div[idx] = (gx[idx] - gx_xm) * hx
285                         + (gy[idx] - gy_ym) * hy
286                         + (gz[idx] - gz_zm) * hz;
287            }
288        }
289    }
290}
291
292/// Forward difference gradient operator (allocating, f32)
293pub fn fgrad_f32(
294    x: &[f32],
295    nx: usize, ny: usize, nz: usize,
296    vsx: f32, vsy: f32, vsz: f32,
297) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
298    let n_total = nx * ny * nz;
299    let mut gx = vec![0.0f32; n_total];
300    let mut gy = vec![0.0f32; n_total];
301    let mut gz = vec![0.0f32; n_total];
302    fgrad_inplace_f32(&mut gx, &mut gy, &mut gz, x, nx, ny, nz, vsx, vsy, vsz);
303    (gx, gy, gz)
304}
305
306// ============================================================================
307// Symmetric Gradient (for TGV)
308// ============================================================================
309
310/// Symmetric gradient operator for TGV regularization (in-place, f32)
311///
312/// Computes the symmetric gradient tensor from a vector field w = (wx, wy, wz).
313/// The output is a 6-component symmetric tensor:
314///   q[0] = ∂wx/∂x (Sxx)
315///   q[1] = (∂wx/∂y + ∂wy/∂x) / 2 (Sxy)
316///   q[2] = (∂wx/∂z + ∂wz/∂x) / 2 (Sxz)
317///   q[3] = ∂wy/∂y (Syy)
318///   q[4] = (∂wy/∂z + ∂wz/∂y) / 2 (Syz)
319///   q[5] = ∂wz/∂z (Szz)
320#[inline]
321pub fn symgrad_inplace_f32(
322    sxx: &mut [f32], sxy: &mut [f32], sxz: &mut [f32],
323    syy: &mut [f32], syz: &mut [f32], szz: &mut [f32],
324    wx: &[f32], wy: &[f32], wz: &[f32],
325    nx: usize, ny: usize, nz: usize,
326    vsx: f32, vsy: f32, vsz: f32,
327) {
328    let hx = 1.0 / vsx;
329    let hy = 1.0 / vsy;
330    let hz = 1.0 / vsz;
331
332    for k in 0..nz {
333        let k_offset = k * nx * ny;
334
335        for j in 0..ny {
336            let j_offset = j * nx;
337
338            for i in 0..nx {
339                let idx = i + j_offset + k_offset;
340
341                let wx0 = wx[idx];
342                let wy0 = wy[idx];
343                let wz0 = wz[idx];
344
345                // X derivatives (zero at boundary, matching Julia)
346                if i + 1 < nx {
347                    let idx_xp = (i + 1) + j_offset + k_offset;
348                    sxx[idx] = (wx[idx_xp] - wx0) * hx;
349                    // Contributions to off-diagonal terms
350                    let dwy_dx = (wy[idx_xp] - wy0) * hx;
351                    let dwz_dx = (wz[idx_xp] - wz0) * hx;
352                    sxy[idx] = dwy_dx * 0.5;
353                    sxz[idx] = dwz_dx * 0.5;
354                } else {
355                    sxx[idx] = 0.0;
356                    sxy[idx] = 0.0;
357                    sxz[idx] = 0.0;
358                }
359
360                // Y derivatives (zero at boundary)
361                if j + 1 < ny {
362                    let idx_yp = i + (j + 1) * nx + k_offset;
363                    syy[idx] = (wy[idx_yp] - wy0) * hy;
364                    let dwx_dy = (wx[idx_yp] - wx0) * hy;
365                    let dwz_dy = (wz[idx_yp] - wz0) * hy;
366                    sxy[idx] += dwx_dy * 0.5;
367                    syz[idx] = dwz_dy * 0.5;
368                } else {
369                    syy[idx] = 0.0;
370                    syz[idx] = 0.0;
371                }
372
373                // Z derivatives (zero at boundary)
374                if k + 1 < nz {
375                    let idx_zp = i + j_offset + (k + 1) * nx * ny;
376                    szz[idx] = (wz[idx_zp] - wz0) * hz;
377                    let dwx_dz = (wx[idx_zp] - wx0) * hz;
378                    let dwy_dz = (wy[idx_zp] - wy0) * hz;
379                    sxz[idx] += dwx_dz * 0.5;
380                    syz[idx] += dwy_dz * 0.5;
381                } else {
382                    szz[idx] = 0.0;
383                }
384            }
385        }
386    }
387}
388
389/// Divergence of symmetric tensor field (adjoint of symgrad)
390///
391/// Computes the divergence of a 6-component symmetric tensor field,
392/// producing a 3-component vector field.
393/// This is the adjoint of symgrad_inplace_f32.
394/// Uses zero boundary conditions (matching Julia).
395#[inline]
396pub fn symdiv_inplace_f32(
397    divx: &mut [f32], divy: &mut [f32], divz: &mut [f32],
398    sxx: &[f32], sxy: &[f32], sxz: &[f32],
399    syy: &[f32], syz: &[f32], szz: &[f32],
400    nx: usize, ny: usize, nz: usize,
401    vsx: f32, vsy: f32, vsz: f32,
402) {
403    let hx = 1.0 / vsx;
404    let hy = 1.0 / vsy;
405    let hz = 1.0 / vsz;
406
407    for k in 0..nz {
408        let k_offset = k * nx * ny;
409
410        for j in 0..ny {
411            let j_offset = j * nx;
412
413            for i in 0..nx {
414                let idx = i + j_offset + k_offset;
415
416                // Divergence of first row of tensor: div([Sxx, Sxy, Sxz])
417                // Using backward difference with zero at boundary
418                let sxx_xm = if i > 0 { sxx[(i - 1) + j_offset + k_offset] } else { 0.0 };
419                let sxy_ym = if j > 0 { sxy[i + (j - 1) * nx + k_offset] } else { 0.0 };
420                let sxz_zm = if k > 0 { sxz[i + j_offset + (k - 1) * nx * ny] } else { 0.0 };
421
422                divx[idx] = (sxx[idx] - sxx_xm) * hx
423                          + (sxy[idx] - sxy_ym) * hy
424                          + (sxz[idx] - sxz_zm) * hz;
425
426                // Divergence of second row: div([Sxy, Syy, Syz])
427                let sxy_xm = if i > 0 { sxy[(i - 1) + j_offset + k_offset] } else { 0.0 };
428                let syy_ym = if j > 0 { syy[i + (j - 1) * nx + k_offset] } else { 0.0 };
429                let syz_zm = if k > 0 { syz[i + j_offset + (k - 1) * nx * ny] } else { 0.0 };
430
431                divy[idx] = (sxy[idx] - sxy_xm) * hx
432                          + (syy[idx] - syy_ym) * hy
433                          + (syz[idx] - syz_zm) * hz;
434
435                // Divergence of third row: div([Sxz, Syz, Szz])
436                let sxz_xm = if i > 0 { sxz[(i - 1) + j_offset + k_offset] } else { 0.0 };
437                let syz_ym = if j > 0 { syz[i + (j - 1) * nx + k_offset] } else { 0.0 };
438                let szz_zm = if k > 0 { szz[i + j_offset + (k - 1) * nx * ny] } else { 0.0 };
439
440                divz[idx] = (sxz[idx] - sxz_xm) * hx
441                          + (syz[idx] - syz_ym) * hy
442                          + (szz[idx] - szz_zm) * hz;
443            }
444        }
445    }
446}
447
448/// Forward difference gradient operator (in-place, f32) - masked version
449/// Only computes gradient where mask is non-zero
450#[inline]
451pub fn fgrad_masked_inplace_f32(
452    gx: &mut [f32], gy: &mut [f32], gz: &mut [f32],
453    x: &[f32],
454    mask: &[u8],
455    nx: usize, ny: usize, nz: usize,
456    vsx: f32, vsy: f32, vsz: f32,
457) {
458    let hx = 1.0 / vsx;
459    let hy = 1.0 / vsy;
460    let hz = 1.0 / vsz;
461
462    for k in 0..nz {
463        let kp1 = if k + 1 < nz { k + 1 } else { 0 };
464        let k_offset = k * nx * ny;
465        let kp1_offset = kp1 * nx * ny;
466
467        for j in 0..ny {
468            let jp1 = if j + 1 < ny { j + 1 } else { 0 };
469            let j_offset = j * nx;
470            let jp1_offset = jp1 * nx;
471
472            for i in 0..nx {
473                let ip1 = if i + 1 < nx { i + 1 } else { 0 };
474
475                let idx = i + j_offset + k_offset;
476
477                if mask[idx] == 0 {
478                    gx[idx] = 0.0;
479                    gy[idx] = 0.0;
480                    gz[idx] = 0.0;
481                    continue;
482                }
483
484                let idx_xp = ip1 + j_offset + k_offset;
485                let idx_yp = i + jp1_offset + k_offset;
486                let idx_zp = i + j_offset + kp1_offset;
487
488                let x_val = x[idx];
489                gx[idx] = (x[idx_xp] - x_val) * hx;
490                gy[idx] = (x[idx_yp] - x_val) * hy;
491                gz[idx] = (x[idx_zp] - x_val) * hz;
492            }
493        }
494    }
495}
496
497/// Backward divergence operator (in-place, f32) - masked version
498#[inline]
499pub fn bdiv_masked_inplace_f32(
500    div: &mut [f32],
501    gx: &[f32], gy: &[f32], gz: &[f32],
502    mask: &[u8],
503    nx: usize, ny: usize, nz: usize,
504    vsx: f32, vsy: f32, vsz: f32,
505) {
506    let hx = 1.0 / vsx;
507    let hy = 1.0 / vsy;
508    let hz = 1.0 / vsz;
509
510    for k in 0..nz {
511        let k_offset = k * nx * ny;
512
513        for j in 0..ny {
514            let j_offset = j * nx;
515
516            for i in 0..nx {
517                let idx = i + j_offset + k_offset;
518
519                if mask[idx] == 0 {
520                    div[idx] = 0.0;
521                    continue;
522                }
523
524                // Julia: div = mask[i]*g[i] - mask[i-1]*g[i-1] (0 if i<=1)
525                let m = if mask[idx] != 0 { 1.0 } else { 0.0 };
526
527                let gx_term = m * gx[idx] * hx - if i > 0 {
528                    let idx_xm = (i - 1) + j_offset + k_offset;
529                    let m_xm = if mask[idx_xm] != 0 { 1.0 } else { 0.0 };
530                    m_xm * gx[idx_xm] * hx
531                } else {
532                    0.0
533                };
534
535                let gy_term = m * gy[idx] * hy - if j > 0 {
536                    let idx_ym = i + (j - 1) * nx + k_offset;
537                    let m_ym = if mask[idx_ym] != 0 { 1.0 } else { 0.0 };
538                    m_ym * gy[idx_ym] * hy
539                } else {
540                    0.0
541                };
542
543                let gz_term = m * gz[idx] * hz - if k > 0 {
544                    let idx_zm = i + j_offset + (k - 1) * nx * ny;
545                    let m_zm = if mask[idx_zm] != 0 { 1.0 } else { 0.0 };
546                    m_zm * gz[idx_zm] * hz
547                } else {
548                    0.0
549                };
550
551                div[idx] = gx_term + gy_term + gz_term;
552            }
553        }
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_grad_constant() {
563        // Gradient of constant should be zero
564        let n = 4;
565        let x = vec![1.0; n * n * n];
566
567        let (gx, gy, gz) = fgrad(&x, n, n, n, 1.0, 1.0, 1.0);
568
569        for i in 0..n*n*n {
570            assert!(gx[i].abs() < 1e-10);
571            assert!(gy[i].abs() < 1e-10);
572            assert!(gz[i].abs() < 1e-10);
573        }
574    }
575
576    #[test]
577    fn test_div_grad_adjoint() {
578        // Check that <grad(x), h> = <x, -div(h)> (adjoint relationship)
579        let n = 4;
580        let x: Vec<f64> = (0..n*n*n).map(|i| (i as f64) * 0.1).collect();
581
582        // Create an arbitrary vector field h
583        let hx: Vec<f64> = (0..n*n*n).map(|i| ((i as f64) * 0.2).sin()).collect();
584        let hy: Vec<f64> = (0..n*n*n).map(|i| ((i as f64) * 0.3).cos()).collect();
585        let hz: Vec<f64> = (0..n*n*n).map(|i| ((i as f64) * 0.1).sin()).collect();
586
587        let (gx, gy, gz) = fgrad(&x, n, n, n, 1.0, 1.0, 1.0);
588        let div_h = bdiv(&hx, &hy, &hz, n, n, n, 1.0, 1.0, 1.0);
589
590        // <grad(x), h> should equal <x, -div(h)>
591        let lhs: f64 = gx.iter().zip(hx.iter())
592            .chain(gy.iter().zip(hy.iter()))
593            .chain(gz.iter().zip(hz.iter()))
594            .map(|(&a, &b)| a * b)
595            .sum();
596
597        // Note: bdiv returns div, not -div, so we need to negate
598        let rhs: f64 = x.iter().zip(div_h.iter())
599            .map(|(&xi, &di)| -xi * di)
600            .sum();
601
602        let rel_err = (lhs - rhs).abs() / (lhs.abs() + rhs.abs() + 1e-10);
603        assert!(rel_err < 1e-10, "Adjoint property failed: lhs={}, rhs={}, rel_err={}", lhs, rhs, rel_err);
604    }
605
606    #[test]
607    fn test_fgrad_inplace_f32() {
608        // Linear ramp in x: x[i,j,k] = i
609        let nx = 4;
610        let ny = 3;
611        let nz = 3;
612        let n = nx * ny * nz;
613        let mut x = vec![0.0f32; n];
614        for k in 0..nz {
615            for j in 0..ny {
616                for i in 0..nx {
617                    x[i + j * nx + k * nx * ny] = i as f32;
618                }
619            }
620        }
621
622        let mut gx = vec![0.0f32; n];
623        let mut gy = vec![0.0f32; n];
624        let mut gz = vec![0.0f32; n];
625        fgrad_inplace_f32(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, 1.0, 1.0, 1.0);
626
627        // Interior x-gradient should be 1.0
628        for k in 0..nz {
629            for j in 0..ny {
630                for i in 0..nx {
631                    let idx = i + j * nx + k * nx * ny;
632                    if i + 1 < nx {
633                        assert!(
634                            (gx[idx] - 1.0).abs() < 1e-6,
635                            "gx[{},{},{}] = {}, expected 1.0",
636                            i, j, k, gx[idx]
637                        );
638                    } else {
639                        // Boundary: zero BC
640                        assert!(
641                            gx[idx].abs() < 1e-6,
642                            "gx at boundary should be 0, got {}",
643                            gx[idx]
644                        );
645                    }
646                    // y and z gradients should be 0 everywhere (x doesn't vary along y or z)
647                    if j + 1 < ny {
648                        assert!(gy[idx].abs() < 1e-6, "gy should be 0, got {}", gy[idx]);
649                    }
650                    if k + 1 < nz {
651                        assert!(gz[idx].abs() < 1e-6, "gz should be 0, got {}", gz[idx]);
652                    }
653                }
654            }
655        }
656
657        // Test with non-unit voxel size
658        let vsx = 2.0f32;
659        fgrad_inplace_f32(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, 1.0, 1.0);
660        for k in 0..nz {
661            for j in 0..ny {
662                for i in 0..(nx - 1) {
663                    let idx = i + j * nx + k * nx * ny;
664                    assert!(
665                        (gx[idx] - 0.5).abs() < 1e-6,
666                        "gx with vsx=2 should be 0.5, got {}",
667                        gx[idx]
668                    );
669                }
670            }
671        }
672    }
673
674    #[test]
675    fn test_symgrad_f32() {
676        // Linear vector field: wx = y, wy = 0, wz = 0
677        // => Sxx=0, Syy=0, Szz=0, Sxy = (dwx/dy + dwy/dx)/2 = (1+0)/2 = 0.5
678        // Sxz=0, Syz=0
679        let nx = 4;
680        let ny = 4;
681        let nz = 4;
682        let n = nx * ny * nz;
683
684        let mut wx = vec![0.0f32; n];
685        let wy = vec![0.0f32; n];
686        let wz = vec![0.0f32; n];
687
688        // wx = j (y coordinate)
689        for k in 0..nz {
690            for j in 0..ny {
691                for i in 0..nx {
692                    wx[i + j * nx + k * nx * ny] = j as f32;
693                }
694            }
695        }
696
697        let mut sxx = vec![0.0f32; n];
698        let mut sxy = vec![0.0f32; n];
699        let mut sxz = vec![0.0f32; n];
700        let mut syy = vec![0.0f32; n];
701        let mut syz = vec![0.0f32; n];
702        let mut szz = vec![0.0f32; n];
703
704        symgrad_inplace_f32(
705            &mut sxx, &mut sxy, &mut sxz, &mut syy, &mut syz, &mut szz,
706            &wx, &wy, &wz,
707            nx, ny, nz, 1.0, 1.0, 1.0,
708        );
709
710        // Check interior points (away from boundaries)
711        for k in 0..(nz - 1) {
712            for j in 0..(ny - 1) {
713                for i in 0..(nx - 1) {
714                    let idx = i + j * nx + k * nx * ny;
715                    assert!(
716                        sxx[idx].abs() < 1e-6,
717                        "sxx[{},{},{}] = {}, expected 0",
718                        i, j, k, sxx[idx]
719                    );
720                    assert!(
721                        (sxy[idx] - 0.5).abs() < 1e-6,
722                        "sxy[{},{},{}] = {}, expected 0.5",
723                        i, j, k, sxy[idx]
724                    );
725                    assert!(
726                        sxz[idx].abs() < 1e-6,
727                        "sxz[{},{},{}] = {}, expected 0",
728                        i, j, k, sxz[idx]
729                    );
730                    assert!(
731                        syy[idx].abs() < 1e-6,
732                        "syy[{},{},{}] = {}, expected 0",
733                        i, j, k, syy[idx]
734                    );
735                    assert!(
736                        syz[idx].abs() < 1e-6,
737                        "syz[{},{},{}] = {}, expected 0",
738                        i, j, k, syz[idx]
739                    );
740                    assert!(
741                        szz[idx].abs() < 1e-6,
742                        "szz[{},{},{}] = {}, expected 0",
743                        i, j, k, szz[idx]
744                    );
745                }
746            }
747        }
748    }
749
750    #[test]
751    fn test_symdiv_f32() {
752        // Constant tensor field => divergence should be zero at interior
753        let nx = 4;
754        let ny = 4;
755        let nz = 4;
756        let n = nx * ny * nz;
757
758        let sxx = vec![1.0f32; n];
759        let sxy = vec![0.5f32; n];
760        let sxz = vec![0.0f32; n];
761        let syy = vec![1.0f32; n];
762        let syz = vec![0.0f32; n];
763        let szz = vec![1.0f32; n];
764
765        let mut divx = vec![0.0f32; n];
766        let mut divy = vec![0.0f32; n];
767        let mut divz = vec![0.0f32; n];
768
769        symdiv_inplace_f32(
770            &mut divx, &mut divy, &mut divz,
771            &sxx, &sxy, &sxz, &syy, &syz, &szz,
772            nx, ny, nz, 1.0, 1.0, 1.0,
773        );
774
775        // For a constant tensor field, backward differences give 0 at interior points
776        // but non-zero at boundaries (i=0, j=0, k=0) due to zero BC
777        for k in 1..nz {
778            for j in 1..ny {
779                for i in 1..nx {
780                    let idx = i + j * nx + k * nx * ny;
781                    assert!(
782                        divx[idx].abs() < 1e-6,
783                        "divx[{},{},{}] = {}, expected 0",
784                        i, j, k, divx[idx]
785                    );
786                    assert!(
787                        divy[idx].abs() < 1e-6,
788                        "divy[{},{},{}] = {}, expected 0",
789                        i, j, k, divy[idx]
790                    );
791                    assert!(
792                        divz[idx].abs() < 1e-6,
793                        "divz[{},{},{}] = {}, expected 0",
794                        i, j, k, divz[idx]
795                    );
796                }
797            }
798        }
799
800        // At i=0, j=0, k=0 the backward BC (0.0 for i-1) means
801        // divx[0,0,0] = (sxx[0]-0)*1 + (sxy[0]-0)*1 + (sxz[0]-0)*1 = 1 + 0.5 + 0 = 1.5
802        let idx_origin = 0;
803        assert!(
804            (divx[idx_origin] - 1.5).abs() < 1e-5,
805            "divx at origin expected 1.5, got {}",
806            divx[idx_origin]
807        );
808    }
809
810    #[test]
811    fn test_bdiv_masked_f32() {
812        let nx = 4;
813        let ny = 4;
814        let nz = 4;
815        let n = nx * ny * nz;
816
817        // Create a mask that is 1 for the inner 2x2x2 cube
818        let mut mask = vec![0u8; n];
819        for k in 1..3 {
820            for j in 1..3 {
821                for i in 1..3 {
822                    mask[i + j * nx + k * nx * ny] = 1;
823                }
824            }
825        }
826
827        // Constant gradient field
828        let gx = vec![1.0f32; n];
829        let gy = vec![1.0f32; n];
830        let gz = vec![1.0f32; n];
831
832        let mut div = vec![0.0f32; n];
833        bdiv_masked_inplace_f32(&mut div, &gx, &gy, &gz, &mask, nx, ny, nz, 1.0, 1.0, 1.0);
834
835        // Outside the mask, divergence should be zero
836        for k in 0..nz {
837            for j in 0..ny {
838                for i in 0..nx {
839                    let idx = i + j * nx + k * nx * ny;
840                    if mask[idx] == 0 {
841                        assert!(
842                            div[idx].abs() < 1e-6,
843                            "div outside mask at [{},{},{}] = {}, expected 0",
844                            i, j, k, div[idx]
845                        );
846                    }
847                }
848            }
849        }
850
851        // Inside the mask with a neighbor also in mask: constant field gives 0 divergence
852        // At (2,2,2): mask[2,2,2]=1, mask[1,2,2]=1 => gx_term = 1*1 - 1*1 = 0, same for y,z
853        let idx_222 = 2 + 2 * nx + 2 * nx * ny;
854        assert!(
855            div[idx_222].abs() < 1e-6,
856            "div at inner point (2,2,2) should be 0 for constant field, got {}",
857            div[idx_222]
858        );
859
860        // At (1,1,1): mask[1,1,1]=1, mask[0,1,1]=0
861        // gx_term = 1*1*1 - 0 = 1 (since mask at i-1 is 0)
862        // gy_term = 1*1*1 - 0 = 1
863        // gz_term = 1*1*1 - 0 = 1
864        // total = 3
865        let idx_111 = 1 + 1 * nx + 1 * nx * ny;
866        assert!(
867            (div[idx_111] - 3.0).abs() < 1e-5,
868            "div at boundary-of-mask (1,1,1) expected 3.0, got {}",
869            div[idx_111]
870        );
871    }
872
873    #[test]
874    fn test_fgrad_bdiv_adjoint_f64() {
875        // More thorough adjoint test with non-uniform voxels and larger grid
876        let nx = 6;
877        let ny = 7;
878        let nz = 5;
879        let n = nx * ny * nz;
880        let vsx = 1.5;
881        let vsy = 0.8;
882        let vsz = 2.0;
883
884        // Create non-trivial scalar field
885        let x: Vec<f64> = (0..n)
886            .map(|idx| {
887                let i = idx % nx;
888                let j = (idx / nx) % ny;
889                let k = idx / (nx * ny);
890                (i as f64 * 0.3).sin() + (j as f64 * 0.7).cos() + (k as f64 * 0.2)
891            })
892            .collect();
893
894        // Create non-trivial vector field
895        let hx: Vec<f64> = (0..n)
896            .map(|idx| {
897                let i = idx % nx;
898                let j = (idx / nx) % ny;
899                ((i as f64 + 1.0) * (j as f64 + 1.0)).sqrt()
900            })
901            .collect();
902        let hy: Vec<f64> = (0..n)
903            .map(|idx| {
904                let k = idx / (nx * ny);
905                (k as f64 * 0.5 + 0.1).sin()
906            })
907            .collect();
908        let hz: Vec<f64> = (0..n)
909            .map(|idx| {
910                let i = idx % nx;
911                let k = idx / (nx * ny);
912                (i as f64 * 0.4 - k as f64 * 0.2).cos()
913            })
914            .collect();
915
916        let (gx, gy, gz) = fgrad(&x, nx, ny, nz, vsx, vsy, vsz);
917        let div_h = bdiv(&hx, &hy, &hz, nx, ny, nz, vsx, vsy, vsz);
918
919        // <grad(x), h> = sum(gx*hx + gy*hy + gz*hz)
920        let lhs: f64 = gx.iter().zip(hx.iter())
921            .chain(gy.iter().zip(hy.iter()))
922            .chain(gz.iter().zip(hz.iter()))
923            .map(|(&a, &b)| a * b)
924            .sum();
925
926        // <x, -div(h)> = -sum(x * div_h)
927        let rhs: f64 = x.iter().zip(div_h.iter())
928            .map(|(&xi, &di)| -xi * di)
929            .sum();
930
931        let rel_err = (lhs - rhs).abs() / (lhs.abs().max(rhs.abs()) + 1e-15);
932        assert!(
933            rel_err < 1e-10,
934            "Adjoint property failed for non-uniform voxels: lhs={}, rhs={}, rel_err={}",
935            lhs, rhs, rel_err
936        );
937    }
938
939    // ====================================================================
940    // f32 gradient tests on larger arrays
941    // ====================================================================
942
943    #[test]
944    fn test_fgrad_f32_allocating() {
945        // Test the allocating version
946        let nx = 8;
947        let ny = 8;
948        let nz = 8;
949        let n = nx * ny * nz;
950
951        let mut x = vec![0.0f32; n];
952        // Linear ramp in x
953        for k in 0..nz {
954            for j in 0..ny {
955                for i in 0..nx {
956                    x[i + j * nx + k * nx * ny] = i as f32;
957                }
958            }
959        }
960
961        let (gx, gy, gz) = fgrad_f32(&x, nx, ny, nz, 1.0, 1.0, 1.0);
962
963        assert_eq!(gx.len(), n);
964        assert_eq!(gy.len(), n);
965        assert_eq!(gz.len(), n);
966
967        // Interior x-gradient should be 1.0
968        for k in 0..nz {
969            for j in 0..ny {
970                for i in 0..(nx - 1) {
971                    let idx = i + j * nx + k * nx * ny;
972                    assert!(
973                        (gx[idx] - 1.0).abs() < 1e-5,
974                        "fgrad_f32 gx[{},{},{}] = {}, expected 1.0",
975                        i, j, k, gx[idx]
976                    );
977                }
978            }
979        }
980    }
981
982    #[test]
983    fn test_fgrad_inplace_f32_larger() {
984        // Test with larger 8x8x8 volume (512 elements)
985        let nx = 8;
986        let ny = 8;
987        let nz = 8;
988        let n = nx * ny * nz;
989
990        let mut x = vec![0.0f32; n];
991        // Linear ramp in y
992        for k in 0..nz {
993            for j in 0..ny {
994                for i in 0..nx {
995                    x[i + j * nx + k * nx * ny] = j as f32;
996                }
997            }
998        }
999
1000        let mut gx = vec![0.0f32; n];
1001        let mut gy = vec![0.0f32; n];
1002        let mut gz = vec![0.0f32; n];
1003        fgrad_inplace_f32(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, 1.0, 1.0, 1.0);
1004
1005        // Interior y-gradient should be 1.0
1006        for k in 0..nz {
1007            for j in 0..(ny - 1) {
1008                for i in 0..nx {
1009                    let idx = i + j * nx + k * nx * ny;
1010                    assert!(
1011                        (gy[idx] - 1.0).abs() < 1e-5,
1012                        "fgrad_inplace_f32 gy at interior should be 1.0, got {}",
1013                        gy[idx]
1014                    );
1015                }
1016            }
1017        }
1018
1019        // gx should be 0 (no variation in x for y-ramp)
1020        for k in 0..nz {
1021            for j in 0..ny {
1022                for i in 0..(nx - 1) {
1023                    let idx = i + j * nx + k * nx * ny;
1024                    assert!(
1025                        gx[idx].abs() < 1e-5,
1026                        "gx should be 0 for y-ramp, got {}",
1027                        gx[idx]
1028                    );
1029                }
1030            }
1031        }
1032    }
1033
1034    #[test]
1035    fn test_fgrad_inplace_f32_z_ramp() {
1036        let nx = 8;
1037        let ny = 8;
1038        let nz = 8;
1039        let n = nx * ny * nz;
1040
1041        let mut x = vec![0.0f32; n];
1042        // Linear ramp in z
1043        for k in 0..nz {
1044            for j in 0..ny {
1045                for i in 0..nx {
1046                    x[i + j * nx + k * nx * ny] = k as f32;
1047                }
1048            }
1049        }
1050
1051        let mut gx = vec![0.0f32; n];
1052        let mut gy = vec![0.0f32; n];
1053        let mut gz = vec![0.0f32; n];
1054        fgrad_inplace_f32(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, 1.0, 1.0, 1.0);
1055
1056        // Interior z-gradient should be 1.0
1057        for k in 0..(nz - 1) {
1058            for j in 0..ny {
1059                for i in 0..nx {
1060                    let idx = i + j * nx + k * nx * ny;
1061                    assert!(
1062                        (gz[idx] - 1.0).abs() < 1e-5,
1063                        "gz at interior should be 1.0, got {}",
1064                        gz[idx]
1065                    );
1066                }
1067            }
1068        }
1069    }
1070
1071    #[test]
1072    fn test_bdiv_inplace_f32_larger() {
1073        // Test bdiv on 8x8x8 volume
1074        let nx = 8;
1075        let ny = 8;
1076        let nz = 8;
1077        let n = nx * ny * nz;
1078
1079        // Constant gradient field
1080        let gx = vec![1.0f32; n];
1081        let gy = vec![1.0f32; n];
1082        let gz = vec![1.0f32; n];
1083
1084        let mut div = vec![0.0f32; n];
1085        bdiv_inplace_f32(&mut div, &gx, &gy, &gz, nx, ny, nz, 1.0, 1.0, 1.0);
1086
1087        // For constant gradient, backward differences give 0 at interior (i>0,j>0,k>0)
1088        for k in 1..nz {
1089            for j in 1..ny {
1090                for i in 1..nx {
1091                    let idx = i + j * nx + k * nx * ny;
1092                    assert!(
1093                        div[idx].abs() < 1e-5,
1094                        "bdiv interior should be 0 for constant gradient, got {}",
1095                        div[idx]
1096                    );
1097                }
1098            }
1099        }
1100
1101        // At boundary (i=0 or j=0 or k=0), should be non-zero
1102        assert!(
1103            div[0].abs() > 0.1,
1104            "bdiv at origin should be non-zero for constant gradient"
1105        );
1106    }
1107
1108    #[test]
1109    fn test_fgrad_masked_inplace_f32_larger() {
1110        let nx = 8;
1111        let ny = 8;
1112        let nz = 8;
1113        let n = nx * ny * nz;
1114
1115        // Create a mask covering inner 6x6x6
1116        let mut mask = vec![0u8; n];
1117        for k in 1..7 {
1118            for j in 1..7 {
1119                for i in 1..7 {
1120                    mask[i + j * nx + k * nx * ny] = 1;
1121                }
1122            }
1123        }
1124
1125        // Linear ramp in x
1126        let mut x = vec![0.0f32; n];
1127        for k in 0..nz {
1128            for j in 0..ny {
1129                for i in 0..nx {
1130                    x[i + j * nx + k * nx * ny] = i as f32;
1131                }
1132            }
1133        }
1134
1135        let mut gx = vec![0.0f32; n];
1136        let mut gy = vec![0.0f32; n];
1137        let mut gz = vec![0.0f32; n];
1138        fgrad_masked_inplace_f32(
1139            &mut gx, &mut gy, &mut gz, &x, &mask,
1140            nx, ny, nz, 1.0, 1.0, 1.0,
1141        );
1142
1143        // Outside mask should be zero
1144        for i in 0..n {
1145            if mask[i] == 0 {
1146                assert_eq!(gx[i], 0.0);
1147                assert_eq!(gy[i], 0.0);
1148                assert_eq!(gz[i], 0.0);
1149            }
1150        }
1151
1152        // Inside mask, x-gradient should be ~1.0 (for interior)
1153        for k in 1..7 {
1154            for j in 1..7 {
1155                for i in 1..6 { // not at boundary of mask
1156                    let idx = i + j * nx + k * nx * ny;
1157                    if mask[idx] != 0 {
1158                        assert!(
1159                            (gx[idx] - 1.0).abs() < 1e-5,
1160                            "masked gx inside should be 1.0, got {} at ({},{},{})",
1161                            gx[idx], i, j, k
1162                        );
1163                    }
1164                }
1165            }
1166        }
1167    }
1168
1169    #[test]
1170    fn test_bdiv_masked_inplace_f32_larger() {
1171        let nx = 8;
1172        let ny = 8;
1173        let nz = 8;
1174        let n = nx * ny * nz;
1175
1176        let mut mask = vec![0u8; n];
1177        for k in 1..7 {
1178            for j in 1..7 {
1179                for i in 1..7 {
1180                    mask[i + j * nx + k * nx * ny] = 1;
1181                }
1182            }
1183        }
1184
1185        let gx = vec![1.0f32; n];
1186        let gy = vec![1.0f32; n];
1187        let gz = vec![1.0f32; n];
1188
1189        let mut div = vec![0.0f32; n];
1190        bdiv_masked_inplace_f32(&mut div, &gx, &gy, &gz, &mask, nx, ny, nz, 1.0, 1.0, 1.0);
1191
1192        // Outside mask should be zero
1193        for i in 0..n {
1194            if mask[i] == 0 {
1195                assert_eq!(div[i], 0.0, "Outside mask div should be zero");
1196            }
1197        }
1198    }
1199
1200    #[test]
1201    fn test_fgrad_inplace_f64() {
1202        // Test the f64 in-place gradient
1203        let nx = 8;
1204        let ny = 8;
1205        let nz = 8;
1206        let n = nx * ny * nz;
1207
1208        let mut x = vec![0.0f64; n];
1209        for k in 0..nz {
1210            for j in 0..ny {
1211                for i in 0..nx {
1212                    x[i + j * nx + k * nx * ny] = i as f64;
1213                }
1214            }
1215        }
1216
1217        let mut gx = vec![0.0f64; n];
1218        let mut gy = vec![0.0f64; n];
1219        let mut gz = vec![0.0f64; n];
1220        fgrad_inplace(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, 1.0, 1.0, 1.0);
1221
1222        // Interior gradient in x should be 1
1223        for k in 0..nz {
1224            for j in 0..ny {
1225                for i in 0..(nx - 1) {
1226                    let idx = i + j * nx + k * nx * ny;
1227                    assert!(
1228                        (gx[idx] - 1.0).abs() < 1e-10,
1229                        "fgrad_inplace gx should be 1.0 at interior, got {}",
1230                        gx[idx]
1231                    );
1232                }
1233            }
1234        }
1235    }
1236
1237    #[test]
1238    fn test_bdiv_inplace_f64() {
1239        let nx = 8;
1240        let ny = 8;
1241        let nz = 8;
1242        let n = nx * ny * nz;
1243
1244        let gx = vec![1.0f64; n];
1245        let gy = vec![1.0f64; n];
1246        let gz = vec![1.0f64; n];
1247
1248        let mut div = vec![0.0f64; n];
1249        bdiv_inplace(&mut div, &gx, &gy, &gz, nx, ny, nz, 1.0, 1.0, 1.0);
1250
1251        // Interior divergence of constant field should be 0
1252        for k in 1..nz {
1253            for j in 1..ny {
1254                for i in 1..nx {
1255                    let idx = i + j * nx + k * nx * ny;
1256                    assert!(
1257                        div[idx].abs() < 1e-10,
1258                        "bdiv_inplace interior should be 0, got {}",
1259                        div[idx]
1260                    );
1261                }
1262            }
1263        }
1264    }
1265
1266    #[test]
1267    fn test_grad_magnitude_squared() {
1268        let n = 64;
1269        let gx: Vec<f64> = (0..n).map(|i| i as f64 * 0.1).collect();
1270        let gy: Vec<f64> = (0..n).map(|i| i as f64 * 0.2).collect();
1271        let gz: Vec<f64> = (0..n).map(|i| i as f64 * 0.3).collect();
1272
1273        let mag_sq = grad_magnitude_squared(&gx, &gy, &gz);
1274
1275        assert_eq!(mag_sq.len(), n);
1276        for i in 0..n {
1277            let expected = gx[i] * gx[i] + gy[i] * gy[i] + gz[i] * gz[i];
1278            assert!(
1279                (mag_sq[i] - expected).abs() < 1e-10,
1280                "grad_magnitude_squared[{}] = {}, expected {}",
1281                i, mag_sq[i], expected
1282            );
1283        }
1284    }
1285
1286    #[test]
1287    fn test_fgrad_f32_nonunit_voxels() {
1288        let nx = 8;
1289        let ny = 8;
1290        let nz = 8;
1291        let n = nx * ny * nz;
1292        let vsx = 2.0f32;
1293        let vsy = 0.5f32;
1294        let vsz = 3.0f32;
1295
1296        // Linear ramp in each direction simultaneously
1297        let mut x = vec![0.0f32; n];
1298        for k in 0..nz {
1299            for j in 0..ny {
1300                for i in 0..nx {
1301                    x[i + j * nx + k * nx * ny] = i as f32 + j as f32 + k as f32;
1302                }
1303            }
1304        }
1305
1306        let mut gx = vec![0.0f32; n];
1307        let mut gy = vec![0.0f32; n];
1308        let mut gz = vec![0.0f32; n];
1309        fgrad_inplace_f32(&mut gx, &mut gy, &mut gz, &x, nx, ny, nz, vsx, vsy, vsz);
1310
1311        // Interior x-gradient should be 1/vsx = 0.5
1312        for k in 0..nz {
1313            for j in 0..ny {
1314                for i in 0..(nx - 1) {
1315                    let idx = i + j * nx + k * nx * ny;
1316                    assert!(
1317                        (gx[idx] - 1.0 / vsx).abs() < 1e-5,
1318                        "gx should be {}, got {}",
1319                        1.0 / vsx, gx[idx]
1320                    );
1321                }
1322            }
1323        }
1324
1325        // Interior y-gradient should be 1/vsy = 2.0
1326        for k in 0..nz {
1327            for j in 0..(ny - 1) {
1328                for i in 0..nx {
1329                    let idx = i + j * nx + k * nx * ny;
1330                    assert!(
1331                        (gy[idx] - 1.0 / vsy).abs() < 1e-5,
1332                        "gy should be {}, got {}",
1333                        1.0 / vsy, gy[idx]
1334                    );
1335                }
1336            }
1337        }
1338
1339        // Interior z-gradient should be 1/vsz
1340        for k in 0..(nz - 1) {
1341            for j in 0..ny {
1342                for i in 0..nx {
1343                    let idx = i + j * nx + k * nx * ny;
1344                    assert!(
1345                        (gz[idx] - 1.0 / vsz).abs() < 1e-5,
1346                        "gz should be {}, got {}",
1347                        1.0 / vsz, gz[idx]
1348                    );
1349                }
1350            }
1351        }
1352    }
1353
1354    #[test]
1355    fn test_symgrad_symdiv_basics_f32() {
1356        // Test basic properties of symgrad and symdiv
1357        let nx = 4;
1358        let ny = 4;
1359        let nz = 4;
1360        let n = nx * ny * nz;
1361
1362        // Zero vector field => zero symgrad
1363        let wx = vec![0.0f32; n];
1364        let wy = vec![0.0f32; n];
1365        let wz = vec![0.0f32; n];
1366
1367        let mut sxx = vec![0.0f32; n];
1368        let mut sxy = vec![0.0f32; n];
1369        let mut sxz = vec![0.0f32; n];
1370        let mut syy = vec![0.0f32; n];
1371        let mut syz = vec![0.0f32; n];
1372        let mut szz = vec![0.0f32; n];
1373
1374        symgrad_inplace_f32(
1375            &mut sxx, &mut sxy, &mut sxz, &mut syy, &mut syz, &mut szz,
1376            &wx, &wy, &wz, nx, ny, nz, 1.0, 1.0, 1.0,
1377        );
1378
1379        for i in 0..n {
1380            assert_eq!(sxx[i], 0.0);
1381            assert_eq!(sxy[i], 0.0);
1382            assert_eq!(sxz[i], 0.0);
1383            assert_eq!(syy[i], 0.0);
1384            assert_eq!(syz[i], 0.0);
1385            assert_eq!(szz[i], 0.0);
1386        }
1387
1388        // Constant vector field => zero symgrad
1389        let wx = vec![5.0f32; n];
1390        let wy = vec![3.0f32; n];
1391        let wz = vec![7.0f32; n];
1392
1393        symgrad_inplace_f32(
1394            &mut sxx, &mut sxy, &mut sxz, &mut syy, &mut syz, &mut szz,
1395            &wx, &wy, &wz, nx, ny, nz, 1.0, 1.0, 1.0,
1396        );
1397
1398        // Interior points (not at boundary) should have zero symgrad
1399        for k in 0..(nz - 1) {
1400            for j in 0..(ny - 1) {
1401                for i in 0..(nx - 1) {
1402                    let idx = i + j * nx + k * nx * ny;
1403                    assert!(
1404                        sxx[idx].abs() < 1e-6,
1405                        "Constant field: sxx should be 0 at interior"
1406                    );
1407                    assert!(
1408                        sxy[idx].abs() < 1e-6,
1409                        "Constant field: sxy should be 0 at interior"
1410                    );
1411                }
1412            }
1413        }
1414
1415        // Test symdiv of zero tensor => zero
1416        let qxx = vec![0.0f32; n];
1417        let qxy = vec![0.0f32; n];
1418        let qxz = vec![0.0f32; n];
1419        let qyy = vec![0.0f32; n];
1420        let qyz = vec![0.0f32; n];
1421        let qzz = vec![0.0f32; n];
1422
1423        let mut divx = vec![0.0f32; n];
1424        let mut divy = vec![0.0f32; n];
1425        let mut divz = vec![0.0f32; n];
1426        symdiv_inplace_f32(
1427            &mut divx, &mut divy, &mut divz,
1428            &qxx, &qxy, &qxz, &qyy, &qyz, &qzz,
1429            nx, ny, nz, 1.0, 1.0, 1.0,
1430        );
1431
1432        for i in 0..n {
1433            assert_eq!(divx[i], 0.0);
1434            assert_eq!(divy[i], 0.0);
1435            assert_eq!(divz[i], 0.0);
1436        }
1437    }
1438
1439    #[test]
1440    fn test_symgrad_nonuniform_voxels_f32() {
1441        // Test symgrad with non-unit voxel sizes
1442        let nx = 6;
1443        let ny = 6;
1444        let nz = 6;
1445        let n = nx * ny * nz;
1446        let vsx = 2.0f32;
1447        let vsy = 0.5f32;
1448        let vsz = 1.5f32;
1449
1450        // wx = x coordinate (linear in i)
1451        let mut wx = vec![0.0f32; n];
1452        let wy = vec![0.0f32; n];
1453        let wz = vec![0.0f32; n];
1454        for k in 0..nz {
1455            for j in 0..ny {
1456                for i in 0..nx {
1457                    wx[i + j * nx + k * nx * ny] = i as f32;
1458                }
1459            }
1460        }
1461
1462        let mut sxx = vec![0.0f32; n];
1463        let mut sxy = vec![0.0f32; n];
1464        let mut sxz = vec![0.0f32; n];
1465        let mut syy = vec![0.0f32; n];
1466        let mut syz = vec![0.0f32; n];
1467        let mut szz = vec![0.0f32; n];
1468
1469        symgrad_inplace_f32(
1470            &mut sxx, &mut sxy, &mut sxz, &mut syy, &mut syz, &mut szz,
1471            &wx, &wy, &wz, nx, ny, nz, vsx, vsy, vsz,
1472        );
1473
1474        // sxx should be dwx/dx = 1/vsx at interior
1475        for k in 0..(nz - 1) {
1476            for j in 0..(ny - 1) {
1477                for i in 0..(nx - 1) {
1478                    let idx = i + j * nx + k * nx * ny;
1479                    assert!(
1480                        (sxx[idx] - 1.0 / vsx).abs() < 1e-5,
1481                        "sxx should be {}, got {}",
1482                        1.0 / vsx, sxx[idx]
1483                    );
1484                }
1485            }
1486        }
1487    }
1488}