numcodecs_fixed_offset_scale/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-fixed-offset-scale
10//! [crates.io]: https://crates.io/crates/numcodecs-fixed-offset-scale
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-fixed-offset-scale
13//! [docs.rs]: https://docs.rs/numcodecs-fixed-offset-scale/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_fixed_offset_scale
17//!
18//! `$\frac{x - o}{s}$` codec implementation for the [`numcodecs`] API.
19
20use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension};
21use num_traits::Float;
22use numcodecs::{
23    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
24    Codec, StaticCodec, StaticCodecConfig,
25};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29
30#[derive(Clone, Serialize, Deserialize, JsonSchema)]
31#[serde(deny_unknown_fields)]
32/// Fixed offset-scale codec which calculates `$c = \frac{x - o}{s}$` on
33/// encoding and `$d = (c \cdot s) + o$` on decoding.
34///
35/// - Setting `$o = \text{mean}(x)$` and `$s = \text{std}(x)$` normalizes that
36///   data.
37/// - Setting `$o = \text{min}(x)$` and `$s = \text{max}(x) - \text{min}(x)$`
38///   standardizes the data.
39///
40/// The codec only supports floating point numbers.
41pub struct FixedOffsetScaleCodec {
42    /// The offset of the data.
43    pub offset: f64,
44    /// The scale of the data.
45    pub scale: f64,
46}
47
48impl Codec for FixedOffsetScaleCodec {
49    type Error = FixedOffsetScaleCodecError;
50
51    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
52        match data {
53            #[expect(clippy::cast_possible_truncation)]
54            AnyCowArray::F32(data) => Ok(AnyArray::F32(scale(
55                data,
56                self.offset as f32,
57                self.scale as f32,
58            ))),
59            AnyCowArray::F64(data) => Ok(AnyArray::F64(scale(data, self.offset, self.scale))),
60            encoded => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
61                encoded.dtype(),
62            )),
63        }
64    }
65
66    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
67        match encoded {
68            #[expect(clippy::cast_possible_truncation)]
69            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(unscale(
70                encoded,
71                self.offset as f32,
72                self.scale as f32,
73            ))),
74            AnyCowArray::F64(encoded) => {
75                Ok(AnyArray::F64(unscale(encoded, self.offset, self.scale)))
76            }
77            encoded => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
78                encoded.dtype(),
79            )),
80        }
81    }
82
83    fn decode_into(
84        &self,
85        encoded: AnyArrayView,
86        decoded: AnyArrayViewMut,
87    ) -> Result<(), Self::Error> {
88        match (encoded, decoded) {
89            #[expect(clippy::cast_possible_truncation)]
90            (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
91                unscale_into(encoded, decoded, self.offset as f32, self.scale as f32)
92            }
93            (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
94                unscale_into(encoded, decoded, self.offset, self.scale)
95            }
96            (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => {
97                Err(FixedOffsetScaleCodecError::MismatchedDecodeIntoArray {
98                    source: AnyArrayAssignError::DTypeMismatch {
99                        src: encoded.dtype(),
100                        dst: decoded.dtype(),
101                    },
102                })
103            }
104            (encoded, _decoded) => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
105                encoded.dtype(),
106            )),
107        }
108    }
109}
110
111impl StaticCodec for FixedOffsetScaleCodec {
112    const CODEC_ID: &'static str = "fixed-offset-scale";
113
114    type Config<'de> = Self;
115
116    fn from_config(config: Self::Config<'_>) -> Self {
117        config
118    }
119
120    fn get_config(&self) -> StaticCodecConfig<Self> {
121        StaticCodecConfig::from(self)
122    }
123}
124
125#[derive(Debug, Error)]
126/// Errors that may occur when applying the [`FixedOffsetScaleCodec`].
127pub enum FixedOffsetScaleCodecError {
128    /// [`FixedOffsetScaleCodec`] does not support the dtype
129    #[error("FixedOffsetScale does not support the dtype {0}")]
130    UnsupportedDtype(AnyArrayDType),
131    /// [`FixedOffsetScaleCodec`] cannot decode into the provided array
132    #[error("FixedOffsetScale cannot decode into the provided array")]
133    MismatchedDecodeIntoArray {
134        /// The source of the error
135        #[from]
136        source: AnyArrayAssignError,
137    },
138}
139
140/// Compute `$\frac{x - o}{s}$` over the elements of the input `data` array.
141pub fn scale<T: Float, S: Data<Elem = T>, D: Dimension>(
142    data: ArrayBase<S, D>,
143    offset: T,
144    scale: T,
145) -> Array<T, D> {
146    let inverse_scale = scale.recip();
147
148    let mut data = data.into_owned();
149    data.mapv_inplace(|x| (x - offset) * inverse_scale);
150
151    data
152}
153
154/// Compute `$(x \cdot s) + o$` over the elements of the input `data` array.
155pub fn unscale<T: Float, S: Data<Elem = T>, D: Dimension>(
156    data: ArrayBase<S, D>,
157    offset: T,
158    scale: T,
159) -> Array<T, D> {
160    let mut data = data.into_owned();
161    data.mapv_inplace(|x| x.mul_add(scale, offset));
162
163    data
164}
165
166#[expect(clippy::needless_pass_by_value)]
167/// Compute `$(x \cdot s) + o$` over the elements of the input `data` array and
168/// write them into the `out`put array.
169///
170/// # Errors
171///
172/// Errors with
173/// - [`FixedOffsetScaleCodecError::MismatchedDecodeIntoArray`] if the `data`
174///   array's shape does not match the `out`put array's shape
175pub fn unscale_into<T: Float, D: Dimension>(
176    data: ArrayView<T, D>,
177    mut out: ArrayViewMut<T, D>,
178    offset: T,
179    scale: T,
180) -> Result<(), FixedOffsetScaleCodecError> {
181    if data.shape() != out.shape() {
182        return Err(FixedOffsetScaleCodecError::MismatchedDecodeIntoArray {
183            source: AnyArrayAssignError::ShapeMismatch {
184                src: data.shape().to_vec(),
185                dst: out.shape().to_vec(),
186            },
187        });
188    }
189
190    // iteration must occur in synchronised (standard) order
191    for (d, o) in data.iter().zip(out.iter_mut()) {
192        *o = (*d).mul_add(scale, offset);
193    }
194
195    Ok(())
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn identity() {
204        let data = (0..1000).map(f64::from).collect::<Vec<_>>();
205        let data = Array::from_vec(data);
206
207        let encoded = scale(data.view(), 0.0, 1.0);
208
209        for (r, e) in data.iter().zip(encoded.iter()) {
210            assert_eq!((*r).to_bits(), (*e).to_bits());
211        }
212
213        let decoded = unscale(encoded, 0.0, 1.0);
214
215        for (r, d) in data.iter().zip(decoded.iter()) {
216            assert_eq!((*r).to_bits(), (*d).to_bits());
217        }
218    }
219
220    #[test]
221    fn roundtrip() {
222        let data = (0..1000).map(f64::from).collect::<Vec<_>>();
223        let data = Array::from_vec(data);
224
225        let encoded = scale(data.view(), 512.0, 64.0);
226
227        for (r, e) in data.iter().zip(encoded.iter()) {
228            assert_eq!((((*r) - 512.0) / 64.0).to_bits(), (*e).to_bits());
229        }
230
231        let decoded = unscale(encoded, 512.0, 64.0);
232
233        for (r, d) in data.iter().zip(decoded.iter()) {
234            assert_eq!((*r).to_bits(), (*d).to_bits());
235        }
236    }
237}