1#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, io};
23
24use ndarray::Array1;
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Deserializer, Serialize, Serializer};
31use thiserror::Error;
32use zstd_sys as _;
34
35#[derive(Clone, Serialize, Deserialize, JsonSchema)]
36#[serde(deny_unknown_fields)]
37pub struct ZstdCodec {
39 pub level: ZstdLevel,
43}
44
45impl Codec for ZstdCodec {
46 type Error = ZstdCodecError;
47
48 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
49 compress(data.view(), self.level)
50 .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
51 }
52
53 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
54 let AnyCowArray::U8(encoded) = encoded else {
55 return Err(ZstdCodecError::EncodedDataNotBytes {
56 dtype: encoded.dtype(),
57 });
58 };
59
60 if !matches!(encoded.shape(), [_]) {
61 return Err(ZstdCodecError::EncodedDataNotOneDimensional {
62 shape: encoded.shape().to_vec(),
63 });
64 }
65
66 decompress(&AnyCowArray::U8(encoded).as_bytes())
67 }
68
69 fn decode_into(
70 &self,
71 encoded: AnyArrayView,
72 decoded: AnyArrayViewMut,
73 ) -> Result<(), Self::Error> {
74 let AnyArrayView::U8(encoded) = encoded else {
75 return Err(ZstdCodecError::EncodedDataNotBytes {
76 dtype: encoded.dtype(),
77 });
78 };
79
80 if !matches!(encoded.shape(), [_]) {
81 return Err(ZstdCodecError::EncodedDataNotOneDimensional {
82 shape: encoded.shape().to_vec(),
83 });
84 }
85
86 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
87 }
88}
89
90impl StaticCodec for ZstdCodec {
91 const CODEC_ID: &'static str = "zstd";
92
93 type Config<'de> = Self;
94
95 fn from_config(config: Self::Config<'_>) -> Self {
96 config
97 }
98
99 fn get_config(&self) -> StaticCodecConfig<Self> {
100 StaticCodecConfig::from(self)
101 }
102}
103
104#[derive(Clone, Copy, JsonSchema)]
105#[schemars(transparent)]
106pub struct ZstdLevel {
110 level: zstd::zstd_safe::CompressionLevel,
111}
112
113impl Serialize for ZstdLevel {
114 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
115 self.level.serialize(serializer)
116 }
117}
118
119impl<'de> Deserialize<'de> for ZstdLevel {
120 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
121 let level = Deserialize::deserialize(deserializer)?;
122
123 let level_range = zstd::compression_level_range();
124
125 if !level_range.contains(&level) {
126 return Err(serde::de::Error::custom(format!(
127 "level {level} is not in {}..={}",
128 level_range.start(),
129 level_range.end()
130 )));
131 }
132
133 Ok(Self { level })
134 }
135}
136
137#[derive(Debug, Error)]
138pub enum ZstdCodecError {
140 #[error("Zstd failed to encode the header")]
142 HeaderEncodeFailed {
143 source: ZstdHeaderError,
145 },
146 #[error("Zstd failed to decode the encoded data")]
148 ZstdEncodeFailed {
149 source: ZstdCodingError,
151 },
152 #[error(
155 "Zstd can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
156 )]
157 EncodedDataNotBytes {
158 dtype: AnyArrayDType,
160 },
161 #[error("Zstd can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
164 EncodedDataNotOneDimensional {
165 shape: Vec<usize>,
167 },
168 #[error("Zstd failed to decode the header")]
170 HeaderDecodeFailed {
171 source: ZstdHeaderError,
173 },
174 #[error("Zstd decode consumed less encoded data, which contains trailing junk")]
177 DecodeExcessiveEncodedData,
178 #[error("Zstd produced less decoded data than expected")]
180 DecodeProducedLess,
181 #[error("Zstd failed to decode the encoded data")]
183 ZstdDecodeFailed {
184 source: ZstdCodingError,
186 },
187 #[error("Zstd cannot decode into the provided array")]
189 MismatchedDecodeIntoArray {
190 #[from]
192 source: AnyArrayAssignError,
193 },
194}
195
196#[derive(Debug, Error)]
197#[error(transparent)]
198pub struct ZstdHeaderError(postcard::Error);
200
201#[derive(Debug, Error)]
202#[error(transparent)]
203pub struct ZstdCodingError(io::Error);
205
206#[expect(clippy::needless_pass_by_value)]
207pub fn compress(array: AnyArrayView, level: ZstdLevel) -> Result<Vec<u8>, ZstdCodecError> {
220 let mut encoded = postcard::to_extend(
221 &CompressionHeader {
222 dtype: array.dtype(),
223 shape: Cow::Borrowed(array.shape()),
224 },
225 Vec::new(),
226 )
227 .map_err(|err| ZstdCodecError::HeaderEncodeFailed {
228 source: ZstdHeaderError(err),
229 })?;
230
231 zstd::stream::copy_encode(&*array.as_bytes(), &mut encoded, level.level).map_err(|err| {
232 ZstdCodecError::ZstdEncodeFailed {
233 source: ZstdCodingError(err),
234 }
235 })?;
236
237 Ok(encoded)
238}
239
240pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZstdCodecError> {
252 let (header, encoded) =
253 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
254 ZstdCodecError::HeaderDecodeFailed {
255 source: ZstdHeaderError(err),
256 }
257 })?;
258
259 let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
260 decompress_into_bytes(encoded, decoded)
261 });
262
263 result.map(|()| decoded)
264}
265
266pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZstdCodecError> {
281 let (header, encoded) =
282 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
283 ZstdCodecError::HeaderDecodeFailed {
284 source: ZstdHeaderError(err),
285 }
286 })?;
287
288 if header.dtype != decoded.dtype() {
289 return Err(ZstdCodecError::MismatchedDecodeIntoArray {
290 source: AnyArrayAssignError::DTypeMismatch {
291 src: header.dtype,
292 dst: decoded.dtype(),
293 },
294 });
295 }
296
297 if header.shape != decoded.shape() {
298 return Err(ZstdCodecError::MismatchedDecodeIntoArray {
299 source: AnyArrayAssignError::ShapeMismatch {
300 src: header.shape.into_owned(),
301 dst: decoded.shape().to_vec(),
302 },
303 });
304 }
305
306 decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
307}
308
309fn decompress_into_bytes(mut encoded: &[u8], mut decoded: &mut [u8]) -> Result<(), ZstdCodecError> {
310 zstd::stream::copy_decode(&mut encoded, &mut decoded).map_err(|err| {
312 ZstdCodecError::ZstdDecodeFailed {
313 source: ZstdCodingError(err),
314 }
315 })?;
316
317 if !encoded.is_empty() {
318 return Err(ZstdCodecError::DecodeExcessiveEncodedData);
319 }
320
321 if !decoded.is_empty() {
322 return Err(ZstdCodecError::DecodeProducedLess);
323 }
324
325 Ok(())
326}
327
328#[derive(Serialize, Deserialize)]
329struct CompressionHeader<'a> {
330 dtype: AnyArrayDType,
331 #[serde(borrow)]
332 shape: Cow<'a, [usize]>,
333}