numcodecs_reinterpret/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.85.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-reinterpret
10//! [crates.io]: https://crates.io/crates/numcodecs-reinterpret
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-reinterpret
13//! [docs.rs]: https://docs.rs/numcodecs-reinterpret/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_reinterpret
17//!
18//! Binary reinterpret codec implementation for the [`numcodecs`] API.
19
20use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension, ViewRepr};
21use numcodecs::{
22    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23    ArrayDType, Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use thiserror::Error;
28
29#[derive(Clone, JsonSchema)]
30#[serde(deny_unknown_fields)]
31/// Codec to reinterpret data between different compatible types.
32///
33/// Note that no conversion happens, only the meaning of the bits changes.
34///
35/// Reinterpreting to bytes, or to a same-sized unsigned integer type, or
36/// without the changing the dtype are supported.
37pub struct ReinterpretCodec {
38    /// Dtype of the encoded data.
39    encode_dtype: AnyArrayDType,
40    /// Dtype of the decoded data
41    decode_dtype: AnyArrayDType,
42    /// The codec's encoding format version. Do not provide this parameter explicitly.
43    #[serde(default)]
44    _version: StaticCodecVersion<1, 0, 0>,
45}
46
47impl ReinterpretCodec {
48    /// Try to create a [`ReinterpretCodec`] that reinterprets the input data
49    /// from `decode_dtype` to `encode_dtype` on encoding, and from
50    /// `encode_dtype` back to `decode_dtype` on decoding.
51    ///
52    /// # Errors
53    ///
54    /// Errors with [`ReinterpretCodecError::InvalidReinterpret`] if
55    /// `encode_dtype` and `decode_dtype` are incompatible.
56    pub fn try_new(
57        encode_dtype: AnyArrayDType,
58        decode_dtype: AnyArrayDType,
59    ) -> Result<Self, ReinterpretCodecError> {
60        #[expect(clippy::match_same_arms)]
61        match (decode_dtype, encode_dtype) {
62            // performing no conversion always works
63            (ty_a, ty_b) if ty_a == ty_b => (),
64            // converting to bytes always works
65            (_, AnyArrayDType::U8) => (),
66            // converting from signed / floating to same-size binary always works
67            (AnyArrayDType::I16, AnyArrayDType::U16)
68            | (AnyArrayDType::I32 | AnyArrayDType::F32, AnyArrayDType::U32)
69            | (AnyArrayDType::I64 | AnyArrayDType::F64, AnyArrayDType::U64) => (),
70            (decode_dtype, encode_dtype) => {
71                return Err(ReinterpretCodecError::InvalidReinterpret {
72                    decode_dtype,
73                    encode_dtype,
74                });
75            }
76        }
77
78        Ok(Self {
79            encode_dtype,
80            decode_dtype,
81            _version: StaticCodecVersion,
82        })
83    }
84
85    #[must_use]
86    /// Create a [`ReinterpretCodec`] that does not change the `dtype`.
87    pub const fn passthrough(dtype: AnyArrayDType) -> Self {
88        Self {
89            encode_dtype: dtype,
90            decode_dtype: dtype,
91            _version: StaticCodecVersion,
92        }
93    }
94
95    #[must_use]
96    /// Create a [`ReinterpretCodec`] that reinterprets `dtype` as
97    /// [bytes][`AnyArrayDType::U8`].
98    pub const fn to_bytes(dtype: AnyArrayDType) -> Self {
99        Self {
100            encode_dtype: AnyArrayDType::U8,
101            decode_dtype: dtype,
102            _version: StaticCodecVersion,
103        }
104    }
105
106    #[must_use]
107    /// Create a  [`ReinterpretCodec`] that reinterprets `dtype` as its
108    /// [binary][`AnyArrayDType::to_binary`] equivalent.
109    pub const fn to_binary(dtype: AnyArrayDType) -> Self {
110        Self {
111            encode_dtype: dtype.to_binary(),
112            decode_dtype: dtype,
113            _version: StaticCodecVersion,
114        }
115    }
116}
117
118impl Codec for ReinterpretCodec {
119    type Error = ReinterpretCodecError;
120
121    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
122        if data.dtype() != self.decode_dtype {
123            return Err(ReinterpretCodecError::MismatchedEncodeDType {
124                configured: self.decode_dtype,
125                provided: data.dtype(),
126            });
127        }
128
129        let encoded = match (data, self.encode_dtype) {
130            (data, dtype) if data.dtype() == dtype => data.into_owned(),
131            (data, AnyArrayDType::U8) => {
132                let mut shape = data.shape().to_vec();
133                if let Some(last) = shape.last_mut() {
134                    *last *= data.dtype().size();
135                }
136                #[expect(unsafe_code)]
137                // Safety: the shape is extended to match the expansion into bytes
138                let encoded =
139                    unsafe { Array::from_shape_vec_unchecked(shape, data.as_bytes().into_owned()) };
140                AnyArray::U8(encoded)
141            }
142            (AnyCowArray::I16(data), AnyArrayDType::U16) => {
143                AnyArray::U16(reinterpret_array(data, |x| {
144                    u16::from_ne_bytes(x.to_ne_bytes())
145                }))
146            }
147            (AnyCowArray::I32(data), AnyArrayDType::U32) => {
148                AnyArray::U32(reinterpret_array(data, |x| {
149                    u32::from_ne_bytes(x.to_ne_bytes())
150                }))
151            }
152            (AnyCowArray::F32(data), AnyArrayDType::U32) => {
153                AnyArray::U32(reinterpret_array(data, f32::to_bits))
154            }
155            (AnyCowArray::I64(data), AnyArrayDType::U64) => {
156                AnyArray::U64(reinterpret_array(data, |x| {
157                    u64::from_ne_bytes(x.to_ne_bytes())
158                }))
159            }
160            (AnyCowArray::F64(data), AnyArrayDType::U64) => {
161                AnyArray::U64(reinterpret_array(data, f64::to_bits))
162            }
163            (data, dtype) => {
164                return Err(ReinterpretCodecError::InvalidReinterpret {
165                    decode_dtype: data.dtype(),
166                    encode_dtype: dtype,
167                });
168            }
169        };
170
171        Ok(encoded)
172    }
173
174    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
175        if encoded.dtype() != self.encode_dtype {
176            return Err(ReinterpretCodecError::MismatchedDecodeDType {
177                configured: self.encode_dtype,
178                provided: encoded.dtype(),
179            });
180        }
181
182        let decoded = match (encoded, self.decode_dtype) {
183            (encoded, dtype) if encoded.dtype() == dtype => encoded.into_owned(),
184            (AnyCowArray::U8(encoded), dtype) => {
185                let mut shape = encoded.shape().to_vec();
186
187                if (encoded.len() % dtype.size()) != 0 {
188                    return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
189                }
190
191                if let Some(last) = shape.last_mut() {
192                    *last /= dtype.size();
193                }
194
195                let (decoded, ()) = AnyArray::with_zeros_bytes(dtype, &shape, |bytes| {
196                    bytes.copy_from_slice(&AnyCowArray::U8(encoded).as_bytes());
197                });
198
199                decoded
200            }
201            (AnyCowArray::U16(encoded), AnyArrayDType::I16) => {
202                AnyArray::I16(reinterpret_array(encoded, |x| {
203                    i16::from_ne_bytes(x.to_ne_bytes())
204                }))
205            }
206            (AnyCowArray::U32(encoded), AnyArrayDType::I32) => {
207                AnyArray::I32(reinterpret_array(encoded, |x| {
208                    i32::from_ne_bytes(x.to_ne_bytes())
209                }))
210            }
211            (AnyCowArray::U32(encoded), AnyArrayDType::F32) => {
212                AnyArray::F32(reinterpret_array(encoded, f32::from_bits))
213            }
214            (AnyCowArray::U64(encoded), AnyArrayDType::U64) => {
215                AnyArray::I64(reinterpret_array(encoded, |x| {
216                    i64::from_ne_bytes(x.to_ne_bytes())
217                }))
218            }
219            (AnyCowArray::U64(encoded), AnyArrayDType::F64) => {
220                AnyArray::F64(reinterpret_array(encoded, f64::from_bits))
221            }
222            (encoded, dtype) => {
223                return Err(ReinterpretCodecError::InvalidReinterpret {
224                    decode_dtype: dtype,
225                    encode_dtype: encoded.dtype(),
226                });
227            }
228        };
229
230        Ok(decoded)
231    }
232
233    fn decode_into(
234        &self,
235        encoded: AnyArrayView,
236        mut decoded: AnyArrayViewMut,
237    ) -> Result<(), Self::Error> {
238        if encoded.dtype() != self.encode_dtype {
239            return Err(ReinterpretCodecError::MismatchedDecodeDType {
240                configured: self.encode_dtype,
241                provided: encoded.dtype(),
242            });
243        }
244
245        match (encoded, self.decode_dtype) {
246            (encoded, dtype) if encoded.dtype() == dtype => Ok(decoded.assign(&encoded)?),
247            (AnyArrayView::U8(encoded), dtype) => {
248                if decoded.dtype() != dtype {
249                    return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
250                        source: AnyArrayAssignError::DTypeMismatch {
251                            src: dtype,
252                            dst: decoded.dtype(),
253                        },
254                    });
255                }
256
257                let mut shape = encoded.shape().to_vec();
258
259                if (encoded.len() % dtype.size()) != 0 {
260                    return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
261                }
262
263                if let Some(last) = shape.last_mut() {
264                    *last /= dtype.size();
265                }
266
267                if decoded.shape() != shape {
268                    return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
269                        source: AnyArrayAssignError::ShapeMismatch {
270                            src: shape,
271                            dst: decoded.shape().to_vec(),
272                        },
273                    });
274                }
275
276                let () = decoded.with_bytes_mut(|bytes| {
277                    bytes.copy_from_slice(&AnyArrayView::U8(encoded).as_bytes());
278                });
279
280                Ok(())
281            }
282            (AnyArrayView::U16(encoded), AnyArrayDType::I16) => {
283                reinterpret_array_into(encoded, |x| i16::from_ne_bytes(x.to_ne_bytes()), decoded)
284            }
285            (AnyArrayView::U32(encoded), AnyArrayDType::I32) => {
286                reinterpret_array_into(encoded, |x| i32::from_ne_bytes(x.to_ne_bytes()), decoded)
287            }
288            (AnyArrayView::U32(encoded), AnyArrayDType::F32) => {
289                reinterpret_array_into(encoded, f32::from_bits, decoded)
290            }
291            (AnyArrayView::U64(encoded), AnyArrayDType::U64) => {
292                reinterpret_array_into(encoded, |x| i64::from_ne_bytes(x.to_ne_bytes()), decoded)
293            }
294            (AnyArrayView::U64(encoded), AnyArrayDType::F64) => {
295                reinterpret_array_into(encoded, f64::from_bits, decoded)
296            }
297            (encoded, dtype) => Err(ReinterpretCodecError::InvalidReinterpret {
298                decode_dtype: dtype,
299                encode_dtype: encoded.dtype(),
300            }),
301        }?;
302
303        Ok(())
304    }
305}
306
307impl StaticCodec for ReinterpretCodec {
308    const CODEC_ID: &'static str = "reinterpret.rs";
309
310    type Config<'de> = Self;
311
312    fn from_config(config: Self::Config<'_>) -> Self {
313        config
314    }
315
316    fn get_config(&self) -> StaticCodecConfig<Self> {
317        StaticCodecConfig::from(self)
318    }
319}
320
321impl Serialize for ReinterpretCodec {
322    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
323        ReinterpretCodecConfig {
324            encode_dtype: self.encode_dtype,
325            decode_dtype: self.decode_dtype,
326        }
327        .serialize(serializer)
328    }
329}
330
331impl<'de> Deserialize<'de> for ReinterpretCodec {
332    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
333        let config = ReinterpretCodecConfig::deserialize(deserializer)?;
334
335        Self::try_new(config.encode_dtype, config.decode_dtype).map_err(serde::de::Error::custom)
336    }
337}
338
339#[derive(Clone, Serialize, Deserialize)]
340#[serde(rename = "ReinterpretCodec")]
341struct ReinterpretCodecConfig {
342    encode_dtype: AnyArrayDType,
343    decode_dtype: AnyArrayDType,
344}
345
346#[derive(Debug, Error)]
347/// Errors that may occur when applying the [`ReinterpretCodec`].
348pub enum ReinterpretCodecError {
349    /// [`ReinterpretCodec`] cannot cannot bitcast the `decode_dtype` as
350    /// `encode_dtype`
351    #[error("Reinterpret cannot bitcast {decode_dtype} as {encode_dtype}")]
352    InvalidReinterpret {
353        /// Dtype of the configured `decode_dtype`
354        decode_dtype: AnyArrayDType,
355        /// Dtype of the configured `encode_dtype`
356        encode_dtype: AnyArrayDType,
357    },
358    /// [`ReinterpretCodec`] cannot encode the provided dtype which differs
359    /// from the configured dtype
360    #[error(
361        "Reinterpret cannot encode the provided dtype {provided} which differs from the configured dtype {configured}"
362    )]
363    MismatchedEncodeDType {
364        /// Dtype of the `configured` `decode_dtype`
365        configured: AnyArrayDType,
366        /// Dtype of the `provided` array from which the data is to be encoded
367        provided: AnyArrayDType,
368    },
369    /// [`ReinterpretCodec`] cannot decode the provided dtype which differs
370    /// from the configured dtype
371    #[error(
372        "Reinterpret cannot decode the provided dtype {provided} which differs from the configured dtype {configured}"
373    )]
374    MismatchedDecodeDType {
375        /// Dtype of the `configured` `encode_dtype`
376        configured: AnyArrayDType,
377        /// Dtype of the `provided` array from which the data is to be decoded
378        provided: AnyArrayDType,
379    },
380    /// [`ReinterpretCodec`] cannot decode a byte array with `shape` into an array of `dtype`s
381    #[error("Reinterpret cannot decode a byte array of shape {shape:?} into an array of {dtype}-s")]
382    InvalidEncodedShape {
383        /// Shape of the encoded array
384        shape: Vec<usize>,
385        /// Dtype of the array into which the encoded data is to be decoded
386        dtype: AnyArrayDType,
387    },
388    /// [`ReinterpretCodec`] cannot decode into the provided array
389    #[error("Reinterpret cannot decode into the provided array")]
390    MismatchedDecodeIntoArray {
391        /// The source of the error
392        #[from]
393        source: AnyArrayAssignError,
394    },
395}
396
397/// Reinterpret the data elements of the `array` using the provided `reinterpret`
398/// closure. The shape of the data is preserved.
399#[inline]
400pub fn reinterpret_array<T: Copy, U, S: Data<Elem = T>, D: Dimension>(
401    array: ArrayBase<S, D>,
402    reinterpret: impl Fn(T) -> U,
403) -> Array<U, D> {
404    let array = array.into_owned();
405    let (shape, data) = (array.raw_dim(), array.into_raw_vec_and_offset().0);
406
407    let data = data.into_iter().map(reinterpret).collect();
408
409    #[expect(unsafe_code)]
410    // Safety: we have preserved the shape, which comes from a valid array
411    let array = unsafe { Array::from_shape_vec_unchecked(shape, data) };
412
413    array
414}
415
416#[expect(clippy::needless_pass_by_value)]
417/// Reinterpret the data elements of the `encoded` array using the provided
418/// `reinterpret` closure into the `decoded` array.
419///
420/// # Errors
421///
422/// Errors with
423/// - [`ReinterpretCodecError::MismatchedDecodeIntoArray`] if `decoded` does not
424///   contain an array with elements of type `U` or its shape does not match the
425///   `encoded` array's shape
426#[inline]
427pub fn reinterpret_array_into<'a, T: Copy, U: ArrayDType, D: Dimension>(
428    encoded: ArrayView<T, D>,
429    reinterpret: impl Fn(T) -> U,
430    mut decoded: AnyArrayViewMut<'a>,
431) -> Result<(), ReinterpretCodecError>
432where
433    U::RawData<ViewRepr<&'a mut ()>>: DataMut,
434{
435    let Some(decoded) = decoded.as_typed_mut::<U>() else {
436        return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
437            source: AnyArrayAssignError::DTypeMismatch {
438                src: U::DTYPE,
439                dst: decoded.dtype(),
440            },
441        });
442    };
443
444    if encoded.shape() != decoded.shape() {
445        return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
446            source: AnyArrayAssignError::ShapeMismatch {
447                src: encoded.shape().to_vec(),
448                dst: decoded.shape().to_vec(),
449            },
450        });
451    }
452
453    // iterate over the elements in standard order
454    for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
455        *d = reinterpret(*e);
456    }
457
458    Ok(())
459}