ndelement/
math.rs

1//! Mathematical functions.
2use rlst::{Array, MutableArrayImpl, RlstScalar, ValueArrayImpl};
3
4/// Orthogonalise the rows of a matrix, starting with the row numbered `start`
5pub(crate) fn orthogonalise_3<T: RlstScalar, Array3MutImpl: MutableArrayImpl<T, 3>>(
6    mat: &mut Array<Array3MutImpl, 3>,
7    start: usize,
8) {
9    for row in start..mat.shape()[0] {
10        let norm = (0..mat.shape()[1])
11            .map(|i| {
12                (0..mat.shape()[2])
13                    .map(|j| mat.get([row, i, j]).unwrap().powi(2))
14                    .sum::<T>()
15            })
16            .sum::<T>()
17            .sqrt();
18        for i in 0..mat.shape()[1] {
19            for j in 0..mat.shape()[2] {
20                *mat.get_mut([row, i, j]).unwrap() /= norm;
21            }
22        }
23        for r in row + 1..mat.shape()[0] {
24            let dot = (0..mat.shape()[1])
25                .map(|i| {
26                    (0..mat.shape()[2])
27                        .map(|j| *mat.get([row, i, j]).unwrap() * *mat.get([r, i, j]).unwrap())
28                        .sum::<T>()
29                })
30                .sum::<T>();
31            for i in 0..mat.shape()[1] {
32                for j in 0..mat.shape()[2] {
33                    let sub = dot * *mat.get([row, i, j]).unwrap();
34                    *mat.get_mut([r, i, j]).unwrap() -= sub;
35                }
36            }
37        }
38    }
39}
40
41/// Swap two entries in a matrix
42unsafe fn entry_swap<const N: usize, T: RlstScalar, ArrayMut: MutableArrayImpl<T, N>>(
43    mat: &mut Array<ArrayMut, N>,
44    mindex0: [usize; N],
45    mindex1: [usize; N],
46) {
47    unsafe {
48        let value = *mat.get_unchecked(mindex0);
49        *mat.get_unchecked_mut(mindex0) = *mat.get_unchecked(mindex1);
50        *mat.get_unchecked_mut(mindex1) = value;
51    }
52}
53
54/// Compute the LU decomposition of the transpose of a square matrix
55pub fn lu_transpose<T: RlstScalar, Array2MutImpl: MutableArrayImpl<T, 2>>(
56    mat: &mut Array<Array2MutImpl, 2>,
57) -> Vec<usize> {
58    let dim = mat.shape()[0];
59    assert_eq!(mat.shape()[1], dim);
60    let mut perm = (0..dim).collect::<Vec<_>>();
61    if dim > 0 {
62        for i in 0..dim - 1 {
63            let mut max_col = i;
64            let mut max_value = unsafe { mat.get_unchecked([i, i]).abs() };
65            for j in i + 1..dim {
66                let value = unsafe { mat.get_unchecked([i, j]).abs() };
67                if value > max_value {
68                    max_col = j;
69                    max_value = value;
70                }
71            }
72            for j in 0..dim {
73                unsafe {
74                    entry_swap(mat, [j, i], [j, max_col]);
75                }
76            }
77            perm.swap(i, max_col);
78
79            let diag = unsafe { *mat.get_unchecked([i, i]) };
80            for j in i + 1..dim {
81                unsafe {
82                    *mat.get_unchecked_mut([i, j]) /= diag;
83                }
84                for k in i + 1..dim {
85                    unsafe {
86                        let sub = *mat.get_unchecked([i, j]) * *mat.get_unchecked([k, i]);
87                        *mat.get_unchecked_mut([k, j]) -= sub;
88                    }
89                }
90            }
91        }
92    }
93    perm
94}
95
96/// Comvert a permutation into the format expected by `apply_permutation`
97pub fn prepare_permutation(perm: &mut [usize]) {
98    for i in 0..perm.len() {
99        while perm[i] < i {
100            perm[i] = perm[perm[i]];
101        }
102    }
103}
104
105/// Apply a permutation to some data
106pub fn apply_permutation<T>(perm: &[usize], data: &mut [T]) {
107    debug_assert!(data.len().is_multiple_of(perm.len()));
108    let block_size = data.len() / perm.len();
109    for (i, j) in perm.iter().enumerate() {
110        for k in 0..block_size {
111            data.swap(i * block_size + k, *j * block_size + k);
112        }
113    }
114}
115
116/// Convert a linear transformation info the format expected by `apply_matrix` and return the premutation to pass into `apply_matrix`
117pub fn prepare_matrix<T: RlstScalar, Array2Mut: MutableArrayImpl<T, 2>>(
118    mat: &mut Array<Array2Mut, 2>,
119) -> Vec<usize> {
120    let mut perm = lu_transpose(mat);
121    prepare_permutation(&mut perm);
122    perm
123}
124
125/// Apply a permutation and a matrix to some data
126pub fn apply_perm_and_matrix<T: RlstScalar, Array2Impl: ValueArrayImpl<T, 2>>(
127    mat: &Array<Array2Impl, 2>,
128    perm: &[usize],
129    data: &mut [T],
130) {
131    apply_permutation(perm, data);
132    apply_matrix(mat, data);
133}
134
135/// Apply a matrix to some data
136pub fn apply_matrix<T: RlstScalar, Array2Impl: ValueArrayImpl<T, 2>>(
137    mat: &Array<Array2Impl, 2>,
138    data: &mut [T],
139) {
140    let dim = mat.shape()[0];
141    debug_assert!(data.len().is_multiple_of(dim));
142    let block_size = data.len() / dim;
143    for i in 0..dim {
144        for j in i + 1..dim {
145            for k in 0..block_size {
146                data[i * block_size + k] +=
147                    mat.get_value([i, j]).unwrap() * data[j * block_size + k];
148            }
149        }
150    }
151    for i in 1..=dim {
152        for k in 0..block_size {
153            data[(dim - i) * block_size + k] *= mat.get_value([dim - i, dim - i]).unwrap();
154        }
155        for j in 0..dim - i {
156            for k in 0..block_size {
157                data[(dim - i) * block_size + k] +=
158                    mat.get_value([dim - i, j]).unwrap() * data[j * block_size + k];
159            }
160        }
161    }
162}
163
164#[cfg(test)]
165mod test {
166    use super::*;
167    use approx::*;
168    use itertools::izip;
169    use rlst::rlst_dynamic_array;
170
171    #[test]
172    fn test_permutation() {
173        let perm = vec![1, 4, 3, 0, 6, 5, 2];
174        let data = vec![9, 4, 1, 5, 3, 2, 10];
175
176        let mut perm2 = perm.clone();
177        let mut data2 = data.clone();
178
179        prepare_permutation(&mut perm2);
180        apply_permutation(&perm2, &mut data2);
181        for (i, p) in perm.iter().enumerate() {
182            assert_eq!(data2[i], data[*p]);
183        }
184
185        let data = (0..21).map(|i| format!("{i}")).collect::<Vec<_>>();
186        let mut data2 = data.clone();
187
188        apply_permutation(&perm2, &mut data2);
189        for (i, p) in perm.iter().enumerate() {
190            for (a, b) in izip!(&data2[3 * i..3 * i + 3], &data[3 * p..3 * p + 3]) {
191                assert_eq!(a, b);
192            }
193        }
194    }
195
196    #[test]
197    fn test_matrix_2by2() {
198        let mut matrix = rlst_dynamic_array!(f64, [2, 2]);
199        matrix[[0, 0]] = 0.5;
200        matrix[[0, 1]] = 1.5;
201        matrix[[1, 0]] = 1.0;
202        matrix[[1, 1]] = 1.0;
203
204        let perm = prepare_matrix(&mut matrix);
205
206        assert_eq!(perm[0], 1);
207        assert_eq!(perm[1], 1);
208
209        assert_relative_eq!(*matrix.get([0, 0]).unwrap(), 1.5);
210        assert_relative_eq!(*matrix.get([0, 1]).unwrap(), 1.0 / 3.0);
211        assert_relative_eq!(*matrix.get([1, 0]).unwrap(), 1.0);
212        assert_relative_eq!(*matrix.get([1, 1]).unwrap(), 2.0 / 3.0);
213
214        let mut data = vec![1.0, 2.0];
215        apply_perm_and_matrix(&matrix, &perm, &mut data);
216
217        assert_relative_eq!(data[0], 3.5);
218        assert_relative_eq!(data[1], 3.0);
219
220        let mut data = vec![1.0, 2.0, 3.0, 4.0];
221        apply_perm_and_matrix(&matrix, &perm, &mut data);
222
223        assert_relative_eq!(data[0], 5.0);
224        assert_relative_eq!(data[1], 7.0);
225        assert_relative_eq!(data[2], 4.0);
226        assert_relative_eq!(data[3], 6.0);
227    }
228
229    #[test]
230    fn test_matrix_3by3() {
231        let mut matrix = rlst_dynamic_array!(f64, [3, 3]);
232        matrix[[0, 0]] = 0.5;
233        matrix[[0, 1]] = 1.5;
234        matrix[[0, 2]] = 1.0;
235        matrix[[1, 0]] = 1.0;
236        matrix[[1, 1]] = 1.0;
237        matrix[[1, 2]] = 1.0;
238        matrix[[2, 0]] = 0.5;
239        matrix[[2, 1]] = 1.0;
240        matrix[[2, 2]] = 0.5;
241
242        let perm = prepare_matrix(&mut matrix);
243
244        assert_eq!(perm[0], 1);
245        assert_eq!(perm[1], 1);
246        assert_eq!(perm[2], 2);
247
248        assert_relative_eq!(*matrix.get([0, 0]).unwrap(), 1.5);
249        assert_relative_eq!(*matrix.get([0, 1]).unwrap(), 1.0 / 3.0);
250        assert_relative_eq!(*matrix.get([0, 2]).unwrap(), 2.0 / 3.0);
251        assert_relative_eq!(*matrix.get([1, 0]).unwrap(), 1.0);
252        assert_relative_eq!(*matrix.get([1, 1]).unwrap(), 2.0 / 3.0);
253        assert_relative_eq!(*matrix.get([1, 2]).unwrap(), 0.5);
254        assert_relative_eq!(*matrix.get([2, 0]).unwrap(), 1.0);
255        assert_relative_eq!(*matrix.get([2, 1]).unwrap(), 1.0 / 6.0);
256        assert_relative_eq!(*matrix.get([2, 2]).unwrap(), -0.25);
257
258        let mut data = vec![1.0, 2.0, 3.0];
259        apply_perm_and_matrix(&matrix, &perm, &mut data);
260
261        assert_relative_eq!(data[0], 6.5);
262        assert_relative_eq!(data[1], 6.0);
263        assert_relative_eq!(data[2], 4.0);
264
265        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
266        apply_perm_and_matrix(&matrix, &perm, &mut data);
267
268        assert_relative_eq!(data[0], 10.0);
269        assert_relative_eq!(data[1], 13.0);
270        assert_relative_eq!(data[2], 9.0);
271        assert_relative_eq!(data[3], 12.0);
272        assert_relative_eq!(data[4], 6.0);
273        assert_relative_eq!(data[5], 8.0);
274    }
275}