use core::cmp::Ordering;
use alloc::vec::Vec;
use necsim_core::cogs::{MathsCore, RngCore};
use necsim_core_bond::{ClosedUnitF64, NonNegativeF64};
pub mod packed;
#[allow(clippy::module_name_repetitions)]
#[allow(non_snake_case)]
#[derive(Clone)]
pub struct AliasMethodSampler<E: Copy + PartialEq> {
Us: Vec<ClosedUnitF64>,
Es: Vec<E>,
Ks: Vec<E>,
}
impl<E: Copy + PartialEq> AliasMethodSampler<E> {
#[allow(clippy::no_effect_underscore_binding)]
#[debug_requires(!event_weights.is_empty(), "event_weights is non-empty")]
#[debug_ensures(
ret.Es.iter().eq(old(event_weights).iter().map(|(e, _p)| e)),
"stores exactly the input events"
)]
#[debug_ensures(
ret.Us.iter().zip(ret.Es.iter()).zip(ret.Ks.iter()).all(|((u, e), k)| {
let full_bucket = *u == ClosedUnitF64::one();
!full_bucket || (e == k)
}),
"full buckets sample the same event just in case"
)]
pub fn new(event_weights: &[(E, NonNegativeF64)]) -> Self {
#[allow(non_snake_case)]
let mut Us = Vec::with_capacity(event_weights.len());
#[allow(non_snake_case)]
let mut Es = Vec::with_capacity(event_weights.len());
#[allow(non_snake_case)]
let mut Ks = Vec::with_capacity(event_weights.len());
let total_weight: NonNegativeF64 = event_weights.iter().map(|(_e, p)| *p).sum();
let n = NonNegativeF64::from(event_weights.len());
for (event, weight) in event_weights {
Us.push(*weight * n / total_weight);
Es.push(*event);
Ks.push(*event);
}
let mut overfull_indices: Vec<usize> = (0..Us.len()).filter(|i| Us[*i] > 1.0_f64).collect();
let mut underfull_indices: Vec<usize> =
(0..Us.len()).filter(|i| Us[*i] < 1.0_f64).collect();
while let Some((overfull_index, underfull_index)) =
pop_overfull_underfull_pair_atomic(&mut overfull_indices, &mut underfull_indices)
{
Us[overfull_index] = unsafe {
NonNegativeF64::new_unchecked(
Us[overfull_index].get() + Us[underfull_index].get() - 1.0_f64,
)
};
Ks[underfull_index] = Es[overfull_index];
match Us[overfull_index].cmp(&NonNegativeF64::one()) {
Ordering::Less => underfull_indices.push(overfull_index),
Ordering::Equal => (),
Ordering::Greater => overfull_indices.push(overfull_index),
}
}
overfull_indices
.into_iter()
.for_each(|i| Us[i] = NonNegativeF64::one());
underfull_indices
.into_iter()
.for_each(|i| Us[i] = NonNegativeF64::one());
#[allow(non_snake_case)]
let Us = unsafe { core::mem::transmute::<Vec<NonNegativeF64>, Vec<ClosedUnitF64>>(Us) };
Self { Us, Es, Ks }
}
#[debug_ensures(self.Es.contains(&ret), "returns one of the weighted events")]
pub fn sample_event<M: MathsCore, G: RngCore<M>>(&self, rng: &mut G) -> E {
use necsim_core::cogs::RngSampler;
let x = rng.sample_uniform_closed_open();
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let i = M::floor(x.get() * (self.Es.len() as f64)) as usize; #[allow(clippy::cast_precision_loss)]
let y = x.get() * (self.Es.len() as f64) - (i as f64); if y < self.Us[i].get() {
self.Es[i]
} else {
self.Ks[i]
}
}
}
fn pop_overfull_underfull_pair_atomic(
overfull_indices: &mut Vec<usize>,
underfull_indices: &mut Vec<usize>,
) -> Option<(usize, usize)> {
match (overfull_indices.pop(), underfull_indices.pop()) {
(Some(overfull_index), Some(underfull_index)) => Some((overfull_index, underfull_index)),
(Some(overfull_index), None) => {
overfull_indices.push(overfull_index);
None
},
(None, Some(underfull_index)) => {
underfull_indices.push(underfull_index);
None
},
(None, None) => None,
}
}