use std::borrow::Cow;
use ndarray::{Array, ArrayBase, Data, Dimension};
use num_traits::Float;
use numcodecs::{
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
Codec, StaticCodec, StaticCodecConfig,
};
use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct RoundCodec {
pub precision: Positive<f64>,
}
impl Codec for RoundCodec {
type Error = RoundCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
match data {
#[allow(clippy::cast_possible_truncation)]
AnyCowArray::F32(data) => Ok(AnyArray::F32(round(
data,
Positive(self.precision.0 as f32),
))),
AnyCowArray::F64(data) => Ok(AnyArray::F64(round(data, self.precision))),
encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
match encoded {
AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode_into(
&self,
encoded: AnyArrayView,
mut decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
return Err(RoundCodecError::UnsupportedDtype(encoded.dtype()));
}
Ok(decoded.assign(&encoded)?)
}
}
impl StaticCodec for RoundCodec {
const CODEC_ID: &'static str = "round";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
config
}
fn get_config(&self) -> StaticCodecConfig<Self> {
StaticCodecConfig::from(self)
}
}
#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
pub struct Positive<T: Float>(T);
impl Serialize for Positive<f64> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_f64(self.0)
}
}
impl<'de> Deserialize<'de> for Positive<f64> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let x = f64::deserialize(deserializer)?;
if x > 0.0 {
Ok(Self(x))
} else {
Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Float(x),
&"a positive value",
))
}
}
}
impl JsonSchema for Positive<f64> {
fn schema_name() -> Cow<'static, str> {
Cow::Borrowed("PositiveF64")
}
fn schema_id() -> Cow<'static, str> {
Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
}
fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
json_schema!({
"type": "number",
"exclusiveMinimum": 0.0
})
}
}
#[derive(Debug, Error)]
pub enum RoundCodecError {
#[error("Round does not support the dtype {0}")]
UnsupportedDtype(AnyArrayDType),
#[error("Round cannot decode into the provided array")]
MismatchedDecodeIntoArray {
#[from]
source: AnyArrayAssignError,
},
}
#[must_use]
pub fn round<T: Float, S: Data<Elem = T>, D: Dimension>(
data: ArrayBase<S, D>,
precision: Positive<T>,
) -> Array<T, D> {
let mut encoded = data.into_owned();
encoded.mapv_inplace(|x| (x / precision.0).round() * precision.0);
encoded
}