1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use core::{
    fmt,
    iter::Empty,
    marker::PhantomData,
    ops::{Deref, DerefMut, RangeFrom},
};

use necsim_core::cogs::MathsCore;
use necsim_core_bond::ClosedUnitF64;
use necsim_partitioning_core::partition::Partition;

const INV_PHI: f64 = 6.180_339_887_498_949e-1_f64;

#[allow(clippy::module_name_repetitions)]
pub struct OriginPreSampler<M: MathsCore, I: Iterator<Item = u64>> {
    inner: I,
    proportion: ClosedUnitF64,
    _marker: PhantomData<M>,
}

impl<M: MathsCore, I: Iterator<Item = u64>> OriginPreSampler<M, I> {
    pub fn get_sample_proportion(&self) -> ClosedUnitF64 {
        self.proportion
    }
}

impl<M: MathsCore, I: Iterator<Item = u64>> Deref for OriginPreSampler<M, I> {
    type Target = I;

    fn deref(&self) -> &Self::Target {
        &self.inner
    }
}

impl<M: MathsCore, I: Iterator<Item = u64>> DerefMut for OriginPreSampler<M, I> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.inner
    }
}

impl<M: MathsCore, I: Iterator<Item = u64>> fmt::Debug for OriginPreSampler<M, I> {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        fmt.debug_struct(stringify!(OriginPreSampler))
            .field("proportion", &self.proportion)
            .finish_non_exhaustive()
    }
}

impl<M: MathsCore> OriginPreSampler<M, RangeFrom<u64>> {
    #[must_use]
    pub fn all() -> Self {
        Self {
            inner: 0..,
            proportion: ClosedUnitF64::one(),
            _marker: PhantomData::<M>,
        }
    }
}

impl<M: MathsCore> OriginPreSampler<M, Empty<u64>> {
    #[must_use]
    pub fn none() -> Self {
        Self {
            inner: core::iter::empty(),
            proportion: ClosedUnitF64::zero(),
            _marker: PhantomData::<M>,
        }
    }
}

impl<M: MathsCore, I: Iterator<Item = u64>> OriginPreSampler<M, I> {
    #[must_use]
    pub fn percentage(
        mut self,
        percentage: ClosedUnitF64,
    ) -> OriginPreSampler<M, impl Iterator<Item = u64>> {
        let inv_geometric_sample_rate = M::ln(1.0_f64 - percentage.get()).recip();

        OriginPreSampler {
            proportion: self.proportion * percentage,
            inner: core::iter::repeat(()).scan(0.5_f64, move |quasi_random, ()| {
                if percentage <= 0.0_f64 {
                    return None;
                }

                if percentage >= 1.0_f64 {
                    return self.next();
                }

                // q = (q + INV_PHI) % 1  where q >= 0
                *quasi_random += INV_PHI;
                *quasi_random -= M::floor(*quasi_random);

                #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
                let skip = M::floor(M::ln(*quasi_random) * inv_geometric_sample_rate) as usize;

                self.nth(skip)
            }),
            _marker: PhantomData::<M>,
        }
    }

    pub fn partition(
        mut self,
        partition: Partition,
    ) -> OriginPreSampler<M, impl Iterator<Item = u64>> {
        let _ = self.advance_by(partition.rank() as usize);

        OriginPreSampler {
            proportion: self.proportion / partition.size().0,
            inner: self.inner.step_by(partition.size().get() as usize),
            _marker: PhantomData::<M>,
        }
    }
}