numcodecs_zstd/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.85.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-zstd
10//! [crates.io]: https://crates.io/crates/numcodecs-zstd
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-zstd
13//! [docs.rs]: https://docs.rs/numcodecs-zstd/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_zstd
17//!
18//! Zstandard codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, io};
23
24use ndarray::Array1;
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Deserializer, Serialize, Serializer};
31use thiserror::Error;
32// Only used to explicitly enable the `no_wasm_shim` feature in zstd/zstd-sys
33use zstd_sys as _;
34
35type ZstdCodecVersion = StaticCodecVersion<0, 1, 0>;
36
37#[derive(Clone, Serialize, Deserialize, JsonSchema)]
38#[serde(deny_unknown_fields)]
39/// Codec providing compression using Zstandard
40pub struct ZstdCodec {
41    /// Zstandard compression level.
42    ///
43    /// The level ranges from small (fastest) to large (best compression).
44    pub level: ZstdLevel,
45    /// The codec's encoding format version. Do not provide this parameter explicitly.
46    #[serde(default, rename = "_version")]
47    pub version: ZstdCodecVersion,
48}
49
50impl Codec for ZstdCodec {
51    type Error = ZstdCodecError;
52
53    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
54        compress(data.view(), self.level)
55            .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
56    }
57
58    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
59        let AnyCowArray::U8(encoded) = encoded else {
60            return Err(ZstdCodecError::EncodedDataNotBytes {
61                dtype: encoded.dtype(),
62            });
63        };
64
65        if !matches!(encoded.shape(), [_]) {
66            return Err(ZstdCodecError::EncodedDataNotOneDimensional {
67                shape: encoded.shape().to_vec(),
68            });
69        }
70
71        decompress(&AnyCowArray::U8(encoded).as_bytes())
72    }
73
74    fn decode_into(
75        &self,
76        encoded: AnyArrayView,
77        decoded: AnyArrayViewMut,
78    ) -> Result<(), Self::Error> {
79        let AnyArrayView::U8(encoded) = encoded else {
80            return Err(ZstdCodecError::EncodedDataNotBytes {
81                dtype: encoded.dtype(),
82            });
83        };
84
85        if !matches!(encoded.shape(), [_]) {
86            return Err(ZstdCodecError::EncodedDataNotOneDimensional {
87                shape: encoded.shape().to_vec(),
88            });
89        }
90
91        decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
92    }
93}
94
95impl StaticCodec for ZstdCodec {
96    const CODEC_ID: &'static str = "zstd.rs";
97
98    type Config<'de> = Self;
99
100    fn from_config(config: Self::Config<'_>) -> Self {
101        config
102    }
103
104    fn get_config(&self) -> StaticCodecConfig<Self> {
105        StaticCodecConfig::from(self)
106    }
107}
108
109#[derive(Clone, Copy, JsonSchema)]
110#[schemars(transparent)]
111/// Zstandard compression level.
112///
113/// The level ranges from small (fastest) to large (best compression).
114pub struct ZstdLevel {
115    level: zstd::zstd_safe::CompressionLevel,
116}
117
118impl Serialize for ZstdLevel {
119    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
120        self.level.serialize(serializer)
121    }
122}
123
124impl<'de> Deserialize<'de> for ZstdLevel {
125    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
126        let level = Deserialize::deserialize(deserializer)?;
127
128        let level_range = zstd::compression_level_range();
129
130        if !level_range.contains(&level) {
131            return Err(serde::de::Error::custom(format!(
132                "level {level} is not in {}..={}",
133                level_range.start(),
134                level_range.end()
135            )));
136        }
137
138        Ok(Self { level })
139    }
140}
141
142#[derive(Debug, Error)]
143/// Errors that may occur when applying the [`ZstdCodec`].
144pub enum ZstdCodecError {
145    /// [`ZstdCodec`] failed to encode the header
146    #[error("Zstd failed to encode the header")]
147    HeaderEncodeFailed {
148        /// Opaque source error
149        source: ZstdHeaderError,
150    },
151    /// [`ZstdCodec`] failed to encode the encoded data
152    #[error("Zstd failed to decode the encoded data")]
153    ZstdEncodeFailed {
154        /// Opaque source error
155        source: ZstdCodingError,
156    },
157    /// [`ZstdCodec`] can only decode one-dimensional byte arrays but received
158    /// an array of a different dtype
159    #[error(
160        "Zstd can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
161    )]
162    EncodedDataNotBytes {
163        /// The unexpected dtype of the encoded array
164        dtype: AnyArrayDType,
165    },
166    /// [`ZstdCodec`] can only decode one-dimensional byte arrays but received
167    /// an array of a different shape
168    #[error(
169        "Zstd can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
170    )]
171    EncodedDataNotOneDimensional {
172        /// The unexpected shape of the encoded array
173        shape: Vec<usize>,
174    },
175    /// [`ZstdCodec`] failed to encode the header
176    #[error("Zstd failed to decode the header")]
177    HeaderDecodeFailed {
178        /// Opaque source error
179        source: ZstdHeaderError,
180    },
181    /// [`ZstdCodec`] decode consumed less encoded data, which contains trailing
182    /// junk
183    #[error("Zstd decode consumed less encoded data, which contains trailing junk")]
184    DecodeExcessiveEncodedData,
185    /// [`ZstdCodec`] produced less decoded data than expected
186    #[error("Zstd produced less decoded data than expected")]
187    DecodeProducedLess,
188    /// [`ZstdCodec`] failed to decode the encoded data
189    #[error("Zstd failed to decode the encoded data")]
190    ZstdDecodeFailed {
191        /// Opaque source error
192        source: ZstdCodingError,
193    },
194    /// [`ZstdCodec`] cannot decode into the provided array
195    #[error("Zstd cannot decode into the provided array")]
196    MismatchedDecodeIntoArray {
197        /// The source of the error
198        #[from]
199        source: AnyArrayAssignError,
200    },
201}
202
203#[derive(Debug, Error)]
204#[error(transparent)]
205/// Opaque error for when encoding or decoding the header fails
206pub struct ZstdHeaderError(postcard::Error);
207
208#[derive(Debug, Error)]
209#[error(transparent)]
210/// Opaque error for when encoding or decoding with Zstandard fails
211pub struct ZstdCodingError(io::Error);
212
213#[expect(clippy::needless_pass_by_value)]
214/// Compress the `array` using Zstandard with the provided `level`.
215///
216/// # Errors
217///
218/// Errors with
219/// - [`ZstdCodecError::HeaderEncodeFailed`] if encoding the header to the
220///   output bytevec failed
221/// - [`ZstdCodecError::ZstdEncodeFailed`] if an opaque encoding error occurred
222///
223/// # Panics
224///
225/// Panics if the infallible encoding with Zstd fails.
226pub fn compress(array: AnyArrayView, level: ZstdLevel) -> Result<Vec<u8>, ZstdCodecError> {
227    let mut encoded = postcard::to_extend(
228        &CompressionHeader {
229            dtype: array.dtype(),
230            shape: Cow::Borrowed(array.shape()),
231            version: StaticCodecVersion,
232        },
233        Vec::new(),
234    )
235    .map_err(|err| ZstdCodecError::HeaderEncodeFailed {
236        source: ZstdHeaderError(err),
237    })?;
238
239    zstd::stream::copy_encode(&*array.as_bytes(), &mut encoded, level.level).map_err(|err| {
240        ZstdCodecError::ZstdEncodeFailed {
241            source: ZstdCodingError(err),
242        }
243    })?;
244
245    Ok(encoded)
246}
247
248/// Decompress the `encoded` data into an array using Zstandard.
249///
250/// # Errors
251///
252/// Errors with
253/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
254/// - [`ZstdCodecError::DecodeExcessiveEncodedData`] if the encoded data
255///   contains excessive trailing data junk
256/// - [`ZstdCodecError::DecodeProducedLess`] if decoding produced less data than
257///   expected
258/// - [`ZstdCodecError::ZstdDecodeFailed`] if an opaque decoding error occurred
259pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZstdCodecError> {
260    let (header, encoded) =
261        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
262            ZstdCodecError::HeaderDecodeFailed {
263                source: ZstdHeaderError(err),
264            }
265        })?;
266
267    let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
268        decompress_into_bytes(encoded, decoded)
269    });
270
271    result.map(|()| decoded)
272}
273
274/// Decompress the `encoded` data into a `decoded` array using Zstandard.
275///
276/// # Errors
277///
278/// Errors with
279/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
280/// - [`ZstdCodecError::MismatchedDecodeIntoArray`] if the `decoded` array is of
281///   the wrong dtype or shape
282/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
283/// - [`ZstdCodecError::DecodeExcessiveEncodedData`] if the encoded data
284///   contains excessive trailing data junk
285/// - [`ZstdCodecError::DecodeProducedLess`] if decoding produced less data than
286///   expected
287/// - [`ZstdCodecError::ZstdDecodeFailed`] if an opaque decoding error occurred
288pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZstdCodecError> {
289    let (header, encoded) =
290        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
291            ZstdCodecError::HeaderDecodeFailed {
292                source: ZstdHeaderError(err),
293            }
294        })?;
295
296    if header.dtype != decoded.dtype() {
297        return Err(ZstdCodecError::MismatchedDecodeIntoArray {
298            source: AnyArrayAssignError::DTypeMismatch {
299                src: header.dtype,
300                dst: decoded.dtype(),
301            },
302        });
303    }
304
305    if header.shape != decoded.shape() {
306        return Err(ZstdCodecError::MismatchedDecodeIntoArray {
307            source: AnyArrayAssignError::ShapeMismatch {
308                src: header.shape.into_owned(),
309                dst: decoded.shape().to_vec(),
310            },
311        });
312    }
313
314    decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
315}
316
317fn decompress_into_bytes(mut encoded: &[u8], mut decoded: &mut [u8]) -> Result<(), ZstdCodecError> {
318    // we want to check encoded and decoded for full consumption after the decoding
319    zstd::stream::copy_decode(&mut encoded, &mut decoded).map_err(|err| {
320        ZstdCodecError::ZstdDecodeFailed {
321            source: ZstdCodingError(err),
322        }
323    })?;
324
325    if !encoded.is_empty() {
326        return Err(ZstdCodecError::DecodeExcessiveEncodedData);
327    }
328
329    if !decoded.is_empty() {
330        return Err(ZstdCodecError::DecodeProducedLess);
331    }
332
333    Ok(())
334}
335
336#[derive(Serialize, Deserialize)]
337struct CompressionHeader<'a> {
338    dtype: AnyArrayDType,
339    #[serde(borrow)]
340    shape: Cow<'a, [usize]>,
341    version: ZstdCodecVersion,
342}