1use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension, ViewRepr};
21use numcodecs::{
22 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23 ArrayDType, Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Deserializer, Serialize, Serializer};
27use thiserror::Error;
28
29#[derive(Clone, JsonSchema)]
30#[serde(deny_unknown_fields)]
31pub struct ReinterpretCodec {
38 encode_dtype: AnyArrayDType,
40 decode_dtype: AnyArrayDType,
42 #[serde(default)]
44 _version: StaticCodecVersion<1, 0, 0>,
45}
46
47impl ReinterpretCodec {
48 pub fn try_new(
57 encode_dtype: AnyArrayDType,
58 decode_dtype: AnyArrayDType,
59 ) -> Result<Self, ReinterpretCodecError> {
60 #[expect(clippy::match_same_arms)]
61 match (decode_dtype, encode_dtype) {
62 (ty_a, ty_b) if ty_a == ty_b => (),
64 (_, AnyArrayDType::U8) => (),
66 (AnyArrayDType::I16, AnyArrayDType::U16)
68 | (AnyArrayDType::I32 | AnyArrayDType::F32, AnyArrayDType::U32)
69 | (AnyArrayDType::I64 | AnyArrayDType::F64, AnyArrayDType::U64) => (),
70 (decode_dtype, encode_dtype) => {
71 return Err(ReinterpretCodecError::InvalidReinterpret {
72 decode_dtype,
73 encode_dtype,
74 })
75 }
76 }
77
78 Ok(Self {
79 encode_dtype,
80 decode_dtype,
81 _version: StaticCodecVersion,
82 })
83 }
84
85 #[must_use]
86 pub const fn passthrough(dtype: AnyArrayDType) -> Self {
88 Self {
89 encode_dtype: dtype,
90 decode_dtype: dtype,
91 _version: StaticCodecVersion,
92 }
93 }
94
95 #[must_use]
96 pub const fn to_bytes(dtype: AnyArrayDType) -> Self {
99 Self {
100 encode_dtype: AnyArrayDType::U8,
101 decode_dtype: dtype,
102 _version: StaticCodecVersion,
103 }
104 }
105
106 #[must_use]
107 pub const fn to_binary(dtype: AnyArrayDType) -> Self {
110 Self {
111 encode_dtype: dtype.to_binary(),
112 decode_dtype: dtype,
113 _version: StaticCodecVersion,
114 }
115 }
116}
117
118impl Codec for ReinterpretCodec {
119 type Error = ReinterpretCodecError;
120
121 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
122 if data.dtype() != self.decode_dtype {
123 return Err(ReinterpretCodecError::MismatchedEncodeDType {
124 configured: self.decode_dtype,
125 provided: data.dtype(),
126 });
127 }
128
129 let encoded = match (data, self.encode_dtype) {
130 (data, dtype) if data.dtype() == dtype => data.into_owned(),
131 (data, AnyArrayDType::U8) => {
132 let mut shape = data.shape().to_vec();
133 if let Some(last) = shape.last_mut() {
134 *last *= data.dtype().size();
135 }
136 #[expect(unsafe_code)]
137 let encoded =
139 unsafe { Array::from_shape_vec_unchecked(shape, data.as_bytes().into_owned()) };
140 AnyArray::U8(encoded)
141 }
142 (AnyCowArray::I16(data), AnyArrayDType::U16) => {
143 AnyArray::U16(reinterpret_array(data, |x| {
144 u16::from_ne_bytes(x.to_ne_bytes())
145 }))
146 }
147 (AnyCowArray::I32(data), AnyArrayDType::U32) => {
148 AnyArray::U32(reinterpret_array(data, |x| {
149 u32::from_ne_bytes(x.to_ne_bytes())
150 }))
151 }
152 (AnyCowArray::F32(data), AnyArrayDType::U32) => {
153 AnyArray::U32(reinterpret_array(data, f32::to_bits))
154 }
155 (AnyCowArray::I64(data), AnyArrayDType::U64) => {
156 AnyArray::U64(reinterpret_array(data, |x| {
157 u64::from_ne_bytes(x.to_ne_bytes())
158 }))
159 }
160 (AnyCowArray::F64(data), AnyArrayDType::U64) => {
161 AnyArray::U64(reinterpret_array(data, f64::to_bits))
162 }
163 (data, dtype) => {
164 return Err(ReinterpretCodecError::InvalidReinterpret {
165 decode_dtype: data.dtype(),
166 encode_dtype: dtype,
167 });
168 }
169 };
170
171 Ok(encoded)
172 }
173
174 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
175 if encoded.dtype() != self.encode_dtype {
176 return Err(ReinterpretCodecError::MismatchedDecodeDType {
177 configured: self.encode_dtype,
178 provided: encoded.dtype(),
179 });
180 }
181
182 let decoded = match (encoded, self.decode_dtype) {
183 (encoded, dtype) if encoded.dtype() == dtype => encoded.into_owned(),
184 (AnyCowArray::U8(encoded), dtype) => {
185 let mut shape = encoded.shape().to_vec();
186
187 if (encoded.len() % dtype.size()) != 0 {
188 return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
189 }
190
191 if let Some(last) = shape.last_mut() {
192 *last /= dtype.size();
193 }
194
195 let (decoded, ()) = AnyArray::with_zeros_bytes(dtype, &shape, |bytes| {
196 bytes.copy_from_slice(&AnyCowArray::U8(encoded).as_bytes());
197 });
198
199 decoded
200 }
201 (AnyCowArray::U16(encoded), AnyArrayDType::I16) => {
202 AnyArray::I16(reinterpret_array(encoded, |x| {
203 i16::from_ne_bytes(x.to_ne_bytes())
204 }))
205 }
206 (AnyCowArray::U32(encoded), AnyArrayDType::I32) => {
207 AnyArray::I32(reinterpret_array(encoded, |x| {
208 i32::from_ne_bytes(x.to_ne_bytes())
209 }))
210 }
211 (AnyCowArray::U32(encoded), AnyArrayDType::F32) => {
212 AnyArray::F32(reinterpret_array(encoded, f32::from_bits))
213 }
214 (AnyCowArray::U64(encoded), AnyArrayDType::U64) => {
215 AnyArray::I64(reinterpret_array(encoded, |x| {
216 i64::from_ne_bytes(x.to_ne_bytes())
217 }))
218 }
219 (AnyCowArray::U64(encoded), AnyArrayDType::F64) => {
220 AnyArray::F64(reinterpret_array(encoded, f64::from_bits))
221 }
222 (encoded, dtype) => {
223 return Err(ReinterpretCodecError::InvalidReinterpret {
224 decode_dtype: dtype,
225 encode_dtype: encoded.dtype(),
226 });
227 }
228 };
229
230 Ok(decoded)
231 }
232
233 fn decode_into(
234 &self,
235 encoded: AnyArrayView,
236 mut decoded: AnyArrayViewMut,
237 ) -> Result<(), Self::Error> {
238 if encoded.dtype() != self.encode_dtype {
239 return Err(ReinterpretCodecError::MismatchedDecodeDType {
240 configured: self.encode_dtype,
241 provided: encoded.dtype(),
242 });
243 }
244
245 match (encoded, self.decode_dtype) {
246 (encoded, dtype) if encoded.dtype() == dtype => Ok(decoded.assign(&encoded)?),
247 (AnyArrayView::U8(encoded), dtype) => {
248 if decoded.dtype() != dtype {
249 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
250 source: AnyArrayAssignError::DTypeMismatch {
251 src: dtype,
252 dst: decoded.dtype(),
253 },
254 });
255 }
256
257 let mut shape = encoded.shape().to_vec();
258
259 if (encoded.len() % dtype.size()) != 0 {
260 return Err(ReinterpretCodecError::InvalidEncodedShape { shape, dtype });
261 }
262
263 if let Some(last) = shape.last_mut() {
264 *last /= dtype.size();
265 }
266
267 if decoded.shape() != shape {
268 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
269 source: AnyArrayAssignError::ShapeMismatch {
270 src: shape,
271 dst: decoded.shape().to_vec(),
272 },
273 });
274 }
275
276 let () = decoded.with_bytes_mut(|bytes| {
277 bytes.copy_from_slice(&AnyArrayView::U8(encoded).as_bytes());
278 });
279
280 Ok(())
281 }
282 (AnyArrayView::U16(encoded), AnyArrayDType::I16) => {
283 reinterpret_array_into(encoded, |x| i16::from_ne_bytes(x.to_ne_bytes()), decoded)
284 }
285 (AnyArrayView::U32(encoded), AnyArrayDType::I32) => {
286 reinterpret_array_into(encoded, |x| i32::from_ne_bytes(x.to_ne_bytes()), decoded)
287 }
288 (AnyArrayView::U32(encoded), AnyArrayDType::F32) => {
289 reinterpret_array_into(encoded, f32::from_bits, decoded)
290 }
291 (AnyArrayView::U64(encoded), AnyArrayDType::U64) => {
292 reinterpret_array_into(encoded, |x| i64::from_ne_bytes(x.to_ne_bytes()), decoded)
293 }
294 (AnyArrayView::U64(encoded), AnyArrayDType::F64) => {
295 reinterpret_array_into(encoded, f64::from_bits, decoded)
296 }
297 (encoded, dtype) => Err(ReinterpretCodecError::InvalidReinterpret {
298 decode_dtype: dtype,
299 encode_dtype: encoded.dtype(),
300 }),
301 }?;
302
303 Ok(())
304 }
305}
306
307impl StaticCodec for ReinterpretCodec {
308 const CODEC_ID: &'static str = "reinterpret.rs";
309
310 type Config<'de> = Self;
311
312 fn from_config(config: Self::Config<'_>) -> Self {
313 config
314 }
315
316 fn get_config(&self) -> StaticCodecConfig<Self> {
317 StaticCodecConfig::from(self)
318 }
319}
320
321impl Serialize for ReinterpretCodec {
322 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
323 ReinterpretCodecConfig {
324 encode_dtype: self.encode_dtype,
325 decode_dtype: self.decode_dtype,
326 }
327 .serialize(serializer)
328 }
329}
330
331impl<'de> Deserialize<'de> for ReinterpretCodec {
332 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
333 let config = ReinterpretCodecConfig::deserialize(deserializer)?;
334
335 Self::try_new(config.encode_dtype, config.decode_dtype).map_err(serde::de::Error::custom)
336 }
337}
338
339#[derive(Clone, Serialize, Deserialize)]
340#[serde(rename = "ReinterpretCodec")]
341struct ReinterpretCodecConfig {
342 encode_dtype: AnyArrayDType,
343 decode_dtype: AnyArrayDType,
344}
345
346#[derive(Debug, Error)]
347pub enum ReinterpretCodecError {
349 #[error("Reinterpret cannot bitcast {decode_dtype} as {encode_dtype}")]
352 InvalidReinterpret {
353 decode_dtype: AnyArrayDType,
355 encode_dtype: AnyArrayDType,
357 },
358 #[error("Reinterpret cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
361 MismatchedEncodeDType {
362 configured: AnyArrayDType,
364 provided: AnyArrayDType,
366 },
367 #[error("Reinterpret cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
370 MismatchedDecodeDType {
371 configured: AnyArrayDType,
373 provided: AnyArrayDType,
375 },
376 #[error(
378 "Reinterpret cannot decode a byte array of shape {shape:?} into an array of {dtype}-s"
379 )]
380 InvalidEncodedShape {
381 shape: Vec<usize>,
383 dtype: AnyArrayDType,
385 },
386 #[error("Reinterpret cannot decode into the provided array")]
388 MismatchedDecodeIntoArray {
389 #[from]
391 source: AnyArrayAssignError,
392 },
393}
394
395#[inline]
398pub fn reinterpret_array<T: Copy, U, S: Data<Elem = T>, D: Dimension>(
399 array: ArrayBase<S, D>,
400 reinterpret: impl Fn(T) -> U,
401) -> Array<U, D> {
402 let array = array.into_owned();
403 let (shape, data) = (array.raw_dim(), array.into_raw_vec_and_offset().0);
404
405 let data = data.into_iter().map(reinterpret).collect();
406
407 #[expect(unsafe_code)]
408 let array = unsafe { Array::from_shape_vec_unchecked(shape, data) };
410
411 array
412}
413
414#[expect(clippy::needless_pass_by_value)]
415#[inline]
425pub fn reinterpret_array_into<'a, T: Copy, U: ArrayDType, D: Dimension>(
426 encoded: ArrayView<T, D>,
427 reinterpret: impl Fn(T) -> U,
428 mut decoded: AnyArrayViewMut<'a>,
429) -> Result<(), ReinterpretCodecError>
430where
431 U::RawData<ViewRepr<&'a mut ()>>: DataMut,
432{
433 let Some(decoded) = decoded.as_typed_mut::<U>() else {
434 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
435 source: AnyArrayAssignError::DTypeMismatch {
436 src: U::DTYPE,
437 dst: decoded.dtype(),
438 },
439 });
440 };
441
442 if encoded.shape() != decoded.shape() {
443 return Err(ReinterpretCodecError::MismatchedDecodeIntoArray {
444 source: AnyArrayAssignError::ShapeMismatch {
445 src: encoded.shape().to_vec(),
446 dst: decoded.shape().to_vec(),
447 },
448 });
449 }
450
451 for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
453 *d = reinterpret(*e);
454 }
455
456 Ok(())
457}