numcodecs_uniform_noise/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-uniform-noise
10//! [crates.io]: https://crates.io/crates/numcodecs-uniform-noise
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-uniform-noise
13//! [docs.rs]: https://docs.rs/numcodecs-uniform-noise/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_uniform_noise
17//!
18//! Uniform noise codec implementation for the [`numcodecs`] API.
19
20use std::hash::{Hash, Hasher};
21
22use ndarray::{Array, ArrayBase, Data, Dimension};
23use num_traits::Float;
24use numcodecs::{
25    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26    Codec, StaticCodec, StaticCodecConfig,
27};
28use rand::{
29    distributions::{Distribution, Open01},
30    SeedableRng,
31};
32use schemars::JsonSchema;
33use serde::{Deserialize, Serialize};
34use thiserror::Error;
35use wyhash::{WyHash, WyRng};
36
37#[derive(Clone, Serialize, Deserialize, JsonSchema)]
38#[serde(deny_unknown_fields)]
39/// Codec that adds `seed`ed `$\text{U}(-0.5 \cdot scale, 0.5 \cdot scale)$`
40/// uniform noise of the given `scale` during encoding and passes through the
41/// input unchanged during decoding.
42///
43/// This codec first hashes the input array data and shape to then seed a
44/// pseudo-random number generator that generates the uniform noise. Therefore,
45/// passing in the same input with the same `seed` will produce the same noise
46/// and thus the same encoded output.
47pub struct UniformNoiseCodec {
48    /// Scale of the uniform noise, which is sampled from
49    /// `$\text{U}(-0.5 \cdot scale, 0.5 \cdot scale)$`
50    pub scale: f64,
51    /// Seed for the random noise generator
52    pub seed: u64,
53}
54
55impl Codec for UniformNoiseCodec {
56    type Error = UniformNoiseCodecError;
57
58    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
59        match data {
60            #[expect(clippy::cast_possible_truncation)]
61            AnyCowArray::F32(data) => Ok(AnyArray::F32(add_uniform_noise(
62                data,
63                self.scale as f32,
64                self.seed,
65            ))),
66            AnyCowArray::F64(data) => Ok(AnyArray::F64(add_uniform_noise(
67                data, self.scale, self.seed,
68            ))),
69            encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
70        }
71    }
72
73    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
74        match encoded {
75            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
76            AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
77            encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
78        }
79    }
80
81    fn decode_into(
82        &self,
83        encoded: AnyArrayView,
84        mut decoded: AnyArrayViewMut,
85    ) -> Result<(), Self::Error> {
86        if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
87            return Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype()));
88        }
89
90        Ok(decoded.assign(&encoded)?)
91    }
92}
93
94impl StaticCodec for UniformNoiseCodec {
95    const CODEC_ID: &'static str = "uniform-noise";
96
97    type Config<'de> = Self;
98
99    fn from_config(config: Self::Config<'_>) -> Self {
100        config
101    }
102
103    fn get_config(&self) -> StaticCodecConfig<Self> {
104        StaticCodecConfig::from(self)
105    }
106}
107
108#[derive(Debug, Error)]
109/// Errors that may occur when applying the [`UniformNoiseCodec`].
110pub enum UniformNoiseCodecError {
111    /// [`UniformNoiseCodec`] does not support the dtype
112    #[error("UniformNoise does not support the dtype {0}")]
113    UnsupportedDtype(AnyArrayDType),
114    /// [`UniformNoiseCodec`] cannot decode into the provided array
115    #[error("UniformNoise cannot decode into the provided array")]
116    MismatchedDecodeIntoArray {
117        /// The source of the error
118        #[from]
119        source: AnyArrayAssignError,
120    },
121}
122
123/// Adds `$\text{U}(-0.5 \cdot scale, 0.5 \cdot scale)$` uniform random noise
124/// to the input `data`.
125///
126/// This function first hashes the input and its shape to then seed a pseudo-
127/// random number generator that generates the uniform noise. Therefore,
128/// passing in the same input with the same `seed` will produce the same noise
129/// and thus the same output.
130#[must_use]
131pub fn add_uniform_noise<T: FloatExt, S: Data<Elem = T>, D: Dimension>(
132    data: ArrayBase<S, D>,
133    scale: T,
134    seed: u64,
135) -> Array<T, D>
136where
137    Open01: Distribution<T>,
138{
139    let mut hasher = WyHash::with_seed(seed);
140    // hashing the shape provides a prefix for the flattened data
141    data.shape().hash(&mut hasher);
142    // the data must be visited in a defined order
143    data.iter().copied().for_each(|x| x.hash_bits(&mut hasher));
144    let seed = hasher.finish();
145
146    let mut rng: WyRng = WyRng::seed_from_u64(seed);
147
148    let mut encoded = data.into_owned();
149
150    // the data must be visited in a defined order
151    for x in &mut encoded {
152        // x = U(0,1)*scale + (scale*-0.5 + x)
153        // --- is equivalent to ---
154        // x += U(-scale/2, +scale/2)
155        *x = Open01
156            .sample(&mut rng)
157            .mul_add(scale, scale.mul_add(T::NEG_HALF, *x));
158    }
159
160    encoded
161}
162
163/// Floating point types
164pub trait FloatExt: Float {
165    /// -0.5
166    const NEG_HALF: Self;
167
168    /// Hash the binary representation of the floating point value
169    fn hash_bits<H: Hasher>(self, hasher: &mut H);
170}
171
172impl FloatExt for f32 {
173    const NEG_HALF: Self = -0.5;
174
175    fn hash_bits<H: Hasher>(self, hasher: &mut H) {
176        hasher.write_u32(self.to_bits());
177    }
178}
179
180impl FloatExt for f64 {
181    const NEG_HALF: Self = -0.5;
182
183    fn hash_bits<H: Hasher>(self, hasher: &mut H) {
184        hasher.write_u64(self.to_bits());
185    }
186}