numcodecs_sz3/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sz3
10//! [crates.io]: https://crates.io/crates/numcodecs-sz3
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sz3
13//! [docs.rs]: https://docs.rs/numcodecs-sz3/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sz3
17//!
18//! SZ3 codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33// Only included to explicitly enable the `no_wasm_shim` feature for
34// sz3-sys/Sz3-sys
35use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40type Sz3CodecVersion = StaticCodecVersion<0, 1, 0>;
41
42#[derive(Clone, Serialize, Deserialize, JsonSchema)]
43// serde cannot deny unknown fields because of the flatten
44#[schemars(deny_unknown_fields)]
45/// Codec providing compression using SZ3
46pub struct Sz3Codec {
47    /// Predictor
48    #[serde(default = "default_predictor")]
49    pub predictor: Option<Sz3Predictor>,
50    /// SZ3 error bound
51    #[serde(flatten)]
52    pub error_bound: Sz3ErrorBound,
53    /// The codec's encoding format version. Do not provide this parameter explicitly.
54    #[serde(default, rename = "_version")]
55    pub version: Sz3CodecVersion,
56}
57
58/// SZ3 error bound
59#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
60#[serde(tag = "eb_mode")]
61#[serde(deny_unknown_fields)]
62pub enum Sz3ErrorBound {
63    /// Errors are bounded by *both* the absolute and relative error, i.e. by
64    /// whichever bound is stricter
65    #[serde(rename = "abs-and-rel")]
66    AbsoluteAndRelative {
67        /// Absolute error bound
68        #[serde(rename = "eb_abs")]
69        abs: f64,
70        /// Relative error bound
71        #[serde(rename = "eb_rel")]
72        rel: f64,
73    },
74    /// Errors are bounded by *either* the absolute or relative error, i.e. by
75    /// whichever bound is weaker
76    #[serde(rename = "abs-or-rel")]
77    AbsoluteOrRelative {
78        /// Absolute error bound
79        #[serde(rename = "eb_abs")]
80        abs: f64,
81        /// Relative error bound
82        #[serde(rename = "eb_rel")]
83        rel: f64,
84    },
85    /// Absolute error bound
86    #[serde(rename = "abs")]
87    Absolute {
88        /// Absolute error bound
89        #[serde(rename = "eb_abs")]
90        abs: f64,
91    },
92    /// Relative error bound
93    #[serde(rename = "rel")]
94    Relative {
95        /// Relative error bound
96        #[serde(rename = "eb_rel")]
97        rel: f64,
98    },
99    /// Peak signal to noise ratio error bound
100    #[serde(rename = "psnr")]
101    PS2NR {
102        /// Peak signal to noise ratio error bound
103        #[serde(rename = "eb_psnr")]
104        psnr: f64,
105    },
106    /// Peak L2 norm error bound
107    #[serde(rename = "l2")]
108    L2Norm {
109        /// Peak L2 norm error bound
110        #[serde(rename = "eb_l2")]
111        l2: f64,
112    },
113}
114
115/// SZ3 predictor
116#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
117#[serde(deny_unknown_fields)]
118pub enum Sz3Predictor {
119    /// Linear interpolation
120    #[serde(rename = "linear-interpolation")]
121    LinearInterpolation,
122    /// Cubic interpolation
123    #[serde(rename = "cubic-interpolation")]
124    CubicInterpolation,
125    /// Linear interpolation + Lorenzo predictor
126    #[serde(rename = "linear-interpolation-lorenzo")]
127    LinearInterpolationLorenzo,
128    /// Cubic interpolation + Lorenzo predictor
129    #[serde(rename = "cubic-interpolation-lorenzo")]
130    CubicInterpolationLorenzo,
131    /// 1st order regression
132    #[serde(rename = "regression")]
133    Regression,
134    /// 2nd order regression
135    #[serde(rename = "regression2")]
136    RegressionSecondOrder,
137    /// 1st+2nd order regression
138    #[serde(rename = "regression-regression2")]
139    RegressionFirstSecondOrder,
140    /// 2nd order Lorenzo predictor
141    #[serde(rename = "lorenzo2")]
142    LorenzoSecondOrder,
143    /// 2nd order Lorenzo predictor + 2nd order regression
144    #[serde(rename = "lorenzo2-regression2")]
145    LorenzoSecondOrderRegressionSecondOrder,
146    /// 2nd order Lorenzo predictor + 1st order regression
147    #[serde(rename = "lorenzo2-regression")]
148    LorenzoSecondOrderRegression,
149    /// 2nd order Lorenzo predictor + 1st order regression
150    #[serde(rename = "lorenzo2-regression-regression2")]
151    LorenzoSecondOrderRegressionFirstSecondOrder,
152    /// 1st order Lorenzo predictor
153    #[serde(rename = "lorenzo")]
154    Lorenzo,
155    /// 1st order Lorenzo predictor + 2nd order regression
156    #[serde(rename = "lorenzo-regression2")]
157    LorenzoRegressionSecondOrder,
158    /// 1st order Lorenzo predictor + 1st order regression
159    #[serde(rename = "lorenzo-regression")]
160    LorenzoRegression,
161    /// 1st order Lorenzo predictor + 1st and 2nd order regression
162    #[serde(rename = "lorenzo-regression-regression2")]
163    LorenzoRegressionFirstSecondOrder,
164    /// 1st+2nd order Lorenzo predictor
165    #[serde(rename = "lorenzo-lorenzo2")]
166    LorenzoFirstSecondOrder,
167    /// 1st+2nd order Lorenzo predictor + 2nd order regression
168    #[serde(rename = "lorenzo-lorenzo2-regression2")]
169    LorenzoFirstSecondOrderRegressionSecondOrder,
170    /// 1st+2nd order Lorenzo predictor + 1st order regression
171    #[serde(rename = "lorenzo-lorenzo2-regression")]
172    LorenzoFirstSecondOrderRegression,
173    /// 1st+2nd order Lorenzo predictor + 1st+2nd order regression
174    #[serde(rename = "lorenzo-lorenzo2-regression-regression2")]
175    LorenzoFirstSecondOrderRegressionFirstSecondOrder,
176}
177
178#[expect(clippy::unnecessary_wraps)]
179const fn default_predictor() -> Option<Sz3Predictor> {
180    Some(Sz3Predictor::CubicInterpolationLorenzo)
181}
182
183impl Codec for Sz3Codec {
184    type Error = Sz3CodecError;
185
186    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
187        match data {
188            AnyCowArray::I32(data) => Ok(AnyArray::U8(
189                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
190                    .into_dyn(),
191            )),
192            AnyCowArray::I64(data) => Ok(AnyArray::U8(
193                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
194                    .into_dyn(),
195            )),
196            AnyCowArray::F32(data) => Ok(AnyArray::U8(
197                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
198                    .into_dyn(),
199            )),
200            AnyCowArray::F64(data) => Ok(AnyArray::U8(
201                Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
202                    .into_dyn(),
203            )),
204            encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
205        }
206    }
207
208    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
209        let AnyCowArray::U8(encoded) = encoded else {
210            return Err(Sz3CodecError::EncodedDataNotBytes {
211                dtype: encoded.dtype(),
212            });
213        };
214
215        if !matches!(encoded.shape(), [_]) {
216            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
217                shape: encoded.shape().to_vec(),
218            });
219        }
220
221        decompress(&AnyCowArray::U8(encoded).as_bytes())
222    }
223
224    fn decode_into(
225        &self,
226        encoded: AnyArrayView,
227        mut decoded: AnyArrayViewMut,
228    ) -> Result<(), Self::Error> {
229        let decoded_in = self.decode(encoded.cow())?;
230
231        Ok(decoded.assign(&decoded_in)?)
232    }
233}
234
235impl StaticCodec for Sz3Codec {
236    const CODEC_ID: &'static str = "sz3.rs";
237
238    type Config<'de> = Self;
239
240    fn from_config(config: Self::Config<'_>) -> Self {
241        config
242    }
243
244    fn get_config(&self) -> StaticCodecConfig<Self> {
245        StaticCodecConfig::from(self)
246    }
247}
248
249#[derive(Debug, Error)]
250/// Errors that may occur when applying the [`Sz3Codec`].
251pub enum Sz3CodecError {
252    /// [`Sz3Codec`] does not support the dtype
253    #[error("Sz3 does not support the dtype {0}")]
254    UnsupportedDtype(AnyArrayDType),
255    /// [`Sz3Codec`] failed to encode the header
256    #[error("Sz3 failed to encode the header")]
257    HeaderEncodeFailed {
258        /// Opaque source error
259        source: Sz3HeaderError,
260    },
261    /// [`Sz3Codec`] cannot encode an array of `shape`
262    #[error("Sz3 cannot encode an array of shape {shape:?}")]
263    InvalidEncodeShape {
264        /// Opaque source error
265        source: Sz3CodingError,
266        /// The invalid shape of the encoded array
267        shape: Vec<usize>,
268    },
269    /// [`Sz3Codec`] failed to encode the data
270    #[error("Sz3 failed to encode the data")]
271    Sz3EncodeFailed {
272        /// Opaque source error
273        source: Sz3CodingError,
274    },
275    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
276    /// an array of a different dtype
277    #[error(
278        "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
279    )]
280    EncodedDataNotBytes {
281        /// The unexpected dtype of the encoded array
282        dtype: AnyArrayDType,
283    },
284    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
285    /// an array of a different shape
286    #[error("Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
287    EncodedDataNotOneDimensional {
288        /// The unexpected shape of the encoded array
289        shape: Vec<usize>,
290    },
291    /// [`Sz3Codec`] failed to decode the header
292    #[error("Sz3 failed to decode the header")]
293    HeaderDecodeFailed {
294        /// Opaque source error
295        source: Sz3HeaderError,
296    },
297    /// [`Sz3Codec`] decoded an invalid array shape header which does not fit
298    /// the decoded data
299    #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
300    DecodeInvalidShapeHeader {
301        /// Source error
302        #[from]
303        source: ShapeError,
304    },
305    /// [`Sz3Codec`] cannot decode into the provided array
306    #[error("Sz3 cannot decode into the provided array")]
307    MismatchedDecodeIntoArray {
308        /// The source of the error
309        #[from]
310        source: AnyArrayAssignError,
311    },
312}
313
314#[derive(Debug, Error)]
315#[error(transparent)]
316/// Opaque error for when encoding or decoding the header fails
317pub struct Sz3HeaderError(postcard::Error);
318
319#[derive(Debug, Error)]
320#[error(transparent)]
321/// Opaque error for when encoding or decoding with SZ3 fails
322pub struct Sz3CodingError(sz3::SZ3Error);
323
324#[expect(clippy::needless_pass_by_value, clippy::too_many_lines)]
325/// Compresses the input `data` array using SZ3, which consists of an optional
326/// `predictor`, an `error_bound`, an optional `encoder`, and an optional
327/// `lossless` compressor.
328///
329/// # Errors
330///
331/// Errors with
332/// - [`Sz3CodecError::HeaderEncodeFailed`] if encoding the header failed
333/// - [`Sz3CodecError::InvalidEncodeShape`] if the array shape is invalid
334/// - [`Sz3CodecError::Sz3EncodeFailed`] if encoding failed with an opaque error
335pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
336    data: ArrayBase<S, D>,
337    predictor: Option<&Sz3Predictor>,
338    error_bound: &Sz3ErrorBound,
339) -> Result<Vec<u8>, Sz3CodecError> {
340    let mut encoded_bytes = postcard::to_extend(
341        &CompressionHeader {
342            dtype: <T as Sz3Element>::DTYPE,
343            shape: Cow::Borrowed(data.shape()),
344            version: StaticCodecVersion,
345        },
346        Vec::new(),
347    )
348    .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
349        source: Sz3HeaderError(err),
350    })?;
351
352    // sz3::DimensionedDataBuilder cannot handle zero-length dimensions
353    if data.is_empty() {
354        return Ok(encoded_bytes);
355    }
356
357    #[expect(clippy::option_if_let_else)]
358    let data_cow = if let Some(data) = data.as_slice() {
359        Cow::Borrowed(data)
360    } else {
361        Cow::Owned(data.iter().copied().collect())
362    };
363    let mut builder = sz3::DimensionedData::build(&data_cow);
364
365    for length in data.shape() {
366        // Sz3 ignores dimensions of length 1 and panics on length zero
367        // Since they carry no information for Sz3 and we already encode them
368        //  in our custom header, we just skip them here
369        if *length > 1 {
370            builder = builder
371                .dim(*length)
372                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
373                    source: Sz3CodingError(err),
374                    shape: data.shape().to_vec(),
375                })?;
376        }
377    }
378
379    if data.len() == 1 {
380        // If there is only one element, all dimensions will have been skipped,
381        //  so we explicitly encode one dimension of size 1 here
382        builder = builder
383            .dim(1)
384            .map_err(|err| Sz3CodecError::InvalidEncodeShape {
385                source: Sz3CodingError(err),
386                shape: data.shape().to_vec(),
387            })?;
388    }
389
390    let data = builder
391        .finish()
392        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
393            source: Sz3CodingError(err),
394            shape: data.shape().to_vec(),
395        })?;
396
397    // configure the error bound
398    let error_bound = match error_bound {
399        Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
400            absolute_bound: *abs,
401            relative_bound: *rel,
402        },
403        Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
404            absolute_bound: *abs,
405            relative_bound: *rel,
406        },
407        Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
408        Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
409        Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
410        Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
411    };
412    let mut config = sz3::Config::new(error_bound);
413
414    // configure the interpolation mode, if necessary
415    let interpolation = match predictor {
416        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::LinearInterpolationLorenzo) => {
417            Some(sz3::InterpolationAlgorithm::Linear)
418        }
419        Some(Sz3Predictor::CubicInterpolation | Sz3Predictor::CubicInterpolationLorenzo) => {
420            Some(sz3::InterpolationAlgorithm::Cubic)
421        }
422        Some(
423            Sz3Predictor::Regression
424            | Sz3Predictor::RegressionSecondOrder
425            | Sz3Predictor::RegressionFirstSecondOrder
426            | Sz3Predictor::LorenzoSecondOrder
427            | Sz3Predictor::LorenzoSecondOrderRegressionSecondOrder
428            | Sz3Predictor::LorenzoSecondOrderRegression
429            | Sz3Predictor::LorenzoSecondOrderRegressionFirstSecondOrder
430            | Sz3Predictor::Lorenzo
431            | Sz3Predictor::LorenzoRegressionSecondOrder
432            | Sz3Predictor::LorenzoRegression
433            | Sz3Predictor::LorenzoRegressionFirstSecondOrder
434            | Sz3Predictor::LorenzoFirstSecondOrder
435            | Sz3Predictor::LorenzoFirstSecondOrderRegressionSecondOrder
436            | Sz3Predictor::LorenzoFirstSecondOrderRegression
437            | Sz3Predictor::LorenzoFirstSecondOrderRegressionFirstSecondOrder,
438        )
439        | None => None,
440    };
441    if let Some(interpolation) = interpolation {
442        config = config.interpolation_algorithm(interpolation);
443    }
444
445    // configure the predictor (compression algorithm)
446    let predictor = match predictor {
447        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::CubicInterpolation) => {
448            sz3::CompressionAlgorithm::Interpolation
449        }
450        Some(
451            Sz3Predictor::LinearInterpolationLorenzo | Sz3Predictor::CubicInterpolationLorenzo,
452        ) => sz3::CompressionAlgorithm::InterpolationLorenzo,
453        Some(Sz3Predictor::RegressionSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
454            lorenzo: false,
455            lorenzo_second_order: false,
456            regression: false,
457            regression_second_order: true,
458            prediction_dimension: None,
459        },
460        Some(Sz3Predictor::Regression) => sz3::CompressionAlgorithm::LorenzoRegression {
461            lorenzo: false,
462            lorenzo_second_order: false,
463            regression: true,
464            regression_second_order: false,
465            prediction_dimension: None,
466        },
467        Some(Sz3Predictor::RegressionFirstSecondOrder) => {
468            sz3::CompressionAlgorithm::LorenzoRegression {
469                lorenzo: false,
470                lorenzo_second_order: false,
471                regression: true,
472                regression_second_order: true,
473                prediction_dimension: None,
474            }
475        }
476        Some(Sz3Predictor::LorenzoSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
477            lorenzo: false,
478            lorenzo_second_order: true,
479            regression: false,
480            regression_second_order: false,
481            prediction_dimension: None,
482        },
483        Some(Sz3Predictor::LorenzoSecondOrderRegressionSecondOrder) => {
484            sz3::CompressionAlgorithm::LorenzoRegression {
485                lorenzo: false,
486                lorenzo_second_order: true,
487                regression: false,
488                regression_second_order: true,
489                prediction_dimension: None,
490            }
491        }
492        Some(Sz3Predictor::LorenzoSecondOrderRegression) => {
493            sz3::CompressionAlgorithm::LorenzoRegression {
494                lorenzo: false,
495                lorenzo_second_order: true,
496                regression: true,
497                regression_second_order: false,
498                prediction_dimension: None,
499            }
500        }
501        Some(Sz3Predictor::LorenzoSecondOrderRegressionFirstSecondOrder) => {
502            sz3::CompressionAlgorithm::LorenzoRegression {
503                lorenzo: false,
504                lorenzo_second_order: true,
505                regression: true,
506                regression_second_order: true,
507                prediction_dimension: None,
508            }
509        }
510        Some(Sz3Predictor::Lorenzo) => sz3::CompressionAlgorithm::LorenzoRegression {
511            lorenzo: true,
512            lorenzo_second_order: false,
513            regression: false,
514            regression_second_order: false,
515            prediction_dimension: None,
516        },
517        Some(Sz3Predictor::LorenzoRegressionSecondOrder) => {
518            sz3::CompressionAlgorithm::LorenzoRegression {
519                lorenzo: true,
520                lorenzo_second_order: false,
521                regression: false,
522                regression_second_order: true,
523                prediction_dimension: None,
524            }
525        }
526        Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::LorenzoRegression {
527            lorenzo: true,
528            lorenzo_second_order: false,
529            regression: true,
530            regression_second_order: false,
531            prediction_dimension: None,
532        },
533        Some(Sz3Predictor::LorenzoRegressionFirstSecondOrder) => {
534            sz3::CompressionAlgorithm::LorenzoRegression {
535                lorenzo: true,
536                lorenzo_second_order: false,
537                regression: true,
538                regression_second_order: true,
539                prediction_dimension: None,
540            }
541        }
542        Some(Sz3Predictor::LorenzoFirstSecondOrder) => {
543            sz3::CompressionAlgorithm::LorenzoRegression {
544                lorenzo: true,
545                lorenzo_second_order: true,
546                regression: false,
547                regression_second_order: false,
548                prediction_dimension: None,
549            }
550        }
551        Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionSecondOrder) => {
552            sz3::CompressionAlgorithm::LorenzoRegression {
553                lorenzo: true,
554                lorenzo_second_order: true,
555                regression: false,
556                regression_second_order: true,
557                prediction_dimension: None,
558            }
559        }
560        Some(Sz3Predictor::LorenzoFirstSecondOrderRegression) => {
561            sz3::CompressionAlgorithm::LorenzoRegression {
562                lorenzo: true,
563                lorenzo_second_order: true,
564                regression: true,
565                regression_second_order: false,
566                prediction_dimension: None,
567            }
568        }
569        Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionFirstSecondOrder) => {
570            sz3::CompressionAlgorithm::LorenzoRegression {
571                lorenzo: true,
572                lorenzo_second_order: true,
573                regression: true,
574                regression_second_order: true,
575                prediction_dimension: None,
576            }
577        }
578        None => sz3::CompressionAlgorithm::NoPrediction,
579    };
580    config = config.compression_algorithm(predictor);
581
582    // TODO: avoid extra allocation here
583    let compressed = sz3::compress_with_config(&data, &config).map_err(|err| {
584        Sz3CodecError::Sz3EncodeFailed {
585            source: Sz3CodingError(err),
586        }
587    })?;
588    encoded_bytes.extend_from_slice(&compressed);
589
590    Ok(encoded_bytes)
591}
592
593/// Decompresses the `encoded` data into an array.
594///
595/// # Errors
596///
597/// Errors with
598/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
599pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
600    let (header, data) =
601        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
602            Sz3CodecError::HeaderDecodeFailed {
603                source: Sz3HeaderError(err),
604            }
605        })?;
606
607    let decoded = if header.shape.iter().copied().product::<usize>() == 0 {
608        match header.dtype {
609            Sz3DType::I32 => {
610                AnyArray::I32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
611            }
612            Sz3DType::I64 => {
613                AnyArray::I64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
614            }
615            Sz3DType::F32 => {
616                AnyArray::F32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
617            }
618            Sz3DType::F64 => {
619                AnyArray::F64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
620            }
621        }
622    } else {
623        // TODO: avoid extra allocation here
624        match header.dtype {
625            Sz3DType::I32 => AnyArray::I32(Array::from_shape_vec(
626                &*header.shape,
627                Vec::from(sz3::decompress(data).1.data()),
628            )?),
629            Sz3DType::I64 => AnyArray::I64(Array::from_shape_vec(
630                &*header.shape,
631                Vec::from(sz3::decompress(data).1.data()),
632            )?),
633            Sz3DType::F32 => AnyArray::F32(Array::from_shape_vec(
634                &*header.shape,
635                Vec::from(sz3::decompress(data).1.data()),
636            )?),
637            Sz3DType::F64 => AnyArray::F64(Array::from_shape_vec(
638                &*header.shape,
639                Vec::from(sz3::decompress(data).1.data()),
640            )?),
641        }
642    };
643
644    Ok(decoded)
645}
646
647/// Array element types which can be compressed with SZ3.
648pub trait Sz3Element: Copy + sz3::SZ3Compressible {
649    /// The dtype representation of the type
650    const DTYPE: Sz3DType;
651}
652
653impl Sz3Element for i32 {
654    const DTYPE: Sz3DType = Sz3DType::I32;
655}
656
657impl Sz3Element for i64 {
658    const DTYPE: Sz3DType = Sz3DType::I64;
659}
660
661impl Sz3Element for f32 {
662    const DTYPE: Sz3DType = Sz3DType::F32;
663}
664
665impl Sz3Element for f64 {
666    const DTYPE: Sz3DType = Sz3DType::F64;
667}
668
669#[derive(Serialize, Deserialize)]
670struct CompressionHeader<'a> {
671    dtype: Sz3DType,
672    #[serde(borrow)]
673    shape: Cow<'a, [usize]>,
674    version: Sz3CodecVersion,
675}
676
677/// Dtypes that SZ3 can compress and decompress
678#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
679#[expect(missing_docs)]
680pub enum Sz3DType {
681    #[serde(rename = "i32", alias = "int32")]
682    I32,
683    #[serde(rename = "i64", alias = "int64")]
684    I64,
685    #[serde(rename = "f32", alias = "float32")]
686    F32,
687    #[serde(rename = "f64", alias = "float64")]
688    F64,
689}
690
691impl fmt::Display for Sz3DType {
692    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
693        fmt.write_str(match self {
694            Self::I32 => "i32",
695            Self::I64 => "i64",
696            Self::F32 => "f32",
697            Self::F64 => "f64",
698        })
699    }
700}
701
702#[cfg(test)]
703mod tests {
704    use ndarray::ArrayView1;
705
706    use super::*;
707
708    #[test]
709    fn zero_length() -> Result<(), Sz3CodecError> {
710        let encoded = compress(
711            Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
712            default_predictor().as_ref(),
713            &Sz3ErrorBound::L2Norm { l2: 27.0 },
714        )?;
715        let decoded = decompress(&encoded)?;
716
717        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
718        assert!(decoded.is_empty());
719        assert_eq!(decoded.shape(), &[1, 27, 0]);
720
721        Ok(())
722    }
723
724    #[test]
725    fn one_dimension() -> Result<(), Sz3CodecError> {
726        let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
727
728        let encoded = compress(
729            data.view(),
730            default_predictor().as_ref(),
731            &Sz3ErrorBound::Absolute { abs: 0.1 },
732        )?;
733        let decoded = decompress(&encoded)?;
734
735        assert_eq!(decoded, AnyArray::I32(data));
736
737        Ok(())
738    }
739
740    #[test]
741    fn small_state() -> Result<(), Sz3CodecError> {
742        for data in [
743            &[][..],
744            &[0.0],
745            &[0.0, 1.0],
746            &[0.0, 1.0, 0.0],
747            &[0.0, 1.0, 0.0, 1.0],
748        ] {
749            let encoded = compress(
750                ArrayView1::from(data),
751                default_predictor().as_ref(),
752                &Sz3ErrorBound::Absolute { abs: 0.1 },
753            )?;
754            let decoded = decompress(&encoded)?;
755
756            assert_eq!(
757                decoded,
758                AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
759            );
760        }
761
762        Ok(())
763    }
764
765    #[test]
766    fn all_predictors() -> Result<(), Sz3CodecError> {
767        let data = Array::linspace(-42.0, 42.0, 85);
768
769        for predictor in [
770            None,
771            Some(Sz3Predictor::Regression),
772            Some(Sz3Predictor::RegressionSecondOrder),
773            Some(Sz3Predictor::RegressionFirstSecondOrder),
774            Some(Sz3Predictor::LorenzoSecondOrder),
775            Some(Sz3Predictor::LorenzoSecondOrderRegressionSecondOrder),
776            Some(Sz3Predictor::LorenzoSecondOrderRegression),
777            Some(Sz3Predictor::LorenzoSecondOrderRegressionFirstSecondOrder),
778            Some(Sz3Predictor::Lorenzo),
779            Some(Sz3Predictor::LorenzoRegressionSecondOrder),
780            Some(Sz3Predictor::LorenzoRegression),
781            Some(Sz3Predictor::LorenzoRegressionFirstSecondOrder),
782            Some(Sz3Predictor::LorenzoFirstSecondOrder),
783            Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionSecondOrder),
784            Some(Sz3Predictor::LorenzoFirstSecondOrderRegression),
785            Some(Sz3Predictor::LorenzoFirstSecondOrderRegressionFirstSecondOrder),
786        ] {
787            let encoded = compress(
788                data.view(),
789                predictor.as_ref(),
790                &Sz3ErrorBound::Absolute { abs: 0.1 },
791            )?;
792            let _decoded = decompress(&encoded)?;
793        }
794
795        Ok(())
796    }
797}