1#![expect(clippy::multiple_crate_versions)] use std::borrow::Cow;
23
24use ndarray::Array1;
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::{JsonSchema, JsonSchema_repr};
30use serde::{Deserialize, Serialize};
31use serde_repr::{Deserialize_repr, Serialize_repr};
32use thiserror::Error;
33
34type ZlibCodecVersion = StaticCodecVersion<0, 1, 0>;
35
36#[derive(Clone, Serialize, Deserialize, JsonSchema)]
37#[serde(deny_unknown_fields)]
38pub struct ZlibCodec {
40 pub level: ZlibLevel,
44 #[serde(default, rename = "_version")]
46 pub version: ZlibCodecVersion,
47}
48
49#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
50#[repr(u8)]
51#[expect(missing_docs)]
55pub enum ZlibLevel {
56 ZNoCompression = 0,
57 ZBestSpeed = 1,
58 ZLevel2 = 2,
59 ZLevel3 = 3,
60 ZLevel4 = 4,
61 ZLevel5 = 5,
62 ZLevel6 = 6,
63 ZLevel7 = 7,
64 ZLevel8 = 8,
65 ZBestCompression = 9,
66}
67
68impl Codec for ZlibCodec {
69 type Error = ZlibCodecError;
70
71 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
72 compress(data.view(), self.level)
73 .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
74 }
75
76 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
77 let AnyCowArray::U8(encoded) = encoded else {
78 return Err(ZlibCodecError::EncodedDataNotBytes {
79 dtype: encoded.dtype(),
80 });
81 };
82
83 if !matches!(encoded.shape(), [_]) {
84 return Err(ZlibCodecError::EncodedDataNotOneDimensional {
85 shape: encoded.shape().to_vec(),
86 });
87 }
88
89 decompress(&AnyCowArray::U8(encoded).as_bytes())
90 }
91
92 fn decode_into(
93 &self,
94 encoded: AnyArrayView,
95 decoded: AnyArrayViewMut,
96 ) -> Result<(), Self::Error> {
97 let AnyArrayView::U8(encoded) = encoded else {
98 return Err(ZlibCodecError::EncodedDataNotBytes {
99 dtype: encoded.dtype(),
100 });
101 };
102
103 if !matches!(encoded.shape(), [_]) {
104 return Err(ZlibCodecError::EncodedDataNotOneDimensional {
105 shape: encoded.shape().to_vec(),
106 });
107 }
108
109 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
110 }
111}
112
113impl StaticCodec for ZlibCodec {
114 const CODEC_ID: &'static str = "zlib.rs";
115
116 type Config<'de> = Self;
117
118 fn from_config(config: Self::Config<'_>) -> Self {
119 config
120 }
121
122 fn get_config(&self) -> StaticCodecConfig<Self> {
123 StaticCodecConfig::from(self)
124 }
125}
126
127#[derive(Debug, Error)]
128pub enum ZlibCodecError {
130 #[error("Zlib failed to encode the header")]
132 HeaderEncodeFailed {
133 source: ZlibHeaderError,
135 },
136 #[error(
139 "Zlib can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
140 )]
141 EncodedDataNotBytes {
142 dtype: AnyArrayDType,
144 },
145 #[error(
148 "Zlib can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
149 )]
150 EncodedDataNotOneDimensional {
151 shape: Vec<usize>,
153 },
154 #[error("Zlib failed to decode the header")]
156 HeaderDecodeFailed {
157 source: ZlibHeaderError,
159 },
160 #[error("Zlib decode consumed less encoded data, which contains trailing junk")]
163 DecodeExcessiveEncodedData,
164 #[error("Zlib produced less decoded data than expected")]
166 DecodeProducedLess,
167 #[error("Zlib failed to decode the encoded data")]
169 ZlibDecodeFailed {
170 source: ZlibDecodeError,
172 },
173 #[error("Zlib cannot decode into the provided array")]
175 MismatchedDecodeIntoArray {
176 #[from]
178 source: AnyArrayAssignError,
179 },
180}
181
182#[derive(Debug, Error)]
183#[error(transparent)]
184pub struct ZlibHeaderError(postcard::Error);
186
187#[derive(Debug, Error)]
188#[error(transparent)]
189pub struct ZlibDecodeError(miniz_oxide::inflate::DecompressError);
191
192#[expect(clippy::needless_pass_by_value)]
193pub fn compress(array: AnyArrayView, level: ZlibLevel) -> Result<Vec<u8>, ZlibCodecError> {
204 let data = array.as_bytes();
205
206 let mut encoded = postcard::to_extend(
207 &CompressionHeader {
208 dtype: array.dtype(),
209 shape: Cow::Borrowed(array.shape()),
210 version: StaticCodecVersion,
211 },
212 Vec::new(),
213 )
214 .map_err(|err| ZlibCodecError::HeaderEncodeFailed {
215 source: ZlibHeaderError(err),
216 })?;
217
218 let mut in_pos = 0;
219 let mut out_pos = encoded.len();
220
221 let flags =
224 miniz_oxide::deflate::core::create_comp_flags_from_zip_params((level as u8).into(), 1, 0);
225 let mut compressor = miniz_oxide::deflate::core::CompressorOxide::new(flags);
226 encoded.resize(encoded.len() + (data.len() / 2).max(2), 0);
227
228 loop {
229 let (Some(data_left), Some(encoded_left)) =
230 (data.get(in_pos..), encoded.get_mut(out_pos..))
231 else {
232 #[expect(clippy::panic)] {
234 panic!("Zlib encode bug: input or output is out of bounds")
235 }
236 };
237
238 let (status, bytes_in, bytes_out) = miniz_oxide::deflate::core::compress(
239 &mut compressor,
240 data_left,
241 encoded_left,
242 miniz_oxide::deflate::core::TDEFLFlush::Finish,
243 );
244
245 out_pos += bytes_out;
246 in_pos += bytes_in;
247
248 match status {
249 miniz_oxide::deflate::core::TDEFLStatus::Okay => {
250 if encoded.len().saturating_sub(out_pos) < 30 {
252 encoded.resize(encoded.len() * 2, 0);
253 }
254 }
255 miniz_oxide::deflate::core::TDEFLStatus::Done => {
256 encoded.truncate(out_pos);
257
258 assert!(
259 in_pos == data.len(),
260 "Zlib encode bug: consumed less input than expected"
261 );
262
263 return Ok(encoded);
264 }
265 #[expect(clippy::panic)] err => panic!("Zlib encode bug: {err:?}"),
267 }
268 }
269}
270
271pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZlibCodecError> {
283 let (header, encoded) =
284 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
285 ZlibCodecError::HeaderDecodeFailed {
286 source: ZlibHeaderError(err),
287 }
288 })?;
289
290 let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
291 decompress_into_bytes(encoded, decoded)
292 });
293
294 result.map(|()| decoded)
295}
296
297pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZlibCodecError> {
312 let (header, encoded) =
313 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
314 ZlibCodecError::HeaderDecodeFailed {
315 source: ZlibHeaderError(err),
316 }
317 })?;
318
319 if header.dtype != decoded.dtype() {
320 return Err(ZlibCodecError::MismatchedDecodeIntoArray {
321 source: AnyArrayAssignError::DTypeMismatch {
322 src: header.dtype,
323 dst: decoded.dtype(),
324 },
325 });
326 }
327
328 if header.shape != decoded.shape() {
329 return Err(ZlibCodecError::MismatchedDecodeIntoArray {
330 source: AnyArrayAssignError::ShapeMismatch {
331 src: header.shape.into_owned(),
332 dst: decoded.shape().to_vec(),
333 },
334 });
335 }
336
337 decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
338}
339
340fn decompress_into_bytes(encoded: &[u8], decoded: &mut [u8]) -> Result<(), ZlibCodecError> {
341 let flags = miniz_oxide::inflate::core::inflate_flags::TINFL_FLAG_PARSE_ZLIB_HEADER
342 | miniz_oxide::inflate::core::inflate_flags::TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF;
343
344 let mut decomp = Box::<miniz_oxide::inflate::core::DecompressorOxide>::default();
345
346 let (status, in_consumed, out_consumed) =
347 miniz_oxide::inflate::core::decompress(&mut decomp, encoded, decoded, 0, flags);
348
349 match status {
350 miniz_oxide::inflate::TINFLStatus::Done => {
351 if in_consumed != encoded.len() {
352 Err(ZlibCodecError::DecodeExcessiveEncodedData)
353 } else if out_consumed == decoded.len() {
354 Ok(())
355 } else {
356 Err(ZlibCodecError::DecodeProducedLess)
357 }
358 }
359 status => Err(ZlibCodecError::ZlibDecodeFailed {
360 source: ZlibDecodeError(miniz_oxide::inflate::DecompressError {
361 status,
362 output: Vec::new(),
363 }),
364 }),
365 }
366}
367
368#[derive(Serialize, Deserialize)]
369struct CompressionHeader<'a> {
370 dtype: AnyArrayDType,
371 #[serde(borrow)]
372 shape: Cow<'a, [usize]>,
373 version: ZlibCodecVersion,
374}