numcodecs_reinterpret/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-reinterpret
10//! [crates.io]: https://crates.io/crates/numcodecs-reinterpret
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-reinterpret
13//! [docs.rs]: https://docs.rs/numcodecs-reinterpret/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_reinterpret
17//!
18//! Binary reinterpret codec implementation for the [`numcodecs`] API.
19
20use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension, ViewRepr};
21use numcodecs::{
22    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23    ArrayDType, Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use thiserror::Error;
28
29#[derive(Clone, JsonSchema)]
30#[serde(deny_unknown_fields)]
31/// Codec to reinterpret data between different compatible types.
32///
33/// Note that no conversion happens, only the meaning of the bits changes.
34///
35/// Reinterpreting to bytes, or to a same-sized unsigned integer type, or
36/// without the changing the dtype are supported.
37pub struct ReinterpretCodec {
38    /// Dtype of the encoded data.
39    encode_dtype: AnyArrayDType,
40    /// Dtype of the decoded data
41    decode_dtype: AnyArrayDType,
42    /// The codec's encoding format version. Do not provide this parameter explicitly.
43    #[serde(default)]
44    _version: StaticCodecVersion<1, 0, 0>,
45}
46
47impl ReinterpretCodec {
48    /// Try to create a [`ReinterpretCodec`] that reinterprets the input data
49    /// from `decode_dtype` to `encode_dtype` on encoding, and from
50    /// `encode_dtype` back to `decode_dtype` on decoding.
51    ///
52    /// # Errors
53    ///
54    /// Errors with [`ReinterpretCodecError::InvalidReinterpret`] if
55    /// `encode_dtype` and `decode_dtype` are incompatible.
56    pub fn try_new(
57        encode_dtype: AnyArrayDType,
58        decode_dtype: AnyArrayDType,
59    ) -> Result<Self, ReinterpretCodecError> {
60        #[expect(clippy::match_same_arms)]
61        match (decode_dtype, encode_dtype) {
62            // performing no conversion always works
63            (ty_a, ty_b) if ty_a == ty_b => (),
64            // converting to bytes always works
65            (_, AnyArrayDType::U8) => (),
66            // converting from signed / floating to same-size binary always works
67            (AnyArrayDType::I16, AnyArrayDType::U16)
68            | (AnyArrayDType::I32 | AnyArrayDType::F32, AnyArrayDType::U32)
69            | (AnyArrayDType::I64 | AnyArrayDType::F64, AnyArrayDType::U64) => (),
70            (decode_dtype, encode_dtype) => {
71                return Err(ReinterpretCodecError::InvalidReinterpret {
72                    decode_dtype,
73                    encode_dtype,
74                })
75            }
76        }
77
78        Ok(Self {
79            encode_dtype,
80            decode_dtype,
81            _version: StaticCodecVersion,
82        })
83    }
84
85    #[must_use]
86    /// Create a [`ReinterpretCodec`] that does not change the `dtype`.
87    pub const fn passthrough(dtype: AnyArrayDType) -> Self {
88        Self {
89            encode_dtype: dtype,
90            decode_dtype: dtype,
91            _version: StaticCodecVersion,
92        }
93    }
94
95    #[must_use]
96    /// Create a [`ReinterpretCodec`] that reinterprets `dtype` as
97    /// [bytes][`AnyArrayDType::U8`].
98    pub const fn to_bytes(dtype: AnyArrayDType) -> Self {
99        Self {
100            encode_dtype: AnyArrayDType::U8,
101            decode_dtype: dtype,
102            _version: StaticCodecVersion,
103        }
104    }
105
106    #[must_use]
107    /// Create a  [`ReinterpretCodec`] that reinterprets `dtype` as its
108    /// [binary][`AnyArrayDType::to_binary`] equivalent.
109    pub const fn to_binary(dtype: AnyArrayDType) -> Self {
110        Self {
111            encode_dtype: dtype.to_binary(),
112            decode_dtype: dtype,
113            _version: StaticCodecVersion,
114        }
115    }
116}
117
118impl Codec for ReinterpretCodec {
119    type Error = ReinterpretCodecError;
120
121    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
122        if data.dtype() != self.decode_dtype {
123            return Err(ReinterpretCodecError::MismatchedEncodeDType {
124                configured: self.decode_dtype,
125                provided: data.dtype(),
126            });
127        }
128
129        let encoded = match (data, self.encode_dtype) {
130            (data, dtype) if data.dtype() == dtype => data.into_owned(),
131            (data, AnyArrayDType::U8) => {
132                let mut shape = data.shape().to_vec();
133                if let Some(last) = shape.last_mut() {
134                    *last *= data.dtype().size();
135                }
136                #[expect(unsafe_code)]
137                // Safety: the shape is extended to match the expansion into bytes
138                let encoded =
139                    unsafe { Array::from_shape_vec_unchecked(shape, data.as_bytes().into_owned()) };
140                AnyArray::U8(encoded)
141            }
142            (AnyCowArray::I16(data), AnyArrayDType::U16) => {
143                AnyArray::U16(reinterpret_array(data, |x| {
144                    u16::from_ne_bytes(x.to_ne_bytes())
145                }))
146            }
147            (AnyCowArray::I32(data), AnyArrayDType::U32) => {
148                AnyArray::U32(reinterpret_array(data, |x| {
149                    u32::from_ne_bytes(x.to_ne_bytes())
150                }))
151            }
152            (AnyCowArray::F32(data), AnyArrayDType::U32) => {
153                AnyArray::U32(reinterpret_array(data, f32::to_bits))
154            }
155            (AnyCowArray::I64(data), AnyArrayDType::U64) => {
156                AnyArray::U64(reinterpret_array(data, |x| {
157                    u64::from_ne_bytes(x.to_ne_bytes())
158                }))
159            }
160            (AnyCowArray::F64(data), AnyArrayDType::U64) => {
161                AnyArray::U64(reinterpret_array(data, f64::to_bits))
162            }
163            (data, dtype) => {
164                return Err(ReinterpretCodecError::InvalidReinterpret {
165                    decode_dtype: data.dtype(),
166                    encode_dtype: dtype,
167                });
168            }
169        };
170
171        Ok(encoded)
172    }
173
174    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
175        if encoded.dtype() != self.encode_dtype {
176            return Err(ReinterpretCodecError::MismatchedDecodeDType {
177                configured: self.encode_dtype,
178                provided: encoded.dtype(),
179            });
180        }
181
182        let decoded = match (encoded, self.decode_dtype) {
183            (encoded, dtype) if encoded.dtype() == dtype => encoded.into_owned(),
184            (AnyCowArray::U8(encoded), dtype) => {
185                let mut shape = encoded.shape().to_vec();
186
187                if (encoded.len() % dtype.size()) != 0 {
188                    return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
189                }
190
191                if let Some(last) = shape.last_mut() {
192                    *last /= dtype.size();
193                }
194
195                let (decoded, ()) = AnyArray::with_zeros_bytes(dtype, &shape, |bytes| {
196                    bytes.copy_from_slice(&AnyCowArray::U8(encoded).as_bytes());
197                });
198
199                decoded
200            }
201            (AnyCowArray::U16(encoded), AnyArrayDType::I16) => {
202                AnyArray::I16(reinterpret_array(encoded, |x| {
203                    i16::from_ne_bytes(x.to_ne_bytes())
204                }))
205            }
206            (AnyCowArray::U32(encoded), AnyArrayDType::I32) => {
207                AnyArray::I32(reinterpret_array(encoded, |x| {
208                    i32::from_ne_bytes(x.to_ne_bytes())
209                }))
210            }
211            (AnyCowArray::U32(encoded), AnyArrayDType::F32) => {
212                AnyArray::F32(reinterpret_array(encoded, f32::from_bits))
213            }
214            (AnyCowArray::U64(encoded), AnyArrayDType::U64) => {
215                AnyArray::I64(reinterpret_array(encoded, |x| {
216                    i64::from_ne_bytes(x.to_ne_bytes())
217                }))
218            }
219            (AnyCowArray::U64(encoded), AnyArrayDType::F64) => {
220                AnyArray::F64(reinterpret_array(encoded, f64::from_bits))
221            }
222            (encoded, dtype) => {
223                return Err(ReinterpretCodecError::InvalidReinterpret {
224                    decode_dtype: dtype,
225                    encode_dtype: encoded.dtype(),
226                });
227            }
228        };
229
230        Ok(decoded)
231    }
232
233    fn decode_into(
234        &self,
235        encoded: AnyArrayView,
236        mut decoded: AnyArrayViewMut,
237    ) -> Result<(), Self::Error> {
238        if encoded.dtype() != self.encode_dtype {
239            return Err(ReinterpretCodecError::MismatchedDecodeDType {
240                configured: self.encode_dtype,
241                provided: encoded.dtype(),
242            });
243        }
244
245        match (encoded, self.decode_dtype) {
246            (encoded, dtype) if encoded.dtype() == dtype => Ok(decoded.assign(&encoded)?),
247            (AnyArrayView::U8(encoded), dtype) => {
248                if decoded.dtype() != dtype {
249                    return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
250                        source: AnyArrayAssignError::DTypeMismatch {
251                            src: dtype,
252                            dst: decoded.dtype(),
253                        },
254                    });
255                }
256
257                let mut shape = encoded.shape().to_vec();
258
259                if (encoded.len() % dtype.size()) != 0 {
260                    return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
261                }
262
263                if let Some(last) = shape.last_mut() {
264                    *last /= dtype.size();
265                }
266
267                if decoded.shape() != shape {
268                    return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
269                        source: AnyArrayAssignError::ShapeMismatch {
270                            src: shape,
271                            dst: decoded.shape().to_vec(),
272                        },
273                    });
274                }
275
276                let () = decoded.with_bytes_mut(|bytes| {
277                    bytes.copy_from_slice(&AnyArrayView::U8(encoded).as_bytes());
278                });
279
280                Ok(())
281            }
282            (AnyArrayView::U16(encoded), AnyArrayDType::I16) => {
283                reinterpret_array_into(encoded, |x| i16::from_ne_bytes(x.to_ne_bytes()), decoded)
284            }
285            (AnyArrayView::U32(encoded), AnyArrayDType::I32) => {
286                reinterpret_array_into(encoded, |x| i32::from_ne_bytes(x.to_ne_bytes()), decoded)
287            }
288            (AnyArrayView::U32(encoded), AnyArrayDType::F32) => {
289                reinterpret_array_into(encoded, f32::from_bits, decoded)
290            }
291            (AnyArrayView::U64(encoded), AnyArrayDType::U64) => {
292                reinterpret_array_into(encoded, |x| i64::from_ne_bytes(x.to_ne_bytes()), decoded)
293            }
294            (AnyArrayView::U64(encoded), AnyArrayDType::F64) => {
295                reinterpret_array_into(encoded, f64::from_bits, decoded)
296            }
297            (encoded, dtype) => Err(ReinterpretCodecError::InvalidReinterpret {
298                decode_dtype: dtype,
299                encode_dtype: encoded.dtype(),
300            }),
301        }?;
302
303        Ok(())
304    }
305}
306
307impl StaticCodec for ReinterpretCodec {
308    const CODEC_ID: &'static str = "reinterpret.rs";
309
310    type Config<'de> = Self;
311
312    fn from_config(config: Self::Config<'_>) -> Self {
313        config
314    }
315
316    fn get_config(&self) -> StaticCodecConfig<Self> {
317        StaticCodecConfig::from(self)
318    }
319}
320
321impl Serialize for ReinterpretCodec {
322    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
323        ReinterpretCodecConfig {
324            encode_dtype: self.encode_dtype,
325            decode_dtype: self.decode_dtype,
326        }
327        .serialize(serializer)
328    }
329}
330
331impl<'de> Deserialize<'de> for ReinterpretCodec {
332    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
333        let config = ReinterpretCodecConfig::deserialize(deserializer)?;
334
335        Self::try_new(config.encode_dtype, config.decode_dtype).map_err(serde::de::Error::custom)
336    }
337}
338
339#[derive(Clone, Serialize, Deserialize)]
340#[serde(rename = "ReinterpretCodec")]
341struct ReinterpretCodecConfig {
342    encode_dtype: AnyArrayDType,
343    decode_dtype: AnyArrayDType,
344}
345
346#[derive(Debug, Error)]
347/// Errors that may occur when applying the [`ReinterpretCodec`].
348pub enum ReinterpretCodecError {
349    /// [`ReinterpretCodec`] cannot cannot bitcast the `decode_dtype` as
350    /// `encode_dtype`
351    #[error("Reinterpret cannot bitcast {decode_dtype} as {encode_dtype}")]
352    InvalidReinterpret {
353        /// Dtype of the configured `decode_dtype`
354        decode_dtype: AnyArrayDType,
355        /// Dtype of the configured `encode_dtype`
356        encode_dtype: AnyArrayDType,
357    },
358    /// [`ReinterpretCodec`] cannot encode the provided dtype which differs
359    /// from the configured dtype
360    #[error("Reinterpret cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
361    MismatchedEncodeDType {
362        /// Dtype of the `configured` `decode_dtype`
363        configured: AnyArrayDType,
364        /// Dtype of the `provided` array from which the data is to be encoded
365        provided: AnyArrayDType,
366    },
367    /// [`ReinterpretCodec`] cannot decode the provided dtype which differs
368    /// from the configured dtype
369    #[error("Reinterpret cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
370    MismatchedDecodeDType {
371        /// Dtype of the `configured` `encode_dtype`
372        configured: AnyArrayDType,
373        /// Dtype of the `provided` array from which the data is to be decoded
374        provided: AnyArrayDType,
375    },
376    /// [`ReinterpretCodec`] cannot decode a byte array with `shape` into an array of `dtype`s
377    #[error(
378        "Reinterpret cannot decode a byte array of shape {shape:?} into an array of {dtype}-s"
379    )]
380    InvalidEncodedShape {
381        /// Shape of the encoded array
382        shape: Vec<usize>,
383        /// Dtype of the array into which the encoded data is to be decoded
384        dtype: AnyArrayDType,
385    },
386    /// [`ReinterpretCodec`] cannot decode into the provided array
387    #[error("Reinterpret cannot decode into the provided array")]
388    MismatchedDecodeIntoArray {
389        /// The source of the error
390        #[from]
391        source: AnyArrayAssignError,
392    },
393}
394
395/// Reinterpret the data elements of the `array` using the provided `reinterpret`
396/// closure. The shape of the data is preserved.
397#[inline]
398pub fn reinterpret_array<T: Copy, U, S: Data<Elem = T>, D: Dimension>(
399    array: ArrayBase<S, D>,
400    reinterpret: impl Fn(T) -> U,
401) -> Array<U, D> {
402    let array = array.into_owned();
403    let (shape, data) = (array.raw_dim(), array.into_raw_vec_and_offset().0);
404
405    let data = data.into_iter().map(reinterpret).collect();
406
407    #[expect(unsafe_code)]
408    // Safety: we have preserved the shape, which comes from a valid array
409    let array = unsafe { Array::from_shape_vec_unchecked(shape, data) };
410
411    array
412}
413
414#[expect(clippy::needless_pass_by_value)]
415/// Reinterpret the data elements of the `encoded` array using the provided
416/// `reinterpret` closure into the `decoded` array.
417///
418/// # Errors
419///
420/// Errors with
421/// - [`ReinterpretCodecError::MismatchedDecodeIntoArray`] if `decoded` does not
422///   contain an array with elements of type `U` or its shape does not match the
423///   `encoded` array's shape
424#[inline]
425pub fn reinterpret_array_into<'a, T: Copy, U: ArrayDType, D: Dimension>(
426    encoded: ArrayView<T, D>,
427    reinterpret: impl Fn(T) -> U,
428    mut decoded: AnyArrayViewMut<'a>,
429) -> Result<(), ReinterpretCodecError>
430where
431    U::RawData<ViewRepr<&'a mut ()>>: DataMut,
432{
433    let Some(decoded) = decoded.as_typed_mut::<U>() else {
434        return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
435            source: AnyArrayAssignError::DTypeMismatch {
436                src: U::DTYPE,
437                dst: decoded.dtype(),
438            },
439        });
440    };
441
442    if encoded.shape() != decoded.shape() {
443        return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
444            source: AnyArrayAssignError::ShapeMismatch {
445                src: encoded.shape().to_vec(),
446                dst: decoded.shape().to_vec(),
447            },
448        });
449    }
450
451    // iterate over the elements in standard order
452    for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
453        *d = reinterpret(*e);
454    }
455
456    Ok(())
457}