1use rlst::{Array, MutableArrayImpl, RlstScalar, ValueArrayImpl};
3
4pub(crate) fn orthogonalise_3<T: RlstScalar, Array3MutImpl: MutableArrayImpl<T, 3>>(
6 mat: &mut Array<Array3MutImpl, 3>,
7 start: usize,
8) {
9 for row in start..mat.shape()[0] {
10 let norm = (0..mat.shape()[1])
11 .map(|i| {
12 (0..mat.shape()[2])
13 .map(|j| mat.get([row, i, j]).unwrap().powi(2))
14 .sum::<T>()
15 })
16 .sum::<T>()
17 .sqrt();
18 for i in 0..mat.shape()[1] {
19 for j in 0..mat.shape()[2] {
20 *mat.get_mut([row, i, j]).unwrap() /= norm;
21 }
22 }
23 for r in row + 1..mat.shape()[0] {
24 let dot = (0..mat.shape()[1])
25 .map(|i| {
26 (0..mat.shape()[2])
27 .map(|j| *mat.get([row, i, j]).unwrap() * *mat.get([r, i, j]).unwrap())
28 .sum::<T>()
29 })
30 .sum::<T>();
31 for i in 0..mat.shape()[1] {
32 for j in 0..mat.shape()[2] {
33 let sub = dot * *mat.get([row, i, j]).unwrap();
34 *mat.get_mut([r, i, j]).unwrap() -= sub;
35 }
36 }
37 }
38 }
39}
40
41unsafe fn entry_swap<const N: usize, T: RlstScalar, ArrayMut: MutableArrayImpl<T, N>>(
43 mat: &mut Array<ArrayMut, N>,
44 mindex0: [usize; N],
45 mindex1: [usize; N],
46) {
47 unsafe {
48 let value = *mat.get_unchecked(mindex0);
49 *mat.get_unchecked_mut(mindex0) = *mat.get_unchecked(mindex1);
50 *mat.get_unchecked_mut(mindex1) = value;
51 }
52}
53
54pub fn lu_transpose<T: RlstScalar, Array2MutImpl: MutableArrayImpl<T, 2>>(
56 mat: &mut Array<Array2MutImpl, 2>,
57) -> Vec<usize> {
58 let dim = mat.shape()[0];
59 assert_eq!(mat.shape()[1], dim);
60 let mut perm = (0..dim).collect::<Vec<_>>();
61 if dim > 0 {
62 for i in 0..dim - 1 {
63 let mut max_col = i;
64 let mut max_value = unsafe { mat.get_unchecked([i, i]).abs() };
65 for j in i + 1..dim {
66 let value = unsafe { mat.get_unchecked([i, j]).abs() };
67 if value > max_value {
68 max_col = j;
69 max_value = value;
70 }
71 }
72 for j in 0..dim {
73 unsafe {
74 entry_swap(mat, [j, i], [j, max_col]);
75 }
76 }
77 perm.swap(i, max_col);
78
79 let diag = unsafe { *mat.get_unchecked([i, i]) };
80 for j in i + 1..dim {
81 unsafe {
82 *mat.get_unchecked_mut([i, j]) /= diag;
83 }
84 for k in i + 1..dim {
85 unsafe {
86 let sub = *mat.get_unchecked([i, j]) * *mat.get_unchecked([k, i]);
87 *mat.get_unchecked_mut([k, j]) -= sub;
88 }
89 }
90 }
91 }
92 }
93 perm
94}
95
96pub fn prepare_permutation(perm: &mut [usize]) {
98 for i in 0..perm.len() {
99 while perm[i] < i {
100 perm[i] = perm[perm[i]];
101 }
102 }
103}
104
105pub fn apply_permutation<T>(perm: &[usize], data: &mut [T]) {
107 debug_assert!(data.len().is_multiple_of(perm.len()));
108 let block_size = data.len() / perm.len();
109 for (i, j) in perm.iter().enumerate() {
110 for k in 0..block_size {
111 data.swap(i * block_size + k, *j * block_size + k);
112 }
113 }
114}
115
116pub fn prepare_matrix<T: RlstScalar, Array2Mut: MutableArrayImpl<T, 2>>(
118 mat: &mut Array<Array2Mut, 2>,
119) -> Vec<usize> {
120 let mut perm = lu_transpose(mat);
121 prepare_permutation(&mut perm);
122 perm
123}
124
125pub fn apply_perm_and_matrix<T: RlstScalar, Array2Impl: ValueArrayImpl<T, 2>>(
127 mat: &Array<Array2Impl, 2>,
128 perm: &[usize],
129 data: &mut [T],
130) {
131 apply_permutation(perm, data);
132 apply_matrix(mat, data);
133}
134
135pub fn apply_matrix<T: RlstScalar, Array2Impl: ValueArrayImpl<T, 2>>(
137 mat: &Array<Array2Impl, 2>,
138 data: &mut [T],
139) {
140 let dim = mat.shape()[0];
141 debug_assert!(data.len().is_multiple_of(dim));
142 let block_size = data.len() / dim;
143 for i in 0..dim {
144 for j in i + 1..dim {
145 for k in 0..block_size {
146 data[i * block_size + k] +=
147 mat.get_value([i, j]).unwrap() * data[j * block_size + k];
148 }
149 }
150 }
151 for i in 1..=dim {
152 for k in 0..block_size {
153 data[(dim - i) * block_size + k] *= mat.get_value([dim - i, dim - i]).unwrap();
154 }
155 for j in 0..dim - i {
156 for k in 0..block_size {
157 data[(dim - i) * block_size + k] +=
158 mat.get_value([dim - i, j]).unwrap() * data[j * block_size + k];
159 }
160 }
161 }
162}
163
164#[cfg(test)]
165mod test {
166 use super::*;
167 use approx::*;
168 use itertools::izip;
169 use rlst::rlst_dynamic_array;
170
171 #[test]
172 fn test_permutation() {
173 let perm = vec![1, 4, 3, 0, 6, 5, 2];
174 let data = vec![9, 4, 1, 5, 3, 2, 10];
175
176 let mut perm2 = perm.clone();
177 let mut data2 = data.clone();
178
179 prepare_permutation(&mut perm2);
180 apply_permutation(&perm2, &mut data2);
181 for (i, p) in perm.iter().enumerate() {
182 assert_eq!(data2[i], data[*p]);
183 }
184
185 let data = (0..21).map(|i| format!("{i}")).collect::<Vec<_>>();
186 let mut data2 = data.clone();
187
188 apply_permutation(&perm2, &mut data2);
189 for (i, p) in perm.iter().enumerate() {
190 for (a, b) in izip!(&data2[3 * i..3 * i + 3], &data[3 * p..3 * p + 3]) {
191 assert_eq!(a, b);
192 }
193 }
194 }
195
196 #[test]
197 fn test_matrix_2by2() {
198 let mut matrix = rlst_dynamic_array!(f64, [2, 2]);
199 matrix[[0, 0]] = 0.5;
200 matrix[[0, 1]] = 1.5;
201 matrix[[1, 0]] = 1.0;
202 matrix[[1, 1]] = 1.0;
203
204 let perm = prepare_matrix(&mut matrix);
205
206 assert_eq!(perm[0], 1);
207 assert_eq!(perm[1], 1);
208
209 assert_relative_eq!(*matrix.get([0, 0]).unwrap(), 1.5);
210 assert_relative_eq!(*matrix.get([0, 1]).unwrap(), 1.0 / 3.0);
211 assert_relative_eq!(*matrix.get([1, 0]).unwrap(), 1.0);
212 assert_relative_eq!(*matrix.get([1, 1]).unwrap(), 2.0 / 3.0);
213
214 let mut data = vec![1.0, 2.0];
215 apply_perm_and_matrix(&matrix, &perm, &mut data);
216
217 assert_relative_eq!(data[0], 3.5);
218 assert_relative_eq!(data[1], 3.0);
219
220 let mut data = vec![1.0, 2.0, 3.0, 4.0];
221 apply_perm_and_matrix(&matrix, &perm, &mut data);
222
223 assert_relative_eq!(data[0], 5.0);
224 assert_relative_eq!(data[1], 7.0);
225 assert_relative_eq!(data[2], 4.0);
226 assert_relative_eq!(data[3], 6.0);
227 }
228
229 #[test]
230 fn test_matrix_3by3() {
231 let mut matrix = rlst_dynamic_array!(f64, [3, 3]);
232 matrix[[0, 0]] = 0.5;
233 matrix[[0, 1]] = 1.5;
234 matrix[[0, 2]] = 1.0;
235 matrix[[1, 0]] = 1.0;
236 matrix[[1, 1]] = 1.0;
237 matrix[[1, 2]] = 1.0;
238 matrix[[2, 0]] = 0.5;
239 matrix[[2, 1]] = 1.0;
240 matrix[[2, 2]] = 0.5;
241
242 let perm = prepare_matrix(&mut matrix);
243
244 assert_eq!(perm[0], 1);
245 assert_eq!(perm[1], 1);
246 assert_eq!(perm[2], 2);
247
248 assert_relative_eq!(*matrix.get([0, 0]).unwrap(), 1.5);
249 assert_relative_eq!(*matrix.get([0, 1]).unwrap(), 1.0 / 3.0);
250 assert_relative_eq!(*matrix.get([0, 2]).unwrap(), 2.0 / 3.0);
251 assert_relative_eq!(*matrix.get([1, 0]).unwrap(), 1.0);
252 assert_relative_eq!(*matrix.get([1, 1]).unwrap(), 2.0 / 3.0);
253 assert_relative_eq!(*matrix.get([1, 2]).unwrap(), 0.5);
254 assert_relative_eq!(*matrix.get([2, 0]).unwrap(), 1.0);
255 assert_relative_eq!(*matrix.get([2, 1]).unwrap(), 1.0 / 6.0);
256 assert_relative_eq!(*matrix.get([2, 2]).unwrap(), -0.25);
257
258 let mut data = vec![1.0, 2.0, 3.0];
259 apply_perm_and_matrix(&matrix, &perm, &mut data);
260
261 assert_relative_eq!(data[0], 6.5);
262 assert_relative_eq!(data[1], 6.0);
263 assert_relative_eq!(data[2], 4.0);
264
265 let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
266 apply_perm_and_matrix(&matrix, &perm, &mut data);
267
268 assert_relative_eq!(data[0], 10.0);
269 assert_relative_eq!(data[1], 13.0);
270 assert_relative_eq!(data[2], 9.0);
271 assert_relative_eq!(data[3], 12.0);
272 assert_relative_eq!(data[4], 6.0);
273 assert_relative_eq!(data[5], 8.0);
274 }
275}