numcodecs_uniform_noise/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
//!
//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
//!
//! [MSRV]: https://img.shields.io/badge/MSRV-1.81.0-blue
//! [repo]: https://github.com/juntyr/numcodecs-rs
//!
//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-uniform-noise
//! [crates.io]: https://crates.io/crates/numcodecs-uniform-noise
//!
//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-uniform-noise
//! [docs.rs]: https://docs.rs/numcodecs-uniform-noise/
//!
//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_uniform_noise
//!
//! Uniform noise codec implementation for the [`numcodecs`] API.
use std::hash::{Hash, Hasher};
use ndarray::{Array, ArrayBase, Data, Dimension};
use num_traits::Float;
use numcodecs::{
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
Codec, StaticCodec, StaticCodecConfig,
};
use rand::{
distributions::{Distribution, Open01},
SeedableRng,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use wyhash::{WyHash, WyRng};
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
/// Codec that adds `seed`ed `U(-scale/2, scale/2)` uniform noise of the given
/// `scale` during encoding and passes through the input unchanged during
/// decoding.
///
/// This codec first hashes the input array data and shape to then seed a
/// pseudo-random number generator that generates the uniform noise. Therefore,
/// passing in the same input with the same `seed` will produce the same noise
/// and thus the same encoded output.
pub struct UniformNoiseCodec {
/// Scale of the uniform noise, which is sampled from
/// `U(-scale/2, +scale/2)`
pub scale: f64,
/// Seed for the random noise generator
pub seed: u64,
}
impl Codec for UniformNoiseCodec {
type Error = UniformNoiseCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
match data {
#[allow(clippy::cast_possible_truncation)]
AnyCowArray::F32(data) => Ok(AnyArray::F32(add_uniform_noise(
data,
self.scale as f32,
self.seed,
))),
AnyCowArray::F64(data) => Ok(AnyArray::F64(add_uniform_noise(
data, self.scale, self.seed,
))),
encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
match encoded {
AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
encoded => Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode_into(
&self,
encoded: AnyArrayView,
mut decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
return Err(UniformNoiseCodecError::UnsupportedDtype(encoded.dtype()));
}
Ok(decoded.assign(&encoded)?)
}
}
impl StaticCodec for UniformNoiseCodec {
const CODEC_ID: &'static str = "uniform-noise";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
config
}
fn get_config(&self) -> StaticCodecConfig<Self> {
StaticCodecConfig::from(self)
}
}
#[derive(Debug, Error)]
/// Errors that may occur when applying the [`UniformNoiseCodec`].
pub enum UniformNoiseCodecError {
/// [`UniformNoiseCodec`] does not support the dtype
#[error("UniformNoise does not support the dtype {0}")]
UnsupportedDtype(AnyArrayDType),
/// [`UniformNoiseCodec`] cannot decode into the provided array
#[error("UniformNoise cannot decode into the provided array")]
MismatchedDecodeIntoArray {
/// The source of the error
#[from]
source: AnyArrayAssignError,
},
}
/// Adds `U(-scale/2, scale/2)` uniform random noise to the input `data`.
///
/// This function first hashes the input and its shape to then seed a pseudo-
/// random number generator that generates the uniform noise. Therefore,
/// passing in the same input with the same `seed` will produce the same noise
/// and thus the same output.
#[must_use]
pub fn add_uniform_noise<T: FloatExt, S: Data<Elem = T>, D: Dimension>(
data: ArrayBase<S, D>,
scale: T,
seed: u64,
) -> Array<T, D>
where
Open01: Distribution<T>,
{
let mut hasher = WyHash::with_seed(seed);
// hashing the shape provides a prefix for the flattened data
data.shape().hash(&mut hasher);
// the data must be visited in a defined order
data.iter().copied().for_each(|x| x.hash_bits(&mut hasher));
let seed = hasher.finish();
let mut rng: WyRng = WyRng::seed_from_u64(seed);
let mut encoded = data.into_owned();
// the data must be visited in a defined order
for x in &mut encoded {
// x = U(0,1)*scale + (scale*-0.5 + x)
// --- is equivalent to ---
// x += U(-scale/2, +scale/2)
*x = Open01
.sample(&mut rng)
.mul_add(scale, scale.mul_add(T::NEG_HALF, *x));
}
encoded
}
/// Floating point types
pub trait FloatExt: Float {
/// -0.5
const NEG_HALF: Self;
/// Hash the binary representation of the floating point value
fn hash_bits<H: Hasher>(self, hasher: &mut H);
}
impl FloatExt for f32 {
const NEG_HALF: Self = -0.5;
fn hash_bits<H: Hasher>(self, hasher: &mut H) {
hasher.write_u32(self.to_bits());
}
}
impl FloatExt for f64 {
const NEG_HALF: Self = -0.5;
fn hash_bits<H: Hasher>(self, hasher: &mut H) {
hasher.write_u64(self.to_bits());
}
}