1use crate::fft::{fft3d, ifft3d};
16use crate::kernels::dipole::dipole_kernel;
17use num_complex::Complex64;
18
19pub fn adjust_offset(
37 removed_voxels: &[f64],
38 lfs_sdf: &[f64],
39 chi_1: &[f64],
40 chi_2: &[f64],
41 nx: usize, ny: usize, nz: usize,
42 vsx: f64, vsy: f64, vsz: f64,
43 b0_dir: (f64, f64, f64),
44 ppm: f64,
45) -> Vec<f64> {
46 let n_total = nx * ny * nz;
47
48 let lfs_scaled: Vec<f64> = lfs_sdf.iter().map(|&v| v / ppm).collect();
50
51 let removed_clean: Vec<f64> = removed_voxels.iter()
53 .map(|&v| if v < 0.0 { 0.0 } else { v })
54 .collect();
55
56 let chi_1_masked: Vec<f64> = chi_1.iter()
58 .zip(removed_clean.iter())
59 .map(|(&c, &r)| if r > 0.0 { c } else { 0.0 })
60 .collect();
61
62 let combined_chi: Vec<f64> = chi_1_masked.iter()
64 .zip(chi_2.iter())
65 .map(|(&c1, &c2)| c1 + c2)
66 .collect();
67
68 let d_kernel = dipole_kernel(nx, ny, nz, vsx, vsy, vsz, b0_dir);
70
71 let mut removed_complex: Vec<Complex64> = removed_clean.iter()
78 .map(|&v| Complex64::new(v, 0.0))
79 .collect();
80 let mut combined_complex: Vec<Complex64> = combined_chi.iter()
81 .map(|&v| Complex64::new(v, 0.0))
82 .collect();
83
84 fft3d(&mut removed_complex, nx, ny, nz);
86 fft3d(&mut combined_complex, nx, ny, nz);
87
88 let mut d_removed: Vec<Complex64> = removed_complex.iter()
90 .zip(d_kernel.iter())
91 .map(|(&c, &d)| c * d)
92 .collect();
93
94 let mut d_combined: Vec<Complex64> = combined_complex.iter()
95 .zip(d_kernel.iter())
96 .map(|(&c, &d)| c * d)
97 .collect();
98
99 ifft3d(&mut d_removed, nx, ny, nz);
101 ifft3d(&mut d_combined, nx, ny, nz);
102
103 let x1: Vec<f64> = d_removed.iter().map(|c| c.re).collect();
105
106 let x2: Vec<f64> = lfs_scaled.iter()
108 .zip(d_combined.iter())
109 .map(|(&lfs, c)| lfs - c.re)
110 .collect();
111
112 let x1_dot_x2: f64 = x1.iter().zip(x2.iter()).map(|(&a, &b)| a * b).sum();
114 let x1_dot_x1: f64 = x1.iter().map(|&a| a * a).sum();
115
116 let offset = if x1_dot_x1.abs() > 1e-10 {
117 x1_dot_x2 / x1_dot_x1
118 } else {
119 0.0
120 };
121
122 let adjusted: Vec<f64> = combined_chi.iter()
124 .zip(removed_clean.iter())
125 .map(|(&c, &r)| c + offset * r)
126 .collect();
127
128 adjusted
129}
130
131#[derive(Clone, Debug)]
133pub struct QsmartParams {
134 pub ppm: f64,
136 pub sdf_sigma1_stage1: f64,
138 pub sdf_sigma2_stage1: f64,
139 pub sdf_sigma1_stage2: f64,
141 pub sdf_sigma2_stage2: f64,
142 pub sdf_spatial_radius: i32,
144 pub sdf_lower_lim: f64,
146 pub sdf_curv_constant: f64,
148 pub vasc_sphere_radius: i32,
150 pub frangi_scale_range: [f64; 2],
152 pub frangi_scale_ratio: f64,
154 pub frangi_c: f64,
156 pub ilsqr_tol: f64,
158 pub ilsqr_max_iter: usize,
160 pub b0_dir: (f64, f64, f64),
162}
163
164impl Default for QsmartParams {
165 fn default() -> Self {
166 Self {
168 ppm: 2.675e8 * 7.0 / 1e6, sdf_sigma1_stage1: 10.0,
170 sdf_sigma2_stage1: 10.0,
171 sdf_sigma1_stage2: 8.0,
172 sdf_sigma2_stage2: 2.0,
173 sdf_spatial_radius: 8,
174 sdf_lower_lim: 0.6,
175 sdf_curv_constant: 500.0,
176 vasc_sphere_radius: 8,
177 frangi_scale_range: [0.5, 6.0],
179 frangi_scale_ratio: 0.5,
180 frangi_c: 500.0,
181 ilsqr_tol: 0.01,
182 ilsqr_max_iter: 50,
183 b0_dir: (0.0, 0.0, 1.0),
184 }
185 }
186}
187
188impl QsmartParams {
189 pub fn for_field_strength(field_tesla: f64) -> Self {
191 let gyro = 2.675e8; Self {
193 ppm: gyro * field_tesla / 1e6,
194 ..Default::default()
195 }
196 }
197}
198
199pub struct QsmartResult {
201 pub chi_qsmart: Vec<f64>,
203 pub chi_stage1: Vec<f64>,
205 pub chi_stage2: Vec<f64>,
207 pub lfs_stage1: Vec<f64>,
209 pub lfs_stage2: Vec<f64>,
211 pub vasc_mask: Vec<f64>,
213}
214
215pub fn compute_removed_voxels(
220 mask: &[f64],
221 r_0: &[f64],
222 vasc_only: &[f64],
223) -> Vec<f64> {
224 mask.iter()
225 .zip(r_0.iter())
226 .zip(vasc_only.iter())
227 .map(|((&m, &r), &v)| m * r - v)
228 .collect()
229}
230
231pub fn compute_weighted_mask_stage1(mask: &[f64], r_0: &[f64]) -> Vec<f64> {
236 mask.iter().zip(r_0.iter()).map(|(&m, &r)| m * r).collect()
237}
238
239pub fn compute_weighted_mask_stage2(mask: &[f64], r_0: &[f64], vasc_only: &[f64]) -> Vec<f64> {
240 mask.iter()
241 .zip(r_0.iter())
242 .zip(vasc_only.iter())
243 .map(|((&m, &r), &v)| m * v * r)
244 .collect()
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_offset_adjustment_basic() {
253 let n = 8;
255 let n_total = n * n * n;
256
257 let removed = vec![1.0f64; n_total];
258 let lfs = vec![0.0f64; n_total];
259 let chi_1 = vec![0.1f64; n_total];
260 let chi_2 = vec![0.1f64; n_total];
261
262 let result = adjust_offset(
263 &removed, &lfs, &chi_1, &chi_2,
264 n, n, n, 1.0, 1.0, 1.0,
265 (0.0, 0.0, 1.0), 1.0
266 );
267
268 assert_eq!(result.len(), n_total);
269 }
270
271 #[test]
272 fn test_params_for_field_strength() {
273 let params_7t = QsmartParams::for_field_strength(7.0);
274 let params_3t = QsmartParams::for_field_strength(3.0);
275
276 assert!(params_7t.ppm > params_3t.ppm);
278 }
279}