numcodecs_sz3/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.85.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sz3
10//! [crates.io]: https://crates.io/crates/numcodecs-sz3
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sz3
13//! [docs.rs]: https://docs.rs/numcodecs-sz3/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sz3
17//!
18//! SZ3 codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33// Only included to explicitly enable the `no_wasm_shim` feature for
34// sz3-sys/Sz3-sys
35use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40type Sz3CodecVersion = StaticCodecVersion<0, 1, 0>;
41
42#[derive(Clone, Serialize, Deserialize, JsonSchema)]
43// serde cannot deny unknown fields because of the flatten
44#[schemars(deny_unknown_fields)]
45/// Codec providing compression using SZ3
46pub struct Sz3Codec {
47    /// Predictor
48    #[serde(default = "default_predictor")]
49    pub predictor: Option<Sz3Predictor>,
50    /// SZ3 error bound
51    #[serde(flatten)]
52    pub error_bound: Sz3ErrorBound,
53    /// The codec's encoding format version. Do not provide this parameter explicitly.
54    #[serde(default, rename = "_version")]
55    pub version: Sz3CodecVersion,
56}
57
58/// SZ3 error bound
59#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
60#[serde(tag = "eb_mode")]
61#[serde(deny_unknown_fields)]
62pub enum Sz3ErrorBound {
63    /// Errors are bounded by *both* the absolute and relative error, i.e. by
64    /// whichever bound is stricter
65    #[serde(rename = "abs-and-rel")]
66    AbsoluteAndRelative {
67        /// Absolute error bound
68        #[serde(rename = "eb_abs")]
69        abs: f64,
70        /// Relative error bound
71        #[serde(rename = "eb_rel")]
72        rel: f64,
73    },
74    /// Errors are bounded by *either* the absolute or relative error, i.e. by
75    /// whichever bound is weaker
76    #[serde(rename = "abs-or-rel")]
77    AbsoluteOrRelative {
78        /// Absolute error bound
79        #[serde(rename = "eb_abs")]
80        abs: f64,
81        /// Relative error bound
82        #[serde(rename = "eb_rel")]
83        rel: f64,
84    },
85    /// Absolute error bound
86    #[serde(rename = "abs")]
87    Absolute {
88        /// Absolute error bound
89        #[serde(rename = "eb_abs")]
90        abs: f64,
91    },
92    /// Relative error bound
93    #[serde(rename = "rel")]
94    Relative {
95        /// Relative error bound
96        #[serde(rename = "eb_rel")]
97        rel: f64,
98    },
99    /// Peak signal to noise ratio error bound
100    #[serde(rename = "psnr")]
101    PS2NR {
102        /// Peak signal to noise ratio error bound
103        #[serde(rename = "eb_psnr")]
104        psnr: f64,
105    },
106    /// Peak L2 norm error bound
107    #[serde(rename = "l2")]
108    L2Norm {
109        /// Peak L2 norm error bound
110        #[serde(rename = "eb_l2")]
111        l2: f64,
112    },
113}
114
115/// SZ3 predictor
116#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
117#[serde(deny_unknown_fields)]
118pub enum Sz3Predictor {
119    /// Linear interpolation
120    #[serde(rename = "linear-interpolation")]
121    LinearInterpolation,
122    /// Cubic interpolation
123    #[serde(rename = "cubic-interpolation")]
124    CubicInterpolation,
125    /// Linear interpolation + Lorenzo predictor
126    #[serde(rename = "linear-interpolation-lorenzo")]
127    LinearInterpolationLorenzo,
128    /// Cubic interpolation + Lorenzo predictor
129    #[serde(rename = "cubic-interpolation-lorenzo")]
130    CubicInterpolationLorenzo,
131    /// 1st order regression
132    #[serde(rename = "regression")]
133    Regression,
134    /// 2nd order regression
135    #[serde(rename = "regression2")]
136    RegressionSecondOrder,
137    /// 1st+2nd order regression
138    #[serde(rename = "regression-regression2")]
139    RegressionFirstSecondOrder,
140    /// 2nd order Lorenzo predictor
141    #[serde(rename = "lorenzo2")]
142    LorenzoSecondOrder,
143    /// 2nd order Lorenzo predictor + 2nd order regression
144    #[serde(rename = "lorenzo2-regression2")]
145    LorenzoSecondOrderRegressionSecondOrder,
146    /// 2nd order Lorenzo predictor + 1st order regression
147    #[serde(rename = "lorenzo2-regression")]
148    LorenzoSecondOrderRegression,
149    /// 2nd order Lorenzo predictor + 1st order regression
150    #[serde(rename = "lorenzo2-regression-regression2")]
151    LorenzoSecondOrderRegressionFirstSecondOrder,
152    /// 1st order Lorenzo predictor
153    #[serde(rename = "lorenzo")]
154    Lorenzo,
155    /// 1st order Lorenzo predictor + 2nd order regression
156    #[serde(rename = "lorenzo-regression2")]
157    LorenzoRegressionSecondOrder,
158    /// 1st order Lorenzo predictor + 1st order regression
159    #[serde(rename = "lorenzo-regression")]
160    LorenzoRegression,
161    /// 1st order Lorenzo predictor + 1st and 2nd order regression
162    #[serde(rename = "lorenzo-regression-regression2")]
163    LorenzoRegressionFirstSecondOrder,
164    /// 1st+2nd order Lorenzo predictor
165    #[serde(rename = "lorenzo-lorenzo2")]
166    LorenzoFirstSecondOrder,
167    /// 1st+2nd order Lorenzo predictor + 2nd order regression
168    #[serde(rename = "lorenzo-lorenzo2-regression2")]
169    LorenzoFirstSecondOrderRegressionSecondOrder,
170    /// 1st+2nd order Lorenzo predictor + 1st order regression
171    #[serde(rename = "lorenzo-lorenzo2-regression")]
172    LorenzoFirstSecondOrderRegression,
173    /// 1st+2nd order Lorenzo predictor + 1st+2nd order regression
174    #[serde(rename = "lorenzo-lorenzo2-regression-regression2")]
175    LorenzoFirstSecondOrderRegressionFirstSecondOrder,
176}
177
178#[expect(clippy::unnecessary_wraps)]
179const fn default_predictor() -> Option<Sz3Predictor> {
180    Some(Sz3Predictor::CubicInterpolationLorenzo)
181}
182
183impl Codec for Sz3Codec {
184    type Error = Sz3CodecError;
185
186    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
187        match data {
188            AnyCowArray::I32(data) => Ok(AnyArray::U8(
189                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
190                    .into_dyn(),
191            )),
192            AnyCowArray::I64(data) => Ok(AnyArray::U8(
193                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
194                    .into_dyn(),
195            )),
196            AnyCowArray::F32(data) => Ok(AnyArray::U8(
197                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
198                    .into_dyn(),
199            )),
200            AnyCowArray::F64(data) => Ok(AnyArray::U8(
201                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
202                    .into_dyn(),
203            )),
204            encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
205        }
206    }
207
208    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
209        let AnyCowArray::U8(encoded) = encoded else {
210            return Err(Sz3CodecError::EncodedDataNotBytes {
211                dtype: encoded.dtype(),
212            });
213        };
214
215        if !matches!(encoded.shape(), [_]) {
216            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
217                shape: encoded.shape().to_vec(),
218            });
219        }
220
221        decompress(&AnyCowArray::U8(encoded).as_bytes())
222    }
223
224    fn decode_into(
225        &self,
226        encoded: AnyArrayView,
227        mut decoded: AnyArrayViewMut,
228    ) -> Result<(), Self::Error> {
229        let decoded_in = self.decode(encoded.cow())?;
230
231        Ok(decoded.assign(&decoded_in)?)
232    }
233}
234
235impl StaticCodec for Sz3Codec {
236    const CODEC_ID: &'static str = "sz3.rs";
237
238    type Config<'de> = Self;
239
240    fn from_config(config: Self::Config<'_>) -> Self {
241        config
242    }
243
244    fn get_config(&self) -> StaticCodecConfig<Self> {
245        StaticCodecConfig::from(self)
246    }
247}
248
249#[derive(Debug, Error)]
250/// Errors that may occur when applying the [`Sz3Codec`].
251pub enum Sz3CodecError {
252    /// [`Sz3Codec`] does not support the dtype
253    #[error("Sz3 does not support the dtype {0}")]
254    UnsupportedDtype(AnyArrayDType),
255    /// [`Sz3Codec`] failed to encode the header
256    #[error("Sz3 failed to encode the header")]
257    HeaderEncodeFailed {
258        /// Opaque source error
259        source: Sz3HeaderError,
260    },
261    /// [`Sz3Codec`] cannot encode an array of `shape`
262    #[error("Sz3 cannot encode an array of shape {shape:?}")]
263    InvalidEncodeShape {
264        /// Opaque source error
265        source: Sz3CodingError,
266        /// The invalid shape of the encoded array
267        shape: Vec<usize>,
268    },
269    /// [`Sz3Codec`] failed to encode the data
270    #[error("Sz3 failed to encode the data")]
271    Sz3EncodeFailed {
272        /// Opaque source error
273        source: Sz3CodingError,
274    },
275    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
276    /// an array of a different dtype
277    #[error(
278        "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
279    )]
280    EncodedDataNotBytes {
281        /// The unexpected dtype of the encoded array
282        dtype: AnyArrayDType,
283    },
284    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
285    /// an array of a different shape
286    #[error(
287        "Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
288    )]
289    EncodedDataNotOneDimensional {
290        /// The unexpected shape of the encoded array
291        shape: Vec<usize>,
292    },
293    /// [`Sz3Codec`] failed to decode the header
294    #[error("Sz3 failed to decode the header")]
295    HeaderDecodeFailed {
296        /// Opaque source error
297        source: Sz3HeaderError,
298    },
299    /// [`Sz3Codec`] decoded an invalid array shape header which does not fit
300    /// the decoded data
301    #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
302    DecodeInvalidShapeHeader {
303        /// Source error
304        #[from]
305        source: ShapeError,
306    },
307    /// [`Sz3Codec`] cannot decode into the provided array
308    #[error("Sz3 cannot decode into the provided array")]
309    MismatchedDecodeIntoArray {
310        /// The source of the error
311        #[from]
312        source: AnyArrayAssignError,
313    },
314}
315
316#[derive(Debug, Error)]
317#[error(transparent)]
318/// Opaque error for when encoding or decoding the header fails
319pub struct Sz3HeaderError(postcard::Error);
320
321#[derive(Debug, Error)]
322#[error(transparent)]
323/// Opaque error for when encoding or decoding with SZ3 fails
324pub struct Sz3CodingError(sz3::SZ3Error);
325
326#[expect(clippy::needless_pass_by_value, clippy::too_many_lines)]
327/// Compresses the input `data` array using SZ3, which consists of an optional
328/// `predictor`, an `error_bound`, an optional `encoder`, and an optional
329/// `lossless` compressor.
330///
331/// # Errors
332///
333/// Errors with
334/// - [`Sz3CodecError::HeaderEncodeFailed`] if encoding the header failed
335/// - [`Sz3CodecError::InvalidEncodeShape`] if the array shape is invalid
336/// - [`Sz3CodecError::Sz3EncodeFailed`] if encoding failed with an opaque error
337pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
338    data: ArrayBase<S, D>,
339    predictor: Option<&Sz3Predictor>,
340    error_bound: &Sz3ErrorBound,
341) -> Result<Vec<u8>, Sz3CodecError> {
342    let mut encoded_bytes = postcard::to_extend(
343        &CompressionHeader {
344            dtype: <T as Sz3Element>::DTYPE,
345            shape: Cow::Borrowed(data.shape()),
346            version: StaticCodecVersion,
347        },
348        Vec::new(),
349    )
350    .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
351        source: Sz3HeaderError(err),
352    })?;
353
354    // sz3::DimensionedDataBuilder cannot handle zero-length dimensions
355    if data.is_empty() {
356        return Ok(encoded_bytes);
357    }
358
359    #[expect(clippy::option_if_let_else)]
360    let data_cow = match data.as_slice() {
361        Some(data) => Cow::Borrowed(data),
362        None => Cow::Owned(data.iter().copied().collect()),
363    };
364    let mut builder = sz3::DimensionedData::build(&data_cow);
365
366    for length in data.shape() {
367        // Sz3 ignores dimensions of length 1 and panics on length zero
368        // Since they carry no information for Sz3 and we already encode them
369        //  in our custom header, we just skip them here
370        if *length > 1 {
371            builder = builder
372                .dim(*length)
373                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
374                    source: Sz3CodingError(err),
375                    shape: data.shape().to_vec(),
376                })?;
377        }
378    }
379
380    if data.len() == 1 {
381        // If there is only one element, all dimensions will have been skipped,
382        //  so we explicitly encode one dimension of size 1 here
383        builder = builder
384            .dim(1)
385            .map_err(|err| Sz3CodecError::InvalidEncodeShape {
386                source: Sz3CodingError(err),
387                shape: data.shape().to_vec(),
388            })?;
389    }
390
391    let data = builder
392        .finish()
393        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
394            source: Sz3CodingError(err),
395            shape: data.shape().to_vec(),
396        })?;
397
398    // configure the error bound
399    let error_bound = match error_bound {
400        Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
401            absolute_bound: *abs,
402            relative_bound: *rel,
403        },
404        Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
405            absolute_bound: *abs,
406            relative_bound: *rel,
407        },
408        Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
409        Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
410        Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
411        Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
412    };
413    let mut config = sz3::Config::new(error_bound);
414
415    // configure the interpolation mode, if necessary
416    let interpolation = match predictor {
417        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::LinearInterpolationLorenzo) => {
418            Some(sz3::InterpolationAlgorithm::Linear)
419        }
420        Some(Sz3Predictor::CubicInterpolation | Sz3Predictor::CubicInterpolationLorenzo) => {
421            Some(sz3::InterpolationAlgorithm::Cubic)
422        }
423        Some(
424            Sz3Predictor::Regression
425            | Sz3Predictor::RegressionSecondOrder
426            | Sz3Predictor::RegressionFirstSecondOrder
427            | Sz3Predictor::LorenzoSecondOrder
428            | Sz3Predictor::LorenzoSecondOrderRegressionSecondOrder
429            | Sz3Predictor::LorenzoSecondOrderRegression
430            | Sz3Predictor::LorenzoSecondOrderRegressionFirstSecondOrder
431            | Sz3Predictor::Lorenzo
432            | Sz3Predictor::LorenzoRegressionSecondOrder
433            | Sz3Predictor::LorenzoRegression
434            | Sz3Predictor::LorenzoRegressionFirstSecondOrder
435            | Sz3Predictor::LorenzoFirstSecondOrder
436            | Sz3Predictor::LorenzoFirstSecondOrderRegressionSecondOrder
437            | Sz3Predictor::LorenzoFirstSecondOrderRegression
438            | Sz3Predictor::LorenzoFirstSecondOrderRegressionFirstSecondOrder,
439        )
440        | None => None,
441    };
442    if let Some(interpolation) = interpolation {
443        config = config.interpolation_algorithm(interpolation);
444    }
445
446    // configure the predictor (compression algorithm)
447    let predictor = match predictor {
448        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::CubicInterpolation) => {
449            sz3::CompressionAlgorithm::Interpolation
450        }
451        Some(
452            Sz3Predictor::LinearInterpolationLorenzo | Sz3Predictor::CubicInterpolationLorenzo,
453        ) => sz3::CompressionAlgorithm::InterpolationLorenzo,
454        Some(Sz3Predictor::RegressionSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
455            lorenzo: false,
456            lorenzo_second_order: false,
457            regression: false,
458            regression_second_order: true,
459            prediction_dimension: None,
460        },
461        Some(Sz3Predictor::Regression) => sz3::CompressionAlgorithm::LorenzoRegression {
462            lorenzo: false,
463            lorenzo_second_order: false,
464            regression: true,
465            regression_second_order: false,
466            prediction_dimension: None,
467        },
468        Some(Sz3Predictor::RegressionFirstSecondOrder) => {
469            sz3::CompressionAlgorithm::LorenzoRegression {
470                lorenzo: false,
471                lorenzo_second_order: false,
472                regression: true,
473                regression_second_order: true,
474                prediction_dimension: None,
475            }
476        }
477        Some(Sz3Predictor::LorenzoSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
478            lorenzo: false,
479            lorenzo_second_order: true,
480            regression: false,
481            regression_second_order: false,
482            prediction_dimension: None,
483        },
484        Some(Sz3Predictor::LorenzoSecondOrderRegressionSecondOrder) => {
485            sz3::CompressionAlgorithm::LorenzoRegression {
486                lorenzo: false,
487                lorenzo_second_order: true,
488                regression: false,
489                regression_second_order: true,
490                prediction_dimension: None,
491            }
492        }
493        Some(Sz3Predictor::LorenzoSecondOrderRegression) => {
494            sz3::CompressionAlgorithm::LorenzoRegression {
495                lorenzo: false,
496                lorenzo_second_order: true,
497                regression: true,
498                regression_second_order: false,
499                prediction_dimension: None,
500            }
501        }
502        Some(Sz3Predictor::LorenzoSecondOrderRegressionFirstSecondOrder) => {
503            sz3::CompressionAlgorithm::LorenzoRegression {
504                lorenzo: false,
505                lorenzo_second_order: true,
506                regression: true,
507                regression_second_order: true,
508                prediction_dimension: None,
509            }
510        }
511        Some(Sz3Predictor::Lorenzo) => sz3::CompressionAlgorithm::LorenzoRegression {
512            lorenzo: true,
513            lorenzo_second_order: false,
514            regression: false,
515            regression_second_order: false,
516            prediction_dimension: None,
517        },
518        Some(Sz3Predictor::LorenzoRegressionSecondOrder) => {
519            sz3::CompressionAlgorithm::LorenzoRegression {
520                lorenzo: true,
521                lorenzo_second_order: false,
522                regression: false,
523                regression_second_order: true,
524                prediction_dimension: None,
525            }
526        }
527        Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::LorenzoRegression {
528            lorenzo: true,
529            lorenzo_second_order: false,
530            regression: true,
531            regression_second_order: false,
532            prediction_dimension: None,
533        },
534        Some(Sz3Predictor::LorenzoRegressionFirstSecondOrder) => {
535            sz3::CompressionAlgorithm::LorenzoRegression {
536                lorenzo: true,
537                lorenzo_second_order: false,
538                regression: true,
539                regression_second_order: true,
540                prediction_dimension: None,
541            }
542        }
543        Some(Sz3Predictor::LorenzoFirstSecondOrder) => {
544            sz3::CompressionAlgorithm::LorenzoRegression {
545                lorenzo: true,
546                lorenzo_second_order: true,
547                regression: false,
548                regression_second_order: false,
549                prediction_dimension: None,
550            }
551        }
552        Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionSecondOrder) => {
553            sz3::CompressionAlgorithm::LorenzoRegression {
554                lorenzo: true,
555                lorenzo_second_order: true,
556                regression: false,
557                regression_second_order: true,
558                prediction_dimension: None,
559            }
560        }
561        Some(Sz3Predictor::LorenzoFirstSecondOrderRegression) => {
562            sz3::CompressionAlgorithm::LorenzoRegression {
563                lorenzo: true,
564                lorenzo_second_order: true,
565                regression: true,
566                regression_second_order: false,
567                prediction_dimension: None,
568            }
569        }
570        Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionFirstSecondOrder) => {
571            sz3::CompressionAlgorithm::LorenzoRegression {
572                lorenzo: true,
573                lorenzo_second_order: true,
574                regression: true,
575                regression_second_order: true,
576                prediction_dimension: None,
577            }
578        }
579        None => sz3::CompressionAlgorithm::NoPrediction,
580    };
581    config = config.compression_algorithm(predictor);
582
583    // TODO: avoid extra allocation here
584    let compressed = sz3::compress_with_config(&data, &config).map_err(|err| {
585        Sz3CodecError::Sz3EncodeFailed {
586            source: Sz3CodingError(err),
587        }
588    })?;
589    encoded_bytes.extend_from_slice(&compressed);
590
591    Ok(encoded_bytes)
592}
593
594/// Decompresses the `encoded` data into an array.
595///
596/// # Errors
597///
598/// Errors with
599/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
600pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
601    let (header, data) =
602        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
603            Sz3CodecError::HeaderDecodeFailed {
604                source: Sz3HeaderError(err),
605            }
606        })?;
607
608    let decoded = if header.shape.iter().copied().product::<usize>() == 0 {
609        match header.dtype {
610            Sz3DType::I32 => {
611                AnyArray::I32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
612            }
613            Sz3DType::I64 => {
614                AnyArray::I64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
615            }
616            Sz3DType::F32 => {
617                AnyArray::F32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
618            }
619            Sz3DType::F64 => {
620                AnyArray::F64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
621            }
622        }
623    } else {
624        // TODO: avoid extra allocation here
625        match header.dtype {
626            Sz3DType::I32 => AnyArray::I32(Array::from_shape_vec(
627                &*header.shape,
628                Vec::from(sz3::decompress(data).1.data()),
629            )?),
630            Sz3DType::I64 => AnyArray::I64(Array::from_shape_vec(
631                &*header.shape,
632                Vec::from(sz3::decompress(data).1.data()),
633            )?),
634            Sz3DType::F32 => AnyArray::F32(Array::from_shape_vec(
635                &*header.shape,
636                Vec::from(sz3::decompress(data).1.data()),
637            )?),
638            Sz3DType::F64 => AnyArray::F64(Array::from_shape_vec(
639                &*header.shape,
640                Vec::from(sz3::decompress(data).1.data()),
641            )?),
642        }
643    };
644
645    Ok(decoded)
646}
647
648/// Array element types which can be compressed with SZ3.
649pub trait Sz3Element: Copy + sz3::SZ3Compressible {
650    /// The dtype representation of the type
651    const DTYPE: Sz3DType;
652}
653
654impl Sz3Element for i32 {
655    const DTYPE: Sz3DType = Sz3DType::I32;
656}
657
658impl Sz3Element for i64 {
659    const DTYPE: Sz3DType = Sz3DType::I64;
660}
661
662impl Sz3Element for f32 {
663    const DTYPE: Sz3DType = Sz3DType::F32;
664}
665
666impl Sz3Element for f64 {
667    const DTYPE: Sz3DType = Sz3DType::F64;
668}
669
670#[derive(Serialize, Deserialize)]
671struct CompressionHeader<'a> {
672    dtype: Sz3DType,
673    #[serde(borrow)]
674    shape: Cow<'a, [usize]>,
675    version: Sz3CodecVersion,
676}
677
678/// Dtypes that SZ3 can compress and decompress
679#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
680#[expect(missing_docs)]
681pub enum Sz3DType {
682    #[serde(rename = "i32", alias = "int32")]
683    I32,
684    #[serde(rename = "i64", alias = "int64")]
685    I64,
686    #[serde(rename = "f32", alias = "float32")]
687    F32,
688    #[serde(rename = "f64", alias = "float64")]
689    F64,
690}
691
692impl fmt::Display for Sz3DType {
693    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
694        fmt.write_str(match self {
695            Self::I32 => "i32",
696            Self::I64 => "i64",
697            Self::F32 => "f32",
698            Self::F64 => "f64",
699        })
700    }
701}
702
703#[cfg(test)]
704mod tests {
705    use ndarray::ArrayView1;
706
707    use super::*;
708
709    #[test]
710    fn zero_length() -> Result<(), Sz3CodecError> {
711        let encoded = compress(
712            Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
713            default_predictor().as_ref(),
714            &Sz3ErrorBound::L2Norm { l2: 27.0 },
715        )?;
716        let decoded = decompress(&encoded)?;
717
718        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
719        assert!(decoded.is_empty());
720        assert_eq!(decoded.shape(), &[1, 27, 0]);
721
722        Ok(())
723    }
724
725    #[test]
726    fn one_dimension() -> Result<(), Sz3CodecError> {
727        let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
728
729        let encoded = compress(
730            data.view(),
731            default_predictor().as_ref(),
732            &Sz3ErrorBound::Absolute { abs: 0.1 },
733        )?;
734        let decoded = decompress(&encoded)?;
735
736        assert_eq!(decoded, AnyArray::I32(data));
737
738        Ok(())
739    }
740
741    #[test]
742    fn small_state() -> Result<(), Sz3CodecError> {
743        for data in [
744            &[][..],
745            &[0.0],
746            &[0.0, 1.0],
747            &[0.0, 1.0, 0.0],
748            &[0.0, 1.0, 0.0, 1.0],
749        ] {
750            let encoded = compress(
751                ArrayView1::from(data),
752                default_predictor().as_ref(),
753                &Sz3ErrorBound::Absolute { abs: 0.1 },
754            )?;
755            let decoded = decompress(&encoded)?;
756
757            assert_eq!(
758                decoded,
759                AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
760            );
761        }
762
763        Ok(())
764    }
765
766    #[test]
767    fn all_predictors() -> Result<(), Sz3CodecError> {
768        let data = Array::linspace(-42.0, 42.0, 85);
769
770        for predictor in [
771            None,
772            Some(Sz3Predictor::Regression),
773            Some(Sz3Predictor::RegressionSecondOrder),
774            Some(Sz3Predictor::RegressionFirstSecondOrder),
775            Some(Sz3Predictor::LorenzoSecondOrder),
776            Some(Sz3Predictor::LorenzoSecondOrderRegressionSecondOrder),
777            Some(Sz3Predictor::LorenzoSecondOrderRegression),
778            Some(Sz3Predictor::LorenzoSecondOrderRegressionFirstSecondOrder),
779            Some(Sz3Predictor::Lorenzo),
780            Some(Sz3Predictor::LorenzoRegressionSecondOrder),
781            Some(Sz3Predictor::LorenzoRegression),
782            Some(Sz3Predictor::LorenzoRegressionFirstSecondOrder),
783            Some(Sz3Predictor::LorenzoFirstSecondOrder),
784            Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionSecondOrder),
785            Some(Sz3Predictor::LorenzoFirstSecondOrderRegression),
786            Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionFirstSecondOrder),
787        ] {
788            let encoded = compress(
789                data.view(),
790                predictor.as_ref(),
791                &Sz3ErrorBound::Absolute { abs: 0.1 },
792            )?;
793            let _decoded = decompress(&encoded)?;
794        }
795
796        Ok(())
797    }
798}