numcodecs_zstd/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-zstd
10//! [crates.io]: https://crates.io/crates/numcodecs-zstd
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-zstd
13//! [docs.rs]: https://docs.rs/numcodecs-zstd/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_zstd
17//!
18//! Zstandard codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)] // embedded-io
21
22use std::{borrow::Cow, io};
23
24use ndarray::Array1;
25use numcodecs::{
26    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27    Codec, StaticCodec, StaticCodecConfig,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Deserializer, Serialize, Serializer};
31use thiserror::Error;
32// Only used to explicitly enable the `no_wasm_shim` feature in zstd/zstd-sys
33use zstd_sys as _;
34
35#[derive(Clone, Serialize, Deserialize, JsonSchema)]
36#[serde(deny_unknown_fields)]
37/// Codec providing compression using Zstandard
38pub struct ZstdCodec {
39    /// Zstandard compression level.
40    ///
41    /// The level ranges from small (fastest) to large (best compression).
42    pub level: ZstdLevel,
43}
44
45impl Codec for ZstdCodec {
46    type Error = ZstdCodecError;
47
48    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
49        compress(data.view(), self.level)
50            .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
51    }
52
53    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
54        let AnyCowArray::U8(encoded) = encoded else {
55            return Err(ZstdCodecError::EncodedDataNotBytes {
56                dtype: encoded.dtype(),
57            });
58        };
59
60        if !matches!(encoded.shape(), [_]) {
61            return Err(ZstdCodecError::EncodedDataNotOneDimensional {
62                shape: encoded.shape().to_vec(),
63            });
64        }
65
66        decompress(&AnyCowArray::U8(encoded).as_bytes())
67    }
68
69    fn decode_into(
70        &self,
71        encoded: AnyArrayView,
72        decoded: AnyArrayViewMut,
73    ) -> Result<(), Self::Error> {
74        let AnyArrayView::U8(encoded) = encoded else {
75            return Err(ZstdCodecError::EncodedDataNotBytes {
76                dtype: encoded.dtype(),
77            });
78        };
79
80        if !matches!(encoded.shape(), [_]) {
81            return Err(ZstdCodecError::EncodedDataNotOneDimensional {
82                shape: encoded.shape().to_vec(),
83            });
84        }
85
86        decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
87    }
88}
89
90impl StaticCodec for ZstdCodec {
91    const CODEC_ID: &'static str = "zstd";
92
93    type Config<'de> = Self;
94
95    fn from_config(config: Self::Config<'_>) -> Self {
96        config
97    }
98
99    fn get_config(&self) -> StaticCodecConfig<Self> {
100        StaticCodecConfig::from(self)
101    }
102}
103
104#[derive(Clone, Copy, JsonSchema)]
105#[schemars(transparent)]
106/// Zstandard compression level.
107///
108/// The level ranges from small (fastest) to large (best compression).
109pub struct ZstdLevel {
110    level: zstd::zstd_safe::CompressionLevel,
111}
112
113impl Serialize for ZstdLevel {
114    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
115        self.level.serialize(serializer)
116    }
117}
118
119impl<'de> Deserialize<'de> for ZstdLevel {
120    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
121        let level = Deserialize::deserialize(deserializer)?;
122
123        let level_range = zstd::compression_level_range();
124
125        if !level_range.contains(&level) {
126            return Err(serde::de::Error::custom(format!(
127                "level {level} is not in {}..={}",
128                level_range.start(),
129                level_range.end()
130            )));
131        }
132
133        Ok(Self { level })
134    }
135}
136
137#[derive(Debug, Error)]
138/// Errors that may occur when applying the [`ZstdCodec`].
139pub enum ZstdCodecError {
140    /// [`ZstdCodec`] failed to encode the header
141    #[error("Zstd failed to encode the header")]
142    HeaderEncodeFailed {
143        /// Opaque source error
144        source: ZstdHeaderError,
145    },
146    /// [`ZstdCodec`] failed to encode the encoded data
147    #[error("Zstd failed to decode the encoded data")]
148    ZstdEncodeFailed {
149        /// Opaque source error
150        source: ZstdCodingError,
151    },
152    /// [`ZstdCodec`] can only decode one-dimensional byte arrays but received
153    /// an array of a different dtype
154    #[error(
155        "Zstd can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
156    )]
157    EncodedDataNotBytes {
158        /// The unexpected dtype of the encoded array
159        dtype: AnyArrayDType,
160    },
161    /// [`ZstdCodec`] can only decode one-dimensional byte arrays but received
162    /// an array of a different shape
163    #[error("Zstd can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
164    EncodedDataNotOneDimensional {
165        /// The unexpected shape of the encoded array
166        shape: Vec<usize>,
167    },
168    /// [`ZstdCodec`] failed to encode the header
169    #[error("Zstd failed to decode the header")]
170    HeaderDecodeFailed {
171        /// Opaque source error
172        source: ZstdHeaderError,
173    },
174    /// [`ZstdCodec`] decode consumed less encoded data, which contains trailing
175    /// junk
176    #[error("Zstd decode consumed less encoded data, which contains trailing junk")]
177    DecodeExcessiveEncodedData,
178    /// [`ZstdCodec`] produced less decoded data than expected
179    #[error("Zstd produced less decoded data than expected")]
180    DecodeProducedLess,
181    /// [`ZstdCodec`] failed to decode the encoded data
182    #[error("Zstd failed to decode the encoded data")]
183    ZstdDecodeFailed {
184        /// Opaque source error
185        source: ZstdCodingError,
186    },
187    /// [`ZstdCodec`] cannot decode into the provided array
188    #[error("Zstd cannot decode into the provided array")]
189    MismatchedDecodeIntoArray {
190        /// The source of the error
191        #[from]
192        source: AnyArrayAssignError,
193    },
194}
195
196#[derive(Debug, Error)]
197#[error(transparent)]
198/// Opaque error for when encoding or decoding the header fails
199pub struct ZstdHeaderError(postcard::Error);
200
201#[derive(Debug, Error)]
202#[error(transparent)]
203/// Opaque error for when encoding or decoding with Zstandard fails
204pub struct ZstdCodingError(io::Error);
205
206#[expect(clippy::needless_pass_by_value)]
207/// Compress the `array` using Zstandard with the provided `level`.
208///
209/// # Errors
210///
211/// Errors with
212/// - [`ZstdCodecError::HeaderEncodeFailed`] if encoding the header to the
213///   output bytevec failed
214/// - [`ZstdCodecError::ZstdEncodeFailed`] if an opaque encoding error occurred
215///
216/// # Panics
217///
218/// Panics if the infallible encoding with Zstd fails.
219pub fn compress(array: AnyArrayView, level: ZstdLevel) -> Result<Vec<u8>, ZstdCodecError> {
220    let mut encoded = postcard::to_extend(
221        &CompressionHeader {
222            dtype: array.dtype(),
223            shape: Cow::Borrowed(array.shape()),
224        },
225        Vec::new(),
226    )
227    .map_err(|err| ZstdCodecError::HeaderEncodeFailed {
228        source: ZstdHeaderError(err),
229    })?;
230
231    zstd::stream::copy_encode(&*array.as_bytes(), &mut encoded, level.level).map_err(|err| {
232        ZstdCodecError::ZstdEncodeFailed {
233            source: ZstdCodingError(err),
234        }
235    })?;
236
237    Ok(encoded)
238}
239
240/// Decompress the `encoded` data into an array using Zstandard.
241///
242/// # Errors
243///
244/// Errors with
245/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
246/// - [`ZstdCodecError::DecodeExcessiveEncodedData`] if the encoded data
247///   contains excessive trailing data junk
248/// - [`ZstdCodecError::DecodeProducedLess`] if decoding produced less data than
249///   expected
250/// - [`ZstdCodecError::ZstdDecodeFailed`] if an opaque decoding error occurred
251pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZstdCodecError> {
252    let (header, encoded) =
253        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
254            ZstdCodecError::HeaderDecodeFailed {
255                source: ZstdHeaderError(err),
256            }
257        })?;
258
259    let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
260        decompress_into_bytes(encoded, decoded)
261    });
262
263    result.map(|()| decoded)
264}
265
266/// Decompress the `encoded` data into a `decoded` array using Zstandard.
267///
268/// # Errors
269///
270/// Errors with
271/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
272/// - [`ZstdCodecError::MismatchedDecodeIntoArray`] if the `decoded` array is of
273///   the wrong dtype or shape
274/// - [`ZstdCodecError::HeaderDecodeFailed`] if decoding the header failed
275/// - [`ZstdCodecError::DecodeExcessiveEncodedData`] if the encoded data
276///   contains excessive trailing data junk
277/// - [`ZstdCodecError::DecodeProducedLess`] if decoding produced less data than
278///   expected
279/// - [`ZstdCodecError::ZstdDecodeFailed`] if an opaque decoding error occurred
280pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZstdCodecError> {
281    let (header, encoded) =
282        postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
283            ZstdCodecError::HeaderDecodeFailed {
284                source: ZstdHeaderError(err),
285            }
286        })?;
287
288    if header.dtype != decoded.dtype() {
289        return Err(ZstdCodecError::MismatchedDecodeIntoArray {
290            source: AnyArrayAssignError::DTypeMismatch {
291                src: header.dtype,
292                dst: decoded.dtype(),
293            },
294        });
295    }
296
297    if header.shape != decoded.shape() {
298        return Err(ZstdCodecError::MismatchedDecodeIntoArray {
299            source: AnyArrayAssignError::ShapeMismatch {
300                src: header.shape.into_owned(),
301                dst: decoded.shape().to_vec(),
302            },
303        });
304    }
305
306    decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
307}
308
309fn decompress_into_bytes(mut encoded: &[u8], mut decoded: &mut [u8]) -> Result<(), ZstdCodecError> {
310    // we want to check encoded and decoded for full consumption after the decoding
311    zstd::stream::copy_decode(&mut encoded, &mut decoded).map_err(|err| {
312        ZstdCodecError::ZstdDecodeFailed {
313            source: ZstdCodingError(err),
314        }
315    })?;
316
317    if !encoded.is_empty() {
318        return Err(ZstdCodecError::DecodeExcessiveEncodedData);
319    }
320
321    if !decoded.is_empty() {
322        return Err(ZstdCodecError::DecodeProducedLess);
323    }
324
325    Ok(())
326}
327
328#[derive(Serialize, Deserialize)]
329struct CompressionHeader<'a> {
330    dtype: AnyArrayDType,
331    #[serde(borrow)]
332    shape: Cow<'a, [usize]>,
333}