1#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, io};
23
24use ndarray::Array1;
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Deserializer, Serialize, Serializer};
31use thiserror::Error;
32use zstd_sys as _;
34
35type ZstdCodecVersion = StaticCodecVersion<0, 1, 0>;
36
37#[derive(Clone, Serialize, Deserialize, JsonSchema)]
38#[serde(deny_unknown_fields)]
39pub struct ZstdCodec {
41 pub level: ZstdLevel,
45 #[serde(default, rename = "_version")]
47 pub version: ZstdCodecVersion,
48}
49
50impl Codec for ZstdCodec {
51 type Error = ZstdCodecError;
52
53 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
54 compress(data.view(), self.level)
55 .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
56 }
57
58 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
59 let AnyCowArray::U8(encoded) = encoded else {
60 return Err(ZstdCodecError::EncodedDataNotBytes {
61 dtype: encoded.dtype(),
62 });
63 };
64
65 if !matches!(encoded.shape(), [_]) {
66 return Err(ZstdCodecError::EncodedDataNotOneDimensional {
67 shape: encoded.shape().to_vec(),
68 });
69 }
70
71 decompress(&AnyCowArray::U8(encoded).as_bytes())
72 }
73
74 fn decode_into(
75 &self,
76 encoded: AnyArrayView,
77 decoded: AnyArrayViewMut,
78 ) -> Result<(), Self::Error> {
79 let AnyArrayView::U8(encoded) = encoded else {
80 return Err(ZstdCodecError::EncodedDataNotBytes {
81 dtype: encoded.dtype(),
82 });
83 };
84
85 if !matches!(encoded.shape(), [_]) {
86 return Err(ZstdCodecError::EncodedDataNotOneDimensional {
87 shape: encoded.shape().to_vec(),
88 });
89 }
90
91 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
92 }
93}
94
95impl StaticCodec for ZstdCodec {
96 const CODEC_ID: &'static str = "zstd.rs";
97
98 type Config<'de> = Self;
99
100 fn from_config(config: Self::Config<'_>) -> Self {
101 config
102 }
103
104 fn get_config(&self) -> StaticCodecConfig<Self> {
105 StaticCodecConfig::from(self)
106 }
107}
108
109#[derive(Clone, Copy, JsonSchema)]
110#[schemars(transparent)]
111pub struct ZstdLevel {
115 level: zstd::zstd_safe::CompressionLevel,
116}
117
118impl Serialize for ZstdLevel {
119 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
120 self.level.serialize(serializer)
121 }
122}
123
124impl<'de> Deserialize<'de> for ZstdLevel {
125 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
126 let level = Deserialize::deserialize(deserializer)?;
127
128 let level_range = zstd::compression_level_range();
129
130 if !level_range.contains(&level) {
131 return Err(serde::de::Error::custom(format!(
132 "level {level} is not in {}..={}",
133 level_range.start(),
134 level_range.end()
135 )));
136 }
137
138 Ok(Self { level })
139 }
140}
141
142#[derive(Debug, Error)]
143pub enum ZstdCodecError {
145 #[error("Zstd failed to encode the header")]
147 HeaderEncodeFailed {
148 source: ZstdHeaderError,
150 },
151 #[error("Zstd failed to decode the encoded data")]
153 ZstdEncodeFailed {
154 source: ZstdCodingError,
156 },
157 #[error(
160 "Zstd can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
161 )]
162 EncodedDataNotBytes {
163 dtype: AnyArrayDType,
165 },
166 #[error(
169 "Zstd can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
170 )]
171 EncodedDataNotOneDimensional {
172 shape: Vec<usize>,
174 },
175 #[error("Zstd failed to decode the header")]
177 HeaderDecodeFailed {
178 source: ZstdHeaderError,
180 },
181 #[error("Zstd decode consumed less encoded data, which contains trailing junk")]
184 DecodeExcessiveEncodedData,
185 #[error("Zstd produced less decoded data than expected")]
187 DecodeProducedLess,
188 #[error("Zstd failed to decode the encoded data")]
190 ZstdDecodeFailed {
191 source: ZstdCodingError,
193 },
194 #[error("Zstd cannot decode into the provided array")]
196 MismatchedDecodeIntoArray {
197 #[from]
199 source: AnyArrayAssignError,
200 },
201}
202
203#[derive(Debug, Error)]
204#[error(transparent)]
205pub struct ZstdHeaderError(postcard::Error);
207
208#[derive(Debug, Error)]
209#[error(transparent)]
210pub struct ZstdCodingError(io::Error);
212
213#[expect(clippy::needless_pass_by_value)]
214pub fn compress(array: AnyArrayView, level: ZstdLevel) -> Result<Vec<u8>, ZstdCodecError> {
227 let mut encoded = postcard::to_extend(
228 &CompressionHeader {
229 dtype: array.dtype(),
230 shape: Cow::Borrowed(array.shape()),
231 version: StaticCodecVersion,
232 },
233 Vec::new(),
234 )
235 .map_err(|err| ZstdCodecError::HeaderEncodeFailed {
236 source: ZstdHeaderError(err),
237 })?;
238
239 zstd::stream::copy_encode(&*array.as_bytes(), &mut encoded, level.level).map_err(|err| {
240 ZstdCodecError::ZstdEncodeFailed {
241 source: ZstdCodingError(err),
242 }
243 })?;
244
245 Ok(encoded)
246}
247
248pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZstdCodecError> {
260 let (header, encoded) =
261 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
262 ZstdCodecError::HeaderDecodeFailed {
263 source: ZstdHeaderError(err),
264 }
265 })?;
266
267 let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
268 decompress_into_bytes(encoded, decoded)
269 });
270
271 result.map(|()| decoded)
272}
273
274pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZstdCodecError> {
289 let (header, encoded) =
290 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
291 ZstdCodecError::HeaderDecodeFailed {
292 source: ZstdHeaderError(err),
293 }
294 })?;
295
296 if header.dtype != decoded.dtype() {
297 return Err(ZstdCodecError::MismatchedDecodeIntoArray {
298 source: AnyArrayAssignError::DTypeMismatch {
299 src: header.dtype,
300 dst: decoded.dtype(),
301 },
302 });
303 }
304
305 if header.shape != decoded.shape() {
306 return Err(ZstdCodecError::MismatchedDecodeIntoArray {
307 source: AnyArrayAssignError::ShapeMismatch {
308 src: header.shape.into_owned(),
309 dst: decoded.shape().to_vec(),
310 },
311 });
312 }
313
314 decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
315}
316
317fn decompress_into_bytes(mut encoded: &[u8], mut decoded: &mut [u8]) -> Result<(), ZstdCodecError> {
318 zstd::stream::copy_decode(&mut encoded, &mut decoded).map_err(|err| {
320 ZstdCodecError::ZstdDecodeFailed {
321 source: ZstdCodingError(err),
322 }
323 })?;
324
325 if !encoded.is_empty() {
326 return Err(ZstdCodecError::DecodeExcessiveEncodedData);
327 }
328
329 if !decoded.is_empty() {
330 return Err(ZstdCodecError::DecodeProducedLess);
331 }
332
333 Ok(())
334}
335
336#[derive(Serialize, Deserialize)]
337struct CompressionHeader<'a> {
338 dtype: AnyArrayDType,
339 #[serde(borrow)]
340 shape: Cow<'a, [usize]>,
341 version: ZstdCodecVersion,
342}