numcodecs_stochastic_rounding/
lib.rs1use std::borrow::Cow;
21use std::hash::{Hash, Hasher};
22
23use ndarray::{Array, ArrayBase, Data, Dimension};
24use num_traits::Float;
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use rand::{
30 SeedableRng,
31 distr::{Distribution, Open01},
32};
33use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
34use serde::{Deserialize, Deserializer, Serialize, Serializer};
35use thiserror::Error;
36use wyhash::{WyHash, WyRng};
37
38#[derive(Clone, Serialize, Deserialize, JsonSchema)]
39#[serde(deny_unknown_fields)]
40pub struct StochasticRoundingCodec {
53 pub precision: NonNegative<f64>,
55 pub seed: u64,
57 #[serde(default, rename = "_version")]
59 pub version: StaticCodecVersion<1, 0, 0>,
60}
61
62impl Codec for StochasticRoundingCodec {
63 type Error = StochasticRoundingCodecError;
64
65 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
66 match data {
67 #[expect(clippy::cast_possible_truncation)]
68 AnyCowArray::F32(data) => Ok(AnyArray::F32(stochastic_rounding(
69 data,
70 NonNegative(self.precision.0 as f32),
71 self.seed,
72 ))),
73 AnyCowArray::F64(data) => Ok(AnyArray::F64(stochastic_rounding(
74 data,
75 self.precision,
76 self.seed,
77 ))),
78 encoded => Err(StochasticRoundingCodecError::UnsupportedDtype(
79 encoded.dtype(),
80 )),
81 }
82 }
83
84 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
85 match encoded {
86 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
87 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
88 encoded => Err(StochasticRoundingCodecError::UnsupportedDtype(
89 encoded.dtype(),
90 )),
91 }
92 }
93
94 fn decode_into(
95 &self,
96 encoded: AnyArrayView,
97 mut decoded: AnyArrayViewMut,
98 ) -> Result<(), Self::Error> {
99 if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
100 return Err(StochasticRoundingCodecError::UnsupportedDtype(
101 encoded.dtype(),
102 ));
103 }
104
105 Ok(decoded.assign(&encoded)?)
106 }
107}
108
109impl StaticCodec for StochasticRoundingCodec {
110 const CODEC_ID: &'static str = "stochastic-rounding.rs";
111
112 type Config<'de> = Self;
113
114 fn from_config(config: Self::Config<'_>) -> Self {
115 config
116 }
117
118 fn get_config(&self) -> StaticCodecConfig<'_, Self> {
119 StaticCodecConfig::from(self)
120 }
121}
122
123#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
125pub struct NonNegative<T: Float>(T);
127
128impl Serialize for NonNegative<f64> {
129 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
130 serializer.serialize_f64(self.0)
131 }
132}
133
134impl<'de> Deserialize<'de> for NonNegative<f64> {
135 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
136 let x = f64::deserialize(deserializer)?;
137
138 if x >= 0.0 {
139 Ok(Self(x))
140 } else {
141 Err(serde::de::Error::invalid_value(
142 serde::de::Unexpected::Float(x),
143 &"a non-negative value",
144 ))
145 }
146 }
147}
148
149impl JsonSchema for NonNegative<f64> {
150 fn schema_name() -> Cow<'static, str> {
151 Cow::Borrowed("NonNegativeF64")
152 }
153
154 fn schema_id() -> Cow<'static, str> {
155 Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
156 }
157
158 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
159 json_schema!({
160 "type": "number",
161 "minimum": 0.0
162 })
163 }
164}
165
166#[derive(Debug, Error)]
167pub enum StochasticRoundingCodecError {
169 #[error("StochasticRounding does not support the dtype {0}")]
171 UnsupportedDtype(AnyArrayDType),
172 #[error("StochasticRounding cannot decode into the provided array")]
174 MismatchedDecodeIntoArray {
175 #[from]
177 source: AnyArrayAssignError,
178 },
179}
180
181#[must_use]
192pub fn stochastic_rounding<T: FloatExt, S: Data<Elem = T>, D: Dimension>(
193 data: ArrayBase<S, D>,
194 precision: NonNegative<T>,
195 seed: u64,
196) -> Array<T, D>
197where
198 Open01: Distribution<T>,
199{
200 let mut encoded = data.into_owned();
201
202 if precision.0.is_zero() {
203 return encoded;
204 }
205
206 let mut hasher = WyHash::with_seed(seed);
207 encoded.shape().hash(&mut hasher);
209 encoded
211 .iter()
212 .copied()
213 .for_each(|x| x.hash_bits(&mut hasher));
214 let seed = hasher.finish();
215
216 let mut rng: WyRng = WyRng::seed_from_u64(seed);
217
218 for x in &mut encoded {
220 if !x.is_finite() {
221 continue;
222 }
223
224 let remainder = x.rem_euclid(precision.0);
225
226 let mut lower = *x - remainder;
230 if (*x - lower) > precision.0 {
231 lower = lower.next_up();
232 }
233 let mut upper = *x + (precision.0 - remainder);
234 if (upper - *x) > precision.0 {
235 upper = upper.next_down();
236 }
237
238 let threshold = remainder / precision.0;
239
240 let u01: T = Open01.sample(&mut rng);
241
242 *x = if u01 >= threshold { lower } else { upper };
245 }
246
247 encoded
248}
249
250pub trait FloatExt: Float {
252 const NEG_HALF: Self;
254
255 fn hash_bits<H: Hasher>(self, hasher: &mut H);
257
258 #[must_use]
260 fn rem_euclid(self, rhs: Self) -> Self;
261
262 #[must_use]
264 fn next_up(self) -> Self;
265
266 #[must_use]
268 fn next_down(self) -> Self;
269}
270
271impl FloatExt for f32 {
272 const NEG_HALF: Self = -0.5;
273
274 fn hash_bits<H: Hasher>(self, hasher: &mut H) {
275 hasher.write_u32(self.to_bits());
276 }
277
278 fn rem_euclid(self, rhs: Self) -> Self {
279 Self::rem_euclid(self, rhs)
280 }
281
282 fn next_up(self) -> Self {
283 Self::next_up(self)
284 }
285
286 fn next_down(self) -> Self {
287 Self::next_down(self)
288 }
289}
290
291impl FloatExt for f64 {
292 const NEG_HALF: Self = -0.5;
293
294 fn hash_bits<H: Hasher>(self, hasher: &mut H) {
295 hasher.write_u64(self.to_bits());
296 }
297
298 fn rem_euclid(self, rhs: Self) -> Self {
299 Self::rem_euclid(self, rhs)
300 }
301
302 fn next_up(self) -> Self {
303 Self::next_up(self)
304 }
305
306 fn next_down(self) -> Self {
307 Self::next_down(self)
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use ndarray::{array, linspace};
314
315 use super::*;
316
317 #[test]
318 fn round_zero_precision() {
319 let data = array![1.1, 2.1];
320
321 let rounded = stochastic_rounding(data.view(), NonNegative(0.0), 42);
322
323 assert_eq!(data, rounded);
324 }
325
326 #[test]
327 fn round_infinite_precision() {
328 let data = array![1.1, 2.1];
329
330 let rounded = stochastic_rounding(data.view(), NonNegative(f64::INFINITY), 42);
331
332 assert_eq!(rounded, array![0.0, 0.0]);
333 }
334
335 #[test]
336 fn round_minimal_precision() {
337 let data = array![0.1, 1.0, 11.0, 21.0];
338
339 assert_eq!(11.0 / f64::MIN_POSITIVE, f64::INFINITY);
340 let rounded = stochastic_rounding(data.view(), NonNegative(f64::MIN_POSITIVE), 42);
341
342 assert_eq!(data, rounded);
343 }
344
345 #[test]
346 fn round_edge_cases() {
347 let data = array![
348 -f64::NAN,
349 -f64::INFINITY,
350 -42.0,
351 -4.2,
352 -0.0,
353 0.0,
354 4.2,
355 42.0,
356 f64::INFINITY,
357 f64::NAN
358 ];
359 let precision = 1.0;
360
361 let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
362
363 for (d, r) in data.into_iter().zip(rounded) {
364 assert!((r - d).abs() <= precision || d.to_bits() == r.to_bits());
365 }
366 }
367
368 #[test]
369 fn round_rounding_errors() {
370 let data = Array::from_iter(linspace(-100.0, 100.0, 3741));
371 let precision = 0.1;
372
373 let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
374
375 for (d, r) in data.into_iter().zip(rounded) {
376 assert!((r - d).abs() <= precision);
377 }
378 }
379
380 #[test]
381 fn test_rounding_bug() {
382 let data = array![
383 -1.23540_f32,
384 -1.23539_f32,
385 -1.23538_f32,
386 -1.23537_f32,
387 -1.23536_f32,
388 -1.23535_f32,
389 -1.23534_f32,
390 -1.23533_f32,
391 -1.23532_f32,
392 -1.23531_f32,
393 -1.23530_f32,
394 1.23540_f32,
395 1.23539_f32,
396 1.23538_f32,
397 1.23537_f32,
398 1.23536_f32,
399 1.23535_f32,
400 1.23534_f32,
401 1.23533_f32,
402 1.23532_f32,
403 1.23531_f32,
404 1.23530_f32,
405 ];
406 let precision = 0.00018_f32;
407
408 let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
409
410 for (d, r) in data.into_iter().zip(rounded) {
411 assert!((r - d).abs() <= precision);
412 }
413 }
414}