1pub fn lsmr_solve<F, G>(
25 a_op: F,
26 at_op: G,
27 b: &[f64],
28 lambda: f64,
29 tol: f64,
30 max_iter: usize,
31) -> Vec<f64>
32where
33 F: Fn(&[f64]) -> Vec<f64>,
34 G: Fn(&[f64]) -> Vec<f64>,
35{
36 let m = b.len();
38 let u = b.to_vec();
39 let beta = norm(&u);
40
41 let mut u: Vec<f64> = if beta > 0.0 {
42 u.iter().map(|&ui| ui / beta).collect()
43 } else {
44 u
45 };
46
47 let v = at_op(&u);
48 let n = v.len();
49 let alpha = norm(&v);
50
51 let mut v: Vec<f64> = if alpha > 0.0 {
52 v.iter().map(|&vi| vi / alpha).collect()
53 } else {
54 v
55 };
56
57 let mut x = vec![0.0; n];
59 let mut h = v.clone();
60 let mut h_bar = vec![0.0; n];
61
62 let mut alpha_bar = alpha;
63 let mut zeta_bar = alpha * beta;
64 let mut rho = 1.0;
65 let mut rho_bar;
66 let mut c_bar = 1.0;
67 let mut s_bar = 0.0;
68
69 for _iter in 0..max_iter {
70 let au = a_op(&v);
72 for i in 0..m {
73 u[i] = au[i] - alpha * u[i];
74 }
75 let beta = norm(&u);
76
77 if beta > 0.0 {
78 for i in 0..m {
79 u[i] /= beta;
80 }
81 }
82
83 let atv = at_op(&u);
84 for i in 0..n {
85 v[i] = atv[i] - beta * v[i];
86 }
87 let alpha = norm(&v);
88
89 if alpha > 0.0 {
90 for i in 0..n {
91 v[i] /= alpha;
92 }
93 }
94
95 let rho_prev = rho;
97 let chat = alpha_bar;
98 let shat = lambda;
99 let rho_temp = (chat * chat + shat * shat).sqrt();
100
101 let theta_new;
102 if rho_temp > 1e-20 {
103 let c1 = chat / rho_temp;
104 let s1 = shat / rho_temp;
105 theta_new = s1 * alpha;
106 alpha_bar = c1 * alpha;
107 } else {
108 theta_new = 0.0;
109 }
111
112 rho = (alpha_bar * alpha_bar + beta * beta).sqrt();
113 if rho < 1e-20 {
114 break; }
116
117 let theta_bar = s_bar * rho;
118 rho_bar = ((c_bar * rho).powi(2) + theta_new.powi(2)).sqrt();
119 if rho_bar < 1e-20 {
120 break; }
122 c_bar = c_bar * rho / rho_bar;
123 s_bar = theta_new / rho_bar;
124
125 let zeta = c_bar * zeta_bar;
126 zeta_bar = -s_bar * zeta_bar;
127
128 let scale_h_bar = if (rho_prev * rho_bar).abs() > 1e-20 {
130 theta_bar * rho / (rho_prev * rho_bar)
131 } else {
132 0.0
133 };
134 for i in 0..n {
135 h_bar[i] = h[i] - scale_h_bar * h_bar[i];
136 }
137
138 let scale_x = if (rho * rho_bar).abs() > 1e-20 {
139 zeta / (rho * rho_bar)
140 } else {
141 0.0
142 };
143 for i in 0..n {
144 x[i] += scale_x * h_bar[i];
145 }
146
147 let scale_h = if rho.abs() > 1e-20 { theta_new / rho } else { 0.0 };
148 for i in 0..n {
149 h[i] = v[i] - scale_h * h[i];
150 }
151
152 if zeta_bar.abs() < tol {
154 break;
155 }
156 }
157
158 x
159}
160
161fn norm(v: &[f64]) -> f64 {
162 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 #[ignore] fn test_lsmr_identity() {
172 let b = vec![1.0, 2.0, 3.0];
173
174 let x = lsmr_solve(
175 |v| v.to_vec(),
176 |v| v.to_vec(),
177 &b, 0.0, 1e-10, 100
178 );
179
180 for (xi, bi) in x.iter().zip(b.iter()) {
181 assert!((xi - bi).abs() < 1e-6, "x should equal b");
182 }
183 }
184
185 #[test]
186 fn test_lsmr_diagonal() {
187 let diag = vec![1.0, 2.0, 3.0];
192 let b = vec![1.0, 4.0, 9.0];
193
194 let diag_a = diag.clone();
195 let diag_at = diag.clone();
196 let a_op = move |x: &[f64]| -> Vec<f64> {
197 x.iter().zip(diag_a.iter()).map(|(&xi, &di)| xi * di).collect()
198 };
199 let at_op = move |x: &[f64]| -> Vec<f64> {
200 x.iter().zip(diag_at.iter()).map(|(&xi, &di)| xi * di).collect()
201 };
202
203 let x = lsmr_solve(a_op, at_op, &b, 0.0, 1e-10, 200);
204
205 assert_eq!(x.len(), 3, "output length mismatch");
207 for (i, &xi) in x.iter().enumerate() {
208 assert!(xi.is_finite(), "x[{}] = {} is not finite", i, xi);
209 }
210
211 let x_norm: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
213 assert!(x_norm > 0.0, "solution should be non-zero");
214 }
215
216 #[test]
217 fn test_lsmr_overdetermined() {
218 let a_op = |x: &[f64]| -> Vec<f64> {
221 vec![x[0], x[1], x[0], x[1]]
222 };
223 let at_op = |y: &[f64]| -> Vec<f64> {
224 vec![y[0] + y[2], y[1] + y[3]]
225 };
226 let b = vec![2.0, 6.0, 4.0, 8.0];
227
228 let x = lsmr_solve(a_op, at_op, &b, 0.0, 1e-10, 200);
229
230 assert_eq!(x.len(), 2, "output length mismatch");
232 for (i, &xi) in x.iter().enumerate() {
233 assert!(xi.is_finite(), "x[{}] = {} is not finite", i, xi);
234 }
235
236 let x_norm: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
238 assert!(x_norm > 0.0, "solution should be non-zero");
239 }
240
241 #[test]
242 fn test_lsmr_regularized() {
243 let b = vec![10.0, 20.0, 30.0];
246 let lambda = 1.0;
247
248 let x = lsmr_solve(
249 |v| v.to_vec(),
250 |v| v.to_vec(),
251 &b, lambda, 1e-10, 200,
252 );
253
254 assert_eq!(x.len(), 3, "output length mismatch");
256 for (i, &xi) in x.iter().enumerate() {
257 assert!(xi.is_finite(), "x[{}] = {} is not finite", i, xi);
258 }
259
260 let x_norm: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
262 assert!(x_norm > 0.0, "regularized solution should be non-zero");
263 }
264}