1use std::borrow::Cow;
21
22use ndarray::{Array, ArrayBase, Data, Dimension};
23use num_traits::Float;
24use numcodecs::{
25 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26 Codec, StaticCodec, StaticCodecConfig,
27};
28use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[derive(Clone, Serialize, Deserialize, JsonSchema)]
33#[serde(deny_unknown_fields)]
34pub struct RoundCodec {
39 pub precision: Positive<f64>,
41}
42
43impl Codec for RoundCodec {
44 type Error = RoundCodecError;
45
46 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
47 match data {
48 #[expect(clippy::cast_possible_truncation)]
49 AnyCowArray::F32(data) => Ok(AnyArray::F32(round(
50 data,
51 Positive(self.precision.0 as f32),
52 ))),
53 AnyCowArray::F64(data) => Ok(AnyArray::F64(round(data, self.precision))),
54 encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
55 }
56 }
57
58 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
59 match encoded {
60 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
61 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
62 encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
63 }
64 }
65
66 fn decode_into(
67 &self,
68 encoded: AnyArrayView,
69 mut decoded: AnyArrayViewMut,
70 ) -> Result<(), Self::Error> {
71 if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
72 return Err(RoundCodecError::UnsupportedDtype(encoded.dtype()));
73 }
74
75 Ok(decoded.assign(&encoded)?)
76 }
77}
78
79impl StaticCodec for RoundCodec {
80 const CODEC_ID: &'static str = "round";
81
82 type Config<'de> = Self;
83
84 fn from_config(config: Self::Config<'_>) -> Self {
85 config
86 }
87
88 fn get_config(&self) -> StaticCodecConfig<Self> {
89 StaticCodecConfig::from(self)
90 }
91}
92
93#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
95pub struct Positive<T: Float>(T);
97
98impl Serialize for Positive<f64> {
99 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
100 serializer.serialize_f64(self.0)
101 }
102}
103
104impl<'de> Deserialize<'de> for Positive<f64> {
105 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
106 let x = f64::deserialize(deserializer)?;
107
108 if x > 0.0 {
109 Ok(Self(x))
110 } else {
111 Err(serde::de::Error::invalid_value(
112 serde::de::Unexpected::Float(x),
113 &"a positive value",
114 ))
115 }
116 }
117}
118
119impl JsonSchema for Positive<f64> {
120 fn schema_name() -> Cow<'static, str> {
121 Cow::Borrowed("PositiveF64")
122 }
123
124 fn schema_id() -> Cow<'static, str> {
125 Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
126 }
127
128 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
129 json_schema!({
130 "type": "number",
131 "exclusiveMinimum": 0.0
132 })
133 }
134}
135
136#[derive(Debug, Error)]
137pub enum RoundCodecError {
139 #[error("Round does not support the dtype {0}")]
141 UnsupportedDtype(AnyArrayDType),
142 #[error("Round cannot decode into the provided array")]
144 MismatchedDecodeIntoArray {
145 #[from]
147 source: AnyArrayAssignError,
148 },
149}
150
151#[must_use]
152pub fn round<T: Float, S: Data<Elem = T>, D: Dimension>(
155 data: ArrayBase<S, D>,
156 precision: Positive<T>,
157) -> Array<T, D> {
158 let mut encoded = data.into_owned();
159 encoded.mapv_inplace(|x| (x / precision.0).round() * precision.0);
160 encoded
161}