Skip to main content

numcodecs_sz3/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.87.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sz3
10//! [crates.io]: https://crates.io/crates/numcodecs-sz3
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sz3
13//! [docs.rs]: https://docs.rs/numcodecs-sz3/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sz3
17//!
18//! SZ3 codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayViewMut, Data, Dimension, IxDyn, ShapeError};
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    ArrayDType, ArrayDataMutExt, Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33// Only included to explicitly enable the `no_wasm_shim` feature for
34// sz3-sys/zstd-sys
35use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40type Sz3CodecVersion = StaticCodecVersion<0, 2, 0>;
41
42#[derive(Clone, Serialize, Deserialize, JsonSchema)]
43// serde cannot deny unknown fields because of the flatten
44#[schemars(deny_unknown_fields)]
45/// Codec providing compression using SZ3
46pub struct Sz3Codec {
47    /// Predictor
48    #[serde(default = "default_predictor")]
49    pub predictor: Option<Sz3Predictor>,
50    /// SZ3 error bound
51    #[serde(flatten)]
52    pub error_bound: Sz3ErrorBound,
53    /// The codec's encoding format version. Do not provide this parameter explicitly.
54    #[serde(default, rename = "_version")]
55    pub version: Sz3CodecVersion,
56}
57
58/// SZ3 error bound
59#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
60#[serde(tag = "eb_mode")]
61#[serde(deny_unknown_fields)]
62pub enum Sz3ErrorBound {
63    /// Errors are bounded by *both* the absolute and relative error, i.e. by
64    /// whichever bound is stricter
65    #[serde(rename = "abs-and-rel")]
66    AbsoluteAndRelative {
67        /// Absolute error bound
68        #[serde(rename = "eb_abs")]
69        abs: f64,
70        /// Relative error bound
71        #[serde(rename = "eb_rel")]
72        rel: f64,
73    },
74    /// Errors are bounded by *either* the absolute or relative error, i.e. by
75    /// whichever bound is weaker
76    #[serde(rename = "abs-or-rel")]
77    AbsoluteOrRelative {
78        /// Absolute error bound
79        #[serde(rename = "eb_abs")]
80        abs: f64,
81        /// Relative error bound
82        #[serde(rename = "eb_rel")]
83        rel: f64,
84    },
85    /// Absolute error bound
86    #[serde(rename = "abs")]
87    Absolute {
88        /// Absolute error bound
89        #[serde(rename = "eb_abs")]
90        abs: f64,
91    },
92    /// Relative error bound
93    #[serde(rename = "rel")]
94    Relative {
95        /// Relative error bound
96        #[serde(rename = "eb_rel")]
97        rel: f64,
98    },
99    /// Peak signal to noise ratio error bound
100    #[serde(rename = "psnr")]
101    PS2NR {
102        /// Peak signal to noise ratio error bound
103        #[serde(rename = "eb_psnr")]
104        psnr: f64,
105    },
106    /// Peak L2 norm error bound
107    #[serde(rename = "l2")]
108    L2Norm {
109        /// Peak L2 norm error bound
110        #[serde(rename = "eb_l2")]
111        l2: f64,
112    },
113}
114
115/// SZ3 predictor
116#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
117#[serde(deny_unknown_fields)]
118pub enum Sz3Predictor {
119    /// Interpolation
120    #[serde(rename = "interpolation")]
121    Interpolation,
122    /// Interpolation + Lorenzo predictor
123    #[serde(rename = "interpolation-lorenzo")]
124    InterpolationLorenzo,
125    /// 1st order regression
126    #[serde(rename = "regression")]
127    Regression,
128    /// 2nd order Lorenzo predictor
129    #[serde(rename = "lorenzo2")]
130    LorenzoSecondOrder,
131    /// 2nd order Lorenzo predictor + 1st order regression
132    #[serde(rename = "lorenzo2-regression")]
133    LorenzoSecondOrderRegression,
134    /// 1st order Lorenzo predictor
135    #[serde(rename = "lorenzo")]
136    Lorenzo,
137    /// 1st order Lorenzo predictor + 1st order regression
138    #[serde(rename = "lorenzo-regression")]
139    LorenzoRegression,
140    /// 1st+2nd order Lorenzo predictor
141    #[serde(rename = "lorenzo-lorenzo2")]
142    LorenzoFirstSecondOrder,
143    /// 1st+2nd order Lorenzo predictor + 1st order regression
144    #[serde(rename = "lorenzo-lorenzo2-regression")]
145    LorenzoFirstSecondOrderRegression,
146}
147
148#[expect(clippy::unnecessary_wraps)]
149const fn default_predictor() -> Option<Sz3Predictor> {
150    Some(Sz3Predictor::InterpolationLorenzo)
151}
152
153impl Codec for Sz3Codec {
154    type Error = Sz3CodecError;
155
156    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
157        match data {
158            AnyCowArray::U8(data) => Ok(AnyArray::U8(
159                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
160                    .into_dyn(),
161            )),
162            AnyCowArray::I8(data) => Ok(AnyArray::U8(
163                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
164                    .into_dyn(),
165            )),
166            AnyCowArray::U16(data) => Ok(AnyArray::U8(
167                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
168                    .into_dyn(),
169            )),
170            AnyCowArray::I16(data) => Ok(AnyArray::U8(
171                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
172                    .into_dyn(),
173            )),
174            AnyCowArray::U32(data) => Ok(AnyArray::U8(
175                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
176                    .into_dyn(),
177            )),
178            AnyCowArray::I32(data) => Ok(AnyArray::U8(
179                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
180                    .into_dyn(),
181            )),
182            AnyCowArray::U64(data) => Ok(AnyArray::U8(
183                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
184                    .into_dyn(),
185            )),
186            AnyCowArray::I64(data) => Ok(AnyArray::U8(
187                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
188                    .into_dyn(),
189            )),
190            AnyCowArray::F32(data) => Ok(AnyArray::U8(
191                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
192                    .into_dyn(),
193            )),
194            AnyCowArray::F64(data) => Ok(AnyArray::U8(
195                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
196                    .into_dyn(),
197            )),
198            encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
199        }
200    }
201
202    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
203        let AnyCowArray::U8(encoded) = encoded else {
204            return Err(Sz3CodecError::EncodedDataNotBytes {
205                dtype: encoded.dtype(),
206            });
207        };
208
209        if !matches!(encoded.shape(), [_]) {
210            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
211                shape: encoded.shape().to_vec(),
212            });
213        }
214
215        decompress(&AnyCowArray::U8(encoded).as_bytes())
216    }
217
218    fn decode_into(
219        &self,
220        encoded: AnyArrayView,
221        decoded: AnyArrayViewMut,
222    ) -> Result<(), Self::Error> {
223        let AnyArrayView::U8(encoded) = encoded else {
224            return Err(Sz3CodecError::EncodedDataNotBytes {
225                dtype: encoded.dtype(),
226            });
227        };
228
229        if !matches!(encoded.shape(), [_]) {
230            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
231                shape: encoded.shape().to_vec(),
232            });
233        }
234
235        decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
236    }
237}
238
239impl StaticCodec for Sz3Codec {
240    const CODEC_ID: &'static str = "sz3.rs";
241
242    type Config<'de> = Self;
243
244    fn from_config(config: Self::Config<'_>) -> Self {
245        config
246    }
247
248    fn get_config(&self) -> StaticCodecConfig<'_, Self> {
249        StaticCodecConfig::from(self)
250    }
251}
252
253#[derive(Debug, Error)]
254/// Errors that may occur when applying the [`Sz3Codec`].
255pub enum Sz3CodecError {
256    /// [`Sz3Codec`] does not support the dtype
257    #[error("Sz3 does not support the dtype {0}")]
258    UnsupportedDtype(AnyArrayDType),
259    /// [`Sz3Codec`] failed to encode the header
260    #[error("Sz3 failed to encode the header")]
261    HeaderEncodeFailed {
262        /// Opaque source error
263        source: Sz3HeaderError,
264    },
265    /// [`Sz3Codec`] cannot encode an array of `shape`
266    #[error("Sz3 cannot encode an array of shape {shape:?}")]
267    InvalidEncodeShape {
268        /// Opaque source error
269        source: Sz3CodingError,
270        /// The invalid shape of the encoded array
271        shape: Vec<usize>,
272    },
273    /// [`Sz3Codec`] failed to encode the data
274    #[error("Sz3 failed to encode the data")]
275    Sz3EncodeFailed {
276        /// Opaque source error
277        source: Sz3CodingError,
278    },
279    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
280    /// an array of a different dtype
281    #[error(
282        "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
283    )]
284    EncodedDataNotBytes {
285        /// The unexpected dtype of the encoded array
286        dtype: AnyArrayDType,
287    },
288    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
289    /// an array of a different shape
290    #[error(
291        "Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
292    )]
293    EncodedDataNotOneDimensional {
294        /// The unexpected shape of the encoded array
295        shape: Vec<usize>,
296    },
297    /// [`Sz3Codec`] failed to decode the header
298    #[error("Sz3 failed to decode the header")]
299    HeaderDecodeFailed {
300        /// Opaque source error
301        source: Sz3HeaderError,
302    },
303    /// [`Sz3Codec`] failed to decode the data
304    #[error("Sz3 failed to decode the data")]
305    Sz3DecodeFailed {
306        /// Opaque source error
307        source: Sz3CodingError,
308    },
309    /// [`Sz3Codec`] decoded an invalid array shape header which does not fit
310    /// the decoded data
311    #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
312    DecodeInvalidShapeHeader {
313        /// Source error
314        #[from]
315        source: ShapeError,
316    },
317    /// [`Sz3Codec`] cannot decode into the provided array
318    #[error("Sz3 cannot decode into the provided array")]
319    MismatchedDecodeIntoArray {
320        /// The source of the error
321        #[from]
322        source: AnyArrayAssignError,
323    },
324}
325
326#[derive(Debug, Error)]
327#[error(transparent)]
328/// Opaque error for when encoding or decoding the header fails
329pub struct Sz3HeaderError(postcard::Error);
330
331#[derive(Debug, Error)]
332#[error(transparent)]
333/// Opaque error for when encoding or decoding with SZ3 fails
334pub struct Sz3CodingError(sz3::SZ3Error);
335
336#[expect(clippy::needless_pass_by_value, clippy::too_many_lines)]
337/// Compresses the input `data` array using SZ3, which consists of an optional
338/// `predictor`, an `error_bound`, an optional `encoder`, and an optional
339/// `lossless` compressor.
340///
341/// # Errors
342///
343/// Errors with
344/// - [`Sz3CodecError::HeaderEncodeFailed`] if encoding the header failed
345/// - [`Sz3CodecError::InvalidEncodeShape`] if the array shape is invalid
346/// - [`Sz3CodecError::Sz3EncodeFailed`] if encoding failed with an opaque error
347pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
348    data: ArrayBase<S, D>,
349    predictor: Option<&Sz3Predictor>,
350    error_bound: &Sz3ErrorBound,
351) -> Result<Vec<u8>, Sz3CodecError> {
352    let mut encoded_bytes = postcard::to_extend(
353        &CompressionHeader {
354            dtype: <T as Sz3Element>::DTYPE,
355            shape: Cow::Borrowed(data.shape()),
356            version: StaticCodecVersion,
357        },
358        Vec::new(),
359    )
360    .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
361        source: Sz3HeaderError(err),
362    })?;
363
364    // sz3::DimensionedDataBuilder cannot handle zero-length dimensions
365    if data.is_empty() {
366        return Ok(encoded_bytes);
367    }
368
369    #[expect(clippy::option_if_let_else)]
370    let data_cow = match data.as_slice() {
371        Some(data) => Cow::Borrowed(data),
372        None => Cow::Owned(data.iter().copied().collect()),
373    };
374    let mut builder = sz3::DimensionedData::build(&data_cow);
375
376    for length in data.shape() {
377        // Sz3 ignores dimensions of length 1 and panics on length zero
378        // Since they carry no information for Sz3 and we already encode them
379        //  in our custom header, we just skip them here
380        if *length > 1 {
381            builder = builder
382                .dim(*length)
383                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
384                    source: Sz3CodingError(err),
385                    shape: data.shape().to_vec(),
386                })?;
387        }
388    }
389
390    if data.len() == 1 {
391        // If there is only one element, all dimensions will have been skipped,
392        //  so we explicitly encode one dimension of size 1 here
393        builder = builder
394            .dim(1)
395            .map_err(|err| Sz3CodecError::InvalidEncodeShape {
396                source: Sz3CodingError(err),
397                shape: data.shape().to_vec(),
398            })?;
399    }
400
401    let data = builder
402        .finish()
403        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
404            source: Sz3CodingError(err),
405            shape: data.shape().to_vec(),
406        })?;
407
408    // configure the error bound
409    let error_bound = match error_bound {
410        Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
411            absolute_bound: *abs,
412            relative_bound: *rel,
413        },
414        Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
415            absolute_bound: *abs,
416            relative_bound: *rel,
417        },
418        Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
419        Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
420        Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
421        Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
422    };
423    let mut config = sz3::Config::new(error_bound);
424
425    // configure the predictor (compression algorithm)
426    let predictor = match predictor {
427        Some(Sz3Predictor::Interpolation) => sz3::CompressionAlgorithm::Interpolation,
428        Some(Sz3Predictor::InterpolationLorenzo) => sz3::CompressionAlgorithm::InterpolationLorenzo,
429        Some(Sz3Predictor::Regression) => sz3::CompressionAlgorithm::LorenzoRegression {
430            lorenzo: false,
431            lorenzo_second_order: false,
432            regression: true,
433        },
434        Some(Sz3Predictor::LorenzoSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
435            lorenzo: false,
436            lorenzo_second_order: true,
437            regression: false,
438        },
439        Some(Sz3Predictor::LorenzoSecondOrderRegression) => {
440            sz3::CompressionAlgorithm::LorenzoRegression {
441                lorenzo: false,
442                lorenzo_second_order: true,
443                regression: true,
444            }
445        }
446        Some(Sz3Predictor::Lorenzo) => sz3::CompressionAlgorithm::LorenzoRegression {
447            lorenzo: true,
448            lorenzo_second_order: false,
449            regression: false,
450        },
451        Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::LorenzoRegression {
452            lorenzo: true,
453            lorenzo_second_order: false,
454            regression: true,
455        },
456        Some(Sz3Predictor::LorenzoFirstSecondOrder) => {
457            sz3::CompressionAlgorithm::LorenzoRegression {
458                lorenzo: true,
459                lorenzo_second_order: true,
460                regression: false,
461            }
462        }
463        Some(Sz3Predictor::LorenzoFirstSecondOrderRegression) => {
464            sz3::CompressionAlgorithm::LorenzoRegression {
465                lorenzo: true,
466                lorenzo_second_order: true,
467                regression: true,
468            }
469        }
470        None => sz3::CompressionAlgorithm::NoPrediction,
471    };
472    config = config.compression_algorithm(predictor);
473
474    sz3::compress_into_with_config(&data, &config, &mut encoded_bytes).map_err(|err| {
475        Sz3CodecError::Sz3EncodeFailed {
476            source: Sz3CodingError(err),
477        }
478    })?;
479
480    Ok(encoded_bytes)
481}
482
483/// Decompresses the `encoded` data into an array using SZ3.
484///
485/// # Errors
486///
487/// Errors with
488/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
489/// - [`Sz3CodecError::Sz3DecodeFailed`] if decoding failed with an opaque error
490pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
491    fn decompress_typed<T: Sz3Element>(
492        encoded: &[u8],
493        shape: &[usize],
494    ) -> Result<Array<T, IxDyn>, Sz3CodecError> {
495        if shape.iter().copied().any(|s| s == 0) {
496            return Ok(Array::from_shape_vec(shape, Vec::new())?);
497        }
498
499        let (_config, decompressed) =
500            sz3::decompress(encoded).map_err(|err| Sz3CodecError::Sz3DecodeFailed {
501                source: Sz3CodingError(err),
502            })?;
503
504        Ok(Array::from_shape_vec(shape, decompressed.into_data())?)
505    }
506
507    let (header, data) =
508        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
509            Sz3CodecError::HeaderDecodeFailed {
510                source: Sz3HeaderError(err),
511            }
512        })?;
513
514    let decoded = match header.dtype {
515        Sz3DType::U8 => AnyArray::U8(decompress_typed(data, &header.shape)?),
516        Sz3DType::I8 => AnyArray::I8(decompress_typed(data, &header.shape)?),
517        Sz3DType::U16 => AnyArray::U16(decompress_typed(data, &header.shape)?),
518        Sz3DType::I16 => AnyArray::I16(decompress_typed(data, &header.shape)?),
519        Sz3DType::U32 => AnyArray::U32(decompress_typed(data, &header.shape)?),
520        Sz3DType::I32 => AnyArray::I32(decompress_typed(data, &header.shape)?),
521        Sz3DType::U64 => AnyArray::U64(decompress_typed(data, &header.shape)?),
522        Sz3DType::I64 => AnyArray::I64(decompress_typed(data, &header.shape)?),
523        Sz3DType::F32 => AnyArray::F32(decompress_typed(data, &header.shape)?),
524        Sz3DType::F64 => AnyArray::F64(decompress_typed(data, &header.shape)?),
525    };
526
527    Ok(decoded)
528}
529
530/// Decompresses the `encoded` data into a `decoded` array using SZ3.
531///
532/// # Errors
533///
534/// Errors with
535/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
536/// - [`Sz3CodecError::MismatchedDecodeIntoArray`] if the `decoded` array is of
537///   the wrong dtype or shape
538/// - [`Sz3CodecError::Sz3DecodeFailed`] if decoding failed with an opaque error
539pub fn decompress_into(encoded: &[u8], decoded: AnyArrayViewMut) -> Result<(), Sz3CodecError> {
540    fn decompress_into_typed<T: Sz3Element>(
541        encoded: &[u8],
542        mut decoded: ArrayViewMut<T, IxDyn>,
543    ) -> Result<(), Sz3CodecError> {
544        if decoded.is_empty() {
545            return Ok(());
546        }
547
548        let decoded_shape = decoded.shape().to_vec();
549
550        decoded.with_slice_mut(|mut decoded| {
551            let decoded_len = decoded.len();
552
553            let mut builder = sz3::DimensionedData::build_mut(&mut decoded);
554
555            for length in &decoded_shape {
556                // Sz3 ignores dimensions of length 1 and panics on length zero
557                // Since they carry no information for Sz3 and we already encode them
558                //  in our custom header, we just skip them here
559                if *length > 1 {
560                    builder = builder
561                        .dim(*length)
562                        // FIXME: different error code
563                        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
564                            source: Sz3CodingError(err),
565                            shape: decoded_shape.clone(),
566                        })?;
567                }
568            }
569
570            if decoded_len == 1 {
571                // If there is only one element, all dimensions will have been skipped,
572                //  so we explicitly encode one dimension of size 1 here
573                builder = builder
574                    .dim(1)
575                    // FIXME: different error code
576                    .map_err(|err| Sz3CodecError::InvalidEncodeShape {
577                        source: Sz3CodingError(err),
578                        shape: decoded_shape.clone(),
579                    })?;
580            }
581
582            let mut decoded = builder
583                .finish()
584                // FIXME: different error code
585                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
586                    source: Sz3CodingError(err),
587                    shape: decoded_shape,
588                })?;
589
590            sz3::decompress_into_dimensioned(encoded, &mut decoded).map_err(|err| {
591                Sz3CodecError::Sz3DecodeFailed {
592                    source: Sz3CodingError(err),
593                }
594            })
595        })?;
596
597        Ok(())
598    }
599
600    let (header, data) =
601        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
602            Sz3CodecError::HeaderDecodeFailed {
603                source: Sz3HeaderError(err),
604            }
605        })?;
606
607    if decoded.shape() != &*header.shape {
608        return Err(Sz3CodecError::MismatchedDecodeIntoArray {
609            source: AnyArrayAssignError::ShapeMismatch {
610                src: header.shape.into_owned(),
611                dst: decoded.shape().to_vec(),
612            },
613        });
614    }
615
616    match (decoded, header.dtype) {
617        (AnyArrayViewMut::U8(decoded), Sz3DType::U8) => decompress_into_typed(data, decoded),
618        (AnyArrayViewMut::I8(decoded), Sz3DType::I8) => decompress_into_typed(data, decoded),
619        (AnyArrayViewMut::U16(decoded), Sz3DType::U16) => decompress_into_typed(data, decoded),
620        (AnyArrayViewMut::I16(decoded), Sz3DType::I16) => decompress_into_typed(data, decoded),
621        (AnyArrayViewMut::U32(decoded), Sz3DType::U32) => decompress_into_typed(data, decoded),
622        (AnyArrayViewMut::I32(decoded), Sz3DType::I32) => decompress_into_typed(data, decoded),
623        (AnyArrayViewMut::U64(decoded), Sz3DType::U64) => decompress_into_typed(data, decoded),
624        (AnyArrayViewMut::I64(decoded), Sz3DType::I64) => decompress_into_typed(data, decoded),
625        (AnyArrayViewMut::F32(decoded), Sz3DType::F32) => decompress_into_typed(data, decoded),
626        (AnyArrayViewMut::F64(decoded), Sz3DType::F64) => decompress_into_typed(data, decoded),
627        (decoded, dtype) => Err(Sz3CodecError::MismatchedDecodeIntoArray {
628            source: AnyArrayAssignError::DTypeMismatch {
629                src: dtype.into_dtype(),
630                dst: decoded.dtype(),
631            },
632        }),
633    }
634}
635
636/// Array element types which can be compressed with SZ3.
637pub trait Sz3Element: Copy + sz3::SZ3Compressible + ArrayDType {
638    /// The dtype representation of the type
639    const DTYPE: Sz3DType;
640}
641
642impl Sz3Element for u8 {
643    const DTYPE: Sz3DType = Sz3DType::U8;
644}
645
646impl Sz3Element for i8 {
647    const DTYPE: Sz3DType = Sz3DType::I8;
648}
649
650impl Sz3Element for u16 {
651    const DTYPE: Sz3DType = Sz3DType::U16;
652}
653
654impl Sz3Element for i16 {
655    const DTYPE: Sz3DType = Sz3DType::I16;
656}
657
658impl Sz3Element for u32 {
659    const DTYPE: Sz3DType = Sz3DType::U32;
660}
661
662impl Sz3Element for i32 {
663    const DTYPE: Sz3DType = Sz3DType::I32;
664}
665
666impl Sz3Element for u64 {
667    const DTYPE: Sz3DType = Sz3DType::U64;
668}
669
670impl Sz3Element for i64 {
671    const DTYPE: Sz3DType = Sz3DType::I64;
672}
673
674impl Sz3Element for f32 {
675    const DTYPE: Sz3DType = Sz3DType::F32;
676}
677
678impl Sz3Element for f64 {
679    const DTYPE: Sz3DType = Sz3DType::F64;
680}
681
682#[derive(Serialize, Deserialize)]
683struct CompressionHeader<'a> {
684    dtype: Sz3DType,
685    #[serde(borrow)]
686    shape: Cow<'a, [usize]>,
687    version: Sz3CodecVersion,
688}
689
690/// Dtypes that SZ3 can compress and decompress
691#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
692#[expect(missing_docs)]
693pub enum Sz3DType {
694    #[serde(rename = "u8", alias = "uint8")]
695    U8,
696    #[serde(rename = "u16", alias = "uint16")]
697    U16,
698    #[serde(rename = "u32", alias = "uint32")]
699    U32,
700    #[serde(rename = "u64", alias = "uint64")]
701    U64,
702    #[serde(rename = "i8", alias = "int8")]
703    I8,
704    #[serde(rename = "i16", alias = "int16")]
705    I16,
706    #[serde(rename = "i32", alias = "int32")]
707    I32,
708    #[serde(rename = "i64", alias = "int64")]
709    I64,
710    #[serde(rename = "f32", alias = "float32")]
711    F32,
712    #[serde(rename = "f64", alias = "float64")]
713    F64,
714}
715
716impl Sz3DType {
717    /// Get the corresponding [`AnyArrayDType`]
718    #[must_use]
719    pub const fn into_dtype(self) -> AnyArrayDType {
720        match self {
721            Self::U8 => AnyArrayDType::U8,
722            Self::U16 => AnyArrayDType::U16,
723            Self::U32 => AnyArrayDType::U32,
724            Self::U64 => AnyArrayDType::U64,
725            Self::I8 => AnyArrayDType::I8,
726            Self::I16 => AnyArrayDType::I16,
727            Self::I32 => AnyArrayDType::I32,
728            Self::I64 => AnyArrayDType::I64,
729            Self::F32 => AnyArrayDType::F32,
730            Self::F64 => AnyArrayDType::F64,
731        }
732    }
733}
734
735impl fmt::Display for Sz3DType {
736    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
737        fmt.write_str(match self {
738            Self::U8 => "u8",
739            Self::U16 => "u16",
740            Self::U32 => "u32",
741            Self::U64 => "u64",
742            Self::I8 => "i8",
743            Self::I16 => "i16",
744            Self::I32 => "i32",
745            Self::I64 => "i64",
746            Self::F32 => "f32",
747            Self::F64 => "f64",
748        })
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use ndarray::ArrayView1;
755
756    use super::*;
757
758    #[test]
759    fn zero_length() -> Result<(), Sz3CodecError> {
760        let encoded = compress(
761            Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
762            default_predictor().as_ref(),
763            &Sz3ErrorBound::L2Norm { l2: 27.0 },
764        )?;
765        let decoded = decompress(&encoded)?;
766
767        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
768        assert!(decoded.is_empty());
769        assert_eq!(decoded.shape(), &[1, 27, 0]);
770
771        Ok(())
772    }
773
774    #[test]
775    fn one_dimension() -> Result<(), Sz3CodecError> {
776        let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
777
778        let encoded = compress(
779            data.view(),
780            default_predictor().as_ref(),
781            &Sz3ErrorBound::Absolute { abs: 0.1 },
782        )?;
783        let decoded = decompress(&encoded)?;
784
785        assert_eq!(decoded, AnyArray::I32(data.clone()));
786
787        let mut decoded = Array::zeros(data.dim());
788        decompress_into(&encoded, AnyArrayViewMut::I32(decoded.view_mut()))?;
789
790        assert_eq!(decoded, data);
791
792        Ok(())
793    }
794
795    #[test]
796    fn small_state() -> Result<(), Sz3CodecError> {
797        for data in [
798            &[][..],
799            &[0.0],
800            &[0.0, 1.0],
801            &[0.0, 1.0, 0.0],
802            &[0.0, 1.0, 0.0, 1.0],
803        ] {
804            let encoded = compress(
805                ArrayView1::from(data),
806                default_predictor().as_ref(),
807                &Sz3ErrorBound::Absolute { abs: 0.1 },
808            )?;
809            let decoded = decompress(&encoded)?;
810
811            assert_eq!(
812                decoded,
813                AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
814            );
815
816            let mut decoded = Array::zeros([data.len()]);
817            decompress_into(
818                &encoded,
819                AnyArrayViewMut::F64(decoded.view_mut().into_dyn()),
820            )?;
821
822            assert_eq!(decoded, Array1::from_vec(data.to_vec()));
823        }
824
825        Ok(())
826    }
827
828    #[test]
829    fn all_predictors() -> Result<(), Sz3CodecError> {
830        let data = Array::linspace(-42.0, 42.0, 85);
831
832        for predictor in [
833            None,
834            Some(Sz3Predictor::Interpolation),
835            Some(Sz3Predictor::InterpolationLorenzo),
836            Some(Sz3Predictor::Regression),
837            Some(Sz3Predictor::LorenzoSecondOrder),
838            Some(Sz3Predictor::LorenzoSecondOrderRegression),
839            Some(Sz3Predictor::Lorenzo),
840            Some(Sz3Predictor::LorenzoRegression),
841            Some(Sz3Predictor::LorenzoFirstSecondOrder),
842            Some(Sz3Predictor::LorenzoFirstSecondOrderRegression),
843        ] {
844            let encoded = compress(
845                data.view(),
846                predictor.as_ref(),
847                &Sz3ErrorBound::Absolute { abs: 0.1 },
848            )?;
849            let _decoded = decompress(&encoded)?;
850
851            let mut decoded = Array::zeros(data.dim());
852            decompress_into(
853                &encoded,
854                AnyArrayViewMut::F64(decoded.view_mut().into_dyn()),
855            )?;
856        }
857
858        Ok(())
859    }
860
861    #[test]
862    fn all_dtypes() -> Result<(), Sz3CodecError> {
863        fn compress_decompress<T: Sz3Element + num_traits::identities::Zero>(
864            iter: impl Clone + IntoIterator<Item = T>,
865            view_mut: impl for<'a> Fn(ArrayViewMut<'a, T, IxDyn>) -> AnyArrayViewMut<'a>,
866        ) -> Result<(), Sz3CodecError> {
867            let encoded = compress(
868                Array::from_iter(iter.clone()).view(),
869                default_predictor().as_ref(),
870                &Sz3ErrorBound::Absolute { abs: 2.0 },
871            )?;
872            let _decoded = decompress(&encoded)?;
873
874            let mut decoded = Array::<T, _>::zeros([iter.into_iter().count()]).into_dyn();
875            decompress_into(&encoded, view_mut(decoded.view_mut().into_dyn()))?;
876
877            Ok(())
878        }
879
880        compress_decompress(0_u8..42, |x| AnyArrayViewMut::U8(x))?;
881        compress_decompress(-42_i8..42, |x| AnyArrayViewMut::I8(x))?;
882        compress_decompress(0_u16..42, |x| AnyArrayViewMut::U16(x))?;
883        compress_decompress(-42_i16..42, |x| AnyArrayViewMut::I16(x))?;
884        compress_decompress(0_u32..42, |x| AnyArrayViewMut::U32(x))?;
885        compress_decompress(-42_i32..42, |x| AnyArrayViewMut::I32(x))?;
886        compress_decompress(0_u64..42, |x| AnyArrayViewMut::U64(x))?;
887        compress_decompress(-42_i64..42, |x| AnyArrayViewMut::I64(x))?;
888        compress_decompress((-42_i16..42).map(f32::from), |x| AnyArrayViewMut::F32(x))?;
889        compress_decompress((-42_i16..42).map(f64::from), |x| AnyArrayViewMut::F64(x))?;
890
891        Ok(())
892    }
893}