numcodecs_zstd/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-zstd
10//! [crates.io]: https://crates.io/crates/numcodecs-zstd
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-zstd
13//! [docs.rs]: https://docs.rs/numcodecs-zstd/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_zstd
17//!
18//! Zstandard codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, io};
23
24use ndarray::Array1;
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Deserializer, Serialize, Serializer};
31use thiserror::Error;
32// Only used to explicitly enable the `no_wasm_shim` feature in zstd/zstd-sys
33use zstd_sys as _;
34
35type ZstdCodecVersion = StaticCodecVersion<0, 1, 0>;
36
37#[derive(Clone, Serialize, Deserialize, JsonSchema)]
38#[serde(deny_unknown_fields)]
39/// Codec providing compression using Zstandard
40pub struct ZstdCodec {
41    /// Zstandard compression level.
42    ///
43    /// The level ranges from small (fastest) to large (best compression).
44    pub level: ZstdLevel,
45    /// The codec's encoding format version. Do not provide this parameter explicitly.
46    #[serde(default, rename = "_version")]
47    pub version: ZstdCodecVersion,
48}
49
50impl Codec for ZstdCodec {
51    type Error = ZstdCodecError;
52
53    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
54        compress(data.view(), self.level)
55            .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
56    }
57
58    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
59        let AnyCowArray::U8(encoded) = encoded else {
60            return Err(ZstdCodecError::EncodedDataNotBytes {
61                dtype: encoded.dtype(),
62            });
63        };
64
65        if !matches!(encoded.shape(), [_]) {
66            return Err(ZstdCodecError::EncodedDataNotOneDimensional {
67                shape: encoded.shape().to_vec(),
68            });
69        }
70
71        decompress(&AnyCowArray::U8(encoded).as_bytes())
72    }
73
74    fn decode_into(
75        &self,
76        encoded: AnyArrayView,
77        decoded: AnyArrayViewMut,
78    ) -> Result<(), Self::Error> {
79        let AnyArrayView::U8(encoded) = encoded else {
80            return Err(ZstdCodecError::EncodedDataNotBytes {
81                dtype: encoded.dtype(),
82            });
83        };
84
85        if !matches!(encoded.shape(), [_]) {
86            return Err(ZstdCodecError::EncodedDataNotOneDimensional {
87                shape: encoded.shape().to_vec(),
88            });
89        }
90
91        decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
92    }
93}
94
95impl StaticCodec for ZstdCodec {
96    const CODEC_ID: &'static str = "zstd.rs";
97
98    type Config<'de> = Self;
99
100    fn from_config(config: Self::Config<'_>) -> Self {
101        config
102    }
103
104    fn get_config(&self) -> StaticCodecConfig<Self> {
105        StaticCodecConfig::from(self)
106    }
107}
108
109#[derive(Clone, Copy, JsonSchema)]
110#[schemars(transparent)]
111/// Zstandard compression level.
112///
113/// The level ranges from small (fastest) to large (best compression).
114pub struct ZstdLevel {
115    level: zstd::zstd_safe::CompressionLevel,
116}
117
118impl Serialize for ZstdLevel {
119    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
120        self.level.serialize(serializer)
121    }
122}
123
124impl<'de> Deserialize<'de> for ZstdLevel {
125    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
126        let level = Deserialize::deserialize(deserializer)?;
127
128        let level_range = zstd::compression_level_range();
129
130        if !level_range.contains(&level) {
131            return Err(serde::de::Error::custom(format!(
132                "level {level} is not in {}..={}",
133                level_range.start(),
134                level_range.end()
135            )));
136        }
137
138        Ok(Self { level })
139    }
140}
141
142#[derive(Debug, Error)]
143/// Errors that may occur when applying the [`ZstdCodec`].
144pub enum ZstdCodecError {
145    /// [`ZstdCodec`] failed to encode the header
146    #[error("Zstd failed to encode the header")]
147    HeaderEncodeFailed {
148        /// Opaque source error
149        source: ZstdHeaderError,
150    },
151    /// [`ZstdCodec`] failed to encode the encoded data
152    #[error("Zstd failed to decode the encoded data")]
153    ZstdEncodeFailed {
154        /// Opaque source error
155        source: ZstdCodingError,
156    },
157    /// [`ZstdCodec`] can only decode one-dimensional byte arrays but received
158    /// an array of a different dtype
159    #[error(
160        "Zstd can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
161    )]
162    EncodedDataNotBytes {
163        /// The unexpected dtype of the encoded array
164        dtype: AnyArrayDType,
165    },
166    /// [`ZstdCodec`] can only decode one-dimensional byte arrays but received
167    /// an array of a different shape
168    #[error("Zstd can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
169    EncodedDataNotOneDimensional {
170        /// The unexpected shape of the encoded array
171        shape: Vec<usize>,
172    },
173    /// [`ZstdCodec`] failed to encode the header
174    #[error("Zstd failed to decode the header")]
175    HeaderDecodeFailed {
176        /// Opaque source error
177        source: ZstdHeaderError,
178    },
179    /// [`ZstdCodec`] decode consumed less encoded data, which contains trailing
180    /// junk
181    #[error("Zstd decode consumed less encoded data, which contains trailing junk")]
182    DecodeExcessiveEncodedData,
183    /// [`ZstdCodec`] produced less decoded data than expected
184    #[error("Zstd produced less decoded data than expected")]
185    DecodeProducedLess,
186    /// [`ZstdCodec`] failed to decode the encoded data
187    #[error("Zstd failed to decode the encoded data")]
188    ZstdDecodeFailed {
189        /// Opaque source error
190        source: ZstdCodingError,
191    },
192    /// [`ZstdCodec`] cannot decode into the provided array
193    #[error("Zstd cannot decode into the provided array")]
194    MismatchedDecodeIntoArray {
195        /// The source of the error
196        #[from]
197        source: AnyArrayAssignError,
198    },
199}
200
201#[derive(Debug, Error)]
202#[error(transparent)]
203/// Opaque error for when encoding or decoding the header fails
204pub struct ZstdHeaderError(postcard::Error);
205
206#[derive(Debug, Error)]
207#[error(transparent)]
208/// Opaque error for when encoding or decoding with Zstandard fails
209pub struct ZstdCodingError(io::Error);
210
211#[expect(clippy::needless_pass_by_value)]
212/// Compress the `array` using Zstandard with the provided `level`.
213///
214/// # Errors
215///
216/// Errors with
217/// - [`ZstdCodecError::HeaderEncodeFailed`] if encoding the header to the
218///   output bytevec failed
219/// - [`ZstdCodecError::ZstdEncodeFailed`] if an opaque encoding error occurred
220///
221/// # Panics
222///
223/// Panics if the infallible encoding with Zstd fails.
224pub fn compress(array: AnyArrayView, level: ZstdLevel) -> Result<Vec<u8>, ZstdCodecError> {
225    let mut encoded = postcard::to_extend(
226        &CompressionHeader {
227            dtype: array.dtype(),
228            shape: Cow::Borrowed(array.shape()),
229            version: StaticCodecVersion,
230        },
231        Vec::new(),
232    )
233    .map_err(|err| ZstdCodecError::HeaderEncodeFailed {
234        source: ZstdHeaderError(err),
235    })?;
236
237    zstd::stream::copy_encode(&*array.as_bytes(), &mut encoded, level.level).map_err(|err| {
238        ZstdCodecError::ZstdEncodeFailed {
239            source: ZstdCodingError(err),
240        }
241    })?;
242
243    Ok(encoded)
244}
245
246/// Decompress the `encoded` data into an array using Zstandard.
247///
248/// # Errors
249///
250/// Errors with
251/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
252/// - [`ZstdCodecError::DecodeExcessiveEncodedData`] if the encoded data
253///   contains excessive trailing data junk
254/// - [`ZstdCodecError::DecodeProducedLess`] if decoding produced less data than
255///   expected
256/// - [`ZstdCodecError::ZstdDecodeFailed`] if an opaque decoding error occurred
257pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZstdCodecError> {
258    let (header, encoded) =
259        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
260            ZstdCodecError::HeaderDecodeFailed {
261                source: ZstdHeaderError(err),
262            }
263        })?;
264
265    let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
266        decompress_into_bytes(encoded, decoded)
267    });
268
269    result.map(|()| decoded)
270}
271
272/// Decompress the `encoded` data into a `decoded` array using Zstandard.
273///
274/// # Errors
275///
276/// Errors with
277/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
278/// - [`ZstdCodecError::MismatchedDecodeIntoArray`] if the `decoded` array is of
279///   the wrong dtype or shape
280/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
281/// - [`ZstdCodecError::DecodeExcessiveEncodedData`] if the encoded data
282///   contains excessive trailing data junk
283/// - [`ZstdCodecError::DecodeProducedLess`] if decoding produced less data than
284///   expected
285/// - [`ZstdCodecError::ZstdDecodeFailed`] if an opaque decoding error occurred
286pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZstdCodecError> {
287    let (header, encoded) =
288        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
289            ZstdCodecError::HeaderDecodeFailed {
290                source: ZstdHeaderError(err),
291            }
292        })?;
293
294    if header.dtype != decoded.dtype() {
295        return Err(ZstdCodecError::MismatchedDecodeIntoArray {
296            source: AnyArrayAssignError::DTypeMismatch {
297                src: header.dtype,
298                dst: decoded.dtype(),
299            },
300        });
301    }
302
303    if header.shape != decoded.shape() {
304        return Err(ZstdCodecError::MismatchedDecodeIntoArray {
305            source: AnyArrayAssignError::ShapeMismatch {
306                src: header.shape.into_owned(),
307                dst: decoded.shape().to_vec(),
308            },
309        });
310    }
311
312    decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
313}
314
315fn decompress_into_bytes(mut encoded: &[u8], mut decoded: &mut [u8]) -> Result<(), ZstdCodecError> {
316    // we want to check encoded and decoded for full consumption after the decoding
317    zstd::stream::copy_decode(&mut encoded, &mut decoded).map_err(|err| {
318        ZstdCodecError::ZstdDecodeFailed {
319            source: ZstdCodingError(err),
320        }
321    })?;
322
323    if !encoded.is_empty() {
324        return Err(ZstdCodecError::DecodeExcessiveEncodedData);
325    }
326
327    if !decoded.is_empty() {
328        return Err(ZstdCodecError::DecodeProducedLess);
329    }
330
331    Ok(())
332}
333
334#[derive(Serialize, Deserialize)]
335struct CompressionHeader<'a> {
336    dtype: AnyArrayDType,
337    #[serde(borrow)]
338    shape: Cow<'a, [usize]>,
339    version: ZstdCodecVersion,
340}