1pub fn cg_solve<F>(
19 a_op: F,
20 b: &[f64],
21 x0: &[f64],
22 tol: f64,
23 max_iter: usize,
24) -> Vec<f64>
25where
26 F: Fn(&[f64]) -> Vec<f64>,
27{
28 let n = b.len();
29 let mut x = x0.to_vec();
30
31 let ax = a_op(&x);
33 let mut r: Vec<f64> = b.iter().zip(ax.iter())
34 .map(|(&bi, &axi)| bi - axi)
35 .collect();
36
37 let mut p = r.clone();
38
39 let mut rsold: f64 = r.iter().map(|&ri| ri * ri).sum();
40 let b_norm: f64 = b.iter().map(|&bi| bi * bi).sum::<f64>().sqrt();
41
42 for _iter in 0..max_iter {
43 let ap = a_op(&p);
44
45 let pap: f64 = p.iter().zip(ap.iter())
46 .map(|(&pi, &api)| pi * api)
47 .sum();
48
49 if pap.abs() < 1e-20 {
50 break;
51 }
52
53 let alpha = rsold / pap;
54
55 for i in 0..n {
57 x[i] += alpha * p[i];
58 }
59
60 for i in 0..n {
62 r[i] -= alpha * ap[i];
63 }
64
65 let rsnew: f64 = r.iter().map(|&ri| ri * ri).sum();
66
67 if rsnew.sqrt() < tol * b_norm {
69 break;
70 }
71
72 let beta = rsnew / rsold;
73
74 for i in 0..n {
76 p[i] = r[i] + beta * p[i];
77 }
78
79 rsold = rsnew;
80 }
81
82 x
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn test_cg_identity() {
91 let b = vec![1.0, 2.0, 3.0];
93 let x0 = vec![0.0, 0.0, 0.0];
94
95 let x = cg_solve(|v| v.to_vec(), &b, &x0, 1e-10, 100);
96
97 for (xi, bi) in x.iter().zip(b.iter()) {
98 assert!((xi - bi).abs() < 1e-8, "x should equal b");
99 }
100 }
101
102 #[test]
103 fn test_cg_diagonal() {
104 let b = vec![2.0, 6.0, 12.0];
107 let x0 = vec![0.0, 0.0, 0.0];
108 let diag = vec![2.0, 3.0, 4.0];
109
110 let x = cg_solve(
111 |v| v.iter().zip(diag.iter()).map(|(&vi, &di)| vi * di).collect(),
112 &b, &x0, 1e-10, 100
113 );
114
115 let expected = vec![1.0, 2.0, 3.0];
116 for (xi, ei) in x.iter().zip(expected.iter()) {
117 assert!((xi - ei).abs() < 1e-8, "Expected {}, got {}", ei, xi);
118 }
119 }
120}