numcodecs_uniform_noise/
lib.rs1use std::hash::{Hash, Hasher};
21
22use ndarray::{Array, ArrayBase, Data, Dimension};
23use num_traits::Float;
24use numcodecs::{
25 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26 Codec, StaticCodec, StaticCodecConfig,
27};
28use rand::{
29 distributions::{Distribution, Open01},
30 SeedableRng,
31};
32use schemars::JsonSchema;
33use serde::{Deserialize, Serialize};
34use thiserror::Error;
35use wyhash::{WyHash, WyRng};
36
37#[derive(Clone, Serialize, Deserialize, JsonSchema)]
38#[serde(deny_unknown_fields)]
39pub struct UniformNoiseCodec {
48 pub scale: f64,
51 pub seed: u64,
53}
54
55impl Codec for UniformNoiseCodec {
56 type Error = UniformNoiseCodecError;
57
58 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
59 match data {
60 #[expect(clippy::cast_possible_truncation)]
61 AnyCowArray::F32(data) => Ok(AnyArray::F32(add_uniform_noise(
62 data,
63 self.scale as f32,
64 self.seed,
65 ))),
66 AnyCowArray::F64(data) => Ok(AnyArray::F64(add_uniform_noise(
67 data, self.scale, self.seed,
68 ))),
69 encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
70 }
71 }
72
73 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
74 match encoded {
75 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
76 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
77 encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
78 }
79 }
80
81 fn decode_into(
82 &self,
83 encoded: AnyArrayView,
84 mut decoded: AnyArrayViewMut,
85 ) -> Result<(), Self::Error> {
86 if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
87 return Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype()));
88 }
89
90 Ok(decoded.assign(&encoded)?)
91 }
92}
93
94impl StaticCodec for UniformNoiseCodec {
95 const CODEC_ID: &'static str = "uniform-noise";
96
97 type Config<'de> = Self;
98
99 fn from_config(config: Self::Config<'_>) -> Self {
100 config
101 }
102
103 fn get_config(&self) -> StaticCodecConfig<Self> {
104 StaticCodecConfig::from(self)
105 }
106}
107
108#[derive(Debug, Error)]
109pub enum UniformNoiseCodecError {
111 #[error("UniformNoise does not support the dtype {0}")]
113 UnsupportedDtype(AnyArrayDType),
114 #[error("UniformNoise cannot decode into the provided array")]
116 MismatchedDecodeIntoArray {
117 #[from]
119 source: AnyArrayAssignError,
120 },
121}
122
123#[must_use]
131pub fn add_uniform_noise<T: FloatExt, S: Data<Elem = T>, D: Dimension>(
132 data: ArrayBase<S, D>,
133 scale: T,
134 seed: u64,
135) -> Array<T, D>
136where
137 Open01: Distribution<T>,
138{
139 let mut hasher = WyHash::with_seed(seed);
140 data.shape().hash(&mut hasher);
142 data.iter().copied().for_each(|x| x.hash_bits(&mut hasher));
144 let seed = hasher.finish();
145
146 let mut rng: WyRng = WyRng::seed_from_u64(seed);
147
148 let mut encoded = data.into_owned();
149
150 for x in &mut encoded {
152 *x = Open01
156 .sample(&mut rng)
157 .mul_add(scale, scale.mul_add(T::NEG_HALF, *x));
158 }
159
160 encoded
161}
162
163pub trait FloatExt: Float {
165 const NEG_HALF: Self;
167
168 fn hash_bits<H: Hasher>(self, hasher: &mut H);
170}
171
172impl FloatExt for f32 {
173 const NEG_HALF: Self = -0.5;
174
175 fn hash_bits<H: Hasher>(self, hasher: &mut H) {
176 hasher.write_u32(self.to_bits());
177 }
178}
179
180impl FloatExt for f64 {
181 const NEG_HALF: Self = -0.5;
182
183 fn hash_bits<H: Hasher>(self, hasher: &mut H) {
184 hasher.write_u64(self.to_bits());
185 }
186}