numcodecs_sperr/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.85.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-sperr
10//! [crates.io]: https://crates.io/crates/numcodecs-sperr
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-sperr
13//! [docs.rs]: https://docs.rs/numcodecs-sperr/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_sperr
17//!
18//! SPERR codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22#[cfg(test)]
23use ::serde_json as _;
24
25use std::borrow::Cow;
26use std::fmt;
27
28use ndarray::{Array, Array1, ArrayBase, Axis, Data, Dimension, IxDyn, ShapeError};
29use num_traits::{Float, identities::Zero};
30use numcodecs::{
31    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
32    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
33};
34use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
35use serde::{Deserialize, Deserializer, Serialize, Serializer};
36use thiserror::Error;
37
38type SperrCodecVersion = StaticCodecVersion<0, 1, 0>;
39
40#[derive(Clone, Serialize, Deserialize, JsonSchema)]
41// serde cannot deny unknown fields because of the flatten
42#[schemars(deny_unknown_fields)]
43/// Codec providing compression using SPERR.
44///
45/// Arrays that are higher-dimensional than 3D are encoded by compressing each
46/// 3D slice with SPERR independently. Specifically, the array's shape is
47/// interpreted as `[.., depth, height, width]`. If you want to compress 3D
48/// slices along three different axes, you can swizzle the array axes
49/// beforehand.
50pub struct SperrCodec {
51    /// SPERR compression mode
52    #[serde(flatten)]
53    pub mode: SperrCompressionMode,
54    /// The codec's encoding format version. Do not provide this parameter explicitly.
55    #[serde(default, rename = "_version")]
56    pub version: SperrCodecVersion,
57}
58
59#[derive(Clone, Serialize, Deserialize, JsonSchema)]
60/// SPERR compression mode
61#[serde(tag = "mode")]
62pub enum SperrCompressionMode {
63    /// Fixed bit-per-pixel rate
64    #[serde(rename = "bpp")]
65    BitsPerPixel {
66        /// positive bits-per-pixel
67        bpp: Positive<f64>,
68    },
69    /// Fixed peak signal-to-noise ratio
70    #[serde(rename = "psnr")]
71    PeakSignalToNoiseRatio {
72        /// positive peak signal-to-noise ratio
73        psnr: Positive<f64>,
74    },
75    /// Fixed point-wise (absolute) error
76    #[serde(rename = "pwe")]
77    PointwiseError {
78        /// positive point-wise (absolute) error
79        pwe: Positive<f64>,
80    },
81}
82
83impl Codec for SperrCodec {
84    type Error = SperrCodecError;
85
86    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
87        match data {
88            AnyCowArray::F32(data) => Ok(AnyArray::U8(
89                Array1::from(compress(data, &self.mode)?).into_dyn(),
90            )),
91            AnyCowArray::F64(data) => Ok(AnyArray::U8(
92                Array1::from(compress(data, &self.mode)?).into_dyn(),
93            )),
94            encoded => Err(SperrCodecError::UnsupportedDtype(encoded.dtype())),
95        }
96    }
97
98    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
99        let AnyCowArray::U8(encoded) = encoded else {
100            return Err(SperrCodecError::EncodedDataNotBytes {
101                dtype: encoded.dtype(),
102            });
103        };
104
105        if !matches!(encoded.shape(), [_]) {
106            return Err(SperrCodecError::EncodedDataNotOneDimensional {
107                shape: encoded.shape().to_vec(),
108            });
109        }
110
111        decompress(&AnyCowArray::U8(encoded).as_bytes())
112    }
113
114    fn decode_into(
115        &self,
116        encoded: AnyArrayView,
117        mut decoded: AnyArrayViewMut,
118    ) -> Result<(), Self::Error> {
119        let decoded_in = self.decode(encoded.cow())?;
120
121        Ok(decoded.assign(&decoded_in)?)
122    }
123}
124
125impl StaticCodec for SperrCodec {
126    const CODEC_ID: &'static str = "sperr.rs";
127
128    type Config<'de> = Self;
129
130    fn from_config(config: Self::Config<'_>) -> Self {
131        config
132    }
133
134    fn get_config(&self) -> StaticCodecConfig<Self> {
135        StaticCodecConfig::from(self)
136    }
137}
138
139#[derive(Debug, Error)]
140/// Errors that may occur when applying the [`SperrCodec`].
141pub enum SperrCodecError {
142    /// [`SperrCodec`] does not support the dtype
143    #[error("Sperr does not support the dtype {0}")]
144    UnsupportedDtype(AnyArrayDType),
145    /// [`SperrCodec`] failed to encode the header
146    #[error("Sperr failed to encode the header")]
147    HeaderEncodeFailed {
148        /// Opaque source error
149        source: SperrHeaderError,
150    },
151    /// [`SperrCodec`] failed to encode the data
152    #[error("Sperr failed to encode the data")]
153    SperrEncodeFailed {
154        /// Opaque source error
155        source: SperrCodingError,
156    },
157    /// [`SperrCodec`] failed to encode a slice
158    #[error("Sperr failed to encode a slice")]
159    SliceEncodeFailed {
160        /// Opaque source error
161        source: SperrSliceError,
162    },
163    /// [`SperrCodec`] can only decode one-dimensional byte arrays but received
164    /// an array of a different dtype
165    #[error(
166        "Sperr can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
167    )]
168    EncodedDataNotBytes {
169        /// The unexpected dtype of the encoded array
170        dtype: AnyArrayDType,
171    },
172    /// [`SperrCodec`] can only decode one-dimensional byte arrays but received
173    /// an array of a different shape
174    #[error(
175        "Sperr can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
176    )]
177    EncodedDataNotOneDimensional {
178        /// The unexpected shape of the encoded array
179        shape: Vec<usize>,
180    },
181    /// [`SperrCodec`] failed to decode the header
182    #[error("Sperr failed to decode the header")]
183    HeaderDecodeFailed {
184        /// Opaque source error
185        source: SperrHeaderError,
186    },
187    /// [`SperrCodec`] failed to decode a slice
188    #[error("Sperr failed to decode a slice")]
189    SliceDecodeFailed {
190        /// Opaque source error
191        source: SperrSliceError,
192    },
193    /// [`SperrCodec`] failed to decode from an excessive number of slices
194    #[error("Sperr failed to decode from an excessive number of slices")]
195    DecodeTooManySlices,
196    /// [`SperrCodec`] failed to decode the data
197    #[error("Sperr failed to decode the data")]
198    SperrDecodeFailed {
199        /// Opaque source error
200        source: SperrCodingError,
201    },
202    /// [`SperrCodec`] decoded into an invalid shape not matching the data size
203    #[error("Sperr decoded into an invalid shape not matching the data size")]
204    DecodeInvalidShape {
205        /// The source of the error
206        source: ShapeError,
207    },
208    /// [`SperrCodec`] cannot decode into the provided array
209    #[error("Sperr cannot decode into the provided array")]
210    MismatchedDecodeIntoArray {
211        /// The source of the error
212        #[from]
213        source: AnyArrayAssignError,
214    },
215}
216
217#[derive(Debug, Error)]
218#[error(transparent)]
219/// Opaque error for when encoding or decoding the header fails
220pub struct SperrHeaderError(postcard::Error);
221
222#[derive(Debug, Error)]
223#[error(transparent)]
224/// Opaque error for when encoding or decoding a slice fails
225pub struct SperrSliceError(postcard::Error);
226
227#[derive(Debug, Error)]
228#[error(transparent)]
229/// Opaque error for when encoding or decoding with SPERR fails
230pub struct SperrCodingError(sperr::Error);
231
232/// Compress the `data` array using SPERR with the provided `mode`.
233///
234/// # Errors
235///
236/// Errors with
237/// - [`SperrCodecError::HeaderEncodeFailed`] if encoding the header failed
238/// - [`SperrCodecError::SperrEncodeFailed`] if encoding with SPERR failed
239/// - [`SperrCodecError::SliceEncodeFailed`] if encoding a slice failed
240#[allow(clippy::missing_panics_doc)]
241pub fn compress<T: SperrElement, S: Data<Elem = T>, D: Dimension>(
242    data: ArrayBase<S, D>,
243    mode: &SperrCompressionMode,
244) -> Result<Vec<u8>, SperrCodecError> {
245    let mut encoded = postcard::to_extend(
246        &CompressionHeader {
247            dtype: T::DTYPE,
248            shape: Cow::Borrowed(data.shape()),
249            version: StaticCodecVersion,
250        },
251        Vec::new(),
252    )
253    .map_err(|err| SperrCodecError::HeaderEncodeFailed {
254        source: SperrHeaderError(err),
255    })?;
256
257    // SPERR cannot handle zero-length dimensions
258    if data.is_empty() {
259        return Ok(encoded);
260    }
261
262    let mut chunk_size = Vec::from(data.shape());
263    let (width, height, depth) = match *chunk_size.as_mut_slice() {
264        [ref mut rest @ .., depth, height, width] => {
265            for r in rest {
266                *r = 1;
267            }
268            (width, height, depth)
269        }
270        [height, width] => (width, height, 1),
271        [width] => (width, 1, 1),
272        [] => (1, 1, 1),
273    };
274
275    for mut slice in data.into_dyn().exact_chunks(chunk_size.as_slice()) {
276        while slice.ndim() < 3 {
277            slice = slice.insert_axis(Axis(0));
278        }
279        #[allow(clippy::unwrap_used)]
280        // slice must now have at least three axes, and all but the last three
281        //  must be of size 1
282        let slice = slice.into_shape_with_order((depth, height, width)).unwrap();
283
284        let encoded_slice = sperr::compress_3d(
285            slice,
286            match mode {
287                SperrCompressionMode::BitsPerPixel { bpp } => {
288                    sperr::CompressionMode::BitsPerPixel { bpp: bpp.0 }
289                }
290                SperrCompressionMode::PeakSignalToNoiseRatio { psnr } => {
291                    sperr::CompressionMode::PeakSignalToNoiseRatio { psnr: psnr.0 }
292                }
293                SperrCompressionMode::PointwiseError { pwe } => {
294                    sperr::CompressionMode::PointwiseError { pwe: pwe.0 }
295                }
296            },
297            (256, 256, 256),
298        )
299        .map_err(|err| SperrCodecError::SperrEncodeFailed {
300            source: SperrCodingError(err),
301        })?;
302
303        encoded = postcard::to_extend(encoded_slice.as_slice(), encoded).map_err(|err| {
304            SperrCodecError::SliceEncodeFailed {
305                source: SperrSliceError(err),
306            }
307        })?;
308    }
309
310    Ok(encoded)
311}
312
313/// Decompress the `encoded` data into an array using SPERR.
314///
315/// # Errors
316///
317/// Errors with
318/// - [`SperrCodecError::HeaderDecodeFailed`] if decoding the header failed
319/// - [`SperrCodecError::SliceDecodeFailed`] if decoding a slice failed
320/// - [`SperrCodecError::SperrDecodeFailed`] if decoding with SPERR failed
321/// - [`SperrCodecError::DecodeInvalidShape`] if the encoded data decodes to
322///   an unexpected shape
323/// - [`SperrCodecError::DecodeTooManySlices`] if the encoded data contains
324///   too many slices
325pub fn decompress(encoded: &[u8]) -> Result<AnyArray, SperrCodecError> {
326    fn decompress_typed<T: SperrElement>(
327        mut encoded: &[u8],
328        shape: &[usize],
329    ) -> Result<Array<T, IxDyn>, SperrCodecError> {
330        let mut decoded = Array::<T, _>::zeros(shape);
331
332        let mut chunk_size = Vec::from(shape);
333        let (width, height, depth) = match *chunk_size.as_mut_slice() {
334            [ref mut rest @ .., depth, height, width] => {
335                for r in rest {
336                    *r = 1;
337                }
338                (width, height, depth)
339            }
340            [height, width] => (width, height, 1),
341            [width] => (width, 1, 1),
342            [] => (1, 1, 1),
343        };
344
345        for mut slice in decoded.exact_chunks_mut(chunk_size.as_slice()) {
346            let (encoded_slice, rest) =
347                postcard::take_from_bytes::<Cow<[u8]>>(encoded).map_err(|err| {
348                    SperrCodecError::SliceDecodeFailed {
349                        source: SperrSliceError(err),
350                    }
351                })?;
352            encoded = rest;
353
354            while slice.ndim() < 3 {
355                slice = slice.insert_axis(Axis(0));
356            }
357            #[allow(clippy::unwrap_used)]
358            // slice must now have at least three axes, and all but the last
359            //  three must be of size 1
360            let slice = slice.into_shape_with_order((depth, height, width)).unwrap();
361
362            sperr::decompress_into_3d(&encoded_slice, slice).map_err(|err| {
363                SperrCodecError::SperrDecodeFailed {
364                    source: SperrCodingError(err),
365                }
366            })?;
367        }
368
369        if !encoded.is_empty() {
370            return Err(SperrCodecError::DecodeTooManySlices);
371        }
372
373        Ok(decoded)
374    }
375
376    let (header, encoded) =
377        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
378            SperrCodecError::HeaderDecodeFailed {
379                source: SperrHeaderError(err),
380            }
381        })?;
382
383    // Return empty data for zero-size arrays
384    if header.shape.iter().copied().product::<usize>() == 0 {
385        return match header.dtype {
386            SperrDType::F32 => Ok(AnyArray::F32(Array::zeros(&*header.shape))),
387            SperrDType::F64 => Ok(AnyArray::F64(Array::zeros(&*header.shape))),
388        };
389    }
390
391    match header.dtype {
392        SperrDType::F32 => Ok(AnyArray::F32(decompress_typed(encoded, &header.shape)?)),
393        SperrDType::F64 => Ok(AnyArray::F64(decompress_typed(encoded, &header.shape)?)),
394    }
395}
396
397/// Array element types which can be compressed with SPERR.
398pub trait SperrElement: sperr::Element + Zero {
399    /// The dtype representation of the type
400    const DTYPE: SperrDType;
401}
402
403impl SperrElement for f32 {
404    const DTYPE: SperrDType = SperrDType::F32;
405}
406impl SperrElement for f64 {
407    const DTYPE: SperrDType = SperrDType::F64;
408}
409
410#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
411#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
412/// Positive floating point number
413pub struct Positive<T: Float>(T);
414
415impl<T: Float> Positive<T> {
416    #[must_use]
417    /// Get the positive floating point value
418    pub const fn get(self) -> T {
419        self.0
420    }
421}
422
423impl Serialize for Positive<f64> {
424    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
425        serializer.serialize_f64(self.0)
426    }
427}
428
429impl<'de> Deserialize<'de> for Positive<f64> {
430    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
431        let x = f64::deserialize(deserializer)?;
432
433        if x > 0.0 {
434            Ok(Self(x))
435        } else {
436            Err(serde::de::Error::invalid_value(
437                serde::de::Unexpected::Float(x),
438                &"a positive value",
439            ))
440        }
441    }
442}
443
444impl JsonSchema for Positive<f64> {
445    fn schema_name() -> Cow<'static, str> {
446        Cow::Borrowed("PositiveF64")
447    }
448
449    fn schema_id() -> Cow<'static, str> {
450        Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
451    }
452
453    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
454        json_schema!({
455            "type": "number",
456            "exclusiveMinimum": 0.0
457        })
458    }
459}
460
461#[derive(Serialize, Deserialize)]
462struct CompressionHeader<'a> {
463    dtype: SperrDType,
464    #[serde(borrow)]
465    shape: Cow<'a, [usize]>,
466    version: SperrCodecVersion,
467}
468
469/// Dtypes that SPERR can compress and decompress
470#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
471#[expect(missing_docs)]
472pub enum SperrDType {
473    #[serde(rename = "f32", alias = "float32")]
474    F32,
475    #[serde(rename = "f64", alias = "float64")]
476    F64,
477}
478
479impl fmt::Display for SperrDType {
480    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
481        fmt.write_str(match self {
482            Self::F32 => "f32",
483            Self::F64 => "f64",
484        })
485    }
486}
487
488#[cfg(test)]
489#[allow(clippy::unwrap_used)]
490mod tests {
491    use ndarray::{Ix0, Ix1, Ix2, Ix3, Ix4};
492
493    use super::*;
494
495    #[test]
496    fn zero_length() {
497        let encoded = compress(
498            Array::<f32, _>::from_shape_vec([3, 0], vec![]).unwrap(),
499            &SperrCompressionMode::PeakSignalToNoiseRatio {
500                psnr: Positive(42.0),
501            },
502        )
503        .unwrap();
504        let decoded = decompress(&encoded).unwrap();
505
506        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
507        assert!(decoded.is_empty());
508        assert_eq!(decoded.shape(), &[3, 0]);
509    }
510
511    #[test]
512    fn small_2d() {
513        let encoded = compress(
514            Array::<f32, _>::from_shape_vec([1, 1], vec![42.0]).unwrap(),
515            &SperrCompressionMode::PeakSignalToNoiseRatio {
516                psnr: Positive(42.0),
517            },
518        )
519        .unwrap();
520        let decoded = decompress(&encoded).unwrap();
521
522        assert_eq!(decoded.dtype(), AnyArrayDType::F32);
523        assert_eq!(decoded.len(), 1);
524        assert_eq!(decoded.shape(), &[1, 1]);
525    }
526
527    #[test]
528    fn large_3d() {
529        let encoded = compress(
530            Array::<f64, _>::zeros((64, 64, 64)),
531            &SperrCompressionMode::PeakSignalToNoiseRatio {
532                psnr: Positive(42.0),
533            },
534        )
535        .unwrap();
536        let decoded = decompress(&encoded).unwrap();
537
538        assert_eq!(decoded.dtype(), AnyArrayDType::F64);
539        assert_eq!(decoded.len(), 64 * 64 * 64);
540        assert_eq!(decoded.shape(), &[64, 64, 64]);
541    }
542
543    #[test]
544    fn all_modes() {
545        for mode in [
546            SperrCompressionMode::BitsPerPixel { bpp: Positive(1.0) },
547            SperrCompressionMode::PeakSignalToNoiseRatio {
548                psnr: Positive(42.0),
549            },
550            SperrCompressionMode::PointwiseError { pwe: Positive(0.1) },
551        ] {
552            let encoded = compress(Array::<f64, _>::zeros((64, 64, 64)), &mode).unwrap();
553            let decoded = decompress(&encoded).unwrap();
554
555            assert_eq!(decoded.dtype(), AnyArrayDType::F64);
556            assert_eq!(decoded.len(), 64 * 64 * 64);
557            assert_eq!(decoded.shape(), &[64, 64, 64]);
558        }
559    }
560
561    #[test]
562    fn many_dimensions() {
563        for data in [
564            Array::<f32, Ix0>::from_shape_vec([], vec![42.0])
565                .unwrap()
566                .into_dyn(),
567            Array::<f32, Ix1>::from_shape_vec([2], vec![1.0, 2.0])
568                .unwrap()
569                .into_dyn(),
570            Array::<f32, Ix2>::from_shape_vec([2, 2], vec![1.0, 2.0, 3.0, 4.0])
571                .unwrap()
572                .into_dyn(),
573            Array::<f32, Ix3>::from_shape_vec(
574                [2, 2, 2],
575                vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
576            )
577            .unwrap()
578            .into_dyn(),
579            Array::<f32, Ix4>::from_shape_vec(
580                [2, 2, 2, 2],
581                vec![
582                    1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
583                    15.0, 16.0,
584                ],
585            )
586            .unwrap()
587            .into_dyn(),
588        ] {
589            let encoded = compress(
590                data.view(),
591                &SperrCompressionMode::PointwiseError {
592                    pwe: Positive(f64::EPSILON),
593                },
594            )
595            .unwrap();
596            let decoded = decompress(&encoded).unwrap();
597
598            assert_eq!(decoded, AnyArray::F32(data));
599        }
600    }
601}