#![deny(clippy::pedantic)]
#[macro_use]
extern crate contracts;
use std::{fmt, mem::ManuallyDrop, num::NonZeroU32, time::Duration};
use anyhow::Context;
use humantime_serde::re::humantime::format_duration;
use mpi::{
datatype::PartitionMut,
environment::Universe,
topology::{Communicator, Rank, SimpleCommunicator},
traits::CommunicatorCollectives,
Count, Tag,
};
use serde::{ser::SerializeStruct, Deserialize, Serialize, Serializer};
use serde_derive_state::DeserializeState;
use serde_state::{DeserializeState, Deserializer};
use thiserror::Error;
use necsim_core::{
lineage::MigratingLineage,
reporter::{
boolean::{False, True},
FilteredReporter, Reporter,
},
};
use necsim_impls_std::event_log::recorder::EventLogConfig;
use necsim_partitioning_core::{
partition::PartitionSize,
reporter::{FinalisableReporter, OpaqueFinalisableReporter, ReporterContext},
Data, Partitioning,
};
mod partition;
mod request;
pub use partition::{MpiLocalPartition, MpiParallelPartition, MpiRootPartition};
use request::{reduce_scope, DataOrRequest};
#[derive(Error, Debug)]
pub enum MpiPartitioningError {
#[error("MPI has already been initialised.")]
AlreadyInitialised,
#[error("MPI must be initialised with at least two partitions.")]
NoParallelism,
}
#[derive(Error, Debug)]
pub enum MpiLocalPartitionError {
#[error("MPI partitioning requires an event log.")]
MissingEventLog,
#[error("Failed to create the event sub-log.")]
InvalidEventSubLog,
}
pub struct MpiPartitioning {
universe: ManuallyDrop<Universe>,
world: SimpleCommunicator,
migration_interval: Duration,
progress_interval: Duration,
}
impl fmt::Debug for MpiPartitioning {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
struct FormattedDuration(Duration);
impl fmt::Debug for FormattedDuration {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(&format_duration(self.0).to_string())
}
}
fmt.debug_struct(stringify!(MpiPartitioning))
.field("world", &self.get_size().get())
.field(
"migration_interval",
&FormattedDuration(self.migration_interval),
)
.field(
"progress_interval",
&FormattedDuration(self.progress_interval),
)
.finish_non_exhaustive()
}
}
impl Serialize for MpiPartitioning {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut args = serializer.serialize_struct(stringify!(MpiPartitioning), 3)?;
args.serialize_field("world", &self.get_size())?;
args.serialize_field(
"migration",
&format_duration(self.migration_interval).to_string(),
)?;
args.serialize_field(
"progress",
&format_duration(self.progress_interval).to_string(),
)?;
args.end()
}
}
impl<'de> Deserialize<'de> for MpiPartitioning {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let mut partitioning = Self::initialise().map_err(serde::de::Error::custom)?;
let raw =
MpiPartitioningRaw::deserialize_state(&mut partitioning.get_size(), deserializer)?;
partitioning.set_migration_interval(raw.migration_interval);
partitioning.set_progress_interval(raw.progress_interval);
Ok(partitioning)
}
}
impl MpiPartitioning {
const MPI_DEFAULT_MIGRATION_INTERVAL: Duration = Duration::from_millis(100_u64);
const MPI_DEFAULT_PROGRESS_INTERVAL: Duration = Duration::from_millis(100_u64);
const MPI_MIGRATION_TAG: Tag = 1;
const MPI_PROGRESS_TAG: Tag = 0;
const ROOT_RANK: Rank = 0;
pub fn initialise() -> Result<Self, MpiPartitioningError> {
let universe =
ManuallyDrop::new(mpi::initialize().ok_or(MpiPartitioningError::AlreadyInitialised)?);
let world = universe.world();
if world.size() > 1 {
Ok(Self {
universe,
world,
migration_interval: Self::MPI_DEFAULT_MIGRATION_INTERVAL,
progress_interval: Self::MPI_DEFAULT_PROGRESS_INTERVAL,
})
} else {
Err(MpiPartitioningError::NoParallelism)
}
}
pub fn set_migration_interval(&mut self, migration_interval: Duration) {
self.migration_interval = migration_interval;
}
pub fn set_progress_interval(&mut self, progress_interval: Duration) {
self.progress_interval = progress_interval;
}
#[debug_ensures(
self.get_size().is_monolithic() -> ret,
"monolithic partition is always root"
)]
#[must_use]
pub fn peek_is_root(&self) -> bool {
self.world.rank() == MpiPartitioning::ROOT_RANK
}
}
impl Partitioning for MpiPartitioning {
type Auxiliary = Option<EventLogConfig>;
type FinalisableReporter<R: Reporter> = FinalisableMpiReporter<R>;
type LocalPartition<'p, R: Reporter> = MpiLocalPartition<'p, R>;
fn get_size(&self) -> PartitionSize {
#[allow(clippy::cast_sign_loss)]
let size = unsafe { NonZeroU32::new_unchecked(self.world.size() as u32) };
PartitionSize(size)
}
fn with_local_partition<
R: Reporter,
P: ReporterContext<Reporter = R>,
A: Data,
Q: Data + serde::Serialize + serde::de::DeserializeOwned,
>(
self,
reporter_context: P,
event_log: Self::Auxiliary,
args: A,
inner: for<'p> fn(&mut Self::LocalPartition<'p, R>, A) -> Q,
fold: fn(Q, Q) -> Q,
) -> anyhow::Result<(Q, Self::FinalisableReporter<R>)> {
let Some(event_log) = event_log else {
anyhow::bail!(MpiLocalPartitionError::MissingEventLog)
};
let partition_event_log = event_log
.new_child_log(&self.world.rank().to_string())
.and_then(EventLogConfig::create)
.context(MpiLocalPartitionError::InvalidEventSubLog)?;
let mut mpi_local_global_wait = (false, false);
let mut mpi_local_remaining = 0_u64;
#[allow(clippy::cast_sign_loss)]
let world_size = self.world.size() as usize;
let mut mpi_emigration_buffers: Vec<Vec<MigratingLineage>> = Vec::with_capacity(world_size);
mpi_emigration_buffers.resize_with(world_size, Vec::new);
mpi::request::scope(|scope| {
let scope = reduce_scope(scope);
let mpi_local_global_wait = DataOrRequest::new(&mut mpi_local_global_wait, scope);
let mpi_local_remaining = DataOrRequest::new(&mut mpi_local_remaining, scope);
let mpi_emigration_buffers = mpi_emigration_buffers
.iter_mut()
.map(|buffer| DataOrRequest::new(buffer, scope))
.collect::<Vec<_>>()
.into_boxed_slice();
let mut local_partition = if self.world.rank() == MpiPartitioning::ROOT_RANK {
MpiLocalPartition::Root(Box::new(MpiRootPartition::new(
ManuallyDrop::into_inner(self.universe),
mpi_local_global_wait,
mpi_emigration_buffers,
reporter_context.try_build()?,
partition_event_log,
self.migration_interval,
self.progress_interval,
)))
} else {
MpiLocalPartition::Parallel(Box::new(MpiParallelPartition::new(
ManuallyDrop::into_inner(self.universe),
mpi_local_global_wait,
mpi_local_remaining,
mpi_emigration_buffers,
partition_event_log,
self.migration_interval,
self.progress_interval,
)))
};
let local_result = inner(&mut local_partition, args);
let result = reduce_partitioning_data(&self.world, local_result, fold)?;
Ok((result, local_partition.into_reporter()))
})
}
}
#[derive(DeserializeState)]
#[serde(rename = "MpiPartitioning")]
#[serde(deny_unknown_fields)]
#[serde(deserialize_state = "PartitionSize")]
#[serde(default)]
#[allow(dead_code)]
struct MpiPartitioningRaw {
#[serde(deserialize_state_with = "deserialize_state_mpi_world")]
world: Option<PartitionSize>,
#[serde(alias = "migration")]
#[serde(with = "humantime_serde")]
migration_interval: Duration,
#[serde(alias = "progress")]
#[serde(with = "humantime_serde")]
progress_interval: Duration,
}
impl Default for MpiPartitioningRaw {
fn default() -> Self {
Self {
world: None,
migration_interval: MpiPartitioning::MPI_DEFAULT_MIGRATION_INTERVAL,
progress_interval: MpiPartitioning::MPI_DEFAULT_PROGRESS_INTERVAL,
}
}
}
fn deserialize_state_mpi_world<'de, D: Deserializer<'de>>(
mpi_world: &mut PartitionSize,
deserializer: D,
) -> Result<Option<PartitionSize>, D::Error> {
let maybe_world = Option::<PartitionSize>::deserialize(deserializer)?;
match maybe_world {
None => Ok(None),
Some(world) if world == *mpi_world => Ok(Some(world)),
Some(_) => Err(serde::de::Error::custom(format!(
"mismatch with MPI world size of {mpi_world}"
))),
}
}
fn reduce_partitioning_data<T: serde::Serialize + serde::de::DeserializeOwned>(
world: &SimpleCommunicator,
data: T,
fold: fn(T, T) -> T,
) -> anyhow::Result<T> {
let local_ser =
postcard::to_stdvec(&data).context("MPI local partition result failed to serialize")?;
std::mem::drop(data);
let local_ser_len = Count::try_from(local_ser.len())
.context("MPI local partition result is too big to share")?;
#[allow(clippy::cast_sign_loss)]
let mut counts = vec![0 as Count; world.size() as usize];
world.all_gather_into(&local_ser_len, &mut counts);
let offsets = counts
.iter()
.scan(0 as Count, |acc, &x| {
let tmp = *acc;
if let Some(a) = (*acc).checked_add(x) {
*acc = a;
} else {
return Some(Err(anyhow::anyhow!(
"MPI combined local partition results are too big to share"
)));
}
Some(Ok(tmp))
})
.collect::<Result<Vec<_>, _>>()?;
#[allow(clippy::cast_sign_loss)]
let mut all_sers = vec![0_u8; counts.iter().copied().sum::<Count>() as usize];
world.all_gather_varcount_into(
local_ser.as_slice(),
&mut PartitionMut::new(all_sers.as_mut_slice(), counts.as_slice(), offsets),
);
let folded: Option<T> = counts
.iter()
.scan(0_usize, |acc, &x| {
let pre = *acc;
#[allow(clippy::cast_sign_loss)]
{
*acc += x as usize;
}
let post = *acc;
let de: anyhow::Result<T> = postcard::from_bytes(&all_sers[pre..post])
.context("MPI data failed to deserialize");
Some(de)
})
.try_fold(None, |acc, x| match (acc, x) {
(_, Err(err)) => Err(err),
(Some(acc), Ok(x)) => Ok(Some(fold(acc, x))),
(None, Ok(x)) => Ok(Some(x)),
})?;
let folded = folded.expect("at least one MPI partitioning result");
Ok(folded)
}
pub enum FinalisableMpiReporter<R: Reporter> {
Root(OpaqueFinalisableReporter<FilteredReporter<R, False, False, True>>),
Parallel,
}
impl<R: Reporter> FinalisableReporter for FinalisableMpiReporter<R> {
fn finalise(self) {
if let Self::Root(reporter) = self {
reporter.finalise();
}
}
}