bempp_octree/
parsort.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
//! Implementation of a parallel samplesort.

use std::fmt::Display;

use itertools::Itertools;
use mpi::traits::CommunicatorCollectives;
use mpi::traits::Equivalence;
use rand::{seq::SliceRandom, Rng};

use crate::morton::MortonKey;
use crate::tools::{gather_to_all, global_max, global_min, redistribute_by_bins};

const OVERSAMPLING: usize = 8;

/// An internal struct. We convert every array element
/// into this struct. The idea is that this is guaranteed to be unique
/// as it encodes not only the element but also its rank and index.
#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Equivalence)]
struct UniqueItem {
    pub value: MortonKey,
    pub rank: usize,
    pub index: usize,
}

impl Display for UniqueItem {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "(value: {}, rank: {}, index: {})",
            self.value, self.rank, self.index
        )
    }
}

impl UniqueItem {
    pub fn new(value: MortonKey, rank: usize, index: usize) -> Self {
        Self { value, rank, index }
    }
}

fn to_unique_item(arr: &[MortonKey], rank: usize) -> Vec<UniqueItem> {
    arr.iter()
        .enumerate()
        .map(|(index, &item)| UniqueItem::new(item, rank, index))
        .collect()
}

fn get_buckets<C, R>(arr: &[UniqueItem], comm: &C, rng: &mut R) -> Vec<UniqueItem>
where
    C: CommunicatorCollectives,
    R: Rng + ?Sized,
{
    let size = comm.size() as usize;

    // In the first step we pick `oversampling * nprocs` splitters.

    let oversampling = if arr.len() < OVERSAMPLING {
        arr.len()
    } else {
        OVERSAMPLING
    };

    // We get the global smallest and global largest element. We do not want those
    // in the splitter so filter out their occurence.

    let global_min_elem = global_min(arr, comm);
    let global_max_elem = global_max(arr, comm);

    // We do not want the global smallest element in the splitter.

    let splitters = arr
        .choose_multiple(rng, oversampling)
        .copied()
        .collect::<Vec<_>>();

    // We gather the splitters into all ranks so that each rank has all splitters.

    let mut all_splitters = gather_to_all(&splitters, comm);

    // We now have all splitters available on each process.
    // We can now sort the splitters. Every process will then have the same list of sorted splitters.

    all_splitters.sort_unstable();

    // We now insert the smallest and largest possible element if they are not already
    // in the splitter collection.

    if *all_splitters.first().unwrap() != global_min_elem {
        all_splitters.insert(0, global_min_elem)
    }

    if *all_splitters.last().unwrap() != global_max_elem {
        all_splitters.push(global_max_elem);
    }

    // We now define p buckets (p is number of processors) and we return
    // a p element array containing the first element of each bucket

    all_splitters = split(&all_splitters, size)
        .map(|slice| slice.first().unwrap())
        .copied()
        .collect::<Vec<_>>();

    all_splitters
}

/// Parallel sort
pub fn parsort<C: CommunicatorCollectives, R: Rng + ?Sized>(
    arr: &[MortonKey],
    comm: &C,
    rng: &mut R,
) -> Vec<MortonKey> {
    let size = comm.size() as usize;
    let rank = comm.rank() as usize;
    // If we only have a single rank simply sort the local array and return

    let mut arr = arr.to_vec();

    if size == 1 {
        arr.sort_unstable();
        return arr;
    }

    // We first convert the array into unique elements by adding information
    // about index and rank. This guarantees that we don't have duplicates in
    // our sorting set.

    let mut arr = to_unique_item(&arr, rank);

    // We now sort the local array.

    arr.sort_unstable();

    // Let us now get the buckets.

    let buckets = get_buckets(&arr, comm, rng);

    // We now redistribute with respect to these buckets.
    let mut recvbuffer = redistribute_by_bins(&arr, &buckets, comm);

    // We now have everything in the receive buffer. Now sort the local elements and return

    recvbuffer.sort_unstable();
    recvbuffer.iter().map(|&elem| elem.value).collect_vec()
}

// The following is a simple iterator that splits a slice into n
// chunks. It is from https://users.rust-lang.org/t/how-to-split-a-slice-into-n-chunks/40008/3

fn split<T>(slice: &[T], n: usize) -> impl Iterator<Item = &[T]> {
    let len = slice.len() / n;
    let rem = slice.len() % n;
    Split { slice, len, rem }
}

struct Split<'a, T> {
    slice: &'a [T],
    len: usize,
    rem: usize,
}

impl<'a, T> Iterator for Split<'a, T> {
    type Item = &'a [T];

    fn next(&mut self) -> Option<Self::Item> {
        if self.slice.is_empty() {
            return None;
        }
        let mut len = self.len;
        if self.rem > 0 {
            len += 1;
            self.rem -= 1;
        }
        let (chunk, rest) = self.slice.split_at(len);
        self.slice = rest;
        Some(chunk)
    }
}