numcodecs_uniform_noise/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-uniform-noise
10//! [crates.io]: https://crates.io/crates/numcodecs-uniform-noise
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-uniform-noise
13//! [docs.rs]: https://docs.rs/numcodecs-uniform-noise/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_uniform_noise
17//!
18//! Uniform noise codec implementation for the [`numcodecs`] API.
19
20use std::hash::{Hash, Hasher};
21
22use ndarray::{Array, ArrayBase, Data, Dimension};
23use num_traits::Float;
24use numcodecs::{
25    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
27};
28use rand::{
29    distributions::{Distribution, Open01},
30    SeedableRng,
31};
32use schemars::JsonSchema;
33use serde::{Deserialize, Serialize};
34use thiserror::Error;
35use wyhash::{WyHash, WyRng};
36
37#[derive(Clone, Serialize, Deserialize, JsonSchema)]
38#[serde(deny_unknown_fields)]
39/// Codec that adds `seed`ed `$\text{U}(-0.5 \cdot scale, 0.5 \cdot scale)$`
40/// uniform noise of the given `scale` during encoding and passes through the
41/// input unchanged during decoding.
42///
43/// This codec first hashes the input array data and shape to then seed a
44/// pseudo-random number generator that generates the uniform noise. Therefore,
45/// passing in the same input with the same `seed` will produce the same noise
46/// and thus the same encoded output.
47pub struct UniformNoiseCodec {
48    /// Scale of the uniform noise, which is sampled from
49    /// `$\text{U}(-0.5 \cdot scale, 0.5 \cdot scale)$`
50    pub scale: f64,
51    /// Seed for the random noise generator
52    pub seed: u64,
53    /// The codec's encoding format version. Do not provide this parameter explicitly.
54    #[serde(default, rename = "_version")]
55    pub version: StaticCodecVersion<1, 0, 0>,
56}
57
58impl Codec for UniformNoiseCodec {
59    type Error = UniformNoiseCodecError;
60
61    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
62        match data {
63            #[expect(clippy::cast_possible_truncation)]
64            AnyCowArray::F32(data) => Ok(AnyArray::F32(add_uniform_noise(
65                data,
66                self.scale as f32,
67                self.seed,
68            ))),
69            AnyCowArray::F64(data) => Ok(AnyArray::F64(add_uniform_noise(
70                data, self.scale, self.seed,
71            ))),
72            encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
73        }
74    }
75
76    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
77        match encoded {
78            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
79            AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
80            encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
81        }
82    }
83
84    fn decode_into(
85        &self,
86        encoded: AnyArrayView,
87        mut decoded: AnyArrayViewMut,
88    ) -> Result<(), Self::Error> {
89        if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
90            return Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype()));
91        }
92
93        Ok(decoded.assign(&encoded)?)
94    }
95}
96
97impl StaticCodec for UniformNoiseCodec {
98    const CODEC_ID: &'static str = "uniform-noise.rs";
99
100    type Config<'de> = Self;
101
102    fn from_config(config: Self::Config<'_>) -> Self {
103        config
104    }
105
106    fn get_config(&self) -> StaticCodecConfig<Self> {
107        StaticCodecConfig::from(self)
108    }
109}
110
111#[derive(Debug, Error)]
112/// Errors that may occur when applying the [`UniformNoiseCodec`].
113pub enum UniformNoiseCodecError {
114    /// [`UniformNoiseCodec`] does not support the dtype
115    #[error("UniformNoise does not support the dtype {0}")]
116    UnsupportedDtype(AnyArrayDType),
117    /// [`UniformNoiseCodec`] cannot decode into the provided array
118    #[error("UniformNoise cannot decode into the provided array")]
119    MismatchedDecodeIntoArray {
120        /// The source of the error
121        #[from]
122        source: AnyArrayAssignError,
123    },
124}
125
126/// Adds `$\text{U}(-0.5 \cdot scale, 0.5 \cdot scale)$` uniform random noise
127/// to the input `data`.
128///
129/// This function first hashes the input and its shape to then seed a pseudo-
130/// random number generator that generates the uniform noise. Therefore,
131/// passing in the same input with the same `seed` will produce the same noise
132/// and thus the same output.
133#[must_use]
134pub fn add_uniform_noise<T: FloatExt, S: Data<Elem = T>, D: Dimension>(
135    data: ArrayBase<S, D>,
136    scale: T,
137    seed: u64,
138) -> Array<T, D>
139where
140    Open01: Distribution<T>,
141{
142    let mut hasher = WyHash::with_seed(seed);
143    // hashing the shape provides a prefix for the flattened data
144    data.shape().hash(&mut hasher);
145    // the data must be visited in a defined order
146    data.iter().copied().for_each(|x| x.hash_bits(&mut hasher));
147    let seed = hasher.finish();
148
149    let mut rng: WyRng = WyRng::seed_from_u64(seed);
150
151    let mut encoded = data.into_owned();
152
153    // the data must be visited in a defined order
154    for x in &mut encoded {
155        // x = U(0,1)*scale + (scale*-0.5 + x)
156        // --- is equivalent to ---
157        // x += U(-scale/2, +scale/2)
158        *x = Open01
159            .sample(&mut rng)
160            .mul_add(scale, scale.mul_add(T::NEG_HALF, *x));
161    }
162
163    encoded
164}
165
166/// Floating point types
167pub trait FloatExt: Float {
168    /// -0.5
169    const NEG_HALF: Self;
170
171    /// Hash the binary representation of the floating point value
172    fn hash_bits<H: Hasher>(self, hasher: &mut H);
173}
174
175impl FloatExt for f32 {
176    const NEG_HALF: Self = -0.5;
177
178    fn hash_bits<H: Hasher>(self, hasher: &mut H) {
179        hasher.write_u32(self.to_bits());
180    }
181}
182
183impl FloatExt for f64 {
184    const NEG_HALF: Self = -0.5;
185
186    fn hash_bits<H: Hasher>(self, hasher: &mut H) {
187        hasher.write_u64(self.to_bits());
188    }
189}