numcodecs_fixed_offset_scale/
lib.rs1use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension};
21use num_traits::Float;
22use numcodecs::{
23 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
24 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
25};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29
30#[derive(Clone, Serialize, Deserialize, JsonSchema)]
31#[serde(deny_unknown_fields)]
32pub struct FixedOffsetScaleCodec {
42 pub offset: f64,
44 pub scale: f64,
46 #[serde(default, rename = "_version")]
48 pub version: StaticCodecVersion<1, 0, 0>,
49}
50
51impl Codec for FixedOffsetScaleCodec {
52 type Error = FixedOffsetScaleCodecError;
53
54 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
55 match data {
56 #[expect(clippy::cast_possible_truncation)]
57 AnyCowArray::F32(data) => Ok(AnyArray::F32(scale(
58 data,
59 self.offset as f32,
60 self.scale as f32,
61 ))),
62 AnyCowArray::F64(data) => Ok(AnyArray::F64(scale(data, self.offset, self.scale))),
63 encoded => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
64 encoded.dtype(),
65 )),
66 }
67 }
68
69 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
70 match encoded {
71 #[expect(clippy::cast_possible_truncation)]
72 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(unscale(
73 encoded,
74 self.offset as f32,
75 self.scale as f32,
76 ))),
77 AnyCowArray::F64(encoded) => {
78 Ok(AnyArray::F64(unscale(encoded, self.offset, self.scale)))
79 }
80 encoded => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
81 encoded.dtype(),
82 )),
83 }
84 }
85
86 fn decode_into(
87 &self,
88 encoded: AnyArrayView,
89 decoded: AnyArrayViewMut,
90 ) -> Result<(), Self::Error> {
91 match (encoded, decoded) {
92 #[expect(clippy::cast_possible_truncation)]
93 (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
94 unscale_into(encoded, decoded, self.offset as f32, self.scale as f32)
95 }
96 (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
97 unscale_into(encoded, decoded, self.offset, self.scale)
98 }
99 (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => {
100 Err(FixedOffsetScaleCodecError::MismatchedDecodeIntoArray {
101 source: AnyArrayAssignError::DTypeMismatch {
102 src: encoded.dtype(),
103 dst: decoded.dtype(),
104 },
105 })
106 }
107 (encoded, _decoded) => Err(FixedOffsetScaleCodecError::UnsupportedDtype(
108 encoded.dtype(),
109 )),
110 }
111 }
112}
113
114impl StaticCodec for FixedOffsetScaleCodec {
115 const CODEC_ID: &'static str = "fixed-offset-scale.rs";
116
117 type Config<'de> = Self;
118
119 fn from_config(config: Self::Config<'_>) -> Self {
120 config
121 }
122
123 fn get_config(&self) -> StaticCodecConfig<Self> {
124 StaticCodecConfig::from(self)
125 }
126}
127
128#[derive(Debug, Error)]
129pub enum FixedOffsetScaleCodecError {
131 #[error("FixedOffsetScale does not support the dtype {0}")]
133 UnsupportedDtype(AnyArrayDType),
134 #[error("FixedOffsetScale cannot decode into the provided array")]
136 MismatchedDecodeIntoArray {
137 #[from]
139 source: AnyArrayAssignError,
140 },
141}
142
143pub fn scale<T: Float, S: Data<Elem = T>, D: Dimension>(
145 data: ArrayBase<S, D>,
146 offset: T,
147 scale: T,
148) -> Array<T, D> {
149 let inverse_scale = scale.recip();
150
151 let mut data = data.into_owned();
152 data.mapv_inplace(|x| (x - offset) * inverse_scale);
153
154 data
155}
156
157pub fn unscale<T: Float, S: Data<Elem = T>, D: Dimension>(
159 data: ArrayBase<S, D>,
160 offset: T,
161 scale: T,
162) -> Array<T, D> {
163 let mut data = data.into_owned();
164 data.mapv_inplace(|x| x.mul_add(scale, offset));
165
166 data
167}
168
169#[expect(clippy::needless_pass_by_value)]
170pub fn unscale_into<T: Float, D: Dimension>(
179 data: ArrayView<T, D>,
180 mut out: ArrayViewMut<T, D>,
181 offset: T,
182 scale: T,
183) -> Result<(), FixedOffsetScaleCodecError> {
184 if data.shape() != out.shape() {
185 return Err(FixedOffsetScaleCodecError::MismatchedDecodeIntoArray {
186 source: AnyArrayAssignError::ShapeMismatch {
187 src: data.shape().to_vec(),
188 dst: out.shape().to_vec(),
189 },
190 });
191 }
192
193 for (d, o) in data.iter().zip(out.iter_mut()) {
195 *o = (*d).mul_add(scale, offset);
196 }
197
198 Ok(())
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn identity() {
207 let data = (0..1000).map(f64::from).collect::<Vec<_>>();
208 let data = Array::from_vec(data);
209
210 let encoded = scale(data.view(), 0.0, 1.0);
211
212 for (r, e) in data.iter().zip(encoded.iter()) {
213 assert_eq!((*r).to_bits(), (*e).to_bits());
214 }
215
216 let decoded = unscale(encoded, 0.0, 1.0);
217
218 for (r, d) in data.iter().zip(decoded.iter()) {
219 assert_eq!((*r).to_bits(), (*d).to_bits());
220 }
221 }
222
223 #[test]
224 fn roundtrip() {
225 let data = (0..1000).map(f64::from).collect::<Vec<_>>();
226 let data = Array::from_vec(data);
227
228 let encoded = scale(data.view(), 512.0, 64.0);
229
230 for (r, e) in data.iter().zip(encoded.iter()) {
231 assert_eq!((((*r) - 512.0) / 64.0).to_bits(), (*e).to_bits());
232 }
233
234 let decoded = unscale(encoded, 512.0, 64.0);
235
236 for (r, d) in data.iter().zip(decoded.iter()) {
237 assert_eq!((*r).to_bits(), (*d).to_bits());
238 }
239 }
240}