ndelement/polynomials/
legendre.rs

1//! Orthonormal polynomials.
2//!
3//! Adapted from the C++ code by Chris Richardson and Matthew Scroggs
4//! <https://github.com/FEniCS/basix/blob/main/cpp/basix/polyset.cpp>
5use super::{derivative_count, polynomial_count};
6use crate::types::ReferenceCellType;
7use rlst::{Array, MutableArrayImpl, RlstScalar, ValueArrayImpl};
8
9fn tri_index(i: usize, j: usize) -> usize {
10    (i + j + 1) * (i + j) / 2 + j
11}
12
13fn quad_index(i: usize, j: usize, n: usize) -> usize {
14    i * (n + 1) + j
15}
16
17fn tet_index(i: usize, j: usize, k: usize) -> usize {
18    (i + j + k) * (i + j + k + 1) * (i + j + k + 2) / 6 + (j + k) * (j + k + 1) / 2 + k
19}
20
21fn hex_index(i: usize, j: usize, k: usize, n: usize) -> usize {
22    i * (n + 1) * (n + 1) + j * (n + 1) + k
23}
24
25/// The coefficients in the Jacobi Polynomial recurrence relation
26fn jrc<T: RlstScalar>(a: usize, n: usize) -> (T, T, T) {
27    (
28        T::from((a + 2 * n + 1) * (a + 2 * n + 2)).unwrap()
29            / T::from(2 * (n + 1) * (a + n + 1)).unwrap(),
30        T::from(a * a * (a + 2 * n + 1)).unwrap()
31            / T::from(2 * (n + 1) * (a + n + 1) * (a + 2 * n)).unwrap(),
32        T::from(n * (a + n) * (a + 2 * n + 2)).unwrap()
33            / T::from((n + 1) * (a + n + 1) * (a + 2 * n)).unwrap(),
34    )
35}
36
37/// Tabulate orthonormal polynomials on an interval
38fn tabulate_interval<
39    T: RlstScalar,
40    TGeo: RlstScalar,
41    Array2Impl: ValueArrayImpl<TGeo, 2>,
42    Array3MutImpl: MutableArrayImpl<T, 3>,
43>(
44    points: &Array<Array2Impl, 2>,
45    degree: usize,
46    derivatives: usize,
47    data: &mut Array<Array3MutImpl, 3>,
48) {
49    debug_assert!(data.shape()[0] == derivatives + 1);
50    debug_assert!(data.shape()[1] == degree + 1);
51    debug_assert!(data.shape()[2] == points.shape()[1]);
52    debug_assert!(points.shape()[0] == 1);
53
54    for i in 0..data.shape()[2] {
55        *data.get_mut([0, 0, i]).unwrap() = T::from(1.0).unwrap();
56    }
57    for k in 1..data.shape()[0] {
58        for i in 0..data.shape()[2] {
59            *data.get_mut([k, 0, i]).unwrap() = T::from(0.0).unwrap();
60        }
61    }
62
63    for k in 0..=derivatives {
64        for p in 1..=degree {
65            let a = T::from(p - 1).unwrap() / T::from(p).unwrap();
66            let b = (a + T::from(1.0).unwrap())
67                * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
68                    / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(1.0).unwrap()))
69                .sqrt();
70            for i in 0..data.shape()[2] {
71                let d = *data.get([k, p - 1, i]).unwrap();
72                *data.get_mut([k, p, i]).unwrap() =
73                    (T::from(points.get_value([0, i]).unwrap()).unwrap() * T::from(2.0).unwrap()
74                        - T::from(1.0).unwrap())
75                        * d
76                        * b;
77            }
78            if p > 1 {
79                let c = a
80                    * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
81                        / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(3.0).unwrap()))
82                    .sqrt();
83                for i in 0..data.shape()[2] {
84                    let d = *data.get([k, p - 2, i]).unwrap();
85                    *data.get_mut([k, p, i]).unwrap() -= d * c;
86                }
87            }
88            if k > 0 {
89                for i in 0..data.shape()[2] {
90                    let d = *data.get([k - 1, p - 1, i]).unwrap();
91                    *data.get_mut([k, p, i]).unwrap() +=
92                        T::from(2.0).unwrap() * T::from(k).unwrap() * d * b;
93                }
94            }
95        }
96    }
97}
98
99/// Tabulate orthonormal polynomials on a quadrilateral
100fn tabulate_quadrilateral<
101    T: RlstScalar,
102    TGeo: RlstScalar,
103    Array2Impl: ValueArrayImpl<TGeo, 2>,
104    Array3MutImpl: MutableArrayImpl<T, 3>,
105>(
106    points: &Array<Array2Impl, 2>,
107    degree: usize,
108    derivatives: usize,
109    data: &mut Array<Array3MutImpl, 3>,
110) {
111    debug_assert!(data.shape()[0] == (derivatives + 1) * (derivatives + 2) / 2);
112    debug_assert!(data.shape()[1] == (degree + 1) * (degree + 1));
113    debug_assert!(data.shape()[2] == points.shape()[1]);
114    debug_assert!(points.shape()[0] == 2);
115
116    for i in 0..data.shape()[2] {
117        *data
118            .get_mut([tri_index(0, 0), quad_index(0, 0, degree), i])
119            .unwrap() = T::from(1.0).unwrap();
120    }
121
122    // Tabulate polynomials in x
123    for k in 1..=derivatives {
124        for i in 0..data.shape()[2] {
125            *data
126                .get_mut([tri_index(k, 0), quad_index(0, 0, degree), i])
127                .unwrap() = T::from(0.0).unwrap();
128        }
129    }
130
131    for k in 0..=derivatives {
132        for p in 1..=degree {
133            let a = T::from(1.0).unwrap() - T::from(1.0).unwrap() / T::from(p).unwrap();
134            let b = (a + T::from(1.0).unwrap())
135                * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
136                    / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(1.0).unwrap()))
137                .sqrt();
138            for i in 0..data.shape()[2] {
139                let d = *data
140                    .get([tri_index(k, 0), quad_index(p - 1, 0, degree), i])
141                    .unwrap();
142                *data
143                    .get_mut([tri_index(k, 0), quad_index(p, 0, degree), i])
144                    .unwrap() = (T::from(points.get_value([0, i]).unwrap()).unwrap()
145                    * T::from(2.0).unwrap()
146                    - T::from(1.0).unwrap())
147                    * d
148                    * b;
149            }
150            if p > 1 {
151                let c = a
152                    * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
153                        / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(3.0).unwrap()))
154                    .sqrt();
155                for i in 0..data.shape()[2] {
156                    let d = *data
157                        .get([tri_index(k, 0), quad_index(p - 2, 0, degree), i])
158                        .unwrap();
159                    *data
160                        .get_mut([tri_index(k, 0), quad_index(p, 0, degree), i])
161                        .unwrap() -= d * c;
162                }
163            }
164            if k > 0 {
165                for i in 0..data.shape()[2] {
166                    let d = *data
167                        .get([tri_index(k - 1, 0), quad_index(p - 1, 0, degree), i])
168                        .unwrap();
169                    *data
170                        .get_mut([tri_index(k, 0), quad_index(p, 0, degree), i])
171                        .unwrap() += T::from(2.0).unwrap() * T::from(k).unwrap() * d * b;
172                }
173            }
174        }
175    }
176
177    // Tabulate polynomials in y
178    for k in 1..=derivatives {
179        for i in 0..data.shape()[2] {
180            *data
181                .get_mut([tri_index(0, k), quad_index(0, 0, degree), i])
182                .unwrap() = T::from(0.0).unwrap();
183        }
184    }
185
186    for k in 0..=derivatives {
187        for p in 1..=degree {
188            let a = T::from(1.0).unwrap() - T::from(1.0).unwrap() / T::from(p).unwrap();
189            let b = (a + T::from(1.0).unwrap())
190                * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
191                    / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(1.0).unwrap()))
192                .sqrt();
193            for i in 0..data.shape()[2] {
194                let d = *data
195                    .get([tri_index(0, k), quad_index(0, p - 1, degree), i])
196                    .unwrap();
197                *data
198                    .get_mut([tri_index(0, k), quad_index(0, p, degree), i])
199                    .unwrap() = (T::from(points.get_value([1, i]).unwrap()).unwrap()
200                    * T::from(2.0).unwrap()
201                    - T::from(1.0).unwrap())
202                    * d
203                    * b;
204            }
205            if p > 1 {
206                let c = a
207                    * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
208                        / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(3.0).unwrap()))
209                    .sqrt();
210                for i in 0..data.shape()[2] {
211                    let d = *data
212                        .get([tri_index(0, k), quad_index(0, p - 2, degree), i])
213                        .unwrap();
214                    *data
215                        .get_mut([tri_index(0, k), quad_index(0, p, degree), i])
216                        .unwrap() -= d * c;
217                }
218            }
219            if k > 0 {
220                for i in 0..data.shape()[2] {
221                    let d = *data
222                        .get([tri_index(0, k - 1), quad_index(0, p - 1, degree), i])
223                        .unwrap();
224                    *data
225                        .get_mut([tri_index(0, k), quad_index(0, p, degree), i])
226                        .unwrap() += T::from(2.0).unwrap() * T::from(k).unwrap() * d * b;
227                }
228            }
229        }
230    }
231
232    // Fill in the rest of the values as products
233    for kx in 0..=derivatives {
234        for ky in 0..=derivatives - kx {
235            for px in 1..=degree {
236                for py in 1..=degree {
237                    for i in 0..data.shape()[2] {
238                        let d = *data
239                            .get([tri_index(0, ky), quad_index(0, py, degree), i])
240                            .unwrap();
241                        *data
242                            .get_mut([tri_index(kx, ky), quad_index(px, py, degree), i])
243                            .unwrap() = *data
244                            .get([tri_index(kx, 0), quad_index(px, 0, degree), i])
245                            .unwrap()
246                            * d;
247                    }
248                }
249            }
250        }
251    }
252}
253/// Tabulate orthonormal polynomials on a triangle
254fn tabulate_triangle<
255    T: RlstScalar,
256    TGeo: RlstScalar,
257    Array2Impl: ValueArrayImpl<TGeo, 2>,
258    Array3MutImpl: MutableArrayImpl<T, 3>,
259>(
260    points: &Array<Array2Impl, 2>,
261    degree: usize,
262    derivatives: usize,
263    data: &mut Array<Array3MutImpl, 3>,
264) {
265    debug_assert!(data.shape()[0] == (derivatives + 1) * (derivatives + 2) / 2);
266    debug_assert!(data.shape()[1] == (degree + 1) * (degree + 2) / 2);
267    debug_assert!(data.shape()[2] == points.shape()[1]);
268    debug_assert!(points.shape()[0] == 2);
269
270    for i in 0..data.shape()[2] {
271        *data.get_mut([tri_index(0, 0), tri_index(0, 0), i]).unwrap() =
272            T::sqrt(T::from(2.0).unwrap());
273    }
274
275    for k in 1..data.shape()[0] {
276        for i in 0..data.shape()[2] {
277            *data.get_mut([k, tri_index(0, 0), i]).unwrap() = T::from(0.0).unwrap();
278        }
279    }
280
281    for kx in 0..=derivatives {
282        for ky in 0..=derivatives - kx {
283            for p in 1..=degree {
284                let a = T::from(2.0).unwrap() - T::from(1.0).unwrap() / T::from(p).unwrap();
285                let scale1 = T::sqrt(
286                    (T::from(p).unwrap() + T::from(0.5).unwrap())
287                        * (T::from(p).unwrap() + T::from(1.0).unwrap())
288                        / ((T::from(p).unwrap() - T::from(0.5).unwrap()) * T::from(p).unwrap()),
289                );
290                for i in 0..data.shape()[2] {
291                    let d = *data
292                        .get([tri_index(kx, ky), tri_index(0, p - 1), i])
293                        .unwrap();
294                    *data
295                        .get_mut([tri_index(kx, ky), tri_index(0, p), i])
296                        .unwrap() = (T::from(points.get_value([0, i]).unwrap()).unwrap()
297                        * T::from(2.0).unwrap()
298                        + T::from(points.get_value([1, i]).unwrap()).unwrap()
299                        - T::from(1.0).unwrap())
300                        * d
301                        * a
302                        * scale1;
303                }
304                if kx > 0 {
305                    for i in 0..data.shape()[2] {
306                        let d = *data
307                            .get([tri_index(kx - 1, ky), tri_index(0, p - 1), i])
308                            .unwrap();
309                        *data
310                            .get_mut([tri_index(kx, ky), tri_index(0, p), i])
311                            .unwrap() +=
312                            T::from(2.0).unwrap() * T::from(kx).unwrap() * a * d * scale1;
313                    }
314                }
315                if ky > 0 {
316                    for i in 0..data.shape()[2] {
317                        let d = *data
318                            .get([tri_index(kx, ky - 1), tri_index(0, p - 1), i])
319                            .unwrap();
320                        *data
321                            .get_mut([tri_index(kx, ky), tri_index(0, p), i])
322                            .unwrap() += T::from(ky).unwrap() * a * d * scale1;
323                    }
324                }
325                if p > 1 {
326                    let scale2 = T::sqrt(
327                        (T::from(p).unwrap() + T::from(0.5).unwrap())
328                            * (T::from(p).unwrap() + T::from(1.0).unwrap()),
329                    ) / T::sqrt(
330                        (T::from(p).unwrap() - T::from(1.5).unwrap())
331                            * (T::from(p).unwrap() - T::from(1.0).unwrap()),
332                    );
333
334                    for i in 0..data.shape()[2] {
335                        let b = T::from(1.0).unwrap()
336                            - T::from(points.get_value([1, i]).unwrap()).unwrap();
337                        let d = *data
338                            .get([tri_index(kx, ky), tri_index(0, p - 2), i])
339                            .unwrap();
340                        *data
341                            .get_mut([tri_index(kx, ky), tri_index(0, p), i])
342                            .unwrap() -= b * b * d * (a - T::from(1.0).unwrap()) * scale2;
343                    }
344                    if ky > 0 {
345                        for i in 0..data.shape()[2] {
346                            let d = *data
347                                .get([tri_index(kx, ky - 1), tri_index(0, p - 2), i])
348                                .unwrap();
349                            *data
350                                .get_mut([tri_index(kx, ky), tri_index(0, p), i])
351                                .unwrap() -= T::from(2.0).unwrap()
352                                * T::from(ky).unwrap()
353                                * (T::from(points.get_value([1, i]).unwrap()).unwrap()
354                                    - T::from(1.0).unwrap())
355                                * d
356                                * scale2
357                                * (a - T::from(1.0).unwrap());
358                        }
359                    }
360                    if ky > 1 {
361                        for i in 0..data.shape()[2] {
362                            let d = *data
363                                .get([tri_index(kx, ky - 2), tri_index(0, p - 2), i])
364                                .unwrap();
365                            *data
366                                .get_mut([tri_index(kx, ky), tri_index(0, p), i])
367                                .unwrap() -= T::from(ky).unwrap()
368                                * (T::from(ky).unwrap() - T::from(1.0).unwrap())
369                                * d
370                                * scale2
371                                * (a - T::from(1.0).unwrap());
372                        }
373                    }
374                }
375            }
376            for p in 0..degree {
377                let scale3 = T::sqrt(
378                    (T::from(p).unwrap() + T::from(2.0).unwrap())
379                        / (T::from(p).unwrap() + T::from(1.0).unwrap()),
380                );
381                for i in 0..data.shape()[2] {
382                    *data
383                        .get_mut([tri_index(kx, ky), tri_index(1, p), i])
384                        .unwrap() = *data.get([tri_index(kx, ky), tri_index(0, p), i]).unwrap()
385                        * scale3
386                        * ((T::from(points.get_value([1, i]).unwrap()).unwrap()
387                            * T::from(2.0).unwrap()
388                            - T::from(1.0).unwrap())
389                            * (T::from(1.5).unwrap() + T::from(p).unwrap())
390                            + T::from(0.5).unwrap()
391                            + T::from(p).unwrap());
392                }
393                if ky > 0 {
394                    for i in 0..data.shape()[2] {
395                        let d = *data
396                            .get([tri_index(kx, ky - 1), tri_index(0, p), i])
397                            .unwrap();
398                        *data
399                            .get_mut([tri_index(kx, ky), tri_index(1, p), i])
400                            .unwrap() += T::from(2.0).unwrap()
401                            * T::from(ky).unwrap()
402                            * (T::from(1.5).unwrap() + T::from(p).unwrap())
403                            * d
404                            * scale3;
405                    }
406                }
407                for q in 1..degree - p {
408                    let scale4 = T::sqrt(
409                        (T::from(p).unwrap() + T::from(q).unwrap() + T::from(2.0).unwrap())
410                            / (T::from(p).unwrap() + T::from(q).unwrap() + T::from(1.0).unwrap()),
411                    );
412                    let scale5 = T::sqrt(
413                        (T::from(p).unwrap() + T::from(q).unwrap() + T::from(2.0).unwrap())
414                            / (T::from(p).unwrap() + T::from(q).unwrap()),
415                    );
416                    let (a1, a2, a3) = jrc(2 * p + 1, q);
417
418                    for i in 0..data.shape()[2] {
419                        let d = *data.get([tri_index(kx, ky), tri_index(q, p), i]).unwrap();
420                        *data
421                            .get_mut([tri_index(kx, ky), tri_index(q + 1, p), i])
422                            .unwrap() = d
423                            * scale4
424                            * ((T::from(points.get_value([1, i]).unwrap()).unwrap()
425                                * T::from(T::from(2.0).unwrap()).unwrap()
426                                - T::from(T::from(1.0).unwrap()).unwrap())
427                                * a1
428                                + a2)
429                            - *data
430                                .get([tri_index(kx, ky), tri_index(q - 1, p), i])
431                                .unwrap()
432                                * scale5
433                                * a3;
434                    }
435                    if ky > 0 {
436                        for i in 0..data.shape()[2] {
437                            let d = *data
438                                .get([tri_index(kx, ky - 1), tri_index(q, p), i])
439                                .unwrap();
440                            *data
441                                .get_mut([tri_index(kx, ky), tri_index(q + 1, p), i])
442                                .unwrap() += T::from(T::from(2.0).unwrap() * T::from(ky).unwrap())
443                                .unwrap()
444                                * a1
445                                * d
446                                * scale4;
447                        }
448                    }
449                }
450            }
451        }
452    }
453}
454
455/// Tabulate orthonormal polynomials on a tetrahedron
456fn tabulate_tetrahedron<
457    T: RlstScalar,
458    TGeo: RlstScalar,
459    Array2Impl: ValueArrayImpl<TGeo, 2>,
460    Array3MutImpl: MutableArrayImpl<T, 3>,
461>(
462    points: &Array<Array2Impl, 2>,
463    degree: usize,
464    derivatives: usize,
465    data: &mut Array<Array3MutImpl, 3>,
466) {
467    debug_assert!(data.shape()[0] == (derivatives + 1) * (derivatives + 2) * (derivatives + 3) / 6);
468    debug_assert!(data.shape()[1] == (degree + 1) * (degree + 2) * (degree + 3) / 6);
469    debug_assert!(data.shape()[2] == points.shape()[1]);
470    debug_assert!(points.shape()[0] == 3);
471
472    for i in 0..data.shape()[2] {
473        *data
474            .get_mut([tet_index(0, 0, 0), tet_index(0, 0, 0), i])
475            .unwrap() = T::sqrt(T::from(6.0).unwrap());
476    }
477
478    for k in 1..data.shape()[0] {
479        for i in 0..data.shape()[2] {
480            *data.get_mut([k, tet_index(0, 0, 0), i]).unwrap() = T::from(0.0).unwrap();
481        }
482    }
483
484    for kx in 0..=derivatives {
485        for ky in 0..=derivatives - kx {
486            for kz in 0..=derivatives - kx - ky {
487                for p in 1..=degree {
488                    let a = T::from(2 * p - 1).unwrap() / T::from(p).unwrap();
489                    for i in 0..points.shape()[1] {
490                        let d = *data
491                            .get([tet_index(kx, ky, kz), tet_index(0, 0, p - 1), i])
492                            .unwrap();
493                        *data
494                            .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
495                            .unwrap() = (T::from(points.get_value([0, i]).unwrap()).unwrap()
496                            * T::from(2.0).unwrap()
497                            + T::from(points.get_value([1, i]).unwrap()).unwrap()
498                            + T::from(points.get_value([2, i]).unwrap()).unwrap()
499                            - T::from(1.0).unwrap())
500                            * a
501                            * d;
502                    }
503                    if kx > 0 {
504                        for i in 0..points.shape()[1] {
505                            let d = *data
506                                .get([tet_index(kx - 1, ky, kz), tet_index(0, 0, p - 1), i])
507                                .unwrap();
508                            *data
509                                .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
510                                .unwrap() += T::from(2 * kx).unwrap() * a * d;
511                        }
512                    }
513                    if ky > 0 {
514                        for i in 0..points.shape()[1] {
515                            let d = *data
516                                .get([tet_index(kx, ky - 1, kz), tet_index(0, 0, p - 1), i])
517                                .unwrap();
518                            *data
519                                .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
520                                .unwrap() += T::from(ky).unwrap() * a * d;
521                        }
522                    }
523                    if kz > 0 {
524                        for i in 0..points.shape()[1] {
525                            let d = *data
526                                .get([tet_index(kx, ky, kz - 1), tet_index(0, 0, p - 1), i])
527                                .unwrap();
528                            *data
529                                .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
530                                .unwrap() += T::from(kz).unwrap() * a * d;
531                        }
532                    }
533                    if p > 1 {
534                        for i in 0..points.shape()[1] {
535                            let d = *data
536                                .get([tet_index(kx, ky, kz), tet_index(0, 0, p - 2), i])
537                                .unwrap();
538                            *data
539                                .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
540                                .unwrap() -= (T::from(
541                                points.get_value([1, i]).unwrap()
542                                    + points.get_value([2, i]).unwrap(),
543                            )
544                            .unwrap()
545                                - T::from(1.0).unwrap())
546                            .powi(2)
547                                * d
548                                * (a - T::from(1.0).unwrap());
549                        }
550                        if ky > 0 {
551                            for i in 0..points.shape()[1] {
552                                let d = *data
553                                    .get([tet_index(kx, ky - 1, kz), tet_index(0, 0, p - 2), i])
554                                    .unwrap();
555                                *data
556                                    .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
557                                    .unwrap() -= T::from(ky * 2).unwrap()
558                                    * (T::from(
559                                        points.get_value([1, i]).unwrap()
560                                            + points.get_value([2, i]).unwrap(),
561                                    )
562                                    .unwrap()
563                                        - T::from(1.0).unwrap())
564                                    * d
565                                    * (a - T::from(1.0).unwrap());
566                            }
567                        }
568                        if ky > 1 {
569                            for i in 0..points.shape()[1] {
570                                let d = *data
571                                    .get([tet_index(kx, ky - 2, kz), tet_index(0, 0, p - 2), i])
572                                    .unwrap();
573                                *data
574                                    .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
575                                    .unwrap() -= T::from(ky * (ky - 1)).unwrap()
576                                    * d
577                                    * (a - T::from(1.0).unwrap());
578                            }
579                        }
580                        if kz > 0 {
581                            for i in 0..points.shape()[1] {
582                                let d = *data
583                                    .get([tet_index(kx, ky, kz - 1), tet_index(0, 0, p - 2), i])
584                                    .unwrap();
585                                *data
586                                    .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
587                                    .unwrap() -= T::from(kz * 2).unwrap()
588                                    * (T::from(
589                                        points.get_value([1, i]).unwrap()
590                                            + points.get_value([2, i]).unwrap(),
591                                    )
592                                    .unwrap()
593                                        - T::from(1.0).unwrap())
594                                    * d
595                                    * (a - T::from(1.0).unwrap());
596                            }
597                        }
598                        if kz > 1 {
599                            for i in 0..points.shape()[1] {
600                                let d = *data
601                                    .get([tet_index(kx, ky, kz - 2), tet_index(0, 0, p - 2), i])
602                                    .unwrap();
603                                *data
604                                    .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
605                                    .unwrap() -= T::from(kz * (kz - 1)).unwrap()
606                                    * d
607                                    * (a - T::from(1.0).unwrap());
608                            }
609                        }
610                        if ky > 0 && kz > 0 {
611                            for i in 0..points.shape()[1] {
612                                let d = *data
613                                    .get([tet_index(kx, ky - 1, kz - 1), tet_index(0, 0, p - 2), i])
614                                    .unwrap();
615                                *data
616                                    .get_mut([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
617                                    .unwrap() -=
618                                    T::from(2 * ky * kz).unwrap() * d * (a - T::from(1.0).unwrap());
619                            }
620                        }
621                    }
622                }
623                for p in 0..degree {
624                    for i in 0..points.shape()[1] {
625                        let d = *data
626                            .get([tet_index(kx, ky, kz), tet_index(0, 0, p), i])
627                            .unwrap();
628                        *data
629                            .get_mut([tet_index(kx, ky, kz), tet_index(0, 1, p), i])
630                            .unwrap() = d
631                            * (T::from(points.get_value([1, i]).unwrap()).unwrap()
632                                * T::from(2 * p + 3).unwrap()
633                                + T::from(points.get_value([2, i]).unwrap()).unwrap()
634                                - T::from(1).unwrap());
635                    }
636                    if ky > 0 {
637                        for i in 0..points.shape()[1] {
638                            let d = *data
639                                .get([tet_index(kx, ky - 1, kz), tet_index(0, 0, p), i])
640                                .unwrap();
641                            *data
642                                .get_mut([tet_index(kx, ky, kz), tet_index(0, 1, p), i])
643                                .unwrap() += T::from(2 * ky).unwrap()
644                                * d
645                                * (T::from(p).unwrap() + T::from(1.5).unwrap());
646                        }
647                    }
648                    if kz > 0 {
649                        for i in 0..points.shape()[1] {
650                            let d = *data
651                                .get([tet_index(kx, ky, kz - 1), tet_index(0, 0, p), i])
652                                .unwrap();
653                            *data
654                                .get_mut([tet_index(kx, ky, kz), tet_index(0, 1, p), i])
655                                .unwrap() += T::from(kz).unwrap() * d;
656                        }
657                    }
658
659                    for q in 1..degree - p {
660                        let (aq, bq, cq) = jrc::<T>(2 * p + 1, q);
661
662                        for i in 0..points.shape()[1] {
663                            let d = *data
664                                .get([tet_index(kx, ky, kz), tet_index(0, q, p), i])
665                                .unwrap();
666                            let d2 = *data
667                                .get([tet_index(kx, ky, kz), tet_index(0, q - 1, p), i])
668                                .unwrap();
669                            *data
670                                .get_mut([tet_index(kx, ky, kz), tet_index(0, q + 1, p), i])
671                                .unwrap() = d
672                                * (aq
673                                    * (T::from(points.get_value([1, i]).unwrap()).unwrap()
674                                        * T::from(2.0).unwrap()
675                                        - T::from(1.0).unwrap()
676                                        + T::from(points.get_value([2, i]).unwrap()).unwrap())
677                                    + bq * (T::from(1.0).unwrap()
678                                        - T::from(points.get_value([2, i]).unwrap()).unwrap()))
679                                - d2 * cq
680                                    * (T::from(1.0).unwrap()
681                                        - T::from(points.get_value([2, i]).unwrap()).unwrap())
682                                    .powi(2);
683                        }
684
685                        if ky > 0 {
686                            for i in 0..points.shape()[1] {
687                                let d = *data
688                                    .get([tet_index(kx, ky - 1, kz), tet_index(0, q, p), i])
689                                    .unwrap();
690                                *data
691                                    .get_mut([tet_index(kx, ky, kz), tet_index(0, q + 1, p), i])
692                                    .unwrap() += T::from(2 * ky).unwrap() * d * aq;
693                            }
694                        }
695                        if kz > 0 {
696                            for i in 0..points.shape()[1] {
697                                let d = *data
698                                    .get([tet_index(kx, ky, kz - 1), tet_index(0, q, p), i])
699                                    .unwrap();
700                                let d2 = *data
701                                    .get([tet_index(kx, ky, kz - 1), tet_index(0, q - 1, p), i])
702                                    .unwrap();
703                                *data
704                                    .get_mut([tet_index(kx, ky, kz), tet_index(0, q + 1, p), i])
705                                    .unwrap() += T::from(kz).unwrap() * d * (aq - bq)
706                                    + T::from(2 * kz).unwrap()
707                                        * (T::from(1.0).unwrap()
708                                            - T::from(points.get_value([2, i]).unwrap()).unwrap())
709                                        * d2
710                                        * cq;
711                            }
712                        }
713                        if kz > 1 {
714                            for i in 0..points.shape()[1] {
715                                let d = *data
716                                    .get([tet_index(kx, ky, kz - 2), tet_index(0, q - 1, p), i])
717                                    .unwrap();
718                                *data
719                                    .get_mut([tet_index(kx, ky, kz), tet_index(0, q + 1, p), i])
720                                    .unwrap() -= T::from(kz * (kz - 1)).unwrap() * d * cq;
721                            }
722                        }
723                    }
724                }
725
726                for p in 0..degree {
727                    for q in 0..degree - p {
728                        for i in 0..points.shape()[1] {
729                            let d = *data
730                                .get([tet_index(kx, ky, kz), tet_index(0, q, p), i])
731                                .unwrap();
732                            *data
733                                .get_mut([tet_index(kx, ky, kz), tet_index(1, q, p), i])
734                                .unwrap() = d
735                                * (T::from(points.get_value([2, i]).unwrap()).unwrap()
736                                    * T::from(2 + p + q).unwrap()
737                                    * T::from(2.0).unwrap()
738                                    - T::from(1.0).unwrap());
739                        }
740                        if kz > 0 {
741                            for i in 0..points.shape()[1] {
742                                let d = *data
743                                    .get([tet_index(kx, ky, kz - 1), tet_index(0, q, p), i])
744                                    .unwrap();
745                                *data
746                                    .get_mut([tet_index(kx, ky, kz), tet_index(1, q, p), i])
747                                    .unwrap() += T::from(2 * kz * (2 + p + q)).unwrap() * d;
748                            }
749                        }
750                    }
751                }
752
753                if degree > 0 {
754                    for p in 0..degree - 1 {
755                        for q in 0..degree - 1 - p {
756                            for r in 1..degree - p - q {
757                                let (ar, br, cr) = jrc::<T>(2 * p + 2 * q + 2, r);
758
759                                for i in 0..points.shape()[1] {
760                                    let d = *data
761                                        .get([tet_index(kx, ky, kz), tet_index(r, q, p), i])
762                                        .unwrap();
763                                    let d2 = *data
764                                        .get([tet_index(kx, ky, kz), tet_index(r - 1, q, p), i])
765                                        .unwrap();
766                                    *data
767                                        .get_mut([tet_index(kx, ky, kz), tet_index(r + 1, q, p), i])
768                                        .unwrap() = d
769                                        * (ar
770                                            * (T::from(2.0).unwrap()
771                                                * T::from(points.get_value([2, i]).unwrap())
772                                                    .unwrap()
773                                                - T::from(1.0).unwrap())
774                                            + br)
775                                        - d2 * cr;
776                                }
777                                if kz > 0 {
778                                    for i in 0..points.shape()[1] {
779                                        let d = *data
780                                            .get([tet_index(kx, ky, kz - 1), tet_index(r, q, p), i])
781                                            .unwrap();
782                                        *data
783                                            .get_mut([
784                                                tet_index(kx, ky, kz),
785                                                tet_index(r + 1, q, p),
786                                                i,
787                                            ])
788                                            .unwrap() += T::from(2 * kz).unwrap() * ar * d;
789                                    }
790                                }
791                            }
792                        }
793                    }
794                }
795            }
796        }
797    }
798
799    // Normalise
800    for p in 0..=degree {
801        for q in 0..=degree - p {
802            for r in 0..=degree - p - q {
803                let norm = T::sqrt(
804                    (T::from(p).unwrap() + T::from(0.5).unwrap())
805                        * T::from(p + q + 1).unwrap()
806                        * (T::from(p + q + r).unwrap() + T::from(1.5).unwrap()),
807                ) * T::from(2).unwrap()
808                    / T::sqrt(T::from(3).unwrap());
809                for j in 0..data.shape()[2] {
810                    for i in 0..data.shape()[0] {
811                        *data.get_mut([i, tet_index(r, q, p), j]).unwrap() *= norm;
812                    }
813                }
814            }
815        }
816    }
817}
818
819/// Tabulate orthonormal polynomials on a hexahedron
820fn tabulate_hexahedron<
821    T: RlstScalar,
822    TGeo: RlstScalar,
823    Array2Impl: ValueArrayImpl<TGeo, 2>,
824    Array3MutImpl: MutableArrayImpl<T, 3>,
825>(
826    points: &Array<Array2Impl, 2>,
827    degree: usize,
828    derivatives: usize,
829    data: &mut Array<Array3MutImpl, 3>,
830) {
831    debug_assert!(data.shape()[0] == (derivatives + 1) * (derivatives + 2) * (derivatives + 3) / 6);
832    debug_assert!(data.shape()[1] == (degree + 1) * (degree + 1) * (degree + 1));
833    debug_assert!(data.shape()[2] == points.shape()[1]);
834    debug_assert!(points.shape()[0] == 3);
835
836    for i in 0..data.shape()[2] {
837        *data
838            .get_mut([tet_index(0, 0, 0), hex_index(0, 0, 0, degree), i])
839            .unwrap() = T::from(1.0).unwrap();
840    }
841
842    // Tabulate polynomials in x
843    for k in 1..=derivatives {
844        for i in 0..data.shape()[2] {
845            *data
846                .get_mut([tet_index(k, 0, 0), hex_index(0, 0, 0, degree), i])
847                .unwrap() = T::from(0.0).unwrap();
848        }
849    }
850
851    for k in 0..=derivatives {
852        for p in 1..=degree {
853            let a = T::from(1.0).unwrap() - T::from(1.0).unwrap() / T::from(p).unwrap();
854            let b = (a + T::from(1.0).unwrap())
855                * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
856                    / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(1.0).unwrap()))
857                .sqrt();
858            for i in 0..data.shape()[2] {
859                let d = *data
860                    .get([tet_index(k, 0, 0), hex_index(p - 1, 0, 0, degree), i])
861                    .unwrap();
862                *data
863                    .get_mut([tet_index(k, 0, 0), hex_index(p, 0, 0, degree), i])
864                    .unwrap() = (T::from(points.get_value([0, i]).unwrap()).unwrap()
865                    * T::from(2.0).unwrap()
866                    - T::from(1.0).unwrap())
867                    * d
868                    * b;
869            }
870            if p > 1 {
871                let c = a
872                    * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
873                        / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(3.0).unwrap()))
874                    .sqrt();
875                for i in 0..data.shape()[2] {
876                    let d = *data
877                        .get([tet_index(k, 0, 0), hex_index(p - 2, 0, 0, degree), i])
878                        .unwrap();
879                    *data
880                        .get_mut([tet_index(k, 0, 0), hex_index(p, 0, 0, degree), i])
881                        .unwrap() -= d * c;
882                }
883            }
884            if k > 0 {
885                for i in 0..data.shape()[2] {
886                    let d = *data
887                        .get([tet_index(k - 1, 0, 0), hex_index(p - 1, 0, 0, degree), i])
888                        .unwrap();
889                    *data
890                        .get_mut([tet_index(k, 0, 0), hex_index(p, 0, 0, degree), i])
891                        .unwrap() += T::from(2.0).unwrap() * T::from(k).unwrap() * d * b;
892                }
893            }
894        }
895    }
896
897    // Tabulate polynomials in y
898    for k in 1..=derivatives {
899        for i in 0..data.shape()[2] {
900            *data
901                .get_mut([tet_index(0, k, 0), hex_index(0, 0, 0, degree), i])
902                .unwrap() = T::from(0.0).unwrap();
903        }
904    }
905
906    for k in 0..=derivatives {
907        for p in 1..=degree {
908            let a = T::from(1.0).unwrap() - T::from(1.0).unwrap() / T::from(p).unwrap();
909            let b = (a + T::from(1.0).unwrap())
910                * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
911                    / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(1.0).unwrap()))
912                .sqrt();
913            for i in 0..data.shape()[2] {
914                let d = *data
915                    .get([tet_index(0, k, 0), hex_index(0, p - 1, 0, degree), i])
916                    .unwrap();
917                *data
918                    .get_mut([tet_index(0, k, 0), hex_index(0, p, 0, degree), i])
919                    .unwrap() = (T::from(points.get_value([1, i]).unwrap()).unwrap()
920                    * T::from(2.0).unwrap()
921                    - T::from(1.0).unwrap())
922                    * d
923                    * b;
924            }
925            if p > 1 {
926                let c = a
927                    * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
928                        / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(3.0).unwrap()))
929                    .sqrt();
930                for i in 0..data.shape()[2] {
931                    let d = *data
932                        .get([tet_index(0, k, 0), hex_index(0, p - 2, 0, degree), i])
933                        .unwrap();
934                    *data
935                        .get_mut([tet_index(0, k, 0), hex_index(0, p, 0, degree), i])
936                        .unwrap() -= d * c;
937                }
938            }
939            if k > 0 {
940                for i in 0..data.shape()[2] {
941                    let d = *data
942                        .get([tet_index(0, k - 1, 0), hex_index(0, p - 1, 0, degree), i])
943                        .unwrap();
944                    *data
945                        .get_mut([tet_index(0, k, 0), hex_index(0, p, 0, degree), i])
946                        .unwrap() += T::from(2.0).unwrap() * T::from(k).unwrap() * d * b;
947                }
948            }
949        }
950    }
951
952    // Tabulate polynomials in z
953    for k in 1..=derivatives {
954        for i in 0..data.shape()[2] {
955            *data
956                .get_mut([tet_index(0, 0, k), hex_index(0, 0, 0, degree), i])
957                .unwrap() = T::from(0.0).unwrap();
958        }
959    }
960
961    for k in 0..=derivatives {
962        for p in 1..=degree {
963            let a = T::from(1.0).unwrap() - T::from(1.0).unwrap() / T::from(p).unwrap();
964            let b = (a + T::from(1.0).unwrap())
965                * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
966                    / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(1.0).unwrap()))
967                .sqrt();
968            for i in 0..data.shape()[2] {
969                let d = *data
970                    .get([tet_index(0, 0, k), hex_index(0, 0, p - 1, degree), i])
971                    .unwrap();
972                *data
973                    .get_mut([tet_index(0, 0, k), hex_index(0, 0, p, degree), i])
974                    .unwrap() = (T::from(points.get_value([2, i]).unwrap()).unwrap()
975                    * T::from(2.0).unwrap()
976                    - T::from(1.0).unwrap())
977                    * d
978                    * b;
979            }
980            if p > 1 {
981                let c = a
982                    * ((T::from(2.0).unwrap() * T::from(p).unwrap() + T::from(1.0).unwrap())
983                        / (T::from(2.0).unwrap() * T::from(p).unwrap() - T::from(3.0).unwrap()))
984                    .sqrt();
985                for i in 0..data.shape()[2] {
986                    let d = *data
987                        .get([tet_index(0, 0, k), hex_index(0, 0, p - 2, degree), i])
988                        .unwrap();
989                    *data
990                        .get_mut([tet_index(0, 0, k), hex_index(0, 0, p, degree), i])
991                        .unwrap() -= d * c;
992                }
993            }
994            if k > 0 {
995                for i in 0..data.shape()[2] {
996                    let d = *data
997                        .get([tet_index(0, 0, k - 1), hex_index(0, 0, p - 1, degree), i])
998                        .unwrap();
999                    *data
1000                        .get_mut([tet_index(0, 0, k), hex_index(0, 0, p, degree), i])
1001                        .unwrap() += T::from(2.0).unwrap() * T::from(k).unwrap() * d * b;
1002                }
1003            }
1004        }
1005    }
1006
1007    // Fill in the rest of the values as products
1008    for kx in 0..=derivatives {
1009        for ky in 0..=derivatives - kx {
1010            for kz in 0..=derivatives - kx - ky {
1011                for px in 0..=degree {
1012                    for py in if px == 0 { 1 } else { 0 }..=degree {
1013                        for pz in if px * py == 0 { 1 } else { 0 }..=degree {
1014                            for i in 0..data.shape()[2] {
1015                                let dx = *data
1016                                    .get([tet_index(kx, 0, 0), hex_index(px, 0, 0, degree), i])
1017                                    .unwrap();
1018                                let dy = *data
1019                                    .get([tet_index(0, ky, 0), hex_index(0, py, 0, degree), i])
1020                                    .unwrap();
1021                                let dz = *data
1022                                    .get([tet_index(0, 0, kz), hex_index(0, 0, pz, degree), i])
1023                                    .unwrap();
1024                                *data
1025                                    .get_mut([
1026                                        tet_index(kx, ky, kz),
1027                                        hex_index(px, py, pz, degree),
1028                                        i,
1029                                    ])
1030                                    .unwrap() = dx * dy * dz;
1031                            }
1032                        }
1033                    }
1034                }
1035            }
1036        }
1037    }
1038}
1039
1040/// The shape of a table containing the values of Legendre polynomials
1041pub fn shape<T, Array2Impl: ValueArrayImpl<T, 2>>(
1042    cell_type: ReferenceCellType,
1043    points: &Array<Array2Impl, 2>,
1044    degree: usize,
1045    derivatives: usize,
1046) -> [usize; 3] {
1047    [
1048        derivative_count(cell_type, derivatives),
1049        polynomial_count(cell_type, degree),
1050        points.shape()[1],
1051    ]
1052}
1053
1054/// Tabulate orthonormal polynomials
1055pub fn tabulate<
1056    T: RlstScalar,
1057    TGeo: RlstScalar,
1058    Array2Impl: ValueArrayImpl<TGeo, 2>,
1059    Array3MutImpl: MutableArrayImpl<T, 3>,
1060>(
1061    cell_type: ReferenceCellType,
1062    points: &Array<Array2Impl, 2>,
1063    degree: usize,
1064    derivatives: usize,
1065    data: &mut Array<Array3MutImpl, 3>,
1066) {
1067    match cell_type {
1068        ReferenceCellType::Interval => tabulate_interval(points, degree, derivatives, data),
1069        ReferenceCellType::Triangle => tabulate_triangle(points, degree, derivatives, data),
1070        ReferenceCellType::Quadrilateral => {
1071            tabulate_quadrilateral(points, degree, derivatives, data)
1072        }
1073        ReferenceCellType::Tetrahedron => tabulate_tetrahedron(points, degree, derivatives, data),
1074        ReferenceCellType::Hexahedron => tabulate_hexahedron(points, degree, derivatives, data),
1075        _ => {
1076            panic!("Unsupported cell type: {cell_type:?}");
1077        }
1078    };
1079}