1#[cfg(feature = "simd")]
11use wide::f32x4;
12
13#[cfg(feature = "simd")]
15pub const SIMD_WIDTH: usize = 4;
16
17#[cfg(not(feature = "simd"))]
18pub const SIMD_WIDTH: usize = 1;
19
20#[cfg(feature = "simd")]
26#[inline]
27pub fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
28 debug_assert_eq!(a.len(), b.len());
29 let n = a.len();
30 let chunks = n / SIMD_WIDTH;
31 let remainder = n % SIMD_WIDTH;
32
33 let mut sum = f32x4::ZERO;
34
35 for i in 0..chunks {
37 let idx = i * SIMD_WIDTH;
38 let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
39 let vb = f32x4::from(&b[idx..idx + SIMD_WIDTH]);
40 sum += va * vb;
41 }
42
43 let mut result = sum.reduce_add();
45
46 let start = chunks * SIMD_WIDTH;
48 for i in 0..remainder {
49 result += a[start + i] * b[start + i];
50 }
51
52 result
53}
54
55#[cfg(not(feature = "simd"))]
56#[inline]
57pub fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
58 debug_assert_eq!(a.len(), b.len());
59 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
60}
61
62#[cfg(feature = "simd")]
64#[inline]
65pub fn norm_squared_f32(a: &[f32]) -> f32 {
66 let n = a.len();
67 let chunks = n / SIMD_WIDTH;
68 let remainder = n % SIMD_WIDTH;
69
70 let mut sum = f32x4::ZERO;
71
72 for i in 0..chunks {
73 let idx = i * SIMD_WIDTH;
74 let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
75 sum += va * va;
76 }
77
78 let mut result = sum.reduce_add();
79
80 let start = chunks * SIMD_WIDTH;
81 for i in 0..remainder {
82 result += a[start + i] * a[start + i];
83 }
84
85 result
86}
87
88#[cfg(not(feature = "simd"))]
89#[inline]
90pub fn norm_squared_f32(a: &[f32]) -> f32 {
91 a.iter().map(|&ai| ai * ai).sum()
92}
93
94#[cfg(feature = "simd")]
100#[inline]
101pub fn axpy_f32(a: &mut [f32], alpha: f32, b: &[f32]) {
102 debug_assert_eq!(a.len(), b.len());
103 let n = a.len();
104 let chunks = n / SIMD_WIDTH;
105 let remainder = n % SIMD_WIDTH;
106
107 let valpha = f32x4::splat(alpha);
108
109 for i in 0..chunks {
110 let idx = i * SIMD_WIDTH;
111 let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
112 let vb = f32x4::from(&b[idx..idx + SIMD_WIDTH]);
113 let result = va + valpha * vb;
114 a[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
115 }
116
117 let start = chunks * SIMD_WIDTH;
118 for i in 0..remainder {
119 a[start + i] += alpha * b[start + i];
120 }
121}
122
123#[cfg(not(feature = "simd"))]
124#[inline]
125pub fn axpy_f32(a: &mut [f32], alpha: f32, b: &[f32]) {
126 debug_assert_eq!(a.len(), b.len());
127 for i in 0..a.len() {
128 a[i] += alpha * b[i];
129 }
130}
131
132#[cfg(feature = "simd")]
134#[inline]
135pub fn xpby_f32(a: &mut [f32], b: &[f32], beta: f32) {
136 debug_assert_eq!(a.len(), b.len());
137 let n = a.len();
138 let chunks = n / SIMD_WIDTH;
139 let remainder = n % SIMD_WIDTH;
140
141 let vbeta = f32x4::splat(beta);
142
143 for i in 0..chunks {
144 let idx = i * SIMD_WIDTH;
145 let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
146 let vb = f32x4::from(&b[idx..idx + SIMD_WIDTH]);
147 let result = vb + vbeta * va;
148 a[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
149 }
150
151 let start = chunks * SIMD_WIDTH;
152 for i in 0..remainder {
153 a[start + i] = b[start + i] + beta * a[start + i];
154 }
155}
156
157#[cfg(not(feature = "simd"))]
158#[inline]
159pub fn xpby_f32(a: &mut [f32], b: &[f32], beta: f32) {
160 debug_assert_eq!(a.len(), b.len());
161 for i in 0..a.len() {
162 a[i] = b[i] + beta * a[i];
163 }
164}
165
166#[cfg(feature = "simd")]
173#[inline]
174pub fn apply_gradient_weights_f32(
175 out_x: &mut [f32], out_y: &mut [f32], out_z: &mut [f32],
176 mx: &[f32], my: &[f32], mz: &[f32],
177 p: &[f32],
178 gx: &[f32], gy: &[f32], gz: &[f32],
179) {
180 let n = out_x.len();
181 debug_assert!(out_y.len() == n && out_z.len() == n);
182 debug_assert!(mx.len() == n && my.len() == n && mz.len() == n);
183 debug_assert!(p.len() == n && gx.len() == n && gy.len() == n && gz.len() == n);
184
185 let chunks = n / SIMD_WIDTH;
186 let remainder = n % SIMD_WIDTH;
187
188 for i in 0..chunks {
189 let idx = i * SIMD_WIDTH;
190
191 let vmx = f32x4::from(&mx[idx..idx + SIMD_WIDTH]);
192 let vmy = f32x4::from(&my[idx..idx + SIMD_WIDTH]);
193 let vmz = f32x4::from(&mz[idx..idx + SIMD_WIDTH]);
194 let vp = f32x4::from(&p[idx..idx + SIMD_WIDTH]);
195 let vgx = f32x4::from(&gx[idx..idx + SIMD_WIDTH]);
196 let vgy = f32x4::from(&gy[idx..idx + SIMD_WIDTH]);
197 let vgz = f32x4::from(&gz[idx..idx + SIMD_WIDTH]);
198
199 let rx = vmx * vp * vmx * vgx;
201 let ry = vmy * vp * vmy * vgy;
202 let rz = vmz * vp * vmz * vgz;
203
204 out_x[idx..idx + SIMD_WIDTH].copy_from_slice(rx.as_array_ref());
205 out_y[idx..idx + SIMD_WIDTH].copy_from_slice(ry.as_array_ref());
206 out_z[idx..idx + SIMD_WIDTH].copy_from_slice(rz.as_array_ref());
207 }
208
209 let start = chunks * SIMD_WIDTH;
210 for i in 0..remainder {
211 let idx = start + i;
212 out_x[idx] = mx[idx] * p[idx] * mx[idx] * gx[idx];
213 out_y[idx] = my[idx] * p[idx] * my[idx] * gy[idx];
214 out_z[idx] = mz[idx] * p[idx] * mz[idx] * gz[idx];
215 }
216}
217
218#[cfg(not(feature = "simd"))]
219#[inline]
220pub fn apply_gradient_weights_f32(
221 out_x: &mut [f32], out_y: &mut [f32], out_z: &mut [f32],
222 mx: &[f32], my: &[f32], mz: &[f32],
223 p: &[f32],
224 gx: &[f32], gy: &[f32], gz: &[f32],
225) {
226 let n = out_x.len();
227 for i in 0..n {
228 out_x[i] = mx[i] * p[i] * mx[i] * gx[i];
229 out_y[i] = my[i] * p[i] * my[i] * gy[i];
230 out_z[i] = mz[i] * p[i] * mz[i] * gz[i];
231 }
232}
233
234#[cfg(feature = "simd")]
237#[inline]
238pub fn compute_p_weights_f32(
239 p: &mut [f32],
240 mx: &[f32], my: &[f32], mz: &[f32],
241 gx: &[f32], gy: &[f32], gz: &[f32],
242 beta: f32,
243) {
244 let n = p.len();
245 let chunks = n / SIMD_WIDTH;
246 let remainder = n % SIMD_WIDTH;
247
248 let vbeta = f32x4::splat(beta);
249
250 for i in 0..chunks {
251 let idx = i * SIMD_WIDTH;
252
253 let vmx = f32x4::from(&mx[idx..idx + SIMD_WIDTH]);
254 let vmy = f32x4::from(&my[idx..idx + SIMD_WIDTH]);
255 let vmz = f32x4::from(&mz[idx..idx + SIMD_WIDTH]);
256 let vgx = f32x4::from(&gx[idx..idx + SIMD_WIDTH]);
257 let vgy = f32x4::from(&gy[idx..idx + SIMD_WIDTH]);
258 let vgz = f32x4::from(&gz[idx..idx + SIMD_WIDTH]);
259
260 let ux = vmx * vgx;
261 let uy = vmy * vgy;
262 let uz = vmz * vgz;
263
264 let norm_sq = ux * ux + uy * uy + uz * uz + vbeta;
265 let result = norm_sq.sqrt().recip();
266
267 p[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
268 }
269
270 let start = chunks * SIMD_WIDTH;
271 for i in 0..remainder {
272 let idx = start + i;
273 let ux = mx[idx] * gx[idx];
274 let uy = my[idx] * gy[idx];
275 let uz = mz[idx] * gz[idx];
276 p[idx] = 1.0 / (ux * ux + uy * uy + uz * uz + beta).sqrt();
277 }
278}
279
280#[cfg(not(feature = "simd"))]
281#[inline]
282pub fn compute_p_weights_f32(
283 p: &mut [f32],
284 mx: &[f32], my: &[f32], mz: &[f32],
285 gx: &[f32], gy: &[f32], gz: &[f32],
286 beta: f32,
287) {
288 let n = p.len();
289 for i in 0..n {
290 let ux = mx[i] * gx[i];
291 let uy = my[i] * gy[i];
292 let uz = mz[i] * gz[i];
293 p[i] = 1.0 / (ux * ux + uy * uy + uz * uz + beta).sqrt();
294 }
295}
296
297#[cfg(feature = "simd")]
300#[inline]
301pub fn combine_terms_f32(out: &mut [f32], reg: &[f32], data: &[f32], lambda: f32) {
302 let n = out.len();
303 let chunks = n / SIMD_WIDTH;
304 let remainder = n % SIMD_WIDTH;
305
306 let vlambda = f32x4::splat(lambda);
307
308 for i in 0..chunks {
309 let idx = i * SIMD_WIDTH;
310 let vreg = f32x4::from(®[idx..idx + SIMD_WIDTH]);
311 let vdata = f32x4::from(&data[idx..idx + SIMD_WIDTH]);
312 let result = vlambda * vreg + vdata;
313 out[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
314 }
315
316 let start = chunks * SIMD_WIDTH;
317 for i in 0..remainder {
318 out[start + i] = lambda * reg[start + i] + data[start + i];
319 }
320}
321
322#[cfg(not(feature = "simd"))]
323#[inline]
324pub fn combine_terms_f32(out: &mut [f32], reg: &[f32], data: &[f32], lambda: f32) {
325 for i in 0..out.len() {
326 out[i] = lambda * reg[i] + data[i];
327 }
328}
329
330#[cfg(feature = "simd")]
332#[inline]
333pub fn negate_f32(a: &mut [f32]) {
334 let n = a.len();
335 let chunks = n / SIMD_WIDTH;
336 let remainder = n % SIMD_WIDTH;
337
338 for i in 0..chunks {
339 let idx = i * SIMD_WIDTH;
340 let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
341 let result = -va;
342 a[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
343 }
344
345 let start = chunks * SIMD_WIDTH;
346 for i in 0..remainder {
347 a[start + i] = -a[start + i];
348 }
349}
350
351#[cfg(not(feature = "simd"))]
352#[inline]
353pub fn negate_f32(a: &mut [f32]) {
354 for val in a.iter_mut() {
355 *val = -*val;
356 }
357}
358
359#[cfg(feature = "simd")]
365#[inline]
366pub fn scale_f32(a: &mut [f32], alpha: f32) {
367 let n = a.len();
368 let chunks = n / SIMD_WIDTH;
369 let remainder = n % SIMD_WIDTH;
370
371 let valpha = f32x4::splat(alpha);
372
373 for i in 0..chunks {
374 let idx = i * SIMD_WIDTH;
375 let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
376 let result = valpha * va;
377 a[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
378 }
379
380 let start = chunks * SIMD_WIDTH;
381 for i in 0..remainder {
382 a[start + i] *= alpha;
383 }
384}
385
386#[cfg(not(feature = "simd"))]
387#[inline]
388pub fn scale_f32(a: &mut [f32], alpha: f32) {
389 for val in a.iter_mut() {
390 *val *= alpha;
391 }
392}
393
394#[inline]
396pub fn scale_f64(a: &mut [f64], alpha: f64) {
397 for val in a.iter_mut() {
398 *val *= alpha;
399 }
400}
401
402#[cfg(feature = "simd")]
408#[inline]
409pub fn subtract_f32(out: &mut [f32], a: &[f32], b: &[f32]) {
410 debug_assert_eq!(a.len(), b.len());
411 debug_assert_eq!(out.len(), a.len());
412 let n = a.len();
413 let chunks = n / SIMD_WIDTH;
414 let remainder = n % SIMD_WIDTH;
415
416 for i in 0..chunks {
417 let idx = i * SIMD_WIDTH;
418 let va = f32x4::from(&a[idx..idx + SIMD_WIDTH]);
419 let vb = f32x4::from(&b[idx..idx + SIMD_WIDTH]);
420 let result = va - vb;
421 out[idx..idx + SIMD_WIDTH].copy_from_slice(result.as_array_ref());
422 }
423
424 let start = chunks * SIMD_WIDTH;
425 for i in 0..remainder {
426 out[start + i] = a[start + i] - b[start + i];
427 }
428}
429
430#[cfg(not(feature = "simd"))]
431#[inline]
432pub fn subtract_f32(out: &mut [f32], a: &[f32], b: &[f32]) {
433 debug_assert_eq!(a.len(), b.len());
434 debug_assert_eq!(out.len(), a.len());
435 for i in 0..out.len() {
436 out[i] = a[i] - b[i];
437 }
438}
439
440#[inline]
442pub fn subtract_f64(out: &mut [f64], a: &[f64], b: &[f64]) {
443 debug_assert_eq!(a.len(), b.len());
444 debug_assert_eq!(out.len(), a.len());
445 for i in 0..out.len() {
446 out[i] = a[i] - b[i];
447 }
448}
449
450#[inline]
456pub fn dot_product_f64(a: &[f64], b: &[f64]) -> f64 {
457 debug_assert_eq!(a.len(), b.len());
458 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
459}
460
461#[inline]
463pub fn norm_squared_f64(a: &[f64]) -> f64 {
464 a.iter().map(|&ai| ai * ai).sum()
465}
466
467#[inline]
469pub fn axpy_f64(a: &mut [f64], alpha: f64, b: &[f64]) {
470 debug_assert_eq!(a.len(), b.len());
471 for i in 0..a.len() {
472 a[i] += alpha * b[i];
473 }
474}
475
476#[inline]
478pub fn negate_f64(a: &mut [f64]) {
479 for val in a.iter_mut() {
480 *val = -*val;
481 }
482}
483
484#[inline]
486pub fn apply_gradient_weights_f64(
487 out_x: &mut [f64], out_y: &mut [f64], out_z: &mut [f64],
488 mx: &[f64], my: &[f64], mz: &[f64],
489 p: &[f64],
490 gx: &[f64], gy: &[f64], gz: &[f64],
491) {
492 let n = out_x.len();
493 for i in 0..n {
494 out_x[i] = mx[i] * p[i] * mx[i] * gx[i];
495 out_y[i] = my[i] * p[i] * my[i] * gy[i];
496 out_z[i] = mz[i] * p[i] * mz[i] * gz[i];
497 }
498}
499
500#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_dot_product() {
510 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
511 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
512
513 let result = dot_product_f32(&a, &b);
514 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
515
516 assert!((result - expected).abs() < 1e-6);
517 }
518
519 #[test]
520 fn test_norm_squared() {
521 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
522
523 let result = norm_squared_f32(&a);
524 let expected: f32 = a.iter().map(|x| x * x).sum();
525
526 assert!((result - expected).abs() < 1e-6);
527 }
528
529 #[test]
530 fn test_axpy() {
531 let mut a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
532 let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
533 let alpha = 0.5f32;
534
535 let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x + alpha * y).collect();
536 axpy_f32(&mut a, alpha, &b);
537
538 for (r, e) in a.iter().zip(expected.iter()) {
539 assert!((r - e).abs() < 1e-6);
540 }
541 }
542
543 #[test]
544 fn test_negate() {
545 let mut a = vec![1.0f32, -2.0, 3.0, -4.0, 5.0];
546 let expected = vec![-1.0f32, 2.0, -3.0, 4.0, -5.0];
547
548 negate_f32(&mut a);
549
550 for (r, e) in a.iter().zip(expected.iter()) {
551 assert!((r - e).abs() < 1e-6);
552 }
553 }
554
555 #[test]
556 fn test_gradient_weights() {
557 let n = 5;
558 let mx = vec![1.0f32; n];
559 let my = vec![0.5f32; n];
560 let mz = vec![0.8f32; n];
561 let p = vec![2.0f32; n];
562 let gx = vec![0.1f32; n];
563 let gy = vec![0.2f32; n];
564 let gz = vec![0.3f32; n];
565
566 let mut out_x = vec![0.0f32; n];
567 let mut out_y = vec![0.0f32; n];
568 let mut out_z = vec![0.0f32; n];
569
570 apply_gradient_weights_f32(&mut out_x, &mut out_y, &mut out_z, &mx, &my, &mz, &p, &gx, &gy, &gz);
571
572 let ex = 1.0 * 2.0 * 1.0 * 0.1; let ey = 0.5 * 2.0 * 0.5 * 0.2; let ez = 0.8 * 2.0 * 0.8 * 0.3; assert!((out_x[0] - ex).abs() < 1e-6);
578 assert!((out_y[0] - ey).abs() < 1e-6);
579 assert!((out_z[0] - ez).abs() < 1e-6);
580 }
581
582 #[test]
587 fn test_dot_product_f64() {
588 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
589 let b = vec![2.0f64, 3.0, 4.0, 5.0, 6.0];
590
591 let result = dot_product_f64(&a, &b);
592 let expected: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
593
594 assert!((result - expected).abs() < 1e-12);
595 }
596
597 #[test]
598 fn test_norm_squared_f64() {
599 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
600
601 let result = norm_squared_f64(&a);
602 let expected: f64 = a.iter().map(|x| x * x).sum();
603
604 assert!((result - expected).abs() < 1e-12);
605 }
606
607 #[test]
608 fn test_axpy_f64() {
609 let mut a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
610 let b = vec![2.0f64, 3.0, 4.0, 5.0, 6.0];
611 let alpha = 0.5f64;
612
613 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(x, y)| x + alpha * y).collect();
614 axpy_f64(&mut a, alpha, &b);
615
616 for (r, e) in a.iter().zip(expected.iter()) {
617 assert!((r - e).abs() < 1e-12);
618 }
619 }
620
621 #[test]
622 fn test_negate_f64() {
623 let mut a = vec![1.0f64, -2.0, 3.0, -4.0, 5.0];
624 let expected = vec![-1.0f64, 2.0, -3.0, 4.0, -5.0];
625
626 negate_f64(&mut a);
627
628 for (r, e) in a.iter().zip(expected.iter()) {
629 assert!((r - e).abs() < 1e-12);
630 }
631 }
632
633 #[test]
634 fn test_scale_f32() {
635 let mut a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
636 let alpha = 2.5f32;
637
638 let expected: Vec<f32> = a.iter().map(|x| x * alpha).collect();
639 scale_f32(&mut a, alpha);
640
641 for (r, e) in a.iter().zip(expected.iter()) {
642 assert!((r - e).abs() < 1e-6, "scale_f32 mismatch: got {}, expected {}", r, e);
643 }
644 }
645
646 #[test]
647 fn test_subtract_f32() {
648 let a = vec![5.0f32, 4.0, 3.0, 2.0, 1.0, 0.5, 0.25];
649 let b = vec![1.0f32, 1.5, 2.0, 2.5, 3.0, 0.1, 0.05];
650 let mut out = vec![0.0f32; 7];
651
652 let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
653 subtract_f32(&mut out, &a, &b);
654
655 for (r, e) in out.iter().zip(expected.iter()) {
656 assert!((r - e).abs() < 1e-6, "subtract_f32 mismatch: got {}, expected {}", r, e);
657 }
658 }
659
660 #[test]
661 fn test_scale_f64() {
662 let mut a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
663 let alpha = 3.14f64;
664
665 let expected: Vec<f64> = a.iter().map(|x| x * alpha).collect();
666 scale_f64(&mut a, alpha);
667
668 for (r, e) in a.iter().zip(expected.iter()) {
669 assert!((r - e).abs() < 1e-12, "scale_f64 mismatch: got {}, expected {}", r, e);
670 }
671 }
672
673 #[test]
674 fn test_subtract_f64() {
675 let a = vec![10.0f64, 20.0, 30.0, 40.0, 50.0];
676 let b = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
677 let mut out = vec![0.0f64; 5];
678
679 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
680 subtract_f64(&mut out, &a, &b);
681
682 for (r, e) in out.iter().zip(expected.iter()) {
683 assert!((r - e).abs() < 1e-12, "subtract_f64 mismatch: got {}, expected {}", r, e);
684 }
685 }
686
687 #[test]
688 fn test_gradient_weights_f64() {
689 let n = 5;
690 let mx = vec![1.0f64; n];
691 let my = vec![0.5f64; n];
692 let mz = vec![0.8f64; n];
693 let p = vec![2.0f64; n];
694 let gx = vec![0.1f64; n];
695 let gy = vec![0.2f64; n];
696 let gz = vec![0.3f64; n];
697
698 let mut out_x = vec![0.0f64; n];
699 let mut out_y = vec![0.0f64; n];
700 let mut out_z = vec![0.0f64; n];
701
702 apply_gradient_weights_f64(&mut out_x, &mut out_y, &mut out_z, &mx, &my, &mz, &p, &gx, &gy, &gz);
703
704 let ex = 1.0 * 2.0 * 1.0 * 0.1;
705 let ey = 0.5 * 2.0 * 0.5 * 0.2;
706 let ez = 0.8 * 2.0 * 0.8 * 0.3;
707
708 for i in 0..n {
709 assert!((out_x[i] - ex).abs() < 1e-12);
710 assert!((out_y[i] - ey).abs() < 1e-12);
711 assert!((out_z[i] - ez).abs() < 1e-12);
712 }
713 }
714
715 #[test]
720 fn test_dot_product_f32_large() {
721 let n = 256;
723 let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.01).collect();
724 let b: Vec<f32> = (0..n).map(|i| 1.0 - (i as f32) * 0.005).collect();
725
726 let result = dot_product_f32(&a, &b);
727 let expected: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
728
729 assert!(
730 (result - expected).abs() < 1e-2,
731 "dot_product_f32 large: got {}, expected {}",
732 result, expected
733 );
734 }
735
736 #[test]
737 fn test_dot_product_f32_large_exact_multiple() {
738 let n = 128;
740 let a = vec![2.0f32; n];
741 let b = vec![3.0f32; n];
742
743 let result = dot_product_f32(&a, &b);
744 let expected = 2.0 * 3.0 * n as f32;
745
746 assert!(
747 (result - expected).abs() < 1e-3,
748 "dot_product_f32 exact multiple: got {}, expected {}",
749 result, expected
750 );
751 }
752
753 #[test]
754 fn test_norm_squared_f32_large() {
755 let n = 256;
756 let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
757
758 let result = norm_squared_f32(&a);
759 let expected: f32 = a.iter().map(|&x| x * x).sum();
760
761 assert!(
762 (result - expected).abs() < 1.0,
763 "norm_squared_f32 large: got {}, expected {}",
764 result, expected
765 );
766 }
767
768 #[test]
769 fn test_norm_squared_f32_large_exact_multiple() {
770 let n = 128;
771 let a = vec![0.5f32; n];
772
773 let result = norm_squared_f32(&a);
774 let expected = 0.25 * n as f32;
775
776 assert!(
777 (result - expected).abs() < 1e-3,
778 "norm_squared_f32 exact: got {}, expected {}",
779 result, expected
780 );
781 }
782
783 #[test]
784 fn test_axpy_f32_large() {
785 let n = 256;
786 let mut a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
787 let b: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.05).collect();
788 let alpha = 2.5f32;
789
790 let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(&x, &y)| x + alpha * y).collect();
791 axpy_f32(&mut a, alpha, &b);
792
793 for i in 0..n {
794 assert!(
795 (a[i] - expected[i]).abs() < 1e-3,
796 "axpy_f32 large mismatch at {}: got {}, expected {}",
797 i, a[i], expected[i]
798 );
799 }
800 }
801
802 #[test]
803 fn test_axpy_f32_large_exact_multiple() {
804 let n = 128;
805 let mut a = vec![1.0f32; n];
806 let b = vec![2.0f32; n];
807 let alpha = 3.0f32;
808
809 axpy_f32(&mut a, alpha, &b);
810
811 for &val in &a {
812 assert!(
813 (val - 7.0).abs() < 1e-6, "axpy exact: got {}, expected 7.0",
815 val
816 );
817 }
818 }
819
820 #[test]
821 fn test_negate_f32_large() {
822 let n = 256;
823 let original: Vec<f32> = (0..n).map(|i| (i as f32) - 128.0).collect();
824 let mut a = original.clone();
825
826 negate_f32(&mut a);
827
828 for i in 0..n {
829 assert!(
830 (a[i] + original[i]).abs() < 1e-6,
831 "negate_f32 large: a[{}] = {}, expected {}",
832 i, a[i], -original[i]
833 );
834 }
835 }
836
837 #[test]
838 fn test_negate_f32_large_exact_multiple() {
839 let n = 128;
840 let mut a: Vec<f32> = (0..n).map(|i| i as f32).collect();
841 let expected: Vec<f32> = a.iter().map(|&x| -x).collect();
842
843 negate_f32(&mut a);
844
845 for i in 0..n {
846 assert!(
847 (a[i] - expected[i]).abs() < 1e-6,
848 "negate_f32 exact: a[{}] = {}, expected {}",
849 i, a[i], expected[i]
850 );
851 }
852 }
853
854 #[test]
855 fn test_apply_gradient_weights_f32_large() {
856 let n = 256;
857 let mx: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.001).collect();
858 let my: Vec<f32> = (0..n).map(|i| 0.3 + (i as f32) * 0.002).collect();
859 let mz: Vec<f32> = (0..n).map(|i| 0.8 - (i as f32) * 0.001).collect();
860 let p: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.01).collect();
861 let gx: Vec<f32> = (0..n).map(|i| (i as f32 * 0.1).sin()).collect();
862 let gy: Vec<f32> = (0..n).map(|i| (i as f32 * 0.2).cos()).collect();
863 let gz: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
864
865 let mut out_x = vec![0.0f32; n];
866 let mut out_y = vec![0.0f32; n];
867 let mut out_z = vec![0.0f32; n];
868
869 apply_gradient_weights_f32(
870 &mut out_x, &mut out_y, &mut out_z,
871 &mx, &my, &mz, &p, &gx, &gy, &gz,
872 );
873
874 for i in 0..n {
876 let ex = mx[i] * p[i] * mx[i] * gx[i];
877 let ey = my[i] * p[i] * my[i] * gy[i];
878 let ez = mz[i] * p[i] * mz[i] * gz[i];
879
880 assert!(
881 (out_x[i] - ex).abs() < 1e-4,
882 "gradient_weights_f32 x[{}]: got {}, expected {}",
883 i, out_x[i], ex
884 );
885 assert!(
886 (out_y[i] - ey).abs() < 1e-4,
887 "gradient_weights_f32 y[{}]: got {}, expected {}",
888 i, out_y[i], ey
889 );
890 assert!(
891 (out_z[i] - ez).abs() < 1e-4,
892 "gradient_weights_f32 z[{}]: got {}, expected {}",
893 i, out_z[i], ez
894 );
895 }
896 }
897
898 #[test]
899 fn test_scale_f32_large() {
900 let n = 256;
901 let mut a: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
902 let alpha = 3.14f32;
903
904 let expected: Vec<f32> = a.iter().map(|&x| x * alpha).collect();
905 scale_f32(&mut a, alpha);
906
907 for i in 0..n {
908 assert!(
909 (a[i] - expected[i]).abs() < 1e-3,
910 "scale_f32 large: a[{}] = {}, expected {}",
911 i, a[i], expected[i]
912 );
913 }
914 }
915
916 #[test]
917 fn test_subtract_f32_large() {
918 let n = 256;
919 let a: Vec<f32> = (0..n).map(|i| (i as f32) * 2.0).collect();
920 let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
921 let mut out = vec![0.0f32; n];
922
923 let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect();
924 subtract_f32(&mut out, &a, &b);
925
926 for i in 0..n {
927 assert!(
928 (out[i] - expected[i]).abs() < 1e-3,
929 "subtract_f32 large: out[{}] = {}, expected {}",
930 i, out[i], expected[i]
931 );
932 }
933 }
934
935 #[test]
936 fn test_xpby_f32_large() {
937 let n = 256;
938 let mut a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
939 let b: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.05).collect();
940 let beta = 0.7f32;
941
942 let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(&ai, &bi)| bi + beta * ai).collect();
943 xpby_f32(&mut a, &b, beta);
944
945 for i in 0..n {
946 assert!(
947 (a[i] - expected[i]).abs() < 1e-3,
948 "xpby_f32 large: a[{}] = {}, expected {}",
949 i, a[i], expected[i]
950 );
951 }
952 }
953
954 #[test]
955 fn test_combine_terms_f32_large() {
956 let n = 256;
957 let reg: Vec<f32> = (0..n).map(|i| (i as f32) * 0.01).collect();
958 let data: Vec<f32> = (0..n).map(|i| 1.0 - (i as f32) * 0.005).collect();
959 let mut out = vec![0.0f32; n];
960 let lambda = 2.5f32;
961
962 let expected: Vec<f32> = reg.iter().zip(data.iter())
963 .map(|(&r, &d)| lambda * r + d).collect();
964 combine_terms_f32(&mut out, ®, &data, lambda);
965
966 for i in 0..n {
967 assert!(
968 (out[i] - expected[i]).abs() < 1e-3,
969 "combine_terms_f32 large: out[{}] = {}, expected {}",
970 i, out[i], expected[i]
971 );
972 }
973 }
974
975 #[test]
976 fn test_compute_p_weights_f32_large() {
977 let n = 256;
978 let mx: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.001).collect();
979 let my: Vec<f32> = (0..n).map(|i| 0.3 + (i as f32) * 0.002).collect();
980 let mz: Vec<f32> = (0..n).map(|i| 0.8 - (i as f32) * 0.001).collect();
981 let gx: Vec<f32> = (0..n).map(|i| (i as f32 * 0.1).sin()).collect();
982 let gy: Vec<f32> = (0..n).map(|i| (i as f32 * 0.2).cos()).collect();
983 let gz: Vec<f32> = (0..n).map(|i| (i as f32 * 0.05).sin()).collect();
984 let beta = 1e-4f32;
985
986 let mut p = vec![0.0f32; n];
987 compute_p_weights_f32(&mut p, &mx, &my, &mz, &gx, &gy, &gz, beta);
988
989 for i in 0..n {
991 let ux = mx[i] * gx[i];
992 let uy = my[i] * gy[i];
993 let uz = mz[i] * gz[i];
994 let expected = 1.0 / (ux * ux + uy * uy + uz * uz + beta).sqrt();
995
996 assert!(
997 (p[i] - expected).abs() < 1e-3,
998 "compute_p_weights large: p[{}] = {}, expected {}",
999 i, p[i], expected
1000 );
1001 }
1002 }
1003
1004 #[test]
1005 fn test_dot_product_f32_with_remainder() {
1006 let n = 131; let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
1009 let b: Vec<f32> = (0..n).map(|i| 1.0 - (i as f32) * 0.005).collect();
1010
1011 let result = dot_product_f32(&a, &b);
1012 let expected: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
1013
1014 assert!(
1015 (result - expected).abs() < 1.0,
1016 "dot_product_f32 remainder: got {}, expected {}",
1017 result, expected
1018 );
1019 }
1020
1021 #[test]
1022 fn test_norm_squared_f32_with_remainder() {
1023 let n = 131;
1024 let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
1025
1026 let result = norm_squared_f32(&a);
1027 let expected: f32 = a.iter().map(|&x| x * x).sum();
1028
1029 assert!(
1030 (result - expected).abs() < 1.0,
1031 "norm_squared_f32 remainder: got {}, expected {}",
1032 result, expected
1033 );
1034 }
1035}