numcodecs_fixed_offset_scale/
lib.rs1use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension};
21use num_traits::Float;
22use numcodecs::{
23 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
24 Codec, StaticCodec, StaticCodecConfig,
25};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29
30#[derive(Clone, Serialize, Deserialize, JsonSchema)]
31#[serde(deny_unknown_fields)]
32pub struct FixedOffsetScaleCodec {
42 pub offset: f64,
44 pub scale: f64,
46}
47
48impl Codec for FixedOffsetScaleCodec {
49 type Error = FixedOffsetScaleCodecError;
50
51 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
52 match data {
53 #[expect(clippy::cast_possible_truncation)]
54 AnyCowArray::F32(data) => Ok(AnyArray::F32(scale(
55 data,
56 self.offset as f32,
57 self.scale as f32,
58 ))),
59 AnyCowArray::F64(data) => Ok(AnyArray::F64(scale(data, self.offset, self.scale))),
60 encoded => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
61 encoded.dtype(),
62 )),
63 }
64 }
65
66 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
67 match encoded {
68 #[expect(clippy::cast_possible_truncation)]
69 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(unscale(
70 encoded,
71 self.offset as f32,
72 self.scale as f32,
73 ))),
74 AnyCowArray::F64(encoded) => {
75 Ok(AnyArray::F64(unscale(encoded, self.offset, self.scale)))
76 }
77 encoded => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
78 encoded.dtype(),
79 )),
80 }
81 }
82
83 fn decode_into(
84 &self,
85 encoded: AnyArrayView,
86 decoded: AnyArrayViewMut,
87 ) -> Result<(), Self::Error> {
88 match (encoded, decoded) {
89 #[expect(clippy::cast_possible_truncation)]
90 (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
91 unscale_into(encoded, decoded, self.offset as f32, self.scale as f32)
92 }
93 (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
94 unscale_into(encoded, decoded, self.offset, self.scale)
95 }
96 (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => {
97 Err(FixedOffsetScaleCodecError::MismatchedDecodeIntoArray {
98 source: AnyArrayAssignError::DTypeMismatch {
99 src: encoded.dtype(),
100 dst: decoded.dtype(),
101 },
102 })
103 }
104 (encoded, _decoded) => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
105 encoded.dtype(),
106 )),
107 }
108 }
109}
110
111impl StaticCodec for FixedOffsetScaleCodec {
112 const CODEC_ID: &'static str = "fixed-offset-scale";
113
114 type Config<'de> = Self;
115
116 fn from_config(config: Self::Config<'_>) -> Self {
117 config
118 }
119
120 fn get_config(&self) -> StaticCodecConfig<Self> {
121 StaticCodecConfig::from(self)
122 }
123}
124
125#[derive(Debug, Error)]
126pub enum FixedOffsetScaleCodecError {
128 #[error("FixedOffsetScale does not support the dtype {0}")]
130 UnsupportedDtype(AnyArrayDType),
131 #[error("FixedOffsetScale cannot decode into the provided array")]
133 MismatchedDecodeIntoArray {
134 #[from]
136 source: AnyArrayAssignError,
137 },
138}
139
140pub fn scale<T: Float, S: Data<Elem = T>, D: Dimension>(
142 data: ArrayBase<S, D>,
143 offset: T,
144 scale: T,
145) -> Array<T, D> {
146 let inverse_scale = scale.recip();
147
148 let mut data = data.into_owned();
149 data.mapv_inplace(|x| (x - offset) * inverse_scale);
150
151 data
152}
153
154pub fn unscale<T: Float, S: Data<Elem = T>, D: Dimension>(
156 data: ArrayBase<S, D>,
157 offset: T,
158 scale: T,
159) -> Array<T, D> {
160 let mut data = data.into_owned();
161 data.mapv_inplace(|x| x.mul_add(scale, offset));
162
163 data
164}
165
166#[expect(clippy::needless_pass_by_value)]
167pub fn unscale_into<T: Float, D: Dimension>(
176 data: ArrayView<T, D>,
177 mut out: ArrayViewMut<T, D>,
178 offset: T,
179 scale: T,
180) -> Result<(), FixedOffsetScaleCodecError> {
181 if data.shape() != out.shape() {
182 return Err(FixedOffsetScaleCodecError::MismatchedDecodeIntoArray {
183 source: AnyArrayAssignError::ShapeMismatch {
184 src: data.shape().to_vec(),
185 dst: out.shape().to_vec(),
186 },
187 });
188 }
189
190 for (d, o) in data.iter().zip(out.iter_mut()) {
192 *o = (*d).mul_add(scale, offset);
193 }
194
195 Ok(())
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn identity() {
204 let data = (0..1000).map(f64::from).collect::<Vec<_>>();
205 let data = Array::from_vec(data);
206
207 let encoded = scale(data.view(), 0.0, 1.0);
208
209 for (r, e) in data.iter().zip(encoded.iter()) {
210 assert_eq!((*r).to_bits(), (*e).to_bits());
211 }
212
213 let decoded = unscale(encoded, 0.0, 1.0);
214
215 for (r, d) in data.iter().zip(decoded.iter()) {
216 assert_eq!((*r).to_bits(), (*d).to_bits());
217 }
218 }
219
220 #[test]
221 fn roundtrip() {
222 let data = (0..1000).map(f64::from).collect::<Vec<_>>();
223 let data = Array::from_vec(data);
224
225 let encoded = scale(data.view(), 512.0, 64.0);
226
227 for (r, e) in data.iter().zip(encoded.iter()) {
228 assert_eq!((((*r) - 512.0) / 64.0).to_bits(), (*e).to_bits());
229 }
230
231 let decoded = unscale(encoded, 512.0, 64.0);
232
233 for (r, d) in data.iter().zip(decoded.iter()) {
234 assert_eq!((*r).to_bits(), (*d).to_bits());
235 }
236 }
237}