numcodecs_sperr/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.87.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sperr
10//! [crates.io]: https://crates.io/crates/numcodecs-sperr
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sperr
13//! [docs.rs]: https://docs.rs/numcodecs-sperr/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sperr
17//!
18//! SPERR codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22#[cfg(test)]
23use ::serde_json as _;
24
25use std::borrow::Cow;
26use std::fmt;
27
28use ndarray::{Array, Array1, ArrayBase, Axis, Data, Dimension, IxDyn, ShapeError};
29use num_traits::{Float, identities::Zero};
30use numcodecs::{
31    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
32    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
33};
34use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
35use serde::{Deserialize, Deserializer, Serialize, Serializer};
36use thiserror::Error;
37
38type SperrCodecVersion = StaticCodecVersion<0, 2, 0>;
39
40#[derive(Clone, Serialize, Deserialize, JsonSchema)]
41// serde cannot deny unknown fields because of the flatten
42#[schemars(deny_unknown_fields)]
43/// Codec providing compression using SPERR.
44///
45/// Arrays that are higher-dimensional than 3D are encoded by compressing each
46/// 3D slice with SPERR independently. Specifically, the array's shape is
47/// interpreted as `[.., depth, height, width]`. If you want to compress 3D
48/// slices along three different axes, you can swizzle the array axes
49/// beforehand.
50pub struct SperrCodec {
51    /// SPERR compression mode
52    #[serde(flatten)]
53    pub mode: SperrCompressionMode,
54    /// The codec's encoding format version. Do not provide this parameter explicitly.
55    #[serde(default, rename = "_version")]
56    pub version: SperrCodecVersion,
57}
58
59#[derive(Clone, Serialize, Deserialize, JsonSchema)]
60/// SPERR compression mode
61#[serde(tag = "mode")]
62pub enum SperrCompressionMode {
63    /// Fixed bit-per-pixel rate
64    #[serde(rename = "bpp")]
65    BitsPerPixel {
66        /// positive bits-per-pixel
67        bpp: Positive<f64>,
68    },
69    /// Fixed peak signal-to-noise ratio
70    #[serde(rename = "psnr")]
71    PeakSignalToNoiseRatio {
72        /// positive peak signal-to-noise ratio
73        psnr: Positive<f64>,
74    },
75    /// Fixed point-wise (absolute) error
76    #[serde(rename = "pwe")]
77    PointwiseError {
78        /// positive point-wise (absolute) error
79        pwe: Positive<f64>,
80    },
81    /// Fixed quantisation step
82    #[serde(rename = "q")]
83    QuantisationStep {
84        /// positive quantisation step
85        q: Positive<f64>,
86    },
87}
88
89impl Codec for SperrCodec {
90    type Error = SperrCodecError;
91
92    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
93        match data {
94            AnyCowArray::F32(data) => Ok(AnyArray::U8(
95                Array1::from(compress(data, &self.mode)?).into_dyn(),
96            )),
97            AnyCowArray::F64(data) => Ok(AnyArray::U8(
98                Array1::from(compress(data, &self.mode)?).into_dyn(),
99            )),
100            encoded => Err(SperrCodecError::UnsupportedDtype(encoded.dtype())),
101        }
102    }
103
104    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
105        let AnyCowArray::U8(encoded) = encoded else {
106            return Err(SperrCodecError::EncodedDataNotBytes {
107                dtype: encoded.dtype(),
108            });
109        };
110
111        if !matches!(encoded.shape(), [_]) {
112            return Err(SperrCodecError::EncodedDataNotOneDimensional {
113                shape: encoded.shape().to_vec(),
114            });
115        }
116
117        decompress(&AnyCowArray::U8(encoded).as_bytes())
118    }
119
120    fn decode_into(
121        &self,
122        encoded: AnyArrayView,
123        mut decoded: AnyArrayViewMut,
124    ) -> Result<(), Self::Error> {
125        let decoded_in = self.decode(encoded.cow())?;
126
127        Ok(decoded.assign(&decoded_in)?)
128    }
129}
130
131impl StaticCodec for SperrCodec {
132    const CODEC_ID: &'static str = "sperr.rs";
133
134    type Config<'de> = Self;
135
136    fn from_config(config: Self::Config<'_>) -> Self {
137        config
138    }
139
140    fn get_config(&self) -> StaticCodecConfig<'_, Self> {
141        StaticCodecConfig::from(self)
142    }
143}
144
145#[derive(Debug, Error)]
146/// Errors that may occur when applying the [`SperrCodec`].
147pub enum SperrCodecError {
148    /// [`SperrCodec`] does not support the dtype
149    #[error("Sperr does not support the dtype {0}")]
150    UnsupportedDtype(AnyArrayDType),
151    /// [`SperrCodec`] failed to encode the header
152    #[error("Sperr failed to encode the header")]
153    HeaderEncodeFailed {
154        /// Opaque source error
155        source: SperrHeaderError,
156    },
157    /// [`SperrCodec`] failed to encode the data
158    #[error("Sperr failed to encode the data")]
159    SperrEncodeFailed {
160        /// Opaque source error
161        source: SperrCodingError,
162    },
163    /// [`SperrCodec`] failed to encode a slice
164    #[error("Sperr failed to encode a slice")]
165    SliceEncodeFailed {
166        /// Opaque source error
167        source: SperrSliceError,
168    },
169    /// [`SperrCodec`] can only decode one-dimensional byte arrays but received
170    /// an array of a different dtype
171    #[error(
172        "Sperr can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
173    )]
174    EncodedDataNotBytes {
175        /// The unexpected dtype of the encoded array
176        dtype: AnyArrayDType,
177    },
178    /// [`SperrCodec`] can only decode one-dimensional byte arrays but received
179    /// an array of a different shape
180    #[error(
181        "Sperr can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
182    )]
183    EncodedDataNotOneDimensional {
184        /// The unexpected shape of the encoded array
185        shape: Vec<usize>,
186    },
187    /// [`SperrCodec`] failed to decode the header
188    #[error("Sperr failed to decode the header")]
189    HeaderDecodeFailed {
190        /// Opaque source error
191        source: SperrHeaderError,
192    },
193    /// [`SperrCodec`] failed to decode a slice
194    #[error("Sperr failed to decode a slice")]
195    SliceDecodeFailed {
196        /// Opaque source error
197        source: SperrSliceError,
198    },
199    /// [`SperrCodec`] failed to decode from an excessive number of slices
200    #[error("Sperr failed to decode from an excessive number of slices")]
201    DecodeTooManySlices,
202    /// [`SperrCodec`] failed to decode the data
203    #[error("Sperr failed to decode the data")]
204    SperrDecodeFailed {
205        /// Opaque source error
206        source: SperrCodingError,
207    },
208    /// [`SperrCodec`] decoded into an invalid shape not matching the data size
209    #[error("Sperr decoded into an invalid shape not matching the data size")]
210    DecodeInvalidShape {
211        /// The source of the error
212        source: ShapeError,
213    },
214    /// [`SperrCodec`] cannot decode into the provided array
215    #[error("Sperr cannot decode into the provided array")]
216    MismatchedDecodeIntoArray {
217        /// The source of the error
218        #[from]
219        source: AnyArrayAssignError,
220    },
221}
222
223#[derive(Debug, Error)]
224#[error(transparent)]
225/// Opaque error for when encoding or decoding the header fails
226pub struct SperrHeaderError(postcard::Error);
227
228#[derive(Debug, Error)]
229#[error(transparent)]
230/// Opaque error for when encoding or decoding a slice fails
231pub struct SperrSliceError(postcard::Error);
232
233#[derive(Debug, Error)]
234#[error(transparent)]
235/// Opaque error for when encoding or decoding with SPERR fails
236pub struct SperrCodingError(sperr::Error);
237
238/// Compress the `data` array using SPERR with the provided `mode`.
239///
240/// # Errors
241///
242/// Errors with
243/// - [`SperrCodecError::HeaderEncodeFailed`] if encoding the header failed
244/// - [`SperrCodecError::SperrEncodeFailed`] if encoding with SPERR failed
245/// - [`SperrCodecError::SliceEncodeFailed`] if encoding a slice failed
246#[allow(clippy::missing_panics_doc)]
247pub fn compress<T: SperrElement, S: Data<Elem = T>, D: Dimension>(
248    data: ArrayBase<S, D>,
249    mode: &SperrCompressionMode,
250) -> Result<Vec<u8>, SperrCodecError> {
251    let mut encoded = postcard::to_extend(
252        &CompressionHeader {
253            dtype: T::DTYPE,
254            shape: Cow::Borrowed(data.shape()),
255            version: StaticCodecVersion,
256        },
257        Vec::new(),
258    )
259    .map_err(|err| SperrCodecError::HeaderEncodeFailed {
260        source: SperrHeaderError(err),
261    })?;
262
263    // SPERR cannot handle zero-length dimensions
264    if data.is_empty() {
265        return Ok(encoded);
266    }
267
268    let mut chunk_size = Vec::from(data.shape());
269    let (width, height, depth) = match *chunk_size.as_mut_slice() {
270        [ref mut rest @ .., depth, height, width] => {
271            for r in rest {
272                *r = 1;
273            }
274            (width, height, depth)
275        }
276        [height, width] => (width, height, 1),
277        [width] => (width, 1, 1),
278        [] => (1, 1, 1),
279    };
280
281    for mut slice in data.into_dyn().exact_chunks(chunk_size.as_slice()) {
282        while slice.ndim() < 3 {
283            slice = slice.insert_axis(Axis(0));
284        }
285        #[allow(clippy::unwrap_used)]
286        // slice must now have at least three axes, and all but the last three
287        //  must be of size 1
288        let slice = slice.into_shape_with_order((depth, height, width)).unwrap();
289
290        let encoded_slice = sperr::compress_3d(
291            slice,
292            match mode {
293                SperrCompressionMode::BitsPerPixel { bpp } => {
294                    sperr::CompressionMode::BitsPerPixel { bpp: bpp.0 }
295                }
296                SperrCompressionMode::PeakSignalToNoiseRatio { psnr } => {
297                    sperr::CompressionMode::PeakSignalToNoiseRatio { psnr: psnr.0 }
298                }
299                SperrCompressionMode::PointwiseError { pwe } => {
300                    sperr::CompressionMode::PointwiseError { pwe: pwe.0 }
301                }
302                SperrCompressionMode::QuantisationStep { q } => {
303                    sperr::CompressionMode::QuantisationStep { q: q.0 }
304                }
305            },
306            (256, 256, 256),
307        )
308        .map_err(|err| SperrCodecError::SperrEncodeFailed {
309            source: SperrCodingError(err),
310        })?;
311
312        encoded = postcard::to_extend(encoded_slice.as_slice(), encoded).map_err(|err| {
313            SperrCodecError::SliceEncodeFailed {
314                source: SperrSliceError(err),
315            }
316        })?;
317    }
318
319    Ok(encoded)
320}
321
322/// Decompress the `encoded` data into an array using SPERR.
323///
324/// # Errors
325///
326/// Errors with
327/// - [`SperrCodecError::HeaderDecodeFailed`] if decoding the header failed
328/// - [`SperrCodecError::SliceDecodeFailed`] if decoding a slice failed
329/// - [`SperrCodecError::SperrDecodeFailed`] if decoding with SPERR failed
330/// - [`SperrCodecError::DecodeInvalidShape`] if the encoded data decodes to
331///   an unexpected shape
332/// - [`SperrCodecError::DecodeTooManySlices`] if the encoded data contains
333///   too many slices
334pub fn decompress(encoded: &[u8]) -> Result<AnyArray, SperrCodecError> {
335    fn decompress_typed<T: SperrElement>(
336        mut encoded: &[u8],
337        shape: &[usize],
338    ) -> Result<Array<T, IxDyn>, SperrCodecError> {
339        let mut decoded = Array::<T, _>::zeros(shape);
340
341        let mut chunk_size = Vec::from(shape);
342        let (width, height, depth) = match *chunk_size.as_mut_slice() {
343            [ref mut rest @ .., depth, height, width] => {
344                for r in rest {
345                    *r = 1;
346                }
347                (width, height, depth)
348            }
349            [height, width] => (width, height, 1),
350            [width] => (width, 1, 1),
351            [] => (1, 1, 1),
352        };
353
354        for mut slice in decoded.exact_chunks_mut(chunk_size.as_slice()) {
355            let (encoded_slice, rest) =
356                postcard::take_from_bytes::<Cow<[u8]>>(encoded).map_err(|err| {
357                    SperrCodecError::SliceDecodeFailed {
358                        source: SperrSliceError(err),
359                    }
360                })?;
361            encoded = rest;
362
363            while slice.ndim() < 3 {
364                slice = slice.insert_axis(Axis(0));
365            }
366            #[allow(clippy::unwrap_used)]
367            // slice must now have at least three axes, and all but the last
368            //  three must be of size 1
369            let slice = slice.into_shape_with_order((depth, height, width)).unwrap();
370
371            sperr::decompress_into_3d(&encoded_slice, slice).map_err(|err| {
372                SperrCodecError::SperrDecodeFailed {
373                    source: SperrCodingError(err),
374                }
375            })?;
376        }
377
378        if !encoded.is_empty() {
379            return Err(SperrCodecError::DecodeTooManySlices);
380        }
381
382        Ok(decoded)
383    }
384
385    let (header, encoded) =
386        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
387            SperrCodecError::HeaderDecodeFailed {
388                source: SperrHeaderError(err),
389            }
390        })?;
391
392    // Return empty data for zero-size arrays
393    if header.shape.iter().copied().product::<usize>() == 0 {
394        return match header.dtype {
395            SperrDType::F32 => Ok(AnyArray::F32(Array::zeros(&*header.shape))),
396            SperrDType::F64 => Ok(AnyArray::F64(Array::zeros(&*header.shape))),
397        };
398    }
399
400    match header.dtype {
401        SperrDType::F32 => Ok(AnyArray::F32(decompress_typed(encoded, &header.shape)?)),
402        SperrDType::F64 => Ok(AnyArray::F64(decompress_typed(encoded, &header.shape)?)),
403    }
404}
405
406/// Array element types which can be compressed with SPERR.
407pub trait SperrElement: sperr::Element + Zero {
408    /// The dtype representation of the type
409    const DTYPE: SperrDType;
410}
411
412impl SperrElement for f32 {
413    const DTYPE: SperrDType = SperrDType::F32;
414}
415impl SperrElement for f64 {
416    const DTYPE: SperrDType = SperrDType::F64;
417}
418
419#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
420#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
421/// Positive floating point number
422pub struct Positive<T: Float>(T);
423
424impl<T: Float> Positive<T> {
425    #[must_use]
426    /// Get the positive floating point value
427    pub const fn get(self) -> T {
428        self.0
429    }
430}
431
432impl Serialize for Positive<f64> {
433    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
434        serializer.serialize_f64(self.0)
435    }
436}
437
438impl<'de> Deserialize<'de> for Positive<f64> {
439    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
440        let x = f64::deserialize(deserializer)?;
441
442        if x > 0.0 {
443            Ok(Self(x))
444        } else {
445            Err(serde::de::Error::invalid_value(
446                serde::de::Unexpected::Float(x),
447                &"a positive value",
448            ))
449        }
450    }
451}
452
453impl JsonSchema for Positive<f64> {
454    fn schema_name() -> Cow<'static, str> {
455        Cow::Borrowed("PositiveF64")
456    }
457
458    fn schema_id() -> Cow<'static, str> {
459        Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
460    }
461
462    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
463        json_schema!({
464            "type": "number",
465            "exclusiveMinimum": 0.0
466        })
467    }
468}
469
470#[derive(Serialize, Deserialize)]
471struct CompressionHeader<'a> {
472    dtype: SperrDType,
473    #[serde(borrow)]
474    shape: Cow<'a, [usize]>,
475    version: SperrCodecVersion,
476}
477
478/// Dtypes that SPERR can compress and decompress
479#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
480#[expect(missing_docs)]
481pub enum SperrDType {
482    #[serde(rename = "f32", alias = "float32")]
483    F32,
484    #[serde(rename = "f64", alias = "float64")]
485    F64,
486}
487
488impl fmt::Display for SperrDType {
489    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
490        fmt.write_str(match self {
491            Self::F32 => "f32",
492            Self::F64 => "f64",
493        })
494    }
495}
496
497#[cfg(test)]
498#[allow(clippy::unwrap_used)]
499mod tests {
500    use ndarray::{Ix0, Ix1, Ix2, Ix3, Ix4};
501
502    use super::*;
503
504    #[test]
505    fn zero_length() {
506        let encoded = compress(
507            Array::<f32, _>::from_shape_vec([3, 0], vec![]).unwrap(),
508            &SperrCompressionMode::PeakSignalToNoiseRatio {
509                psnr: Positive(42.0),
510            },
511        )
512        .unwrap();
513        let decoded = decompress(&encoded).unwrap();
514
515        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
516        assert!(decoded.is_empty());
517        assert_eq!(decoded.shape(), &[3, 0]);
518    }
519
520    #[test]
521    fn small_2d() {
522        let encoded = compress(
523            Array::<f32, _>::from_shape_vec([1, 1], vec![42.0]).unwrap(),
524            &SperrCompressionMode::PeakSignalToNoiseRatio {
525                psnr: Positive(42.0),
526            },
527        )
528        .unwrap();
529        let decoded = decompress(&encoded).unwrap();
530
531        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
532        assert_eq!(decoded.len(), 1);
533        assert_eq!(decoded.shape(), &[1, 1]);
534    }
535
536    #[test]
537    fn large_3d() {
538        let encoded = compress(
539            Array::<f64, _>::zeros((64, 64, 64)),
540            &SperrCompressionMode::PeakSignalToNoiseRatio {
541                psnr: Positive(42.0),
542            },
543        )
544        .unwrap();
545        let decoded = decompress(&encoded).unwrap();
546
547        assert_eq!(decoded.dtype(), AnyArrayDType::F64);
548        assert_eq!(decoded.len(), 64 * 64 * 64);
549        assert_eq!(decoded.shape(), &[64, 64, 64]);
550    }
551
552    #[test]
553    fn all_modes() {
554        for mode in [
555            SperrCompressionMode::BitsPerPixel { bpp: Positive(1.0) },
556            SperrCompressionMode::PeakSignalToNoiseRatio {
557                psnr: Positive(42.0),
558            },
559            SperrCompressionMode::PointwiseError { pwe: Positive(0.1) },
560            SperrCompressionMode::QuantisationStep { q: Positive(1.5) },
561        ] {
562            let encoded = compress(Array::<f64, _>::zeros((64, 64, 64)), &mode).unwrap();
563            let decoded = decompress(&encoded).unwrap();
564
565            assert_eq!(decoded.dtype(), AnyArrayDType::F64);
566            assert_eq!(decoded.len(), 64 * 64 * 64);
567            assert_eq!(decoded.shape(), &[64, 64, 64]);
568        }
569    }
570
571    #[test]
572    fn many_dimensions() {
573        for data in [
574            Array::<f32, Ix0>::from_shape_vec([], vec![42.0])
575                .unwrap()
576                .into_dyn(),
577            Array::<f32, Ix1>::from_shape_vec([2], vec![1.0, 2.0])
578                .unwrap()
579                .into_dyn(),
580            Array::<f32, Ix2>::from_shape_vec([2, 2], vec![1.0, 2.0, 3.0, 4.0])
581                .unwrap()
582                .into_dyn(),
583            Array::<f32, Ix3>::from_shape_vec(
584                [2, 2, 2],
585                vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
586            )
587            .unwrap()
588            .into_dyn(),
589            Array::<f32, Ix4>::from_shape_vec(
590                [2, 2, 2, 2],
591                vec![
592                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
593                    15.0, 16.0,
594                ],
595            )
596            .unwrap()
597            .into_dyn(),
598        ] {
599            let encoded = compress(
600                data.view(),
601                &SperrCompressionMode::PointwiseError {
602                    pwe: Positive(f64::EPSILON),
603                },
604            )
605            .unwrap();
606            let decoded = decompress(&encoded).unwrap();
607
608            assert_eq!(decoded, AnyArray::F32(data));
609        }
610    }
611}