numcodecs_sz3/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.87.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sz3
10//! [crates.io]: https://crates.io/crates/numcodecs-sz3
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sz3
13//! [docs.rs]: https://docs.rs/numcodecs-sz3/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sz3
17//!
18//! SZ3 codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayViewMut, Data, Dimension, IxDyn, ShapeError};
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    ArrayDType, ArrayDataMutExt, Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33// Only included to explicitly enable the `no_wasm_shim` feature for
34// sz3-sys/zstd-sys
35use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40type Sz3CodecVersion = StaticCodecVersion<0, 2, 0>;
41
42#[derive(Clone, Serialize, Deserialize, JsonSchema)]
43// serde cannot deny unknown fields because of the flatten
44#[schemars(deny_unknown_fields)]
45/// Codec providing compression using SZ3
46pub struct Sz3Codec {
47    /// Predictor
48    #[serde(default = "default_predictor")]
49    pub predictor: Option<Sz3Predictor>,
50    /// SZ3 error bound
51    #[serde(flatten)]
52    pub error_bound: Sz3ErrorBound,
53    /// The codec's encoding format version. Do not provide this parameter explicitly.
54    #[serde(default, rename = "_version")]
55    pub version: Sz3CodecVersion,
56}
57
58/// SZ3 error bound
59#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
60#[serde(tag = "eb_mode")]
61#[serde(deny_unknown_fields)]
62pub enum Sz3ErrorBound {
63    /// Errors are bounded by *both* the absolute and relative error, i.e. by
64    /// whichever bound is stricter
65    #[serde(rename = "abs-and-rel")]
66    AbsoluteAndRelative {
67        /// Absolute error bound
68        #[serde(rename = "eb_abs")]
69        abs: f64,
70        /// Relative error bound
71        #[serde(rename = "eb_rel")]
72        rel: f64,
73    },
74    /// Errors are bounded by *either* the absolute or relative error, i.e. by
75    /// whichever bound is weaker
76    #[serde(rename = "abs-or-rel")]
77    AbsoluteOrRelative {
78        /// Absolute error bound
79        #[serde(rename = "eb_abs")]
80        abs: f64,
81        /// Relative error bound
82        #[serde(rename = "eb_rel")]
83        rel: f64,
84    },
85    /// Absolute error bound
86    #[serde(rename = "abs")]
87    Absolute {
88        /// Absolute error bound
89        #[serde(rename = "eb_abs")]
90        abs: f64,
91    },
92    /// Relative error bound
93    #[serde(rename = "rel")]
94    Relative {
95        /// Relative error bound
96        #[serde(rename = "eb_rel")]
97        rel: f64,
98    },
99    /// Peak signal to noise ratio error bound
100    #[serde(rename = "psnr")]
101    PS2NR {
102        /// Peak signal to noise ratio error bound
103        #[serde(rename = "eb_psnr")]
104        psnr: f64,
105    },
106    /// Peak L2 norm error bound
107    #[serde(rename = "l2")]
108    L2Norm {
109        /// Peak L2 norm error bound
110        #[serde(rename = "eb_l2")]
111        l2: f64,
112    },
113}
114
115/// SZ3 predictor
116#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
117#[serde(deny_unknown_fields)]
118pub enum Sz3Predictor {
119    /// Interpolation
120    #[serde(rename = "interpolation")]
121    Interpolation,
122    /// Interpolation + Lorenzo predictor
123    #[serde(rename = "interpolation-lorenzo")]
124    InterpolationLorenzo,
125    /// 1st order regression
126    #[serde(rename = "regression")]
127    Regression,
128    /// 2nd order Lorenzo predictor
129    #[serde(rename = "lorenzo2")]
130    LorenzoSecondOrder,
131    /// 2nd order Lorenzo predictor + 1st order regression
132    #[serde(rename = "lorenzo2-regression")]
133    LorenzoSecondOrderRegression,
134    /// 1st order Lorenzo predictor
135    #[serde(rename = "lorenzo")]
136    Lorenzo,
137    /// 1st order Lorenzo predictor + 1st order regression
138    #[serde(rename = "lorenzo-regression")]
139    LorenzoRegression,
140    /// 1st+2nd order Lorenzo predictor
141    #[serde(rename = "lorenzo-lorenzo2")]
142    LorenzoFirstSecondOrder,
143    /// 1st+2nd order Lorenzo predictor + 1st order regression
144    #[serde(rename = "lorenzo-lorenzo2-regression")]
145    LorenzoFirstSecondOrderRegression,
146}
147
148#[expect(clippy::unnecessary_wraps)]
149const fn default_predictor() -> Option<Sz3Predictor> {
150    Some(Sz3Predictor::InterpolationLorenzo)
151}
152
153impl Codec for Sz3Codec {
154    type Error = Sz3CodecError;
155
156    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
157        match data {
158            AnyCowArray::I32(data) => Ok(AnyArray::U8(
159                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
160                    .into_dyn(),
161            )),
162            AnyCowArray::I64(data) => Ok(AnyArray::U8(
163                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
164                    .into_dyn(),
165            )),
166            AnyCowArray::F32(data) => Ok(AnyArray::U8(
167                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
168                    .into_dyn(),
169            )),
170            AnyCowArray::F64(data) => Ok(AnyArray::U8(
171                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
172                    .into_dyn(),
173            )),
174            encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
175        }
176    }
177
178    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
179        let AnyCowArray::U8(encoded) = encoded else {
180            return Err(Sz3CodecError::EncodedDataNotBytes {
181                dtype: encoded.dtype(),
182            });
183        };
184
185        if !matches!(encoded.shape(), [_]) {
186            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
187                shape: encoded.shape().to_vec(),
188            });
189        }
190
191        decompress(&AnyCowArray::U8(encoded).as_bytes())
192    }
193
194    fn decode_into(
195        &self,
196        encoded: AnyArrayView,
197        decoded: AnyArrayViewMut,
198    ) -> Result<(), Self::Error> {
199        let AnyArrayView::U8(encoded) = encoded else {
200            return Err(Sz3CodecError::EncodedDataNotBytes {
201                dtype: encoded.dtype(),
202            });
203        };
204
205        if !matches!(encoded.shape(), [_]) {
206            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
207                shape: encoded.shape().to_vec(),
208            });
209        }
210
211        decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
212    }
213}
214
215impl StaticCodec for Sz3Codec {
216    const CODEC_ID: &'static str = "sz3.rs";
217
218    type Config<'de> = Self;
219
220    fn from_config(config: Self::Config<'_>) -> Self {
221        config
222    }
223
224    fn get_config(&self) -> StaticCodecConfig<'_, Self> {
225        StaticCodecConfig::from(self)
226    }
227}
228
229#[derive(Debug, Error)]
230/// Errors that may occur when applying the [`Sz3Codec`].
231pub enum Sz3CodecError {
232    /// [`Sz3Codec`] does not support the dtype
233    #[error("Sz3 does not support the dtype {0}")]
234    UnsupportedDtype(AnyArrayDType),
235    /// [`Sz3Codec`] failed to encode the header
236    #[error("Sz3 failed to encode the header")]
237    HeaderEncodeFailed {
238        /// Opaque source error
239        source: Sz3HeaderError,
240    },
241    /// [`Sz3Codec`] cannot encode an array of `shape`
242    #[error("Sz3 cannot encode an array of shape {shape:?}")]
243    InvalidEncodeShape {
244        /// Opaque source error
245        source: Sz3CodingError,
246        /// The invalid shape of the encoded array
247        shape: Vec<usize>,
248    },
249    /// [`Sz3Codec`] failed to encode the data
250    #[error("Sz3 failed to encode the data")]
251    Sz3EncodeFailed {
252        /// Opaque source error
253        source: Sz3CodingError,
254    },
255    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
256    /// an array of a different dtype
257    #[error(
258        "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
259    )]
260    EncodedDataNotBytes {
261        /// The unexpected dtype of the encoded array
262        dtype: AnyArrayDType,
263    },
264    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
265    /// an array of a different shape
266    #[error(
267        "Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
268    )]
269    EncodedDataNotOneDimensional {
270        /// The unexpected shape of the encoded array
271        shape: Vec<usize>,
272    },
273    /// [`Sz3Codec`] failed to decode the header
274    #[error("Sz3 failed to decode the header")]
275    HeaderDecodeFailed {
276        /// Opaque source error
277        source: Sz3HeaderError,
278    },
279    /// [`Sz3Codec`] failed to decode the data
280    #[error("Sz3 failed to decode the data")]
281    Sz3DecodeFailed {
282        /// Opaque source error
283        source: Sz3CodingError,
284    },
285    /// [`Sz3Codec`] decoded an invalid array shape header which does not fit
286    /// the decoded data
287    #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
288    DecodeInvalidShapeHeader {
289        /// Source error
290        #[from]
291        source: ShapeError,
292    },
293    /// [`Sz3Codec`] cannot decode into the provided array
294    #[error("Sz3 cannot decode into the provided array")]
295    MismatchedDecodeIntoArray {
296        /// The source of the error
297        #[from]
298        source: AnyArrayAssignError,
299    },
300}
301
302#[derive(Debug, Error)]
303#[error(transparent)]
304/// Opaque error for when encoding or decoding the header fails
305pub struct Sz3HeaderError(postcard::Error);
306
307#[derive(Debug, Error)]
308#[error(transparent)]
309/// Opaque error for when encoding or decoding with SZ3 fails
310pub struct Sz3CodingError(sz3::SZ3Error);
311
312#[expect(clippy::needless_pass_by_value, clippy::too_many_lines)]
313/// Compresses the input `data` array using SZ3, which consists of an optional
314/// `predictor`, an `error_bound`, an optional `encoder`, and an optional
315/// `lossless` compressor.
316///
317/// # Errors
318///
319/// Errors with
320/// - [`Sz3CodecError::HeaderEncodeFailed`] if encoding the header failed
321/// - [`Sz3CodecError::InvalidEncodeShape`] if the array shape is invalid
322/// - [`Sz3CodecError::Sz3EncodeFailed`] if encoding failed with an opaque error
323pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
324    data: ArrayBase<S, D>,
325    predictor: Option<&Sz3Predictor>,
326    error_bound: &Sz3ErrorBound,
327) -> Result<Vec<u8>, Sz3CodecError> {
328    let mut encoded_bytes = postcard::to_extend(
329        &CompressionHeader {
330            dtype: <T as Sz3Element>::DTYPE,
331            shape: Cow::Borrowed(data.shape()),
332            version: StaticCodecVersion,
333        },
334        Vec::new(),
335    )
336    .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
337        source: Sz3HeaderError(err),
338    })?;
339
340    // sz3::DimensionedDataBuilder cannot handle zero-length dimensions
341    if data.is_empty() {
342        return Ok(encoded_bytes);
343    }
344
345    #[expect(clippy::option_if_let_else)]
346    let data_cow = match data.as_slice() {
347        Some(data) => Cow::Borrowed(data),
348        None => Cow::Owned(data.iter().copied().collect()),
349    };
350    let mut builder = sz3::DimensionedData::build(&data_cow);
351
352    for length in data.shape() {
353        // Sz3 ignores dimensions of length 1 and panics on length zero
354        // Since they carry no information for Sz3 and we already encode them
355        //  in our custom header, we just skip them here
356        if *length > 1 {
357            builder = builder
358                .dim(*length)
359                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
360                    source: Sz3CodingError(err),
361                    shape: data.shape().to_vec(),
362                })?;
363        }
364    }
365
366    if data.len() == 1 {
367        // If there is only one element, all dimensions will have been skipped,
368        //  so we explicitly encode one dimension of size 1 here
369        builder = builder
370            .dim(1)
371            .map_err(|err| Sz3CodecError::InvalidEncodeShape {
372                source: Sz3CodingError(err),
373                shape: data.shape().to_vec(),
374            })?;
375    }
376
377    let data = builder
378        .finish()
379        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
380            source: Sz3CodingError(err),
381            shape: data.shape().to_vec(),
382        })?;
383
384    // configure the error bound
385    let error_bound = match error_bound {
386        Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
387            absolute_bound: *abs,
388            relative_bound: *rel,
389        },
390        Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
391            absolute_bound: *abs,
392            relative_bound: *rel,
393        },
394        Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
395        Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
396        Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
397        Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
398    };
399    let mut config = sz3::Config::new(error_bound);
400
401    // configure the predictor (compression algorithm)
402    let predictor = match predictor {
403        Some(Sz3Predictor::Interpolation) => sz3::CompressionAlgorithm::Interpolation,
404        Some(Sz3Predictor::InterpolationLorenzo) => sz3::CompressionAlgorithm::InterpolationLorenzo,
405        Some(Sz3Predictor::Regression) => sz3::CompressionAlgorithm::LorenzoRegression {
406            lorenzo: false,
407            lorenzo_second_order: false,
408            regression: true,
409        },
410        Some(Sz3Predictor::LorenzoSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
411            lorenzo: false,
412            lorenzo_second_order: true,
413            regression: false,
414        },
415        Some(Sz3Predictor::LorenzoSecondOrderRegression) => {
416            sz3::CompressionAlgorithm::LorenzoRegression {
417                lorenzo: false,
418                lorenzo_second_order: true,
419                regression: true,
420            }
421        }
422        Some(Sz3Predictor::Lorenzo) => sz3::CompressionAlgorithm::LorenzoRegression {
423            lorenzo: true,
424            lorenzo_second_order: false,
425            regression: false,
426        },
427        Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::LorenzoRegression {
428            lorenzo: true,
429            lorenzo_second_order: false,
430            regression: true,
431        },
432        Some(Sz3Predictor::LorenzoFirstSecondOrder) => {
433            sz3::CompressionAlgorithm::LorenzoRegression {
434                lorenzo: true,
435                lorenzo_second_order: true,
436                regression: false,
437            }
438        }
439        Some(Sz3Predictor::LorenzoFirstSecondOrderRegression) => {
440            sz3::CompressionAlgorithm::LorenzoRegression {
441                lorenzo: true,
442                lorenzo_second_order: true,
443                regression: true,
444            }
445        }
446        None => sz3::CompressionAlgorithm::NoPrediction,
447    };
448    config = config.compression_algorithm(predictor);
449
450    sz3::compress_into_with_config(&data, &config, &mut encoded_bytes).map_err(|err| {
451        Sz3CodecError::Sz3EncodeFailed {
452            source: Sz3CodingError(err),
453        }
454    })?;
455
456    Ok(encoded_bytes)
457}
458
459/// Decompresses the `encoded` data into an array using SZ3.
460///
461/// # Errors
462///
463/// Errors with
464/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
465/// - [`Sz3CodecError::Sz3DecodeFailed`] if decoding failed with an opaque error
466pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
467    fn decompress_typed<T: Sz3Element>(
468        encoded: &[u8],
469        shape: &[usize],
470    ) -> Result<Array<T, IxDyn>, Sz3CodecError> {
471        if shape.iter().copied().any(|s| s == 0) {
472            return Ok(Array::from_shape_vec(shape, Vec::new())?);
473        }
474
475        let (_config, decompressed) =
476            sz3::decompress(encoded).map_err(|err| Sz3CodecError::Sz3DecodeFailed {
477                source: Sz3CodingError(err),
478            })?;
479
480        Ok(Array::from_shape_vec(shape, decompressed.into_data())?)
481    }
482
483    let (header, data) =
484        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
485            Sz3CodecError::HeaderDecodeFailed {
486                source: Sz3HeaderError(err),
487            }
488        })?;
489
490    let decoded = match header.dtype {
491        Sz3DType::U8 => AnyArray::U8(decompress_typed(data, &header.shape)?),
492        Sz3DType::I8 => AnyArray::I8(decompress_typed(data, &header.shape)?),
493        Sz3DType::U16 => AnyArray::U16(decompress_typed(data, &header.shape)?),
494        Sz3DType::I16 => AnyArray::I16(decompress_typed(data, &header.shape)?),
495        Sz3DType::U32 => AnyArray::U32(decompress_typed(data, &header.shape)?),
496        Sz3DType::I32 => AnyArray::I32(decompress_typed(data, &header.shape)?),
497        Sz3DType::U64 => AnyArray::U64(decompress_typed(data, &header.shape)?),
498        Sz3DType::I64 => AnyArray::I64(decompress_typed(data, &header.shape)?),
499        Sz3DType::F32 => AnyArray::F32(decompress_typed(data, &header.shape)?),
500        Sz3DType::F64 => AnyArray::F64(decompress_typed(data, &header.shape)?),
501    };
502
503    Ok(decoded)
504}
505
506/// Decompresses the `encoded` data into a `decoded` array using SZ3.
507///
508/// # Errors
509///
510/// Errors with
511/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
512/// - [`Sz3CodecError::MismatchedDecodeIntoArray`] if the `decoded` array is of
513///   the wrong dtype or shape
514/// - [`Sz3CodecError::Sz3DecodeFailed`] if decoding failed with an opaque error
515pub fn decompress_into(encoded: &[u8], decoded: AnyArrayViewMut) -> Result<(), Sz3CodecError> {
516    fn decompress_into_typed<T: Sz3Element>(
517        encoded: &[u8],
518        mut decoded: ArrayViewMut<T, IxDyn>,
519    ) -> Result<(), Sz3CodecError> {
520        if decoded.is_empty() {
521            return Ok(());
522        }
523
524        let decoded_shape = decoded.shape().to_vec();
525
526        decoded.with_slice_mut(|mut decoded| {
527            let decoded_len = decoded.len();
528
529            let mut builder = sz3::DimensionedData::build_mut(&mut decoded);
530
531            for length in &decoded_shape {
532                // Sz3 ignores dimensions of length 1 and panics on length zero
533                // Since they carry no information for Sz3 and we already encode them
534                //  in our custom header, we just skip them here
535                if *length > 1 {
536                    builder = builder
537                        .dim(*length)
538                        // FIXME: different error code
539                        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
540                            source: Sz3CodingError(err),
541                            shape: decoded_shape.clone(),
542                        })?;
543                }
544            }
545
546            if decoded_len == 1 {
547                // If there is only one element, all dimensions will have been skipped,
548                //  so we explicitly encode one dimension of size 1 here
549                builder = builder
550                    .dim(1)
551                    // FIXME: different error code
552                    .map_err(|err| Sz3CodecError::InvalidEncodeShape {
553                        source: Sz3CodingError(err),
554                        shape: decoded_shape.clone(),
555                    })?;
556            }
557
558            let mut decoded = builder
559                .finish()
560                // FIXME: different error code
561                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
562                    source: Sz3CodingError(err),
563                    shape: decoded_shape,
564                })?;
565
566            sz3::decompress_into_dimensioned(encoded, &mut decoded).map_err(|err| {
567                Sz3CodecError::Sz3DecodeFailed {
568                    source: Sz3CodingError(err),
569                }
570            })
571        })?;
572
573        Ok(())
574    }
575
576    let (header, data) =
577        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
578            Sz3CodecError::HeaderDecodeFailed {
579                source: Sz3HeaderError(err),
580            }
581        })?;
582
583    if decoded.shape() != &*header.shape {
584        return Err(Sz3CodecError::MismatchedDecodeIntoArray {
585            source: AnyArrayAssignError::ShapeMismatch {
586                src: header.shape.into_owned(),
587                dst: decoded.shape().to_vec(),
588            },
589        });
590    }
591
592    match (decoded, header.dtype) {
593        (AnyArrayViewMut::U8(decoded), Sz3DType::U8) => decompress_into_typed(data, decoded),
594        (AnyArrayViewMut::I8(decoded), Sz3DType::I8) => decompress_into_typed(data, decoded),
595        (AnyArrayViewMut::U16(decoded), Sz3DType::U16) => decompress_into_typed(data, decoded),
596        (AnyArrayViewMut::I16(decoded), Sz3DType::I16) => decompress_into_typed(data, decoded),
597        (AnyArrayViewMut::U32(decoded), Sz3DType::U32) => decompress_into_typed(data, decoded),
598        (AnyArrayViewMut::I32(decoded), Sz3DType::I32) => decompress_into_typed(data, decoded),
599        (AnyArrayViewMut::U64(decoded), Sz3DType::U64) => decompress_into_typed(data, decoded),
600        (AnyArrayViewMut::I64(decoded), Sz3DType::I64) => decompress_into_typed(data, decoded),
601        (AnyArrayViewMut::F32(decoded), Sz3DType::F32) => decompress_into_typed(data, decoded),
602        (AnyArrayViewMut::F64(decoded), Sz3DType::F64) => decompress_into_typed(data, decoded),
603        (decoded, dtype) => Err(Sz3CodecError::MismatchedDecodeIntoArray {
604            source: AnyArrayAssignError::DTypeMismatch {
605                src: dtype.into_dtype(),
606                dst: decoded.dtype(),
607            },
608        }),
609    }
610}
611
612/// Array element types which can be compressed with SZ3.
613pub trait Sz3Element: Copy + sz3::SZ3Compressible + ArrayDType {
614    /// The dtype representation of the type
615    const DTYPE: Sz3DType;
616}
617
618impl Sz3Element for u8 {
619    const DTYPE: Sz3DType = Sz3DType::U8;
620}
621
622impl Sz3Element for i8 {
623    const DTYPE: Sz3DType = Sz3DType::I8;
624}
625
626impl Sz3Element for u16 {
627    const DTYPE: Sz3DType = Sz3DType::U16;
628}
629
630impl Sz3Element for i16 {
631    const DTYPE: Sz3DType = Sz3DType::I16;
632}
633
634impl Sz3Element for u32 {
635    const DTYPE: Sz3DType = Sz3DType::U32;
636}
637
638impl Sz3Element for i32 {
639    const DTYPE: Sz3DType = Sz3DType::I32;
640}
641
642impl Sz3Element for u64 {
643    const DTYPE: Sz3DType = Sz3DType::U64;
644}
645
646impl Sz3Element for i64 {
647    const DTYPE: Sz3DType = Sz3DType::I64;
648}
649
650impl Sz3Element for f32 {
651    const DTYPE: Sz3DType = Sz3DType::F32;
652}
653
654impl Sz3Element for f64 {
655    const DTYPE: Sz3DType = Sz3DType::F64;
656}
657
658#[derive(Serialize, Deserialize)]
659struct CompressionHeader<'a> {
660    dtype: Sz3DType,
661    #[serde(borrow)]
662    shape: Cow<'a, [usize]>,
663    version: Sz3CodecVersion,
664}
665
666/// Dtypes that SZ3 can compress and decompress
667#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
668#[expect(missing_docs)]
669pub enum Sz3DType {
670    #[serde(rename = "u8", alias = "uint8")]
671    U8,
672    #[serde(rename = "u16", alias = "uint16")]
673    U16,
674    #[serde(rename = "u32", alias = "uint32")]
675    U32,
676    #[serde(rename = "u64", alias = "uint64")]
677    U64,
678    #[serde(rename = "i8", alias = "int8")]
679    I8,
680    #[serde(rename = "i16", alias = "int16")]
681    I16,
682    #[serde(rename = "i32", alias = "int32")]
683    I32,
684    #[serde(rename = "i64", alias = "int64")]
685    I64,
686    #[serde(rename = "f32", alias = "float32")]
687    F32,
688    #[serde(rename = "f64", alias = "float64")]
689    F64,
690}
691
692impl Sz3DType {
693    /// Get the corresponding [`AnyArrayDType`]
694    #[must_use]
695    pub const fn into_dtype(self) -> AnyArrayDType {
696        match self {
697            Self::U8 => AnyArrayDType::U8,
698            Self::U16 => AnyArrayDType::U16,
699            Self::U32 => AnyArrayDType::U32,
700            Self::U64 => AnyArrayDType::U64,
701            Self::I8 => AnyArrayDType::I8,
702            Self::I16 => AnyArrayDType::I16,
703            Self::I32 => AnyArrayDType::I32,
704            Self::I64 => AnyArrayDType::I64,
705            Self::F32 => AnyArrayDType::F32,
706            Self::F64 => AnyArrayDType::F64,
707        }
708    }
709}
710
711impl fmt::Display for Sz3DType {
712    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
713        fmt.write_str(match self {
714            Self::U8 => "u8",
715            Self::U16 => "u16",
716            Self::U32 => "u32",
717            Self::U64 => "u64",
718            Self::I8 => "i8",
719            Self::I16 => "i16",
720            Self::I32 => "i32",
721            Self::I64 => "i64",
722            Self::F32 => "f32",
723            Self::F64 => "f64",
724        })
725    }
726}
727
728#[cfg(test)]
729mod tests {
730    use ndarray::ArrayView1;
731
732    use super::*;
733
734    #[test]
735    fn zero_length() -> Result<(), Sz3CodecError> {
736        let encoded = compress(
737            Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
738            default_predictor().as_ref(),
739            &Sz3ErrorBound::L2Norm { l2: 27.0 },
740        )?;
741        let decoded = decompress(&encoded)?;
742
743        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
744        assert!(decoded.is_empty());
745        assert_eq!(decoded.shape(), &[1, 27, 0]);
746
747        Ok(())
748    }
749
750    #[test]
751    fn one_dimension() -> Result<(), Sz3CodecError> {
752        let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
753
754        let encoded = compress(
755            data.view(),
756            default_predictor().as_ref(),
757            &Sz3ErrorBound::Absolute { abs: 0.1 },
758        )?;
759        let decoded = decompress(&encoded)?;
760
761        assert_eq!(decoded, AnyArray::I32(data.clone()));
762
763        let mut decoded = Array::zeros(data.dim());
764        decompress_into(&encoded, AnyArrayViewMut::I32(decoded.view_mut()))?;
765
766        assert_eq!(decoded, data);
767
768        Ok(())
769    }
770
771    #[test]
772    fn small_state() -> Result<(), Sz3CodecError> {
773        for data in [
774            &[][..],
775            &[0.0],
776            &[0.0, 1.0],
777            &[0.0, 1.0, 0.0],
778            &[0.0, 1.0, 0.0, 1.0],
779        ] {
780            let encoded = compress(
781                ArrayView1::from(data),
782                default_predictor().as_ref(),
783                &Sz3ErrorBound::Absolute { abs: 0.1 },
784            )?;
785            let decoded = decompress(&encoded)?;
786
787            assert_eq!(
788                decoded,
789                AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
790            );
791
792            let mut decoded = Array::zeros([data.len()]);
793            decompress_into(
794                &encoded,
795                AnyArrayViewMut::F64(decoded.view_mut().into_dyn()),
796            )?;
797
798            assert_eq!(decoded, Array1::from_vec(data.to_vec()));
799        }
800
801        Ok(())
802    }
803
804    #[test]
805    fn all_predictors() -> Result<(), Sz3CodecError> {
806        let data = Array::linspace(-42.0, 42.0, 85);
807
808        for predictor in [
809            None,
810            Some(Sz3Predictor::Interpolation),
811            Some(Sz3Predictor::InterpolationLorenzo),
812            Some(Sz3Predictor::Regression),
813            Some(Sz3Predictor::LorenzoSecondOrder),
814            Some(Sz3Predictor::LorenzoSecondOrderRegression),
815            Some(Sz3Predictor::Lorenzo),
816            Some(Sz3Predictor::LorenzoRegression),
817            Some(Sz3Predictor::LorenzoFirstSecondOrder),
818            Some(Sz3Predictor::LorenzoFirstSecondOrderRegression),
819        ] {
820            let encoded = compress(
821                data.view(),
822                predictor.as_ref(),
823                &Sz3ErrorBound::Absolute { abs: 0.1 },
824            )?;
825            let _decoded = decompress(&encoded)?;
826
827            let mut decoded = Array::zeros(data.dim());
828            decompress_into(
829                &encoded,
830                AnyArrayViewMut::F64(decoded.view_mut().into_dyn()),
831            )?;
832        }
833
834        Ok(())
835    }
836
837    #[test]
838    fn all_dtypes() -> Result<(), Sz3CodecError> {
839        fn compress_decompress<T: Sz3Element + num_traits::identities::Zero>(
840            iter: impl Clone + IntoIterator<Item = T>,
841            view_mut: impl for<'a> Fn(ArrayViewMut<'a, T, IxDyn>) -> AnyArrayViewMut<'a>,
842        ) -> Result<(), Sz3CodecError> {
843            let encoded = compress(
844                Array::from_iter(iter.clone()).view(),
845                None,
846                &Sz3ErrorBound::Absolute { abs: 2.0 },
847            )?;
848            let _decoded = decompress(&encoded)?;
849
850            let mut decoded = Array::<T, _>::zeros([iter.into_iter().count()]).into_dyn();
851            decompress_into(&encoded, view_mut(decoded.view_mut().into_dyn()))?;
852
853            Ok(())
854        }
855
856        compress_decompress(0_u8..42, |x| AnyArrayViewMut::U8(x))?;
857        compress_decompress(-42_i8..42, |x| AnyArrayViewMut::I8(x))?;
858        compress_decompress(0_u16..42, |x| AnyArrayViewMut::U16(x))?;
859        compress_decompress(-42_i16..42, |x| AnyArrayViewMut::I16(x))?;
860        compress_decompress(0_u32..42, |x| AnyArrayViewMut::U32(x))?;
861        compress_decompress(-42_i32..42, |x| AnyArrayViewMut::I32(x))?;
862        compress_decompress(0_u64..42, |x| AnyArrayViewMut::U64(x))?;
863        compress_decompress(-42_i64..42, |x| AnyArrayViewMut::I64(x))?;
864        compress_decompress((-42_i16..42).map(f32::from), |x| AnyArrayViewMut::F32(x))?;
865        compress_decompress((-42_i16..42).map(f64::from), |x| AnyArrayViewMut::F64(x))?;
866
867        Ok(())
868    }
869}