numcodecs_stochastic_rounding/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.86.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-stochastic-rounding
10//! [crates.io]: https://crates.io/crates/numcodecs-stochastic-rounding
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-stochastic-rounding
13//! [docs.rs]: https://docs.rs/numcodecs-stochastic-rounding/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_stochastic_rounding
17//!
18//! Stochastic rounding codec implementation for the [`numcodecs`] API.
19
20use std::borrow::Cow;
21use std::hash::{Hash, Hasher};
22
23use ndarray::{Array, ArrayBase, Data, Dimension};
24use num_traits::Float;
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use rand::{
30    SeedableRng,
31    distr::{Distribution, Open01},
32};
33use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
34use serde::{Deserialize, Deserializer, Serialize, Serializer};
35use thiserror::Error;
36use wyhash::{WyHash, WyRng};
37
38#[derive(Clone, Serialize, Deserialize, JsonSchema)]
39#[serde(deny_unknown_fields)]
40/// Codec that stochastically rounds the data to the nearest multiple of
41/// `precision` on encoding and passes through the input unchanged during
42/// decoding.
43///
44/// The nearest representable multiple is chosen such that the absolute
45/// difference between the original value and the rounded value do not exceed
46/// the precision. Therefore, the rounded value may have a non-zero remainder.
47///
48/// This codec first hashes the input array data and shape to then `seed` a
49/// pseudo-random number generator that is used to sample the stochasticity for
50/// rounding. Therefore, passing in the same input with the same `seed` will
51/// produce the same stochasticity and thus the same encoded output.
52pub struct StochasticRoundingCodec {
53    /// The precision of the rounding operation
54    pub precision: NonNegative<f64>,
55    /// Seed for the random generator
56    pub seed: u64,
57    /// The codec's encoding format version. Do not provide this parameter explicitly.
58    #[serde(default, rename = "_version")]
59    pub version: StaticCodecVersion<1, 0, 0>,
60}
61
62impl Codec for StochasticRoundingCodec {
63    type Error = StochasticRoundingCodecError;
64
65    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
66        match data {
67            #[expect(clippy::cast_possible_truncation)]
68            AnyCowArray::F32(data) => Ok(AnyArray::F32(stochastic_rounding(
69                data,
70                NonNegative(self.precision.0 as f32),
71                self.seed,
72            ))),
73            AnyCowArray::F64(data) => Ok(AnyArray::F64(stochastic_rounding(
74                data,
75                self.precision,
76                self.seed,
77            ))),
78            encoded => Err(StochasticRoundingCodecError::UnsupportedDtype(
79                encoded.dtype(),
80            )),
81        }
82    }
83
84    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
85        match encoded {
86            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
87            AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
88            encoded => Err(StochasticRoundingCodecError::UnsupportedDtype(
89                encoded.dtype(),
90            )),
91        }
92    }
93
94    fn decode_into(
95        &self,
96        encoded: AnyArrayView,
97        mut decoded: AnyArrayViewMut,
98    ) -> Result<(), Self::Error> {
99        if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
100            return Err(StochasticRoundingCodecError::UnsupportedDtype(
101                encoded.dtype(),
102            ));
103        }
104
105        Ok(decoded.assign(&encoded)?)
106    }
107}
108
109impl StaticCodec for StochasticRoundingCodec {
110    const CODEC_ID: &'static str = "stochastic-rounding.rs";
111
112    type Config<'de> = Self;
113
114    fn from_config(config: Self::Config<'_>) -> Self {
115        config
116    }
117
118    fn get_config(&self) -> StaticCodecConfig<'_, Self> {
119        StaticCodecConfig::from(self)
120    }
121}
122
123#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
124#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
125/// Non-negative floating point number
126pub struct NonNegative<T: Float>(T);
127
128impl Serialize for NonNegative<f64> {
129    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
130        serializer.serialize_f64(self.0)
131    }
132}
133
134impl<'de> Deserialize<'de> for NonNegative<f64> {
135    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
136        let x = f64::deserialize(deserializer)?;
137
138        if x >= 0.0 {
139            Ok(Self(x))
140        } else {
141            Err(serde::de::Error::invalid_value(
142                serde::de::Unexpected::Float(x),
143                &"a non-negative value",
144            ))
145        }
146    }
147}
148
149impl JsonSchema for NonNegative<f64> {
150    fn schema_name() -> Cow<'static, str> {
151        Cow::Borrowed("NonNegativeF64")
152    }
153
154    fn schema_id() -> Cow<'static, str> {
155        Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
156    }
157
158    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
159        json_schema!({
160            "type": "number",
161            "minimum": 0.0
162        })
163    }
164}
165
166#[derive(Debug, Error)]
167/// Errors that may occur when applying the [`StochasticRoundingCodec`].
168pub enum StochasticRoundingCodecError {
169    /// [`StochasticRoundingCodec`] does not support the dtype
170    #[error("StochasticRounding does not support the dtype {0}")]
171    UnsupportedDtype(AnyArrayDType),
172    /// [`StochasticRoundingCodec`] cannot decode into the provided array
173    #[error("StochasticRounding cannot decode into the provided array")]
174    MismatchedDecodeIntoArray {
175        /// The source of the error
176        #[from]
177        source: AnyArrayAssignError,
178    },
179}
180
181/// Stochastically rounds the `data` to the nearest multiple of the `precision`.
182///
183/// The nearest representable multiple is chosen such that the absolute
184/// difference between the original value and the rounded value do not exceed
185/// the precision. Therefore, the rounded value may have a non-zero remainder.
186///
187/// This function first hashes the input array data and shape to then `seed` a
188/// pseudo-random number generator that is used to sample the stochasticity for
189/// rounding. Therefore, passing in the same input with the same `seed` will
190/// produce the same stochasticity and thus the same encoded output.
191#[must_use]
192pub fn stochastic_rounding<T: FloatExt, S: Data<Elem = T>, D: Dimension>(
193    data: ArrayBase<S, D>,
194    precision: NonNegative<T>,
195    seed: u64,
196) -> Array<T, D>
197where
198    Open01: Distribution<T>,
199{
200    let mut encoded = data.into_owned();
201
202    if precision.0.is_zero() {
203        return encoded;
204    }
205
206    let mut hasher = WyHash::with_seed(seed);
207    // hashing the shape provides a prefix for the flattened data
208    encoded.shape().hash(&mut hasher);
209    // the data must be visited in a defined order
210    encoded
211        .iter()
212        .copied()
213        .for_each(|x| x.hash_bits(&mut hasher));
214    let seed = hasher.finish();
215
216    let mut rng: WyRng = WyRng::seed_from_u64(seed);
217
218    // the data must be visited in a defined order
219    for x in &mut encoded {
220        if !x.is_finite() {
221            continue;
222        }
223
224        let remainder = x.rem_euclid(precision.0);
225
226        // compute the nearest multiples of precision based on the remainder
227        // correct max 1 ULP rounding errors to ensure that the nearest
228        //  multiples are at most precision away from the original value
229        let mut lower = *x - remainder;
230        if (*x - lower) > precision.0 {
231            lower = lower.next_up();
232        }
233        let mut upper = *x + (precision.0 - remainder);
234        if (upper - *x) > precision.0 {
235            upper = upper.next_down();
236        }
237
238        let threshold = remainder / precision.0;
239
240        let u01: T = Open01.sample(&mut rng);
241
242        // if remainder = 0, U(0, 1) >= 0, so lower (i.e. a) is always picked
243        // if threshold = 1/2, U(0, 1) picks lower and upper with equal chance
244        *x = if u01 >= threshold { lower } else { upper };
245    }
246
247    encoded
248}
249
250/// Floating point types
251pub trait FloatExt: Float {
252    /// -0.5
253    const NEG_HALF: Self;
254
255    /// Hash the binary representation of the floating point value
256    fn hash_bits<H: Hasher>(self, hasher: &mut H);
257
258    /// Calculates the least nonnegative remainder of self (mod rhs).
259    #[must_use]
260    fn rem_euclid(self, rhs: Self) -> Self;
261
262    /// Returns the least number greater than `self`.
263    #[must_use]
264    fn next_up(self) -> Self;
265
266    /// Returns the greatest number less than `self`.
267    #[must_use]
268    fn next_down(self) -> Self;
269}
270
271impl FloatExt for f32 {
272    const NEG_HALF: Self = -0.5;
273
274    fn hash_bits<H: Hasher>(self, hasher: &mut H) {
275        hasher.write_u32(self.to_bits());
276    }
277
278    fn rem_euclid(self, rhs: Self) -> Self {
279        Self::rem_euclid(self, rhs)
280    }
281
282    fn next_up(self) -> Self {
283        Self::next_up(self)
284    }
285
286    fn next_down(self) -> Self {
287        Self::next_down(self)
288    }
289}
290
291impl FloatExt for f64 {
292    const NEG_HALF: Self = -0.5;
293
294    fn hash_bits<H: Hasher>(self, hasher: &mut H) {
295        hasher.write_u64(self.to_bits());
296    }
297
298    fn rem_euclid(self, rhs: Self) -> Self {
299        Self::rem_euclid(self, rhs)
300    }
301
302    fn next_up(self) -> Self {
303        Self::next_up(self)
304    }
305
306    fn next_down(self) -> Self {
307        Self::next_down(self)
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use ndarray::{array, linspace};
314
315    use super::*;
316
317    #[test]
318    fn round_zero_precision() {
319        let data = array![1.1, 2.1];
320
321        let rounded = stochastic_rounding(data.view(), NonNegative(0.0), 42);
322
323        assert_eq!(data, rounded);
324    }
325
326    #[test]
327    fn round_infinite_precision() {
328        let data = array![1.1, 2.1];
329
330        let rounded = stochastic_rounding(data.view(), NonNegative(f64::INFINITY), 42);
331
332        assert_eq!(rounded, array![0.0, 0.0]);
333    }
334
335    #[test]
336    fn round_minimal_precision() {
337        let data = array![0.1, 1.0, 11.0, 21.0];
338
339        assert_eq!(11.0 / f64::MIN_POSITIVE, f64::INFINITY);
340        let rounded = stochastic_rounding(data.view(), NonNegative(f64::MIN_POSITIVE), 42);
341
342        assert_eq!(data, rounded);
343    }
344
345    #[test]
346    fn round_edge_cases() {
347        let data = array![
348            -f64::NAN,
349            -f64::INFINITY,
350            -42.0,
351            -4.2,
352            -0.0,
353            0.0,
354            4.2,
355            42.0,
356            f64::INFINITY,
357            f64::NAN
358        ];
359        let precision = 1.0;
360
361        let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
362
363        for (d, r) in data.into_iter().zip(rounded) {
364            assert!((r - d).abs() <= precision || d.to_bits() == r.to_bits());
365        }
366    }
367
368    #[test]
369    fn round_rounding_errors() {
370        let data = Array::from_iter(linspace(-100.0, 100.0, 3741));
371        let precision = 0.1;
372
373        let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
374
375        for (d, r) in data.into_iter().zip(rounded) {
376            assert!((r - d).abs() <= precision);
377        }
378    }
379
380    #[test]
381    fn test_rounding_bug() {
382        let data = array![
383            -1.23540_f32,
384            -1.23539_f32,
385            -1.23538_f32,
386            -1.23537_f32,
387            -1.23536_f32,
388            -1.23535_f32,
389            -1.23534_f32,
390            -1.23533_f32,
391            -1.23532_f32,
392            -1.23531_f32,
393            -1.23530_f32,
394            1.23540_f32,
395            1.23539_f32,
396            1.23538_f32,
397            1.23537_f32,
398            1.23536_f32,
399            1.23535_f32,
400            1.23534_f32,
401            1.23533_f32,
402            1.23532_f32,
403            1.23531_f32,
404            1.23530_f32,
405        ];
406        let precision = 0.00018_f32;
407
408        let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
409
410        for (d, r) in data.into_iter().zip(rounded) {
411            assert!((r - d).abs() <= precision);
412        }
413    }
414}