numcodecs_round/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-round
10//! [crates.io]: https://crates.io/crates/numcodecs-round
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-round
13//! [docs.rs]: https://docs.rs/numcodecs-round/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_round
17//!
18//! Rounding codec implementation for the [`numcodecs`] API.
19
20use std::borrow::Cow;
21
22use ndarray::{Array, ArrayBase, Data, Dimension};
23use num_traits::Float;
24use numcodecs::{
25    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26    Codec, StaticCodec, StaticCodecConfig,
27};
28use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[derive(Clone, Serialize, Deserialize, JsonSchema)]
33#[serde(deny_unknown_fields)]
34/// Codec that rounds the data on encoding and passes through the input
35/// unchanged during decoding.
36///
37/// The codec only supports floating point data.
38pub struct RoundCodec {
39    /// Precision of the rounding operation
40    pub precision: Positive<f64>,
41}
42
43impl Codec for RoundCodec {
44    type Error = RoundCodecError;
45
46    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
47        match data {
48            #[expect(clippy::cast_possible_truncation)]
49            AnyCowArray::F32(data) => Ok(AnyArray::F32(round(
50                data,
51                Positive(self.precision.0 as f32),
52            ))),
53            AnyCowArray::F64(data) => Ok(AnyArray::F64(round(data, self.precision))),
54            encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
55        }
56    }
57
58    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
59        match encoded {
60            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
61            AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
62            encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
63        }
64    }
65
66    fn decode_into(
67        &self,
68        encoded: AnyArrayView,
69        mut decoded: AnyArrayViewMut,
70    ) -> Result<(), Self::Error> {
71        if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
72            return Err(RoundCodecError::UnsupportedDtype(encoded.dtype()));
73        }
74
75        Ok(decoded.assign(&encoded)?)
76    }
77}
78
79impl StaticCodec for RoundCodec {
80    const CODEC_ID: &'static str = "round";
81
82    type Config<'de> = Self;
83
84    fn from_config(config: Self::Config<'_>) -> Self {
85        config
86    }
87
88    fn get_config(&self) -> StaticCodecConfig<Self> {
89        StaticCodecConfig::from(self)
90    }
91}
92
93#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
94#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
95/// Positive floating point number
96pub struct Positive<T: Float>(T);
97
98impl Serialize for Positive<f64> {
99    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
100        serializer.serialize_f64(self.0)
101    }
102}
103
104impl<'de> Deserialize<'de> for Positive<f64> {
105    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
106        let x = f64::deserialize(deserializer)?;
107
108        if x > 0.0 {
109            Ok(Self(x))
110        } else {
111            Err(serde::de::Error::invalid_value(
112                serde::de::Unexpected::Float(x),
113                &"a positive value",
114            ))
115        }
116    }
117}
118
119impl JsonSchema for Positive<f64> {
120    fn schema_name() -> Cow<'static, str> {
121        Cow::Borrowed("PositiveF64")
122    }
123
124    fn schema_id() -> Cow<'static, str> {
125        Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
126    }
127
128    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
129        json_schema!({
130            "type": "number",
131            "exclusiveMinimum": 0.0
132        })
133    }
134}
135
136#[derive(Debug, Error)]
137/// Errors that may occur when applying the [`RoundCodec`].
138pub enum RoundCodecError {
139    /// [`RoundCodec`] does not support the dtype
140    #[error("Round does not support the dtype {0}")]
141    UnsupportedDtype(AnyArrayDType),
142    /// [`RoundCodec`] cannot decode into the provided array
143    #[error("Round cannot decode into the provided array")]
144    MismatchedDecodeIntoArray {
145        /// The source of the error
146        #[from]
147        source: AnyArrayAssignError,
148    },
149}
150
151#[must_use]
152/// Rounds the input `data` using
153/// `$c = \text{round}\left( \frac{x}{precision} \right) \cdot precision$`
154pub fn round<T: Float, S: Data<Elem = T>, D: Dimension>(
155    data: ArrayBase<S, D>,
156    precision: Positive<T>,
157) -> Array<T, D> {
158    let mut encoded = data.into_owned();
159    encoded.mapv_inplace(|x| (x / precision.0).round() * precision.0);
160    encoded
161}