numcodecs_uniform_noise/
lib.rs1use std::hash::{Hash, Hasher};
21
22use ndarray::{Array, ArrayBase, Data, Dimension};
23use num_traits::Float;
24use numcodecs::{
25 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
27};
28use rand::{
29 distributions::{Distribution, Open01},
30 SeedableRng,
31};
32use schemars::JsonSchema;
33use serde::{Deserialize, Serialize};
34use thiserror::Error;
35use wyhash::{WyHash, WyRng};
36
37#[derive(Clone, Serialize, Deserialize, JsonSchema)]
38#[serde(deny_unknown_fields)]
39pub struct UniformNoiseCodec {
48 pub scale: f64,
51 pub seed: u64,
53 #[serde(default, rename = "_version")]
55 pub version: StaticCodecVersion<1, 0, 0>,
56}
57
58impl Codec for UniformNoiseCodec {
59 type Error = UniformNoiseCodecError;
60
61 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
62 match data {
63 #[expect(clippy::cast_possible_truncation)]
64 AnyCowArray::F32(data) => Ok(AnyArray::F32(add_uniform_noise(
65 data,
66 self.scale as f32,
67 self.seed,
68 ))),
69 AnyCowArray::F64(data) => Ok(AnyArray::F64(add_uniform_noise(
70 data, self.scale, self.seed,
71 ))),
72 encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
73 }
74 }
75
76 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
77 match encoded {
78 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
79 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
80 encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
81 }
82 }
83
84 fn decode_into(
85 &self,
86 encoded: AnyArrayView,
87 mut decoded: AnyArrayViewMut,
88 ) -> Result<(), Self::Error> {
89 if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
90 return Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype()));
91 }
92
93 Ok(decoded.assign(&encoded)?)
94 }
95}
96
97impl StaticCodec for UniformNoiseCodec {
98 const CODEC_ID: &'static str = "uniform-noise.rs";
99
100 type Config<'de> = Self;
101
102 fn from_config(config: Self::Config<'_>) -> Self {
103 config
104 }
105
106 fn get_config(&self) -> StaticCodecConfig<Self> {
107 StaticCodecConfig::from(self)
108 }
109}
110
111#[derive(Debug, Error)]
112pub enum UniformNoiseCodecError {
114 #[error("UniformNoise does not support the dtype {0}")]
116 UnsupportedDtype(AnyArrayDType),
117 #[error("UniformNoise cannot decode into the provided array")]
119 MismatchedDecodeIntoArray {
120 #[from]
122 source: AnyArrayAssignError,
123 },
124}
125
126#[must_use]
134pub fn add_uniform_noise<T: FloatExt, S: Data<Elem = T>, D: Dimension>(
135 data: ArrayBase<S, D>,
136 scale: T,
137 seed: u64,
138) -> Array<T, D>
139where
140 Open01: Distribution<T>,
141{
142 let mut hasher = WyHash::with_seed(seed);
143 data.shape().hash(&mut hasher);
145 data.iter().copied().for_each(|x| x.hash_bits(&mut hasher));
147 let seed = hasher.finish();
148
149 let mut rng: WyRng = WyRng::seed_from_u64(seed);
150
151 let mut encoded = data.into_owned();
152
153 for x in &mut encoded {
155 *x = Open01
159 .sample(&mut rng)
160 .mul_add(scale, scale.mul_add(T::NEG_HALF, *x));
161 }
162
163 encoded
164}
165
166pub trait FloatExt: Float {
168 const NEG_HALF: Self;
170
171 fn hash_bits<H: Hasher>(self, hasher: &mut H);
173}
174
175impl FloatExt for f32 {
176 const NEG_HALF: Self = -0.5;
177
178 fn hash_bits<H: Hasher>(self, hasher: &mut H) {
179 hasher.write_u32(self.to_bits());
180 }
181}
182
183impl FloatExt for f64 {
184 const NEG_HALF: Self = -0.5;
185
186 fn hash_bits<H: Hasher>(self, hasher: &mut H) {
187 hasher.write_u64(self.to_bits());
188 }
189}