numcodecs_sz3/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sz3
10//! [crates.io]: https://crates.io/crates/numcodecs-sz3
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sz3
13//! [docs.rs]: https://docs.rs/numcodecs-sz3/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sz3
17//!
18//! SZ3 codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33// Only included to explicitly enable the `no_wasm_shim` feature for
34// sz3-sys/Sz3-sys
35use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40#[derive(Clone, Serialize, Deserialize, JsonSchema)]
41// serde cannot deny unknown fields because of the flatten
42#[schemars(deny_unknown_fields)]
43/// Codec providing compression using SZ3
44pub struct Sz3Codec {
45    /// Predictor
46    #[serde(default = "default_predictor")]
47    pub predictor: Option<Sz3Predictor>,
48    /// SZ3 error bound
49    #[serde(flatten)]
50    pub error_bound: Sz3ErrorBound,
51    /// Encoder
52    #[serde(default = "default_encoder")]
53    pub encoder: Option<Sz3Encoder>,
54    /// Lossless compressor
55    #[serde(default = "default_lossless_compressor")]
56    pub lossless: Option<Sz3LosslessCompressor>,
57}
58
59/// SZ3 error bound
60#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
61#[serde(tag = "eb_mode")]
62#[serde(deny_unknown_fields)]
63pub enum Sz3ErrorBound {
64    /// Errors are bounded by *both* the absolute and relative error, i.e. by
65    /// whichever bound is stricter
66    #[serde(rename = "abs-and-rel")]
67    AbsoluteAndRelative {
68        /// Absolute error bound
69        #[serde(rename = "eb_abs")]
70        abs: f64,
71        /// Relative error bound
72        #[serde(rename = "eb_rel")]
73        rel: f64,
74    },
75    /// Errors are bounded by *either* the absolute or relative error, i.e. by
76    /// whichever bound is weaker
77    #[serde(rename = "abs-or-rel")]
78    AbsoluteOrRelative {
79        /// Absolute error bound
80        #[serde(rename = "eb_abs")]
81        abs: f64,
82        /// Relative error bound
83        #[serde(rename = "eb_rel")]
84        rel: f64,
85    },
86    /// Absolute error bound
87    #[serde(rename = "abs")]
88    Absolute {
89        /// Absolute error bound
90        #[serde(rename = "eb_abs")]
91        abs: f64,
92    },
93    /// Relative error bound
94    #[serde(rename = "rel")]
95    Relative {
96        /// Relative error bound
97        #[serde(rename = "eb_rel")]
98        rel: f64,
99    },
100    /// Peak signal to noise ratio error bound
101    #[serde(rename = "psnr")]
102    PS2NR {
103        /// Peak signal to noise ratio error bound
104        #[serde(rename = "eb_psnr")]
105        psnr: f64,
106    },
107    /// Peak L2 norm error bound
108    #[serde(rename = "l2")]
109    L2Norm {
110        /// Peak L2 norm error bound
111        #[serde(rename = "eb_l2")]
112        l2: f64,
113    },
114}
115
116/// SZ3 predictor
117#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
118#[serde(deny_unknown_fields)]
119pub enum Sz3Predictor {
120    /// Linear interpolation
121    #[serde(rename = "linear-interpolation")]
122    LinearInterpolation,
123    /// Cubic interpolation
124    #[serde(rename = "cubic-interpolation")]
125    CubicInterpolation,
126    /// Linear interpolation + Lorenzo predictor
127    #[serde(rename = "linear-interpolation-lorenzo")]
128    LinearInterpolationLorenzo,
129    /// Cubic interpolation + Lorenzo predictor
130    #[serde(rename = "cubic-interpolation-lorenzo")]
131    CubicInterpolationLorenzo,
132    /// Lorenzo predictor + regression
133    #[serde(rename = "lorenzo-regression")]
134    LorenzoRegression,
135}
136
137#[expect(clippy::unnecessary_wraps)]
138const fn default_predictor() -> Option<Sz3Predictor> {
139    Some(Sz3Predictor::CubicInterpolationLorenzo)
140}
141
142/// SZ3 encoder
143#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
144#[serde(deny_unknown_fields)]
145pub enum Sz3Encoder {
146    /// Huffman coding
147    #[serde(rename = "huffman")]
148    Huffman,
149    /// Arithmetic coding
150    #[serde(rename = "arithmetic")]
151    Arithmetic,
152}
153
154#[expect(clippy::unnecessary_wraps)]
155const fn default_encoder() -> Option<Sz3Encoder> {
156    Some(Sz3Encoder::Huffman)
157}
158
159/// SZ3 lossless compressor
160#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
161#[serde(deny_unknown_fields)]
162pub enum Sz3LosslessCompressor {
163    /// Zstandard
164    #[serde(rename = "zstd")]
165    Zstd,
166}
167
168#[expect(clippy::unnecessary_wraps)]
169const fn default_lossless_compressor() -> Option<Sz3LosslessCompressor> {
170    Some(Sz3LosslessCompressor::Zstd)
171}
172
173impl Codec for Sz3Codec {
174    type Error = Sz3CodecError;
175
176    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
177        match data {
178            AnyCowArray::I32(data) => Ok(AnyArray::U8(
179                Array1::from(compress(
180                    data,
181                    self.predictor.as_ref(),
182                    &self.error_bound,
183                    self.encoder.as_ref(),
184                    self.lossless.as_ref(),
185                )?)
186                .into_dyn(),
187            )),
188            AnyCowArray::I64(data) => Ok(AnyArray::U8(
189                Array1::from(compress(
190                    data,
191                    self.predictor.as_ref(),
192                    &self.error_bound,
193                    self.encoder.as_ref(),
194                    self.lossless.as_ref(),
195                )?)
196                .into_dyn(),
197            )),
198            AnyCowArray::F32(data) => Ok(AnyArray::U8(
199                Array1::from(compress(
200                    data,
201                    self.predictor.as_ref(),
202                    &self.error_bound,
203                    self.encoder.as_ref(),
204                    self.lossless.as_ref(),
205                )?)
206                .into_dyn(),
207            )),
208            AnyCowArray::F64(data) => Ok(AnyArray::U8(
209                Array1::from(compress(
210                    data,
211                    self.predictor.as_ref(),
212                    &self.error_bound,
213                    self.encoder.as_ref(),
214                    self.lossless.as_ref(),
215                )?)
216                .into_dyn(),
217            )),
218            encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
219        }
220    }
221
222    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
223        let AnyCowArray::U8(encoded) = encoded else {
224            return Err(Sz3CodecError::EncodedDataNotBytes {
225                dtype: encoded.dtype(),
226            });
227        };
228
229        if !matches!(encoded.shape(), [_]) {
230            return Err(Sz3CodecError::EncodedDataNotOneDimensional {
231                shape: encoded.shape().to_vec(),
232            });
233        }
234
235        decompress(&AnyCowArray::U8(encoded).as_bytes())
236    }
237
238    fn decode_into(
239        &self,
240        encoded: AnyArrayView,
241        mut decoded: AnyArrayViewMut,
242    ) -> Result<(), Self::Error> {
243        let decoded_in = self.decode(encoded.cow())?;
244
245        Ok(decoded.assign(&decoded_in)?)
246    }
247}
248
249impl StaticCodec for Sz3Codec {
250    const CODEC_ID: &'static str = "sz3";
251
252    type Config<'de> = Self;
253
254    fn from_config(config: Self::Config<'_>) -> Self {
255        config
256    }
257
258    fn get_config(&self) -> StaticCodecConfig<Self> {
259        StaticCodecConfig::from(self)
260    }
261}
262
263#[derive(Debug, Error)]
264/// Errors that may occur when applying the [`Sz3Codec`].
265pub enum Sz3CodecError {
266    /// [`Sz3Codec`] does not support the dtype
267    #[error("Sz3 does not support the dtype {0}")]
268    UnsupportedDtype(AnyArrayDType),
269    /// [`Sz3Codec`] failed to encode the header
270    #[error("Sz3 failed to encode the header")]
271    HeaderEncodeFailed {
272        /// Opaque source error
273        source: Sz3HeaderError,
274    },
275    /// [`Sz3Codec`] cannot encode an array of `shape`
276    #[error("Sz3 cannot encode an array of shape {shape:?}")]
277    InvalidEncodeShape {
278        /// Opaque source error
279        source: Sz3CodingError,
280        /// The invalid shape of the encoded array
281        shape: Vec<usize>,
282    },
283    /// [`Sz3Codec`] failed to encode the data
284    #[error("Sz3 failed to encode the data")]
285    Sz3EncodeFailed {
286        /// Opaque source error
287        source: Sz3CodingError,
288    },
289    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
290    /// an array of a different dtype
291    #[error(
292        "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
293    )]
294    EncodedDataNotBytes {
295        /// The unexpected dtype of the encoded array
296        dtype: AnyArrayDType,
297    },
298    /// [`Sz3Codec`] can only decode one-dimensional byte arrays but received
299    /// an array of a different shape
300    #[error("Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
301    EncodedDataNotOneDimensional {
302        /// The unexpected shape of the encoded array
303        shape: Vec<usize>,
304    },
305    /// [`Sz3Codec`] failed to decode the header
306    #[error("Sz3 failed to decode the header")]
307    HeaderDecodeFailed {
308        /// Opaque source error
309        source: Sz3HeaderError,
310    },
311    /// [`Sz3Codec`] decoded an invalid array shape header which does not fit
312    /// the decoded data
313    #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
314    DecodeInvalidShapeHeader {
315        /// Source error
316        #[from]
317        source: ShapeError,
318    },
319    /// [`Sz3Codec`] cannot decode into the provided array
320    #[error("Sz3 cannot decode into the provided array")]
321    MismatchedDecodeIntoArray {
322        /// The source of the error
323        #[from]
324        source: AnyArrayAssignError,
325    },
326}
327
328#[derive(Debug, Error)]
329#[error(transparent)]
330/// Opaque error for when encoding or decoding the header fails
331pub struct Sz3HeaderError(postcard::Error);
332
333#[derive(Debug, Error)]
334#[error(transparent)]
335/// Opaque error for when encoding or decoding with SZ3 fails
336pub struct Sz3CodingError(sz3::SZ3Error);
337
338#[expect(clippy::needless_pass_by_value)]
339/// Compresses the input `data` array using SZ3, which consists of an optional
340/// `predictor`, an `error_bound`, an optional `encoder`, and an optional
341/// `lossless` compressor.
342///
343/// # Errors
344///
345/// Errors with
346/// - [`Sz3CodecError::HeaderEncodeFailed`] if encoding the header failed
347/// - [`Sz3CodecError::InvalidEncodeShape`] if the array shape is invalid
348/// - [`Sz3CodecError::Sz3EncodeFailed`] if encoding failed with an opaque error
349pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
350    data: ArrayBase<S, D>,
351    predictor: Option<&Sz3Predictor>,
352    error_bound: &Sz3ErrorBound,
353    encoder: Option<&Sz3Encoder>,
354    lossless: Option<&Sz3LosslessCompressor>,
355) -> Result<Vec<u8>, Sz3CodecError> {
356    let mut encoded_bytes = postcard::to_extend(
357        &CompressionHeader {
358            dtype: <T as Sz3Element>::DTYPE,
359            shape: Cow::Borrowed(data.shape()),
360        },
361        Vec::new(),
362    )
363    .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
364        source: Sz3HeaderError(err),
365    })?;
366
367    // sz3::DimensionedDataBuilder cannot handle zero-length dimensions
368    if data.is_empty() {
369        return Ok(encoded_bytes);
370    }
371
372    #[expect(clippy::option_if_let_else)]
373    let data_cow = if let Some(data) = data.as_slice() {
374        Cow::Borrowed(data)
375    } else {
376        Cow::Owned(data.iter().copied().collect())
377    };
378    let mut builder = sz3::DimensionedData::build(&data_cow);
379
380    for length in data.shape() {
381        // Sz3 ignores dimensions of length 1 and panics on length zero
382        // Since they carry no information for Sz3 and we already encode them
383        //  in our custom header, we just skip them here
384        if *length > 1 {
385            builder = builder
386                .dim(*length)
387                .map_err(|err| Sz3CodecError::InvalidEncodeShape {
388                    source: Sz3CodingError(err),
389                    shape: data.shape().to_vec(),
390                })?;
391        }
392    }
393
394    if data.len() == 1 {
395        // If there is only one element, all dimensions will have been skipped,
396        //  so we explicitly encode one dimension of size 1 here
397        builder = builder
398            .dim(1)
399            .map_err(|err| Sz3CodecError::InvalidEncodeShape {
400                source: Sz3CodingError(err),
401                shape: data.shape().to_vec(),
402            })?;
403    }
404
405    let data = builder
406        .finish()
407        .map_err(|err| Sz3CodecError::InvalidEncodeShape {
408            source: Sz3CodingError(err),
409            shape: data.shape().to_vec(),
410        })?;
411
412    // configure the error bound
413    let error_bound = match error_bound {
414        Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
415            absolute_bound: *abs,
416            relative_bound: *rel,
417        },
418        Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
419            absolute_bound: *abs,
420            relative_bound: *rel,
421        },
422        Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
423        Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
424        Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
425        Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
426    };
427    let mut config = sz3::Config::new(error_bound);
428
429    // configure the interpolation mode, if necessary
430    let interpolation = match predictor {
431        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::LinearInterpolationLorenzo) => {
432            Some(sz3::InterpolationAlgorithm::Linear)
433        }
434        Some(Sz3Predictor::CubicInterpolation | Sz3Predictor::CubicInterpolationLorenzo) => {
435            Some(sz3::InterpolationAlgorithm::Cubic)
436        }
437        Some(Sz3Predictor::LorenzoRegression) | None => None,
438    };
439    if let Some(interpolation) = interpolation {
440        config = config.interpolation_algorithm(interpolation);
441    }
442
443    // configure the predictor (compression algorithm)
444    let predictor = match predictor {
445        Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::CubicInterpolation) => {
446            sz3::CompressionAlgorithm::Interpolation
447        }
448        Some(
449            Sz3Predictor::LinearInterpolationLorenzo | Sz3Predictor::CubicInterpolationLorenzo,
450        ) => sz3::CompressionAlgorithm::InterpolationLorenzo,
451        Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::lorenzo_regression(),
452        None => sz3::CompressionAlgorithm::NoPrediction,
453    };
454    config = config.compression_algorithm(predictor);
455
456    // configure the encoder
457    let encoder = match encoder {
458        None => sz3::Encoder::SkipEncoder,
459        Some(Sz3Encoder::Huffman) => sz3::Encoder::HuffmanEncoder,
460        Some(Sz3Encoder::Arithmetic) => sz3::Encoder::ArithmeticEncoder,
461    };
462    config = config.encoder(encoder);
463
464    // configure the lossless compressor
465    let lossless = match lossless {
466        None => sz3::LossLess::LossLessBypass,
467        Some(Sz3LosslessCompressor::Zstd) => sz3::LossLess::ZSTD,
468    };
469    config = config.lossless(lossless);
470
471    // TODO: avoid extra allocation here
472    let compressed = sz3::compress_with_config(&data, &config).map_err(|err| {
473        Sz3CodecError::Sz3EncodeFailed {
474            source: Sz3CodingError(err),
475        }
476    })?;
477    encoded_bytes.extend_from_slice(&compressed);
478
479    Ok(encoded_bytes)
480}
481
482/// Decompresses the `encoded` data into an array.
483///
484/// # Errors
485///
486/// Errors with
487/// - [`Sz3CodecError::HeaderDecodeFailed`] if decoding the header failed
488pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
489    let (header, data) =
490        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
491            Sz3CodecError::HeaderDecodeFailed {
492                source: Sz3HeaderError(err),
493            }
494        })?;
495
496    let decoded = if header.shape.iter().copied().product::<usize>() == 0 {
497        match header.dtype {
498            Sz3DType::I32 => {
499                AnyArray::I32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
500            }
501            Sz3DType::I64 => {
502                AnyArray::I64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
503            }
504            Sz3DType::F32 => {
505                AnyArray::F32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
506            }
507            Sz3DType::F64 => {
508                AnyArray::F64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
509            }
510        }
511    } else {
512        // TODO: avoid extra allocation here
513        match header.dtype {
514            Sz3DType::I32 => AnyArray::I32(Array::from_shape_vec(
515                &*header.shape,
516                Vec::from(sz3::decompress(data).1.data()),
517            )?),
518            Sz3DType::I64 => AnyArray::I64(Array::from_shape_vec(
519                &*header.shape,
520                Vec::from(sz3::decompress(data).1.data()),
521            )?),
522            Sz3DType::F32 => AnyArray::F32(Array::from_shape_vec(
523                &*header.shape,
524                Vec::from(sz3::decompress(data).1.data()),
525            )?),
526            Sz3DType::F64 => AnyArray::F64(Array::from_shape_vec(
527                &*header.shape,
528                Vec::from(sz3::decompress(data).1.data()),
529            )?),
530        }
531    };
532
533    Ok(decoded)
534}
535
536/// Array element types which can be compressed with SZ3.
537pub trait Sz3Element: Copy + sz3::SZ3Compressible {
538    /// The dtype representation of the type
539    const DTYPE: Sz3DType;
540}
541
542impl Sz3Element for i32 {
543    const DTYPE: Sz3DType = Sz3DType::I32;
544}
545
546impl Sz3Element for i64 {
547    const DTYPE: Sz3DType = Sz3DType::I64;
548}
549
550impl Sz3Element for f32 {
551    const DTYPE: Sz3DType = Sz3DType::F32;
552}
553
554impl Sz3Element for f64 {
555    const DTYPE: Sz3DType = Sz3DType::F64;
556}
557
558#[derive(Serialize, Deserialize)]
559struct CompressionHeader<'a> {
560    dtype: Sz3DType,
561    #[serde(borrow)]
562    shape: Cow<'a, [usize]>,
563}
564
565/// Dtypes that SZ3 can compress and decompress
566#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
567#[expect(missing_docs)]
568pub enum Sz3DType {
569    #[serde(rename = "i32", alias = "int32")]
570    I32,
571    #[serde(rename = "i64", alias = "int64")]
572    I64,
573    #[serde(rename = "f32", alias = "float32")]
574    F32,
575    #[serde(rename = "f64", alias = "float64")]
576    F64,
577}
578
579impl fmt::Display for Sz3DType {
580    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
581        fmt.write_str(match self {
582            Self::I32 => "i32",
583            Self::I64 => "i64",
584            Self::F32 => "f32",
585            Self::F64 => "f64",
586        })
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use ndarray::ArrayView1;
593
594    use super::*;
595
596    #[test]
597    fn zero_length() -> Result<(), Sz3CodecError> {
598        let encoded = compress(
599            Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
600            default_predictor().as_ref(),
601            &Sz3ErrorBound::L2Norm { l2: 27.0 },
602            default_encoder().as_ref(),
603            default_lossless_compressor().as_ref(),
604        )?;
605        let decoded = decompress(&encoded)?;
606
607        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
608        assert!(decoded.is_empty());
609        assert_eq!(decoded.shape(), &[1, 27, 0]);
610
611        Ok(())
612    }
613
614    #[test]
615    fn one_dimension() -> Result<(), Sz3CodecError> {
616        let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
617
618        let encoded = compress(
619            data.view(),
620            default_predictor().as_ref(),
621            &Sz3ErrorBound::Absolute { abs: 0.1 },
622            default_encoder().as_ref(),
623            default_lossless_compressor().as_ref(),
624        )?;
625        let decoded = decompress(&encoded)?;
626
627        assert_eq!(decoded, AnyArray::I32(data));
628
629        Ok(())
630    }
631
632    #[test]
633    fn small_state() -> Result<(), Sz3CodecError> {
634        for data in [
635            &[][..],
636            &[0.0],
637            &[0.0, 1.0],
638            &[0.0, 1.0, 0.0],
639            &[0.0, 1.0, 0.0, 1.0],
640        ] {
641            let encoded = compress(
642                ArrayView1::from(data),
643                default_predictor().as_ref(),
644                &Sz3ErrorBound::Absolute { abs: 0.1 },
645                default_encoder().as_ref(),
646                default_lossless_compressor().as_ref(),
647            )?;
648            let decoded = decompress(&encoded)?;
649
650            assert_eq!(
651                decoded,
652                AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
653            );
654        }
655
656        Ok(())
657    }
658}