1use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension, ViewRepr};
21use numcodecs::{
22 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23 ArrayDType, Codec, StaticCodec, StaticCodecConfig,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use thiserror::Error;
28
29#[derive(Clone, JsonSchema)]
30#[serde(deny_unknown_fields)]
31pub struct ReinterpretCodec {
38 encode_dtype: AnyArrayDType,
40 decode_dtype: AnyArrayDType,
42}
43
44impl ReinterpretCodec {
45 pub fn try_new(
54 encode_dtype: AnyArrayDType,
55 decode_dtype: AnyArrayDType,
56 ) -> Result<Self, ReinterpretCodecError> {
57 #[expect(clippy::match_same_arms)]
58 match (decode_dtype, encode_dtype) {
59 (ty_a, ty_b) if ty_a == ty_b => (),
61 (_, AnyArrayDType::U8) => (),
63 (AnyArrayDType::I16, AnyArrayDType::U16)
65 | (AnyArrayDType::I32 | AnyArrayDType::F32, AnyArrayDType::U32)
66 | (AnyArrayDType::I64 | AnyArrayDType::F64, AnyArrayDType::U64) => (),
67 (decode_dtype, encode_dtype) => {
68 return Err(ReinterpretCodecError::InvalidReinterpret {
69 decode_dtype,
70 encode_dtype,
71 })
72 }
73 };
74
75 Ok(Self {
76 encode_dtype,
77 decode_dtype,
78 })
79 }
80
81 #[must_use]
82 pub const fn passthrough(dtype: AnyArrayDType) -> Self {
84 Self {
85 encode_dtype: dtype,
86 decode_dtype: dtype,
87 }
88 }
89
90 #[must_use]
91 pub const fn to_bytes(dtype: AnyArrayDType) -> Self {
94 Self {
95 encode_dtype: AnyArrayDType::U8,
96 decode_dtype: dtype,
97 }
98 }
99
100 #[must_use]
101 pub const fn to_binary(dtype: AnyArrayDType) -> Self {
104 Self {
105 encode_dtype: dtype.to_binary(),
106 decode_dtype: dtype,
107 }
108 }
109}
110
111impl Codec for ReinterpretCodec {
112 type Error = ReinterpretCodecError;
113
114 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
115 if data.dtype() != self.decode_dtype {
116 return Err(ReinterpretCodecError::MismatchedEncodeDType {
117 configured: self.decode_dtype,
118 provided: data.dtype(),
119 });
120 }
121
122 let encoded = match (data, self.encode_dtype) {
123 (data, dtype) if data.dtype() == dtype => data.into_owned(),
124 (data, AnyArrayDType::U8) => {
125 let mut shape = data.shape().to_vec();
126 if let Some(last) = shape.last_mut() {
127 *last *= data.dtype().size();
128 }
129 #[expect(unsafe_code)]
130 let encoded =
132 unsafe { Array::from_shape_vec_unchecked(shape, data.as_bytes().into_owned()) };
133 AnyArray::U8(encoded)
134 }
135 (AnyCowArray::I16(data), AnyArrayDType::U16) => {
136 AnyArray::U16(reinterpret_array(data, |x| {
137 u16::from_ne_bytes(x.to_ne_bytes())
138 }))
139 }
140 (AnyCowArray::I32(data), AnyArrayDType::U32) => {
141 AnyArray::U32(reinterpret_array(data, |x| {
142 u32::from_ne_bytes(x.to_ne_bytes())
143 }))
144 }
145 (AnyCowArray::F32(data), AnyArrayDType::U32) => {
146 AnyArray::U32(reinterpret_array(data, f32::to_bits))
147 }
148 (AnyCowArray::I64(data), AnyArrayDType::U64) => {
149 AnyArray::U64(reinterpret_array(data, |x| {
150 u64::from_ne_bytes(x.to_ne_bytes())
151 }))
152 }
153 (AnyCowArray::F64(data), AnyArrayDType::U64) => {
154 AnyArray::U64(reinterpret_array(data, f64::to_bits))
155 }
156 (data, dtype) => {
157 return Err(ReinterpretCodecError::InvalidReinterpret {
158 decode_dtype: data.dtype(),
159 encode_dtype: dtype,
160 });
161 }
162 };
163
164 Ok(encoded)
165 }
166
167 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
168 if encoded.dtype() != self.encode_dtype {
169 return Err(ReinterpretCodecError::MismatchedDecodeDType {
170 configured: self.encode_dtype,
171 provided: encoded.dtype(),
172 });
173 }
174
175 let decoded = match (encoded, self.decode_dtype) {
176 (encoded, dtype) if encoded.dtype() == dtype => encoded.into_owned(),
177 (AnyCowArray::U8(encoded), dtype) => {
178 let mut shape = encoded.shape().to_vec();
179
180 if (encoded.len() % dtype.size()) != 0 {
181 return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
182 }
183
184 if let Some(last) = shape.last_mut() {
185 *last /= dtype.size();
186 }
187
188 let (decoded, ()) = AnyArray::with_zeros_bytes(dtype, &shape, |bytes| {
189 bytes.copy_from_slice(&AnyCowArray::U8(encoded).as_bytes());
190 });
191
192 decoded
193 }
194 (AnyCowArray::U16(encoded), AnyArrayDType::I16) => {
195 AnyArray::I16(reinterpret_array(encoded, |x| {
196 i16::from_ne_bytes(x.to_ne_bytes())
197 }))
198 }
199 (AnyCowArray::U32(encoded), AnyArrayDType::I32) => {
200 AnyArray::I32(reinterpret_array(encoded, |x| {
201 i32::from_ne_bytes(x.to_ne_bytes())
202 }))
203 }
204 (AnyCowArray::U32(encoded), AnyArrayDType::F32) => {
205 AnyArray::F32(reinterpret_array(encoded, f32::from_bits))
206 }
207 (AnyCowArray::U64(encoded), AnyArrayDType::U64) => {
208 AnyArray::I64(reinterpret_array(encoded, |x| {
209 i64::from_ne_bytes(x.to_ne_bytes())
210 }))
211 }
212 (AnyCowArray::U64(encoded), AnyArrayDType::F64) => {
213 AnyArray::F64(reinterpret_array(encoded, f64::from_bits))
214 }
215 (encoded, dtype) => {
216 return Err(ReinterpretCodecError::InvalidReinterpret {
217 decode_dtype: dtype,
218 encode_dtype: encoded.dtype(),
219 });
220 }
221 };
222
223 Ok(decoded)
224 }
225
226 fn decode_into(
227 &self,
228 encoded: AnyArrayView,
229 mut decoded: AnyArrayViewMut,
230 ) -> Result<(), Self::Error> {
231 if encoded.dtype() != self.encode_dtype {
232 return Err(ReinterpretCodecError::MismatchedDecodeDType {
233 configured: self.encode_dtype,
234 provided: encoded.dtype(),
235 });
236 }
237
238 match (encoded, self.decode_dtype) {
239 (encoded, dtype) if encoded.dtype() == dtype => Ok(decoded.assign(&encoded)?),
240 (AnyArrayView::U8(encoded), dtype) => {
241 if decoded.dtype() != dtype {
242 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
243 source: AnyArrayAssignError::DTypeMismatch {
244 src: dtype,
245 dst: decoded.dtype(),
246 },
247 });
248 }
249
250 let mut shape = encoded.shape().to_vec();
251
252 if (encoded.len() % dtype.size()) != 0 {
253 return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
254 }
255
256 if let Some(last) = shape.last_mut() {
257 *last /= dtype.size();
258 }
259
260 if decoded.shape() != shape {
261 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
262 source: AnyArrayAssignError::ShapeMismatch {
263 src: shape,
264 dst: decoded.shape().to_vec(),
265 },
266 });
267 }
268
269 let () = decoded.with_bytes_mut(|bytes| {
270 bytes.copy_from_slice(&AnyArrayView::U8(encoded).as_bytes());
271 });
272
273 Ok(())
274 }
275 (AnyArrayView::U16(encoded), AnyArrayDType::I16) => {
276 reinterpret_array_into(encoded, |x| i16::from_ne_bytes(x.to_ne_bytes()), decoded)
277 }
278 (AnyArrayView::U32(encoded), AnyArrayDType::I32) => {
279 reinterpret_array_into(encoded, |x| i32::from_ne_bytes(x.to_ne_bytes()), decoded)
280 }
281 (AnyArrayView::U32(encoded), AnyArrayDType::F32) => {
282 reinterpret_array_into(encoded, f32::from_bits, decoded)
283 }
284 (AnyArrayView::U64(encoded), AnyArrayDType::U64) => {
285 reinterpret_array_into(encoded, |x| i64::from_ne_bytes(x.to_ne_bytes()), decoded)
286 }
287 (AnyArrayView::U64(encoded), AnyArrayDType::F64) => {
288 reinterpret_array_into(encoded, f64::from_bits, decoded)
289 }
290 (encoded, dtype) => Err(ReinterpretCodecError::InvalidReinterpret {
291 decode_dtype: dtype,
292 encode_dtype: encoded.dtype(),
293 }),
294 }?;
295
296 Ok(())
297 }
298}
299
300impl StaticCodec for ReinterpretCodec {
301 const CODEC_ID: &'static str = "reinterpret";
302
303 type Config<'de> = Self;
304
305 fn from_config(config: Self::Config<'_>) -> Self {
306 config
307 }
308
309 fn get_config(&self) -> StaticCodecConfig<Self> {
310 StaticCodecConfig::from(self)
311 }
312}
313
314impl Serialize for ReinterpretCodec {
315 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
316 ReinterpretCodecConfig {
317 encode_dtype: self.encode_dtype,
318 decode_dtype: self.decode_dtype,
319 }
320 .serialize(serializer)
321 }
322}
323
324impl<'de> Deserialize<'de> for ReinterpretCodec {
325 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
326 let config = ReinterpretCodecConfig::deserialize(deserializer)?;
327
328 Self::try_new(config.encode_dtype, config.decode_dtype).map_err(serde::de::Error::custom)
329 }
330}
331
332#[derive(Clone, Serialize, Deserialize)]
333#[serde(rename = "ReinterpretCodec")]
334struct ReinterpretCodecConfig {
335 encode_dtype: AnyArrayDType,
336 decode_dtype: AnyArrayDType,
337}
338
339#[derive(Debug, Error)]
340pub enum ReinterpretCodecError {
342 #[error("Reinterpret cannot bitcast {decode_dtype} as {encode_dtype}")]
345 InvalidReinterpret {
346 decode_dtype: AnyArrayDType,
348 encode_dtype: AnyArrayDType,
350 },
351 #[error("Reinterpret cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
354 MismatchedEncodeDType {
355 configured: AnyArrayDType,
357 provided: AnyArrayDType,
359 },
360 #[error("Reinterpret cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
363 MismatchedDecodeDType {
364 configured: AnyArrayDType,
366 provided: AnyArrayDType,
368 },
369 #[error(
371 "Reinterpret cannot decode a byte array of shape {shape:?} into an array of {dtype}-s"
372 )]
373 InvalidEncodedShape {
374 shape: Vec<usize>,
376 dtype: AnyArrayDType,
378 },
379 #[error("Reinterpret cannot decode into the provided array")]
381 MismatchedDecodeIntoArray {
382 #[from]
384 source: AnyArrayAssignError,
385 },
386}
387
388#[inline]
391pub fn reinterpret_array<T: Copy, U, S: Data<Elem = T>, D: Dimension>(
392 array: ArrayBase<S, D>,
393 reinterpret: impl Fn(T) -> U,
394) -> Array<U, D> {
395 let array = array.into_owned();
396 let (shape, data) = (array.raw_dim(), array.into_raw_vec_and_offset().0);
397
398 let data = data.into_iter().map(reinterpret).collect();
399
400 #[expect(unsafe_code)]
401 let array = unsafe { Array::from_shape_vec_unchecked(shape, data) };
403
404 array
405}
406
407#[expect(clippy::needless_pass_by_value)]
408#[inline]
418pub fn reinterpret_array_into<'a, T: Copy, U: ArrayDType, D: Dimension>(
419 encoded: ArrayView<T, D>,
420 reinterpret: impl Fn(T) -> U,
421 mut decoded: AnyArrayViewMut<'a>,
422) -> Result<(), ReinterpretCodecError>
423where
424 U::RawData<ViewRepr<&'a mut ()>>: DataMut,
425{
426 let Some(decoded) = decoded.as_typed_mut::<U>() else {
427 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
428 source: AnyArrayAssignError::DTypeMismatch {
429 src: U::DTYPE,
430 dst: decoded.dtype(),
431 },
432 });
433 };
434
435 if encoded.shape() != decoded.shape() {
436 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
437 source: AnyArrayAssignError::ShapeMismatch {
438 src: encoded.shape().to_vec(),
439 dst: decoded.shape().to_vec(),
440 },
441 });
442 }
443
444 for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
446 *d = reinterpret(*e);
447 }
448
449 Ok(())
450}