numcodecs_fixed_offset_scale/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-fixed-offset-scale
10//! [crates.io]: https://crates.io/crates/numcodecs-fixed-offset-scale
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-fixed-offset-scale
13//! [docs.rs]: https://docs.rs/numcodecs-fixed-offset-scale/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_fixed_offset_scale
17//!
18//! `$\frac{x - o}{s}$` codec implementation for the [`numcodecs`] API.
19
20use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension};
21use num_traits::Float;
22use numcodecs::{
23    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
24    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
25};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29
30#[derive(Clone, Serialize, Deserialize, JsonSchema)]
31#[serde(deny_unknown_fields)]
32/// Fixed offset-scale codec which calculates `$c = \frac{x - o}{s}$` on
33/// encoding and `$d = (c \cdot s) + o$` on decoding.
34///
35/// - Setting `$o = \text{mean}(x)$` and `$s = \text{std}(x)$` normalizes that
36///   data.
37/// - Setting `$o = \text{min}(x)$` and `$s = \text{max}(x) - \text{min}(x)$`
38///   standardizes the data.
39///
40/// The codec only supports floating point numbers.
41pub struct FixedOffsetScaleCodec {
42    /// The offset of the data.
43    pub offset: f64,
44    /// The scale of the data.
45    pub scale: f64,
46    /// The codec's encoding format version. Do not provide this parameter explicitly.
47    #[serde(default, rename = "_version")]
48    pub version: StaticCodecVersion<1, 0, 0>,
49}
50
51impl Codec for FixedOffsetScaleCodec {
52    type Error = FixedOffsetScaleCodecError;
53
54    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
55        match data {
56            #[expect(clippy::cast_possible_truncation)]
57            AnyCowArray::F32(data) => Ok(AnyArray::F32(scale(
58                data,
59                self.offset as f32,
60                self.scale as f32,
61            ))),
62            AnyCowArray::F64(data) => Ok(AnyArray::F64(scale(data, self.offset, self.scale))),
63            encoded => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
64                encoded.dtype(),
65            )),
66        }
67    }
68
69    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
70        match encoded {
71            #[expect(clippy::cast_possible_truncation)]
72            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(unscale(
73                encoded,
74                self.offset as f32,
75                self.scale as f32,
76            ))),
77            AnyCowArray::F64(encoded) => {
78                Ok(AnyArray::F64(unscale(encoded, self.offset, self.scale)))
79            }
80            encoded => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
81                encoded.dtype(),
82            )),
83        }
84    }
85
86    fn decode_into(
87        &self,
88        encoded: AnyArrayView,
89        decoded: AnyArrayViewMut,
90    ) -> Result<(), Self::Error> {
91        match (encoded, decoded) {
92            #[expect(clippy::cast_possible_truncation)]
93            (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
94                unscale_into(encoded, decoded, self.offset as f32, self.scale as f32)
95            }
96            (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
97                unscale_into(encoded, decoded, self.offset, self.scale)
98            }
99            (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => {
100                Err(FixedOffsetScaleCodecError::MismatchedDecodeIntoArray {
101                    source: AnyArrayAssignError::DTypeMismatch {
102                        src: encoded.dtype(),
103                        dst: decoded.dtype(),
104                    },
105                })
106            }
107            (encoded, _decoded) => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
108                encoded.dtype(),
109            )),
110        }
111    }
112}
113
114impl StaticCodec for FixedOffsetScaleCodec {
115    const CODEC_ID: &'static str = "fixed-offset-scale.rs";
116
117    type Config<'de> = Self;
118
119    fn from_config(config: Self::Config<'_>) -> Self {
120        config
121    }
122
123    fn get_config(&self) -> StaticCodecConfig<Self> {
124        StaticCodecConfig::from(self)
125    }
126}
127
128#[derive(Debug, Error)]
129/// Errors that may occur when applying the [`FixedOffsetScaleCodec`].
130pub enum FixedOffsetScaleCodecError {
131    /// [`FixedOffsetScaleCodec`] does not support the dtype
132    #[error("FixedOffsetScale does not support the dtype {0}")]
133    UnsupportedDtype(AnyArrayDType),
134    /// [`FixedOffsetScaleCodec`] cannot decode into the provided array
135    #[error("FixedOffsetScale cannot decode into the provided array")]
136    MismatchedDecodeIntoArray {
137        /// The source of the error
138        #[from]
139        source: AnyArrayAssignError,
140    },
141}
142
143/// Compute `$\frac{x - o}{s}$` over the elements of the input `data` array.
144pub fn scale<T: Float, S: Data<Elem = T>, D: Dimension>(
145    data: ArrayBase<S, D>,
146    offset: T,
147    scale: T,
148) -> Array<T, D> {
149    let inverse_scale = scale.recip();
150
151    let mut data = data.into_owned();
152    data.mapv_inplace(|x| (x - offset) * inverse_scale);
153
154    data
155}
156
157/// Compute `$(x \cdot s) + o$` over the elements of the input `data` array.
158pub fn unscale<T: Float, S: Data<Elem = T>, D: Dimension>(
159    data: ArrayBase<S, D>,
160    offset: T,
161    scale: T,
162) -> Array<T, D> {
163    let mut data = data.into_owned();
164    data.mapv_inplace(|x| x.mul_add(scale, offset));
165
166    data
167}
168
169#[expect(clippy::needless_pass_by_value)]
170/// Compute `$(x \cdot s) + o$` over the elements of the input `data` array and
171/// write them into the `out`put array.
172///
173/// # Errors
174///
175/// Errors with
176/// - [`FixedOffsetScaleCodecError::MismatchedDecodeIntoArray`] if the `data`
177///   array's shape does not match the `out`put array's shape
178pub fn unscale_into<T: Float, D: Dimension>(
179    data: ArrayView<T, D>,
180    mut out: ArrayViewMut<T, D>,
181    offset: T,
182    scale: T,
183) -> Result<(), FixedOffsetScaleCodecError> {
184    if data.shape() != out.shape() {
185        return Err(FixedOffsetScaleCodecError::MismatchedDecodeIntoArray {
186            source: AnyArrayAssignError::ShapeMismatch {
187                src: data.shape().to_vec(),
188                dst: out.shape().to_vec(),
189            },
190        });
191    }
192
193    // iteration must occur in synchronised (standard) order
194    for (d, o) in data.iter().zip(out.iter_mut()) {
195        *o = (*d).mul_add(scale, offset);
196    }
197
198    Ok(())
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn identity() {
207        let data = (0..1000).map(f64::from).collect::<Vec<_>>();
208        let data = Array::from_vec(data);
209
210        let encoded = scale(data.view(), 0.0, 1.0);
211
212        for (r, e) in data.iter().zip(encoded.iter()) {
213            assert_eq!((*r).to_bits(), (*e).to_bits());
214        }
215
216        let decoded = unscale(encoded, 0.0, 1.0);
217
218        for (r, d) in data.iter().zip(decoded.iter()) {
219            assert_eq!((*r).to_bits(), (*d).to_bits());
220        }
221    }
222
223    #[test]
224    fn roundtrip() {
225        let data = (0..1000).map(f64::from).collect::<Vec<_>>();
226        let data = Array::from_vec(data);
227
228        let encoded = scale(data.view(), 512.0, 64.0);
229
230        for (r, e) in data.iter().zip(encoded.iter()) {
231            assert_eq!((((*r) - 512.0) / 64.0).to_bits(), (*e).to_bits());
232        }
233
234        let decoded = unscale(encoded, 512.0, 64.0);
235
236        for (r, d) in data.iter().zip(decoded.iter()) {
237            assert_eq!((*r).to_bits(), (*d).to_bits());
238        }
239    }
240}