1#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, io};
23
24use ndarray::Array1;
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Deserializer, Serialize, Serializer};
31use thiserror::Error;
32use zstd_sys as _;
34
35type ZstdCodecVersion = StaticCodecVersion<0, 1, 0>;
36
37#[derive(Clone, Serialize, Deserialize, JsonSchema)]
38#[serde(deny_unknown_fields)]
39pub struct ZstdCodec {
41 pub level: ZstdLevel,
45 #[serde(default, rename = "_version")]
47 pub version: ZstdCodecVersion,
48}
49
50impl Codec for ZstdCodec {
51 type Error = ZstdCodecError;
52
53 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
54 compress(data.view(), self.level)
55 .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
56 }
57
58 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
59 let AnyCowArray::U8(encoded) = encoded else {
60 return Err(ZstdCodecError::EncodedDataNotBytes {
61 dtype: encoded.dtype(),
62 });
63 };
64
65 if !matches!(encoded.shape(), [_]) {
66 return Err(ZstdCodecError::EncodedDataNotOneDimensional {
67 shape: encoded.shape().to_vec(),
68 });
69 }
70
71 decompress(&AnyCowArray::U8(encoded).as_bytes())
72 }
73
74 fn decode_into(
75 &self,
76 encoded: AnyArrayView,
77 decoded: AnyArrayViewMut,
78 ) -> Result<(), Self::Error> {
79 let AnyArrayView::U8(encoded) = encoded else {
80 return Err(ZstdCodecError::EncodedDataNotBytes {
81 dtype: encoded.dtype(),
82 });
83 };
84
85 if !matches!(encoded.shape(), [_]) {
86 return Err(ZstdCodecError::EncodedDataNotOneDimensional {
87 shape: encoded.shape().to_vec(),
88 });
89 }
90
91 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
92 }
93}
94
95impl StaticCodec for ZstdCodec {
96 const CODEC_ID: &'static str = "zstd.rs";
97
98 type Config<'de> = Self;
99
100 fn from_config(config: Self::Config<'_>) -> Self {
101 config
102 }
103
104 fn get_config(&self) -> StaticCodecConfig<Self> {
105 StaticCodecConfig::from(self)
106 }
107}
108
109#[derive(Clone, Copy, JsonSchema)]
110#[schemars(transparent)]
111pub struct ZstdLevel {
115 level: zstd::zstd_safe::CompressionLevel,
116}
117
118impl Serialize for ZstdLevel {
119 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
120 self.level.serialize(serializer)
121 }
122}
123
124impl<'de> Deserialize<'de> for ZstdLevel {
125 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
126 let level = Deserialize::deserialize(deserializer)?;
127
128 let level_range = zstd::compression_level_range();
129
130 if !level_range.contains(&level) {
131 return Err(serde::de::Error::custom(format!(
132 "level {level} is not in {}..={}",
133 level_range.start(),
134 level_range.end()
135 )));
136 }
137
138 Ok(Self { level })
139 }
140}
141
142#[derive(Debug, Error)]
143pub enum ZstdCodecError {
145 #[error("Zstd failed to encode the header")]
147 HeaderEncodeFailed {
148 source: ZstdHeaderError,
150 },
151 #[error("Zstd failed to decode the encoded data")]
153 ZstdEncodeFailed {
154 source: ZstdCodingError,
156 },
157 #[error(
160 "Zstd can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
161 )]
162 EncodedDataNotBytes {
163 dtype: AnyArrayDType,
165 },
166 #[error("Zstd can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
169 EncodedDataNotOneDimensional {
170 shape: Vec<usize>,
172 },
173 #[error("Zstd failed to decode the header")]
175 HeaderDecodeFailed {
176 source: ZstdHeaderError,
178 },
179 #[error("Zstd decode consumed less encoded data, which contains trailing junk")]
182 DecodeExcessiveEncodedData,
183 #[error("Zstd produced less decoded data than expected")]
185 DecodeProducedLess,
186 #[error("Zstd failed to decode the encoded data")]
188 ZstdDecodeFailed {
189 source: ZstdCodingError,
191 },
192 #[error("Zstd cannot decode into the provided array")]
194 MismatchedDecodeIntoArray {
195 #[from]
197 source: AnyArrayAssignError,
198 },
199}
200
201#[derive(Debug, Error)]
202#[error(transparent)]
203pub struct ZstdHeaderError(postcard::Error);
205
206#[derive(Debug, Error)]
207#[error(transparent)]
208pub struct ZstdCodingError(io::Error);
210
211#[expect(clippy::needless_pass_by_value)]
212pub fn compress(array: AnyArrayView, level: ZstdLevel) -> Result<Vec<u8>, ZstdCodecError> {
225 let mut encoded = postcard::to_extend(
226 &CompressionHeader {
227 dtype: array.dtype(),
228 shape: Cow::Borrowed(array.shape()),
229 version: StaticCodecVersion,
230 },
231 Vec::new(),
232 )
233 .map_err(|err| ZstdCodecError::HeaderEncodeFailed {
234 source: ZstdHeaderError(err),
235 })?;
236
237 zstd::stream::copy_encode(&*array.as_bytes(), &mut encoded, level.level).map_err(|err| {
238 ZstdCodecError::ZstdEncodeFailed {
239 source: ZstdCodingError(err),
240 }
241 })?;
242
243 Ok(encoded)
244}
245
246pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZstdCodecError> {
258 let (header, encoded) =
259 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
260 ZstdCodecError::HeaderDecodeFailed {
261 source: ZstdHeaderError(err),
262 }
263 })?;
264
265 let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
266 decompress_into_bytes(encoded, decoded)
267 });
268
269 result.map(|()| decoded)
270}
271
272pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZstdCodecError> {
287 let (header, encoded) =
288 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
289 ZstdCodecError::HeaderDecodeFailed {
290 source: ZstdHeaderError(err),
291 }
292 })?;
293
294 if header.dtype != decoded.dtype() {
295 return Err(ZstdCodecError::MismatchedDecodeIntoArray {
296 source: AnyArrayAssignError::DTypeMismatch {
297 src: header.dtype,
298 dst: decoded.dtype(),
299 },
300 });
301 }
302
303 if header.shape != decoded.shape() {
304 return Err(ZstdCodecError::MismatchedDecodeIntoArray {
305 source: AnyArrayAssignError::ShapeMismatch {
306 src: header.shape.into_owned(),
307 dst: decoded.shape().to_vec(),
308 },
309 });
310 }
311
312 decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
313}
314
315fn decompress_into_bytes(mut encoded: &[u8], mut decoded: &mut [u8]) -> Result<(), ZstdCodecError> {
316 zstd::stream::copy_decode(&mut encoded, &mut decoded).map_err(|err| {
318 ZstdCodecError::ZstdDecodeFailed {
319 source: ZstdCodingError(err),
320 }
321 })?;
322
323 if !encoded.is_empty() {
324 return Err(ZstdCodecError::DecodeExcessiveEncodedData);
325 }
326
327 if !decoded.is_empty() {
328 return Err(ZstdCodecError::DecodeProducedLess);
329 }
330
331 Ok(())
332}
333
334#[derive(Serialize, Deserialize)]
335struct CompressionHeader<'a> {
336 dtype: AnyArrayDType,
337 #[serde(borrow)]
338 shape: Cow<'a, [usize]>,
339 version: ZstdCodecVersion,
340}