numcodecs_reinterpret/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-reinterpret
10//! [crates.io]: https://crates.io/crates/numcodecs-reinterpret
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-reinterpret
13//! [docs.rs]: https://docs.rs/numcodecs-reinterpret/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_reinterpret
17//!
18//! Binary reinterpret codec implementation for the [`numcodecs`] API.
19
20use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension, ViewRepr};
21use numcodecs::{
22    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23    ArrayDType, Codec, StaticCodec, StaticCodecConfig,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use thiserror::Error;
28
29#[derive(Clone, JsonSchema)]
30#[serde(deny_unknown_fields)]
31/// Codec to reinterpret data between different compatible types.
32///
33/// Note that no conversion happens, only the meaning of the bits changes.
34///
35/// Reinterpreting to bytes, or to a same-sized unsigned integer type, or
36/// without the changing the dtype are supported.
37pub struct ReinterpretCodec {
38    /// Dtype of the encoded data.
39    encode_dtype: AnyArrayDType,
40    /// Dtype of the decoded data
41    decode_dtype: AnyArrayDType,
42}
43
44impl ReinterpretCodec {
45    /// Try to create a [`ReinterpretCodec`] that reinterprets the input data
46    /// from `decode_dtype` to `encode_dtype` on encoding, and from
47    /// `encode_dtype` back to `decode_dtype` on decoding.
48    ///
49    /// # Errors
50    ///
51    /// Errors with [`ReinterpretCodecError::InvalidReinterpret`] if
52    /// `encode_dtype` and `decode_dtype` are incompatible.
53    pub fn try_new(
54        encode_dtype: AnyArrayDType,
55        decode_dtype: AnyArrayDType,
56    ) -> Result<Self, ReinterpretCodecError> {
57        #[expect(clippy::match_same_arms)]
58        match (decode_dtype, encode_dtype) {
59            // performing no conversion always works
60            (ty_a, ty_b) if ty_a == ty_b => (),
61            // converting to bytes always works
62            (_, AnyArrayDType::U8) => (),
63            // converting from signed / floating to same-size binary always works
64            (AnyArrayDType::I16, AnyArrayDType::U16)
65            | (AnyArrayDType::I32 | AnyArrayDType::F32, AnyArrayDType::U32)
66            | (AnyArrayDType::I64 | AnyArrayDType::F64, AnyArrayDType::U64) => (),
67            (decode_dtype, encode_dtype) => {
68                return Err(ReinterpretCodecError::InvalidReinterpret {
69                    decode_dtype,
70                    encode_dtype,
71                })
72            }
73        };
74
75        Ok(Self {
76            encode_dtype,
77            decode_dtype,
78        })
79    }
80
81    #[must_use]
82    /// Create a [`ReinterpretCodec`] that does not change the `dtype`.
83    pub const fn passthrough(dtype: AnyArrayDType) -> Self {
84        Self {
85            encode_dtype: dtype,
86            decode_dtype: dtype,
87        }
88    }
89
90    #[must_use]
91    /// Create a [`ReinterpretCodec`] that reinterprets `dtype` as
92    /// [bytes][`AnyArrayDType::U8`].
93    pub const fn to_bytes(dtype: AnyArrayDType) -> Self {
94        Self {
95            encode_dtype: AnyArrayDType::U8,
96            decode_dtype: dtype,
97        }
98    }
99
100    #[must_use]
101    /// Create a  [`ReinterpretCodec`] that reinterprets `dtype` as its
102    /// [binary][`AnyArrayDType::to_binary`] equivalent.
103    pub const fn to_binary(dtype: AnyArrayDType) -> Self {
104        Self {
105            encode_dtype: dtype.to_binary(),
106            decode_dtype: dtype,
107        }
108    }
109}
110
111impl Codec for ReinterpretCodec {
112    type Error = ReinterpretCodecError;
113
114    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
115        if data.dtype() != self.decode_dtype {
116            return Err(ReinterpretCodecError::MismatchedEncodeDType {
117                configured: self.decode_dtype,
118                provided: data.dtype(),
119            });
120        }
121
122        let encoded = match (data, self.encode_dtype) {
123            (data, dtype) if data.dtype() == dtype => data.into_owned(),
124            (data, AnyArrayDType::U8) => {
125                let mut shape = data.shape().to_vec();
126                if let Some(last) = shape.last_mut() {
127                    *last *= data.dtype().size();
128                }
129                #[expect(unsafe_code)]
130                // Safety: the shape is extended to match the expansion into bytes
131                let encoded =
132                    unsafe { Array::from_shape_vec_unchecked(shape, data.as_bytes().into_owned()) };
133                AnyArray::U8(encoded)
134            }
135            (AnyCowArray::I16(data), AnyArrayDType::U16) => {
136                AnyArray::U16(reinterpret_array(data, |x| {
137                    u16::from_ne_bytes(x.to_ne_bytes())
138                }))
139            }
140            (AnyCowArray::I32(data), AnyArrayDType::U32) => {
141                AnyArray::U32(reinterpret_array(data, |x| {
142                    u32::from_ne_bytes(x.to_ne_bytes())
143                }))
144            }
145            (AnyCowArray::F32(data), AnyArrayDType::U32) => {
146                AnyArray::U32(reinterpret_array(data, f32::to_bits))
147            }
148            (AnyCowArray::I64(data), AnyArrayDType::U64) => {
149                AnyArray::U64(reinterpret_array(data, |x| {
150                    u64::from_ne_bytes(x.to_ne_bytes())
151                }))
152            }
153            (AnyCowArray::F64(data), AnyArrayDType::U64) => {
154                AnyArray::U64(reinterpret_array(data, f64::to_bits))
155            }
156            (data, dtype) => {
157                return Err(ReinterpretCodecError::InvalidReinterpret {
158                    decode_dtype: data.dtype(),
159                    encode_dtype: dtype,
160                });
161            }
162        };
163
164        Ok(encoded)
165    }
166
167    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
168        if encoded.dtype() != self.encode_dtype {
169            return Err(ReinterpretCodecError::MismatchedDecodeDType {
170                configured: self.encode_dtype,
171                provided: encoded.dtype(),
172            });
173        }
174
175        let decoded = match (encoded, self.decode_dtype) {
176            (encoded, dtype) if encoded.dtype() == dtype => encoded.into_owned(),
177            (AnyCowArray::U8(encoded), dtype) => {
178                let mut shape = encoded.shape().to_vec();
179
180                if (encoded.len() % dtype.size()) != 0 {
181                    return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
182                }
183
184                if let Some(last) = shape.last_mut() {
185                    *last /= dtype.size();
186                }
187
188                let (decoded, ()) = AnyArray::with_zeros_bytes(dtype, &shape, |bytes| {
189                    bytes.copy_from_slice(&AnyCowArray::U8(encoded).as_bytes());
190                });
191
192                decoded
193            }
194            (AnyCowArray::U16(encoded), AnyArrayDType::I16) => {
195                AnyArray::I16(reinterpret_array(encoded, |x| {
196                    i16::from_ne_bytes(x.to_ne_bytes())
197                }))
198            }
199            (AnyCowArray::U32(encoded), AnyArrayDType::I32) => {
200                AnyArray::I32(reinterpret_array(encoded, |x| {
201                    i32::from_ne_bytes(x.to_ne_bytes())
202                }))
203            }
204            (AnyCowArray::U32(encoded), AnyArrayDType::F32) => {
205                AnyArray::F32(reinterpret_array(encoded, f32::from_bits))
206            }
207            (AnyCowArray::U64(encoded), AnyArrayDType::U64) => {
208                AnyArray::I64(reinterpret_array(encoded, |x| {
209                    i64::from_ne_bytes(x.to_ne_bytes())
210                }))
211            }
212            (AnyCowArray::U64(encoded), AnyArrayDType::F64) => {
213                AnyArray::F64(reinterpret_array(encoded, f64::from_bits))
214            }
215            (encoded, dtype) => {
216                return Err(ReinterpretCodecError::InvalidReinterpret {
217                    decode_dtype: dtype,
218                    encode_dtype: encoded.dtype(),
219                });
220            }
221        };
222
223        Ok(decoded)
224    }
225
226    fn decode_into(
227        &self,
228        encoded: AnyArrayView,
229        mut decoded: AnyArrayViewMut,
230    ) -> Result<(), Self::Error> {
231        if encoded.dtype() != self.encode_dtype {
232            return Err(ReinterpretCodecError::MismatchedDecodeDType {
233                configured: self.encode_dtype,
234                provided: encoded.dtype(),
235            });
236        }
237
238        match (encoded, self.decode_dtype) {
239            (encoded, dtype) if encoded.dtype() == dtype => Ok(decoded.assign(&encoded)?),
240            (AnyArrayView::U8(encoded), dtype) => {
241                if decoded.dtype() != dtype {
242                    return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
243                        source: AnyArrayAssignError::DTypeMismatch {
244                            src: dtype,
245                            dst: decoded.dtype(),
246                        },
247                    });
248                }
249
250                let mut shape = encoded.shape().to_vec();
251
252                if (encoded.len() % dtype.size()) != 0 {
253                    return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
254                }
255
256                if let Some(last) = shape.last_mut() {
257                    *last /= dtype.size();
258                }
259
260                if decoded.shape() != shape {
261                    return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
262                        source: AnyArrayAssignError::ShapeMismatch {
263                            src: shape,
264                            dst: decoded.shape().to_vec(),
265                        },
266                    });
267                }
268
269                let () = decoded.with_bytes_mut(|bytes| {
270                    bytes.copy_from_slice(&AnyArrayView::U8(encoded).as_bytes());
271                });
272
273                Ok(())
274            }
275            (AnyArrayView::U16(encoded), AnyArrayDType::I16) => {
276                reinterpret_array_into(encoded, |x| i16::from_ne_bytes(x.to_ne_bytes()), decoded)
277            }
278            (AnyArrayView::U32(encoded), AnyArrayDType::I32) => {
279                reinterpret_array_into(encoded, |x| i32::from_ne_bytes(x.to_ne_bytes()), decoded)
280            }
281            (AnyArrayView::U32(encoded), AnyArrayDType::F32) => {
282                reinterpret_array_into(encoded, f32::from_bits, decoded)
283            }
284            (AnyArrayView::U64(encoded), AnyArrayDType::U64) => {
285                reinterpret_array_into(encoded, |x| i64::from_ne_bytes(x.to_ne_bytes()), decoded)
286            }
287            (AnyArrayView::U64(encoded), AnyArrayDType::F64) => {
288                reinterpret_array_into(encoded, f64::from_bits, decoded)
289            }
290            (encoded, dtype) => Err(ReinterpretCodecError::InvalidReinterpret {
291                decode_dtype: dtype,
292                encode_dtype: encoded.dtype(),
293            }),
294        }?;
295
296        Ok(())
297    }
298}
299
300impl StaticCodec for ReinterpretCodec {
301    const CODEC_ID: &'static str = "reinterpret";
302
303    type Config<'de> = Self;
304
305    fn from_config(config: Self::Config<'_>) -> Self {
306        config
307    }
308
309    fn get_config(&self) -> StaticCodecConfig<Self> {
310        StaticCodecConfig::from(self)
311    }
312}
313
314impl Serialize for ReinterpretCodec {
315    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
316        ReinterpretCodecConfig {
317            encode_dtype: self.encode_dtype,
318            decode_dtype: self.decode_dtype,
319        }
320        .serialize(serializer)
321    }
322}
323
324impl<'de> Deserialize<'de> for ReinterpretCodec {
325    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
326        let config = ReinterpretCodecConfig::deserialize(deserializer)?;
327
328        Self::try_new(config.encode_dtype, config.decode_dtype).map_err(serde::de::Error::custom)
329    }
330}
331
332#[derive(Clone, Serialize, Deserialize)]
333#[serde(rename = "ReinterpretCodec")]
334struct ReinterpretCodecConfig {
335    encode_dtype: AnyArrayDType,
336    decode_dtype: AnyArrayDType,
337}
338
339#[derive(Debug, Error)]
340/// Errors that may occur when applying the [`ReinterpretCodec`].
341pub enum ReinterpretCodecError {
342    /// [`ReinterpretCodec`] cannot cannot bitcast the `decode_dtype` as
343    /// `encode_dtype`
344    #[error("Reinterpret cannot bitcast {decode_dtype} as {encode_dtype}")]
345    InvalidReinterpret {
346        /// Dtype of the configured `decode_dtype`
347        decode_dtype: AnyArrayDType,
348        /// Dtype of the configured `encode_dtype`
349        encode_dtype: AnyArrayDType,
350    },
351    /// [`ReinterpretCodec`] cannot encode the provided dtype which differs
352    /// from the configured dtype
353    #[error("Reinterpret cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
354    MismatchedEncodeDType {
355        /// Dtype of the `configured` `decode_dtype`
356        configured: AnyArrayDType,
357        /// Dtype of the `provided` array from which the data is to be encoded
358        provided: AnyArrayDType,
359    },
360    /// [`ReinterpretCodec`] cannot decode the provided dtype which differs
361    /// from the configured dtype
362    #[error("Reinterpret cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
363    MismatchedDecodeDType {
364        /// Dtype of the `configured` `encode_dtype`
365        configured: AnyArrayDType,
366        /// Dtype of the `provided` array from which the data is to be decoded
367        provided: AnyArrayDType,
368    },
369    /// [`ReinterpretCodec`] cannot decode a byte array with `shape` into an array of `dtype`s
370    #[error(
371        "Reinterpret cannot decode a byte array of shape {shape:?} into an array of {dtype}-s"
372    )]
373    InvalidEncodedShape {
374        /// Shape of the encoded array
375        shape: Vec<usize>,
376        /// Dtype of the array into which the encoded data is to be decoded
377        dtype: AnyArrayDType,
378    },
379    /// [`ReinterpretCodec`] cannot decode into the provided array
380    #[error("Reinterpret cannot decode into the provided array")]
381    MismatchedDecodeIntoArray {
382        /// The source of the error
383        #[from]
384        source: AnyArrayAssignError,
385    },
386}
387
388/// Reinterpret the data elements of the `array` using the provided `reinterpret`
389/// closure. The shape of the data is preserved.
390#[inline]
391pub fn reinterpret_array<T: Copy, U, S: Data<Elem = T>, D: Dimension>(
392    array: ArrayBase<S, D>,
393    reinterpret: impl Fn(T) -> U,
394) -> Array<U, D> {
395    let array = array.into_owned();
396    let (shape, data) = (array.raw_dim(), array.into_raw_vec_and_offset().0);
397
398    let data = data.into_iter().map(reinterpret).collect();
399
400    #[expect(unsafe_code)]
401    // Safety: we have preserved the shape, which comes from a valid array
402    let array = unsafe { Array::from_shape_vec_unchecked(shape, data) };
403
404    array
405}
406
407#[expect(clippy::needless_pass_by_value)]
408/// Reinterpret the data elements of the `encoded` array using the provided
409/// `reinterpret` closure into the `decoded` array.
410///
411/// # Errors
412///
413/// Errors with
414/// - [`ReinterpretCodecError::MismatchedDecodeIntoArray`] if `decoded` does not
415///   contain an array with elements of type `U` or its shape does not match the
416///   `encoded` array's shape
417#[inline]
418pub fn reinterpret_array_into<'a, T: Copy, U: ArrayDType, D: Dimension>(
419    encoded: ArrayView<T, D>,
420    reinterpret: impl Fn(T) -> U,
421    mut decoded: AnyArrayViewMut<'a>,
422) -> Result<(), ReinterpretCodecError>
423where
424    U::RawData<ViewRepr<&'a mut ()>>: DataMut,
425{
426    let Some(decoded) = decoded.as_typed_mut::<U>() else {
427        return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
428            source: AnyArrayAssignError::DTypeMismatch {
429                src: U::DTYPE,
430                dst: decoded.dtype(),
431            },
432        });
433    };
434
435    if encoded.shape() != decoded.shape() {
436        return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
437            source: AnyArrayAssignError::ShapeMismatch {
438                src: encoded.shape().to_vec(),
439                dst: decoded.shape().to_vec(),
440            },
441        });
442    }
443
444    // iterate over the elements in standard order
445    for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
446        *d = reinterpret(*e);
447    }
448
449    Ok(())
450}