Skip to main content

qsm_core/utils/
simd_ops.rs

1//! SIMD-accelerated operations for QSM processing
2//!
3//! This module provides vectorized versions of common operations used in
4//! iterative algorithms like MEDI, TV, and TGV. When the `simd` feature is
5//! enabled, these use 128-bit SIMD (f32x4) which is compatible with both
6//! native SSE/NEON and WASM SIMD.
7//!
8//! All operations have scalar fallbacks when SIMD is disabled.
9
10#[cfg(feature = "simd")]
11use wide::f32x4;
12
13/// SIMD lane width (4 for f32x4)
14#[cfg(feature = "simd")]
15pub const SIMD_WIDTH: usize = 4;
16
17#[cfg(not(feature = "simd"))]
18pub const SIMD_WIDTH: usize = 1;
19
20// ============================================================================
21// Dot Product Operations
22// ============================================================================
23
24/// Compute dot product: sum(a[i] * b[i])
25#[cfg(feature = "simd")]
26#[inline]
27pub fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
28    debug_assert_eq!(a.len(), b.len());
29    let n = a.len();
30    let chunks = n / SIMD_WIDTH;
31    let remainder = n % SIMD_WIDTH;
32
33    let mut sum = f32x4::ZERO;
34
35    // Process 4 elements at a time
36    for i in 0..chunks {
37        let idx = i * SIMD_WIDTH;
38        let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
39        let vb = f32x4::from(&b[idx..idx + SIMD_WIDTH]);
40        sum += va * vb;
41    }
42
43    // Horizontal sum of SIMD register
44    let mut result = sum.reduce_add();
45
46    // Handle remainder
47    let start = chunks * SIMD_WIDTH;
48    for i in 0..remainder {
49        result += a[start + i] * b[start + i];
50    }
51
52    result
53}
54
55#[cfg(not(feature = "simd"))]
56#[inline]
57pub fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
58    debug_assert_eq!(a.len(), b.len());
59    a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
60}
61
62/// Compute squared norm: sum(a[i]^2)
63#[cfg(feature = "simd")]
64#[inline]
65pub fn norm_squared_f32(a: &[f32]) -> f32 {
66    let n = a.len();
67    let chunks = n / SIMD_WIDTH;
68    let remainder = n % SIMD_WIDTH;
69
70    let mut sum = f32x4::ZERO;
71
72    for i in 0..chunks {
73        let idx = i * SIMD_WIDTH;
74        let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
75        sum += va * va;
76    }
77
78    let mut result = sum.reduce_add();
79
80    let start = chunks * SIMD_WIDTH;
81    for i in 0..remainder {
82        result += a[start + i] * a[start + i];
83    }
84
85    result
86}
87
88#[cfg(not(feature = "simd"))]
89#[inline]
90pub fn norm_squared_f32(a: &[f32]) -> f32 {
91    a.iter().map(|&ai| ai * ai).sum()
92}
93
94// ============================================================================
95// Fused Multiply-Add Operations
96// ============================================================================
97
98/// Compute a[i] = a[i] + alpha * b[i] (axpy operation)
99#[cfg(feature = "simd")]
100#[inline]
101pub fn axpy_f32(a: &mut [f32], alpha: f32, b: &[f32]) {
102    debug_assert_eq!(a.len(), b.len());
103    let n = a.len();
104    let chunks = n / SIMD_WIDTH;
105    let remainder = n % SIMD_WIDTH;
106
107    let valpha = f32x4::splat(alpha);
108
109    for i in 0..chunks {
110        let idx = i * SIMD_WIDTH;
111        let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
112        let vb = f32x4::from(&b[idx..idx + SIMD_WIDTH]);
113        let result = va + valpha * vb;
114        a[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
115    }
116
117    let start = chunks * SIMD_WIDTH;
118    for i in 0..remainder {
119        a[start + i] += alpha * b[start + i];
120    }
121}
122
123#[cfg(not(feature = "simd"))]
124#[inline]
125pub fn axpy_f32(a: &mut [f32], alpha: f32, b: &[f32]) {
126    debug_assert_eq!(a.len(), b.len());
127    for i in 0..a.len() {
128        a[i] += alpha * b[i];
129    }
130}
131
132/// Compute a[i] = b[i] + beta * a[i] (used in CG for p update)
133#[cfg(feature = "simd")]
134#[inline]
135pub fn xpby_f32(a: &mut [f32], b: &[f32], beta: f32) {
136    debug_assert_eq!(a.len(), b.len());
137    let n = a.len();
138    let chunks = n / SIMD_WIDTH;
139    let remainder = n % SIMD_WIDTH;
140
141    let vbeta = f32x4::splat(beta);
142
143    for i in 0..chunks {
144        let idx = i * SIMD_WIDTH;
145        let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
146        let vb = f32x4::from(&b[idx..idx + SIMD_WIDTH]);
147        let result = vb + vbeta * va;
148        a[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
149    }
150
151    let start = chunks * SIMD_WIDTH;
152    for i in 0..remainder {
153        a[start + i] = b[start + i] + beta * a[start + i];
154    }
155}
156
157#[cfg(not(feature = "simd"))]
158#[inline]
159pub fn xpby_f32(a: &mut [f32], b: &[f32], beta: f32) {
160    debug_assert_eq!(a.len(), b.len());
161    for i in 0..a.len() {
162        a[i] = b[i] + beta * a[i];
163    }
164}
165
166// ============================================================================
167// Element-wise Operations for MEDI
168// ============================================================================
169
170/// Apply per-direction gradient weights: out[i] = mx[i] * p[i] * mx[i] * gx[i]
171/// This is the core operation in MEDI's regularization term
172#[cfg(feature = "simd")]
173#[inline]
174pub fn apply_gradient_weights_f32(
175    out_x: &mut [f32], out_y: &mut [f32], out_z: &mut [f32],
176    mx: &[f32], my: &[f32], mz: &[f32],
177    p: &[f32],
178    gx: &[f32], gy: &[f32], gz: &[f32],
179) {
180    let n = out_x.len();
181    debug_assert!(out_y.len() == n && out_z.len() == n);
182    debug_assert!(mx.len() == n && my.len() == n && mz.len() == n);
183    debug_assert!(p.len() == n && gx.len() == n && gy.len() == n && gz.len() == n);
184
185    let chunks = n / SIMD_WIDTH;
186    let remainder = n % SIMD_WIDTH;
187
188    for i in 0..chunks {
189        let idx = i * SIMD_WIDTH;
190
191        let vmx = f32x4::from(&mx[idx..idx + SIMD_WIDTH]);
192        let vmy = f32x4::from(&my[idx..idx + SIMD_WIDTH]);
193        let vmz = f32x4::from(&mz[idx..idx + SIMD_WIDTH]);
194        let vp = f32x4::from(&p[idx..idx + SIMD_WIDTH]);
195        let vgx = f32x4::from(&gx[idx..idx + SIMD_WIDTH]);
196        let vgy = f32x4::from(&gy[idx..idx + SIMD_WIDTH]);
197        let vgz = f32x4::from(&gz[idx..idx + SIMD_WIDTH]);
198
199        // out = m * p * m * g = m^2 * p * g
200        let rx = vmx * vp * vmx * vgx;
201        let ry = vmy * vp * vmy * vgy;
202        let rz = vmz * vp * vmz * vgz;
203
204        out_x[idx..idx + SIMD_WIDTH].copy_from_slice(rx.as_array_ref());
205        out_y[idx..idx + SIMD_WIDTH].copy_from_slice(ry.as_array_ref());
206        out_z[idx..idx + SIMD_WIDTH].copy_from_slice(rz.as_array_ref());
207    }
208
209    let start = chunks * SIMD_WIDTH;
210    for i in 0..remainder {
211        let idx = start + i;
212        out_x[idx] = mx[idx] * p[idx] * mx[idx] * gx[idx];
213        out_y[idx] = my[idx] * p[idx] * my[idx] * gy[idx];
214        out_z[idx] = mz[idx] * p[idx] * mz[idx] * gz[idx];
215    }
216}
217
218#[cfg(not(feature = "simd"))]
219#[inline]
220pub fn apply_gradient_weights_f32(
221    out_x: &mut [f32], out_y: &mut [f32], out_z: &mut [f32],
222    mx: &[f32], my: &[f32], mz: &[f32],
223    p: &[f32],
224    gx: &[f32], gy: &[f32], gz: &[f32],
225) {
226    let n = out_x.len();
227    for i in 0..n {
228        out_x[i] = mx[i] * p[i] * mx[i] * gx[i];
229        out_y[i] = my[i] * p[i] * my[i] * gy[i];
230        out_z[i] = mz[i] * p[i] * mz[i] * gz[i];
231    }
232}
233
234/// Compute P = 1 / sqrt(ux^2 + uy^2 + uz^2 + beta)
235/// where ux = mx * gx, uy = my * gy, uz = mz * gz
236#[cfg(feature = "simd")]
237#[inline]
238pub fn compute_p_weights_f32(
239    p: &mut [f32],
240    mx: &[f32], my: &[f32], mz: &[f32],
241    gx: &[f32], gy: &[f32], gz: &[f32],
242    beta: f32,
243) {
244    let n = p.len();
245    let chunks = n / SIMD_WIDTH;
246    let remainder = n % SIMD_WIDTH;
247
248    let vbeta = f32x4::splat(beta);
249
250    for i in 0..chunks {
251        let idx = i * SIMD_WIDTH;
252
253        let vmx = f32x4::from(&mx[idx..idx + SIMD_WIDTH]);
254        let vmy = f32x4::from(&my[idx..idx + SIMD_WIDTH]);
255        let vmz = f32x4::from(&mz[idx..idx + SIMD_WIDTH]);
256        let vgx = f32x4::from(&gx[idx..idx + SIMD_WIDTH]);
257        let vgy = f32x4::from(&gy[idx..idx + SIMD_WIDTH]);
258        let vgz = f32x4::from(&gz[idx..idx + SIMD_WIDTH]);
259
260        let ux = vmx * vgx;
261        let uy = vmy * vgy;
262        let uz = vmz * vgz;
263
264        let norm_sq = ux * ux + uy * uy + uz * uz + vbeta;
265        let result = norm_sq.sqrt().recip();
266
267        p[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
268    }
269
270    let start = chunks * SIMD_WIDTH;
271    for i in 0..remainder {
272        let idx = start + i;
273        let ux = mx[idx] * gx[idx];
274        let uy = my[idx] * gy[idx];
275        let uz = mz[idx] * gz[idx];
276        p[idx] = 1.0 / (ux * ux + uy * uy + uz * uz + beta).sqrt();
277    }
278}
279
280#[cfg(not(feature = "simd"))]
281#[inline]
282pub fn compute_p_weights_f32(
283    p: &mut [f32],
284    mx: &[f32], my: &[f32], mz: &[f32],
285    gx: &[f32], gy: &[f32], gz: &[f32],
286    beta: f32,
287) {
288    let n = p.len();
289    for i in 0..n {
290        let ux = mx[i] * gx[i];
291        let uy = my[i] * gy[i];
292        let uz = mz[i] * gz[i];
293        p[i] = 1.0 / (ux * ux + uy * uy + uz * uz + beta).sqrt();
294    }
295}
296
297/// Combine regularization and data terms: out[i] = lambda * reg[i] + data[i]
298/// Matches MATLAB MEDI: y = D + R where D is data term, R = lambda * reg term
299#[cfg(feature = "simd")]
300#[inline]
301pub fn combine_terms_f32(out: &mut [f32], reg: &[f32], data: &[f32], lambda: f32) {
302    let n = out.len();
303    let chunks = n / SIMD_WIDTH;
304    let remainder = n % SIMD_WIDTH;
305
306    let vlambda = f32x4::splat(lambda);
307
308    for i in 0..chunks {
309        let idx = i * SIMD_WIDTH;
310        let vreg = f32x4::from(&reg[idx..idx + SIMD_WIDTH]);
311        let vdata = f32x4::from(&data[idx..idx + SIMD_WIDTH]);
312        let result = vlambda * vreg + vdata;
313        out[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
314    }
315
316    let start = chunks * SIMD_WIDTH;
317    for i in 0..remainder {
318        out[start + i] = lambda * reg[start + i] + data[start + i];
319    }
320}
321
322#[cfg(not(feature = "simd"))]
323#[inline]
324pub fn combine_terms_f32(out: &mut [f32], reg: &[f32], data: &[f32], lambda: f32) {
325    for i in 0..out.len() {
326        out[i] = lambda * reg[i] + data[i];
327    }
328}
329
330/// Negate array in place: a[i] = -a[i]
331#[cfg(feature = "simd")]
332#[inline]
333pub fn negate_f32(a: &mut [f32]) {
334    let n = a.len();
335    let chunks = n / SIMD_WIDTH;
336    let remainder = n % SIMD_WIDTH;
337
338    for i in 0..chunks {
339        let idx = i * SIMD_WIDTH;
340        let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
341        let result = -va;
342        a[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
343    }
344
345    let start = chunks * SIMD_WIDTH;
346    for i in 0..remainder {
347        a[start + i] = -a[start + i];
348    }
349}
350
351#[cfg(not(feature = "simd"))]
352#[inline]
353pub fn negate_f32(a: &mut [f32]) {
354    for val in a.iter_mut() {
355        *val = -*val;
356    }
357}
358
359// ============================================================================
360// Scale Operations
361// ============================================================================
362
363/// Scale array in place: a[i] = alpha * a[i]
364#[cfg(feature = "simd")]
365#[inline]
366pub fn scale_f32(a: &mut [f32], alpha: f32) {
367    let n = a.len();
368    let chunks = n / SIMD_WIDTH;
369    let remainder = n % SIMD_WIDTH;
370
371    let valpha = f32x4::splat(alpha);
372
373    for i in 0..chunks {
374        let idx = i * SIMD_WIDTH;
375        let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
376        let result = valpha * va;
377        a[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
378    }
379
380    let start = chunks * SIMD_WIDTH;
381    for i in 0..remainder {
382        a[start + i] *= alpha;
383    }
384}
385
386#[cfg(not(feature = "simd"))]
387#[inline]
388pub fn scale_f32(a: &mut [f32], alpha: f32) {
389    for val in a.iter_mut() {
390        *val *= alpha;
391    }
392}
393
394/// Scale array in place: a[i] = alpha * a[i] (f64 version)
395#[inline]
396pub fn scale_f64(a: &mut [f64], alpha: f64) {
397    for val in a.iter_mut() {
398        *val *= alpha;
399    }
400}
401
402// ============================================================================
403// Subtract Operations
404// ============================================================================
405
406/// Subtract arrays element-wise: out[i] = a[i] - b[i]
407#[cfg(feature = "simd")]
408#[inline]
409pub fn subtract_f32(out: &mut [f32], a: &[f32], b: &[f32]) {
410    debug_assert_eq!(a.len(), b.len());
411    debug_assert_eq!(out.len(), a.len());
412    let n = a.len();
413    let chunks = n / SIMD_WIDTH;
414    let remainder = n % SIMD_WIDTH;
415
416    for i in 0..chunks {
417        let idx = i * SIMD_WIDTH;
418        let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
419        let vb = f32x4::from(&b[idx..idx + SIMD_WIDTH]);
420        let result = va - vb;
421        out[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
422    }
423
424    let start = chunks * SIMD_WIDTH;
425    for i in 0..remainder {
426        out[start + i] = a[start + i] - b[start + i];
427    }
428}
429
430#[cfg(not(feature = "simd"))]
431#[inline]
432pub fn subtract_f32(out: &mut [f32], a: &[f32], b: &[f32]) {
433    debug_assert_eq!(a.len(), b.len());
434    debug_assert_eq!(out.len(), a.len());
435    for i in 0..out.len() {
436        out[i] = a[i] - b[i];
437    }
438}
439
440/// Subtract arrays element-wise: out[i] = a[i] - b[i] (f64 version)
441#[inline]
442pub fn subtract_f64(out: &mut [f64], a: &[f64], b: &[f64]) {
443    debug_assert_eq!(a.len(), b.len());
444    debug_assert_eq!(out.len(), a.len());
445    for i in 0..out.len() {
446        out[i] = a[i] - b[i];
447    }
448}
449
450// ============================================================================
451// F64 Versions of Core Operations
452// ============================================================================
453
454/// Compute dot product: sum(a[i] * b[i]) (f64 version)
455#[inline]
456pub fn dot_product_f64(a: &[f64], b: &[f64]) -> f64 {
457    debug_assert_eq!(a.len(), b.len());
458    a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
459}
460
461/// Compute squared norm: sum(a[i]^2) (f64 version)
462#[inline]
463pub fn norm_squared_f64(a: &[f64]) -> f64 {
464    a.iter().map(|&ai| ai * ai).sum()
465}
466
467/// Compute a[i] = a[i] + alpha * b[i] (axpy operation, f64 version)
468#[inline]
469pub fn axpy_f64(a: &mut [f64], alpha: f64, b: &[f64]) {
470    debug_assert_eq!(a.len(), b.len());
471    for i in 0..a.len() {
472        a[i] += alpha * b[i];
473    }
474}
475
476/// Negate array in place: a[i] = -a[i] (f64 version)
477#[inline]
478pub fn negate_f64(a: &mut [f64]) {
479    for val in a.iter_mut() {
480        *val = -*val;
481    }
482}
483
484/// Apply per-direction gradient weights (f64 version)
485#[inline]
486pub fn apply_gradient_weights_f64(
487    out_x: &mut [f64], out_y: &mut [f64], out_z: &mut [f64],
488    mx: &[f64], my: &[f64], mz: &[f64],
489    p: &[f64],
490    gx: &[f64], gy: &[f64], gz: &[f64],
491) {
492    let n = out_x.len();
493    for i in 0..n {
494        out_x[i] = mx[i] * p[i] * mx[i] * gx[i];
495        out_y[i] = my[i] * p[i] * my[i] * gy[i];
496        out_z[i] = mz[i] * p[i] * mz[i] * gz[i];
497    }
498}
499
500// ============================================================================
501// Tests
502// ============================================================================
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    #[test]
509    fn test_dot_product() {
510        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
511        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
512
513        let result = dot_product_f32(&a, &b);
514        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
515
516        assert!((result - expected).abs() < 1e-6);
517    }
518
519    #[test]
520    fn test_norm_squared() {
521        let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
522
523        let result = norm_squared_f32(&a);
524        let expected: f32 = a.iter().map(|x| x * x).sum();
525
526        assert!((result - expected).abs() < 1e-6);
527    }
528
529    #[test]
530    fn test_axpy() {
531        let mut a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
532        let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
533        let alpha = 0.5f32;
534
535        let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x + alpha * y).collect();
536        axpy_f32(&mut a, alpha, &b);
537
538        for (r, e) in a.iter().zip(expected.iter()) {
539            assert!((r - e).abs() < 1e-6);
540        }
541    }
542
543    #[test]
544    fn test_negate() {
545        let mut a = vec![1.0f32, -2.0, 3.0, -4.0, 5.0];
546        let expected = vec![-1.0f32, 2.0, -3.0, 4.0, -5.0];
547
548        negate_f32(&mut a);
549
550        for (r, e) in a.iter().zip(expected.iter()) {
551            assert!((r - e).abs() < 1e-6);
552        }
553    }
554
555    #[test]
556    fn test_gradient_weights() {
557        let n = 5;
558        let mx = vec![1.0f32; n];
559        let my = vec![0.5f32; n];
560        let mz = vec![0.8f32; n];
561        let p = vec![2.0f32; n];
562        let gx = vec![0.1f32; n];
563        let gy = vec![0.2f32; n];
564        let gz = vec![0.3f32; n];
565
566        let mut out_x = vec![0.0f32; n];
567        let mut out_y = vec![0.0f32; n];
568        let mut out_z = vec![0.0f32; n];
569
570        apply_gradient_weights_f32(&mut out_x, &mut out_y, &mut out_z, &mx, &my, &mz, &p, &gx, &gy, &gz);
571
572        // Check first element manually
573        let ex = 1.0 * 2.0 * 1.0 * 0.1;  // mx * p * mx * gx
574        let ey = 0.5 * 2.0 * 0.5 * 0.2;  // my * p * my * gy
575        let ez = 0.8 * 2.0 * 0.8 * 0.3;  // mz * p * mz * gz
576
577        assert!((out_x[0] - ex).abs() < 1e-6);
578        assert!((out_y[0] - ey).abs() < 1e-6);
579        assert!((out_z[0] - ez).abs() < 1e-6);
580    }
581
582    // ====================================================================
583    // F64 versions and new operations
584    // ====================================================================
585
586    #[test]
587    fn test_dot_product_f64() {
588        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
589        let b = vec![2.0f64, 3.0, 4.0, 5.0, 6.0];
590
591        let result = dot_product_f64(&a, &b);
592        let expected: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
593
594        assert!((result - expected).abs() < 1e-12);
595    }
596
597    #[test]
598    fn test_norm_squared_f64() {
599        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
600
601        let result = norm_squared_f64(&a);
602        let expected: f64 = a.iter().map(|x| x * x).sum();
603
604        assert!((result - expected).abs() < 1e-12);
605    }
606
607    #[test]
608    fn test_axpy_f64() {
609        let mut a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
610        let b = vec![2.0f64, 3.0, 4.0, 5.0, 6.0];
611        let alpha = 0.5f64;
612
613        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(x, y)| x + alpha * y).collect();
614        axpy_f64(&mut a, alpha, &b);
615
616        for (r, e) in a.iter().zip(expected.iter()) {
617            assert!((r - e).abs() < 1e-12);
618        }
619    }
620
621    #[test]
622    fn test_negate_f64() {
623        let mut a = vec![1.0f64, -2.0, 3.0, -4.0, 5.0];
624        let expected = vec![-1.0f64, 2.0, -3.0, 4.0, -5.0];
625
626        negate_f64(&mut a);
627
628        for (r, e) in a.iter().zip(expected.iter()) {
629            assert!((r - e).abs() < 1e-12);
630        }
631    }
632
633    #[test]
634    fn test_scale_f32() {
635        let mut a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
636        let alpha = 2.5f32;
637
638        let expected: Vec<f32> = a.iter().map(|x| x * alpha).collect();
639        scale_f32(&mut a, alpha);
640
641        for (r, e) in a.iter().zip(expected.iter()) {
642            assert!((r - e).abs() < 1e-6, "scale_f32 mismatch: got {}, expected {}", r, e);
643        }
644    }
645
646    #[test]
647    fn test_subtract_f32() {
648        let a = vec![5.0f32, 4.0, 3.0, 2.0, 1.0, 0.5, 0.25];
649        let b = vec![1.0f32, 1.5, 2.0, 2.5, 3.0, 0.1, 0.05];
650        let mut out = vec![0.0f32; 7];
651
652        let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
653        subtract_f32(&mut out, &a, &b);
654
655        for (r, e) in out.iter().zip(expected.iter()) {
656            assert!((r - e).abs() < 1e-6, "subtract_f32 mismatch: got {}, expected {}", r, e);
657        }
658    }
659
660    #[test]
661    fn test_scale_f64() {
662        let mut a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
663        let alpha = 3.14f64;
664
665        let expected: Vec<f64> = a.iter().map(|x| x * alpha).collect();
666        scale_f64(&mut a, alpha);
667
668        for (r, e) in a.iter().zip(expected.iter()) {
669            assert!((r - e).abs() < 1e-12, "scale_f64 mismatch: got {}, expected {}", r, e);
670        }
671    }
672
673    #[test]
674    fn test_subtract_f64() {
675        let a = vec![10.0f64, 20.0, 30.0, 40.0, 50.0];
676        let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
677        let mut out = vec![0.0f64; 5];
678
679        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
680        subtract_f64(&mut out, &a, &b);
681
682        for (r, e) in out.iter().zip(expected.iter()) {
683            assert!((r - e).abs() < 1e-12, "subtract_f64 mismatch: got {}, expected {}", r, e);
684        }
685    }
686
687    #[test]
688    fn test_gradient_weights_f64() {
689        let n = 5;
690        let mx = vec![1.0f64; n];
691        let my = vec![0.5f64; n];
692        let mz = vec![0.8f64; n];
693        let p = vec![2.0f64; n];
694        let gx = vec![0.1f64; n];
695        let gy = vec![0.2f64; n];
696        let gz = vec![0.3f64; n];
697
698        let mut out_x = vec![0.0f64; n];
699        let mut out_y = vec![0.0f64; n];
700        let mut out_z = vec![0.0f64; n];
701
702        apply_gradient_weights_f64(&mut out_x, &mut out_y, &mut out_z, &mx, &my, &mz, &p, &gx, &gy, &gz);
703
704        let ex = 1.0 * 2.0 * 1.0 * 0.1;
705        let ey = 0.5 * 2.0 * 0.5 * 0.2;
706        let ez = 0.8 * 2.0 * 0.8 * 0.3;
707
708        for i in 0..n {
709            assert!((out_x[i] - ex).abs() < 1e-12);
710            assert!((out_y[i] - ey).abs() < 1e-12);
711            assert!((out_z[i] - ez).abs() < 1e-12);
712        }
713    }
714
715    // ====================================================================
716    // Large-array f32 tests for SIMD coverage
717    // ====================================================================
718
719    #[test]
720    fn test_dot_product_f32_large() {
721        // Use 128+ elements so SIMD loop executes multiple iterations
722        let n = 256;
723        let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.01).collect();
724        let b: Vec<f32> = (0..n).map(|i| 1.0 - (i as f32) * 0.005).collect();
725
726        let result = dot_product_f32(&a, &b);
727        let expected: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
728
729        assert!(
730            (result - expected).abs() < 1e-2,
731            "dot_product_f32 large: got {}, expected {}",
732            result, expected
733        );
734    }
735
736    #[test]
737    fn test_dot_product_f32_large_exact_multiple() {
738        // Array size that is an exact multiple of SIMD_WIDTH (no remainder)
739        let n = 128;
740        let a = vec![2.0f32; n];
741        let b = vec![3.0f32; n];
742
743        let result = dot_product_f32(&a, &b);
744        let expected = 2.0 * 3.0 * n as f32;
745
746        assert!(
747            (result - expected).abs() < 1e-3,
748            "dot_product_f32 exact multiple: got {}, expected {}",
749            result, expected
750        );
751    }
752
753    #[test]
754    fn test_norm_squared_f32_large() {
755        let n = 256;
756        let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
757
758        let result = norm_squared_f32(&a);
759        let expected: f32 = a.iter().map(|&x| x * x).sum();
760
761        assert!(
762            (result - expected).abs() < 1.0,
763            "norm_squared_f32 large: got {}, expected {}",
764            result, expected
765        );
766    }
767
768    #[test]
769    fn test_norm_squared_f32_large_exact_multiple() {
770        let n = 128;
771        let a = vec![0.5f32; n];
772
773        let result = norm_squared_f32(&a);
774        let expected = 0.25 * n as f32;
775
776        assert!(
777            (result - expected).abs() < 1e-3,
778            "norm_squared_f32 exact: got {}, expected {}",
779            result, expected
780        );
781    }
782
783    #[test]
784    fn test_axpy_f32_large() {
785        let n = 256;
786        let mut a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
787        let b: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.05).collect();
788        let alpha = 2.5f32;
789
790        let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(&x, &y)| x + alpha * y).collect();
791        axpy_f32(&mut a, alpha, &b);
792
793        for i in 0..n {
794            assert!(
795                (a[i] - expected[i]).abs() < 1e-3,
796                "axpy_f32 large mismatch at {}: got {}, expected {}",
797                i, a[i], expected[i]
798            );
799        }
800    }
801
802    #[test]
803    fn test_axpy_f32_large_exact_multiple() {
804        let n = 128;
805        let mut a = vec![1.0f32; n];
806        let b = vec![2.0f32; n];
807        let alpha = 3.0f32;
808
809        axpy_f32(&mut a, alpha, &b);
810
811        for &val in &a {
812            assert!(
813                (val - 7.0).abs() < 1e-6, // 1 + 3*2 = 7
814                "axpy exact: got {}, expected 7.0",
815                val
816            );
817        }
818    }
819
820    #[test]
821    fn test_negate_f32_large() {
822        let n = 256;
823        let original: Vec<f32> = (0..n).map(|i| (i as f32) - 128.0).collect();
824        let mut a = original.clone();
825
826        negate_f32(&mut a);
827
828        for i in 0..n {
829            assert!(
830                (a[i] + original[i]).abs() < 1e-6,
831                "negate_f32 large: a[{}] = {}, expected {}",
832                i, a[i], -original[i]
833            );
834        }
835    }
836
837    #[test]
838    fn test_negate_f32_large_exact_multiple() {
839        let n = 128;
840        let mut a: Vec<f32> = (0..n).map(|i| i as f32).collect();
841        let expected: Vec<f32> = a.iter().map(|&x| -x).collect();
842
843        negate_f32(&mut a);
844
845        for i in 0..n {
846            assert!(
847                (a[i] - expected[i]).abs() < 1e-6,
848                "negate_f32 exact: a[{}] = {}, expected {}",
849                i, a[i], expected[i]
850            );
851        }
852    }
853
854    #[test]
855    fn test_apply_gradient_weights_f32_large() {
856        let n = 256;
857        let mx: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.001).collect();
858        let my: Vec<f32> = (0..n).map(|i| 0.3 + (i as f32) * 0.002).collect();
859        let mz: Vec<f32> = (0..n).map(|i| 0.8 - (i as f32) * 0.001).collect();
860        let p: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.01).collect();
861        let gx: Vec<f32> = (0..n).map(|i| (i as f32 * 0.1).sin()).collect();
862        let gy: Vec<f32> = (0..n).map(|i| (i as f32 * 0.2).cos()).collect();
863        let gz: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
864
865        let mut out_x = vec![0.0f32; n];
866        let mut out_y = vec![0.0f32; n];
867        let mut out_z = vec![0.0f32; n];
868
869        apply_gradient_weights_f32(
870            &mut out_x, &mut out_y, &mut out_z,
871            &mx, &my, &mz, &p, &gx, &gy, &gz,
872        );
873
874        // Verify against scalar computation
875        for i in 0..n {
876            let ex = mx[i] * p[i] * mx[i] * gx[i];
877            let ey = my[i] * p[i] * my[i] * gy[i];
878            let ez = mz[i] * p[i] * mz[i] * gz[i];
879
880            assert!(
881                (out_x[i] - ex).abs() < 1e-4,
882                "gradient_weights_f32 x[{}]: got {}, expected {}",
883                i, out_x[i], ex
884            );
885            assert!(
886                (out_y[i] - ey).abs() < 1e-4,
887                "gradient_weights_f32 y[{}]: got {}, expected {}",
888                i, out_y[i], ey
889            );
890            assert!(
891                (out_z[i] - ez).abs() < 1e-4,
892                "gradient_weights_f32 z[{}]: got {}, expected {}",
893                i, out_z[i], ez
894            );
895        }
896    }
897
898    #[test]
899    fn test_scale_f32_large() {
900        let n = 256;
901        let mut a: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
902        let alpha = 3.14f32;
903
904        let expected: Vec<f32> = a.iter().map(|&x| x * alpha).collect();
905        scale_f32(&mut a, alpha);
906
907        for i in 0..n {
908            assert!(
909                (a[i] - expected[i]).abs() < 1e-3,
910                "scale_f32 large: a[{}] = {}, expected {}",
911                i, a[i], expected[i]
912            );
913        }
914    }
915
916    #[test]
917    fn test_subtract_f32_large() {
918        let n = 256;
919        let a: Vec<f32> = (0..n).map(|i| (i as f32) * 2.0).collect();
920        let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
921        let mut out = vec![0.0f32; n];
922
923        let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect();
924        subtract_f32(&mut out, &a, &b);
925
926        for i in 0..n {
927            assert!(
928                (out[i] - expected[i]).abs() < 1e-3,
929                "subtract_f32 large: out[{}] = {}, expected {}",
930                i, out[i], expected[i]
931            );
932        }
933    }
934
935    #[test]
936    fn test_xpby_f32_large() {
937        let n = 256;
938        let mut a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
939        let b: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.05).collect();
940        let beta = 0.7f32;
941
942        let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(&ai, &bi)| bi + beta * ai).collect();
943        xpby_f32(&mut a, &b, beta);
944
945        for i in 0..n {
946            assert!(
947                (a[i] - expected[i]).abs() < 1e-3,
948                "xpby_f32 large: a[{}] = {}, expected {}",
949                i, a[i], expected[i]
950            );
951        }
952    }
953
954    #[test]
955    fn test_combine_terms_f32_large() {
956        let n = 256;
957        let reg: Vec<f32> = (0..n).map(|i| (i as f32) * 0.01).collect();
958        let data: Vec<f32> = (0..n).map(|i| 1.0 - (i as f32) * 0.005).collect();
959        let mut out = vec![0.0f32; n];
960        let lambda = 2.5f32;
961
962        let expected: Vec<f32> = reg.iter().zip(data.iter())
963            .map(|(&r, &d)| lambda * r + d).collect();
964        combine_terms_f32(&mut out, &reg, &data, lambda);
965
966        for i in 0..n {
967            assert!(
968                (out[i] - expected[i]).abs() < 1e-3,
969                "combine_terms_f32 large: out[{}] = {}, expected {}",
970                i, out[i], expected[i]
971            );
972        }
973    }
974
975    #[test]
976    fn test_compute_p_weights_f32_large() {
977        let n = 256;
978        let mx: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.001).collect();
979        let my: Vec<f32> = (0..n).map(|i| 0.3 + (i as f32) * 0.002).collect();
980        let mz: Vec<f32> = (0..n).map(|i| 0.8 - (i as f32) * 0.001).collect();
981        let gx: Vec<f32> = (0..n).map(|i| (i as f32 * 0.1).sin()).collect();
982        let gy: Vec<f32> = (0..n).map(|i| (i as f32 * 0.2).cos()).collect();
983        let gz: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
984        let beta = 1e-4f32;
985
986        let mut p = vec![0.0f32; n];
987        compute_p_weights_f32(&mut p, &mx, &my, &mz, &gx, &gy, &gz, beta);
988
989        // Verify against scalar computation
990        for i in 0..n {
991            let ux = mx[i] * gx[i];
992            let uy = my[i] * gy[i];
993            let uz = mz[i] * gz[i];
994            let expected = 1.0 / (ux * ux + uy * uy + uz * uz + beta).sqrt();
995
996            assert!(
997                (p[i] - expected).abs() < 1e-3,
998                "compute_p_weights large: p[{}] = {}, expected {}",
999                i, p[i], expected
1000            );
1001        }
1002    }
1003
1004    #[test]
1005    fn test_dot_product_f32_with_remainder() {
1006        // Array size that is NOT a multiple of SIMD_WIDTH to test remainder path
1007        let n = 131; // 131 = 32*4 + 3 remainder
1008        let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
1009        let b: Vec<f32> = (0..n).map(|i| 1.0 - (i as f32) * 0.005).collect();
1010
1011        let result = dot_product_f32(&a, &b);
1012        let expected: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
1013
1014        assert!(
1015            (result - expected).abs() < 1.0,
1016            "dot_product_f32 remainder: got {}, expected {}",
1017            result, expected
1018        );
1019    }
1020
1021    #[test]
1022    fn test_norm_squared_f32_with_remainder() {
1023        let n = 131;
1024        let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
1025
1026        let result = norm_squared_f32(&a);
1027        let expected: f32 = a.iter().map(|&x| x * x).sum();
1028
1029        assert!(
1030            (result - expected).abs() < 1.0,
1031            "norm_squared_f32 remainder: got {}, expected {}",
1032            result, expected
1033        );
1034    }
1035}