1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use std::{fmt, marker::PhantomData};

use pcg_rand::{seeds::PcgSeeder, PCGStateInfo, Pcg64};
use rand_core::{RngCore as _, SeedableRng};
use serde::{Deserialize, Serialize};

use necsim_core::cogs::{MathsCore, RngCore, SplittableRng};

#[allow(clippy::module_name_repetitions)]
#[derive(Serialize, Deserialize)]
#[serde(from = "PcgState", into = "PcgState")]
pub struct Pcg<M: MathsCore> {
    inner: Pcg64,
    marker: PhantomData<M>,
}

impl<M: MathsCore> Clone for Pcg<M> {
    fn clone(&self) -> Self {
        Self {
            inner: Pcg64::restore_state_with_no_verification(self.inner.get_state()),
            marker: PhantomData::<M>,
        }
    }
}

impl<M: MathsCore> fmt::Debug for Pcg<M> {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        let state = self.inner.get_state();

        fmt.debug_struct("Pcg")
            .field("state", &state.state)
            .field("stream", &(state.increment >> 1))
            .finish()
    }
}

impl<M: MathsCore> RngCore<M> for Pcg<M> {
    type Seed = [u8; 16];

    #[must_use]
    #[inline]
    fn from_seed(seed: Self::Seed) -> Self {
        Self {
            inner: Pcg64::from_seed(PcgSeeder::seed_with_stream(
                u128::from_le_bytes(seed),
                0_u128,
            )),
            marker: PhantomData::<M>,
        }
    }

    #[must_use]
    #[inline]
    fn sample_u64(&mut self) -> u64 {
        self.inner.next_u64()
    }
}

impl<M: MathsCore> SplittableRng<M> for Pcg<M> {
    #[allow(clippy::identity_op)]
    fn split(self) -> (Self, Self) {
        let mut left_state = self.inner.get_state();
        left_state.increment = (((left_state.increment >> 1) * 2 + 0) << 1) | 1;

        let mut right_state = self.inner.get_state();
        right_state.increment = (((right_state.increment >> 1) * 2 + 1) << 1) | 1;

        let left = Self {
            inner: Pcg64::restore_state_with_no_verification(left_state),
            marker: PhantomData::<M>,
        };
        let right = Self {
            inner: Pcg64::restore_state_with_no_verification(right_state),
            marker: PhantomData::<M>,
        };

        (left, right)
    }

    fn split_to_stream(self, stream: u64) -> Self {
        let mut state = self.inner.get_state();
        state.increment = (u128::from(stream) << 1) | 1;

        Self {
            inner: Pcg64::restore_state_with_no_verification(state),
            marker: PhantomData::<M>,
        }
    }
}

#[derive(Serialize, Deserialize)]
#[serde(rename = "Pcg")]
#[serde(deny_unknown_fields)]
struct PcgState {
    state: u128,
    increment: u128,
}

impl<M: MathsCore> From<Pcg<M>> for PcgState {
    fn from(rng: Pcg<M>) -> Self {
        let state_info = rng.inner.get_state();

        Self {
            state: state_info.state,
            increment: state_info.increment,
        }
    }
}

impl<M: MathsCore> From<PcgState> for Pcg<M> {
    fn from(state: PcgState) -> Self {
        use pcg_rand::{
            multiplier::{DefaultMultiplier, Multiplier},
            outputmix::{DXsMMixin, OutputMixin},
        };

        let state_info = PCGStateInfo {
            state: state.state,
            increment: state.increment,
            multiplier: DefaultMultiplier::multiplier(),
            internal_width: u128::BITS as usize,
            output_width: u64::BITS as usize,
            output_mixin: <DXsMMixin as OutputMixin<u128, u64>>::SERIALIZER_ID.into(),
        };

        Self {
            inner: Pcg64::restore_state_with_no_verification(state_info),
            marker: PhantomData::<M>,
        }
    }
}