Skip to main content

qsm_core/solvers/
cg.rs

1//! Conjugate Gradient solver
2//!
3//! Solves Ax = b for symmetric positive definite A.
4
5/// Conjugate gradient solver
6///
7/// Solves Ax = b where A is a linear operator represented by a closure.
8///
9/// # Arguments
10/// * `a_op` - Closure that computes A*x
11/// * `b` - Right-hand side vector
12/// * `x0` - Initial guess
13/// * `tol` - Convergence tolerance
14/// * `max_iter` - Maximum iterations
15///
16/// # Returns
17/// Solution vector x
18pub fn cg_solve<F>(
19    a_op: F,
20    b: &[f64],
21    x0: &[f64],
22    tol: f64,
23    max_iter: usize,
24) -> Vec<f64>
25where
26    F: Fn(&[f64]) -> Vec<f64>,
27{
28    let n = b.len();
29    let mut x = x0.to_vec();
30
31    // r = b - A*x
32    let ax = a_op(&x);
33    let mut r: Vec<f64> = b.iter().zip(ax.iter())
34        .map(|(&bi, &axi)| bi - axi)
35        .collect();
36
37    let mut p = r.clone();
38
39    let mut rsold: f64 = r.iter().map(|&ri| ri * ri).sum();
40    let b_norm: f64 = b.iter().map(|&bi| bi * bi).sum::<f64>().sqrt();
41
42    for _iter in 0..max_iter {
43        let ap = a_op(&p);
44
45        let pap: f64 = p.iter().zip(ap.iter())
46            .map(|(&pi, &api)| pi * api)
47            .sum();
48
49        if pap.abs() < 1e-20 {
50            break;
51        }
52
53        let alpha = rsold / pap;
54
55        // x = x + alpha * p
56        for i in 0..n {
57            x[i] += alpha * p[i];
58        }
59
60        // r = r - alpha * A*p
61        for i in 0..n {
62            r[i] -= alpha * ap[i];
63        }
64
65        let rsnew: f64 = r.iter().map(|&ri| ri * ri).sum();
66
67        // Check convergence
68        if rsnew.sqrt() < tol * b_norm {
69            break;
70        }
71
72        let beta = rsnew / rsold;
73
74        // p = r + beta * p
75        for i in 0..n {
76            p[i] = r[i] + beta * p[i];
77        }
78
79        rsold = rsnew;
80    }
81
82    x
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn test_cg_identity() {
91        // Solve Ix = b (identity matrix)
92        let b = vec![1.0, 2.0, 3.0];
93        let x0 = vec![0.0, 0.0, 0.0];
94
95        let x = cg_solve(|v| v.to_vec(), &b, &x0, 1e-10, 100);
96
97        for (xi, bi) in x.iter().zip(b.iter()) {
98            assert!((xi - bi).abs() < 1e-8, "x should equal b");
99        }
100    }
101
102    #[test]
103    fn test_cg_diagonal() {
104        // Solve diag(2,3,4) * x = [2, 6, 12]
105        // Solution: x = [1, 2, 3]
106        let b = vec![2.0, 6.0, 12.0];
107        let x0 = vec![0.0, 0.0, 0.0];
108        let diag = vec![2.0, 3.0, 4.0];
109
110        let x = cg_solve(
111            |v| v.iter().zip(diag.iter()).map(|(&vi, &di)| vi * di).collect(),
112            &b, &x0, 1e-10, 100
113        );
114
115        let expected = vec![1.0, 2.0, 3.0];
116        for (xi, ei) in x.iter().zip(expected.iter()) {
117            assert!((xi - ei).abs() < 1e-8, "Expected {}, got {}", ei, xi);
118        }
119    }
120}