1#![expect(clippy::multiple_crate_versions)] use std::borrow::Cow;
23
24use ndarray::Array1;
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig,
28};
29use schemars::{JsonSchema, JsonSchema_repr};
30use serde::{Deserialize, Serialize};
31use serde_repr::{Deserialize_repr, Serialize_repr};
32use thiserror::Error;
33
34#[derive(Clone, Serialize, Deserialize, JsonSchema)]
35#[serde(deny_unknown_fields)]
36pub struct ZlibCodec {
38 pub level: ZlibLevel,
42}
43
44#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
45#[repr(u8)]
46#[expect(missing_docs)]
50pub enum ZlibLevel {
51 ZNoCompression = 0,
52 ZBestSpeed = 1,
53 ZLevel2 = 2,
54 ZLevel3 = 3,
55 ZLevel4 = 4,
56 ZLevel5 = 5,
57 ZLevel6 = 6,
58 ZLevel7 = 7,
59 ZLevel8 = 8,
60 ZBestCompression = 9,
61}
62
63impl Codec for ZlibCodec {
64 type Error = ZlibCodecError;
65
66 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
67 compress(data.view(), self.level)
68 .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
69 }
70
71 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
72 let AnyCowArray::U8(encoded) = encoded else {
73 return Err(ZlibCodecError::EncodedDataNotBytes {
74 dtype: encoded.dtype(),
75 });
76 };
77
78 if !matches!(encoded.shape(), [_]) {
79 return Err(ZlibCodecError::EncodedDataNotOneDimensional {
80 shape: encoded.shape().to_vec(),
81 });
82 }
83
84 decompress(&AnyCowArray::U8(encoded).as_bytes())
85 }
86
87 fn decode_into(
88 &self,
89 encoded: AnyArrayView,
90 decoded: AnyArrayViewMut,
91 ) -> Result<(), Self::Error> {
92 let AnyArrayView::U8(encoded) = encoded else {
93 return Err(ZlibCodecError::EncodedDataNotBytes {
94 dtype: encoded.dtype(),
95 });
96 };
97
98 if !matches!(encoded.shape(), [_]) {
99 return Err(ZlibCodecError::EncodedDataNotOneDimensional {
100 shape: encoded.shape().to_vec(),
101 });
102 }
103
104 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
105 }
106}
107
108impl StaticCodec for ZlibCodec {
109 const CODEC_ID: &'static str = "zlib";
110
111 type Config<'de> = Self;
112
113 fn from_config(config: Self::Config<'_>) -> Self {
114 config
115 }
116
117 fn get_config(&self) -> StaticCodecConfig<Self> {
118 StaticCodecConfig::from(self)
119 }
120}
121
122#[derive(Debug, Error)]
123pub enum ZlibCodecError {
125 #[error("Zlib failed to encode the header")]
127 HeaderEncodeFailed {
128 source: ZlibHeaderError,
130 },
131 #[error(
134 "Zlib can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
135 )]
136 EncodedDataNotBytes {
137 dtype: AnyArrayDType,
139 },
140 #[error("Zlib can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
143 EncodedDataNotOneDimensional {
144 shape: Vec<usize>,
146 },
147 #[error("Zlib failed to decode the header")]
149 HeaderDecodeFailed {
150 source: ZlibHeaderError,
152 },
153 #[error("Zlib decode consumed less encoded data, which contains trailing junk")]
156 DecodeExcessiveEncodedData,
157 #[error("Zlib produced less decoded data than expected")]
159 DecodeProducedLess,
160 #[error("Zlib failed to decode the encoded data")]
162 ZlibDecodeFailed {
163 source: ZlibDecodeError,
165 },
166 #[error("Zlib cannot decode into the provided array")]
168 MismatchedDecodeIntoArray {
169 #[from]
171 source: AnyArrayAssignError,
172 },
173}
174
175#[derive(Debug, Error)]
176#[error(transparent)]
177pub struct ZlibHeaderError(postcard::Error);
179
180#[derive(Debug, Error)]
181#[error(transparent)]
182pub struct ZlibDecodeError(miniz_oxide::inflate::DecompressError);
184
185#[expect(clippy::needless_pass_by_value)]
186pub fn compress(array: AnyArrayView, level: ZlibLevel) -> Result<Vec<u8>, ZlibCodecError> {
197 let data = array.as_bytes();
198
199 let mut encoded = postcard::to_extend(
200 &CompressionHeader {
201 dtype: array.dtype(),
202 shape: Cow::Borrowed(array.shape()),
203 },
204 Vec::new(),
205 )
206 .map_err(|err| ZlibCodecError::HeaderEncodeFailed {
207 source: ZlibHeaderError(err),
208 })?;
209
210 let mut in_pos = 0;
211 let mut out_pos = encoded.len();
212
213 let flags =
216 miniz_oxide::deflate::core::create_comp_flags_from_zip_params((level as u8).into(), 1, 0);
217 let mut compressor = miniz_oxide::deflate::core::CompressorOxide::new(flags);
218 encoded.resize(encoded.len() + (data.len() / 2).max(2), 0);
219
220 loop {
221 let (Some(data_left), Some(encoded_left)) =
222 (data.get(in_pos..), encoded.get_mut(out_pos..))
223 else {
224 #[expect(clippy::panic)] {
226 panic!("Zlib encode bug: input or output is out of bounds")
227 }
228 };
229
230 let (status, bytes_in, bytes_out) = miniz_oxide::deflate::core::compress(
231 &mut compressor,
232 data_left,
233 encoded_left,
234 miniz_oxide::deflate::core::TDEFLFlush::Finish,
235 );
236
237 out_pos += bytes_out;
238 in_pos += bytes_in;
239
240 match status {
241 miniz_oxide::deflate::core::TDEFLStatus::Okay => {
242 if encoded.len().saturating_sub(out_pos) < 30 {
244 encoded.resize(encoded.len() * 2, 0);
245 }
246 }
247 miniz_oxide::deflate::core::TDEFLStatus::Done => {
248 encoded.truncate(out_pos);
249
250 assert!(
251 in_pos == data.len(),
252 "Zlib encode bug: consumed less input than expected"
253 );
254
255 return Ok(encoded);
256 }
257 #[expect(clippy::panic)] err => panic!("Zlib encode bug: {err:?}"),
259 }
260 }
261}
262
263pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZlibCodecError> {
275 let (header, encoded) =
276 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
277 ZlibCodecError::HeaderDecodeFailed {
278 source: ZlibHeaderError(err),
279 }
280 })?;
281
282 let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
283 decompress_into_bytes(encoded, decoded)
284 });
285
286 result.map(|()| decoded)
287}
288
289pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZlibCodecError> {
304 let (header, encoded) =
305 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
306 ZlibCodecError::HeaderDecodeFailed {
307 source: ZlibHeaderError(err),
308 }
309 })?;
310
311 if header.dtype != decoded.dtype() {
312 return Err(ZlibCodecError::MismatchedDecodeIntoArray {
313 source: AnyArrayAssignError::DTypeMismatch {
314 src: header.dtype,
315 dst: decoded.dtype(),
316 },
317 });
318 }
319
320 if header.shape != decoded.shape() {
321 return Err(ZlibCodecError::MismatchedDecodeIntoArray {
322 source: AnyArrayAssignError::ShapeMismatch {
323 src: header.shape.into_owned(),
324 dst: decoded.shape().to_vec(),
325 },
326 });
327 }
328
329 decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
330}
331
332fn decompress_into_bytes(encoded: &[u8], decoded: &mut [u8]) -> Result<(), ZlibCodecError> {
333 let flags = miniz_oxide::inflate::core::inflate_flags::TINFL_FLAG_PARSE_ZLIB_HEADER
334 | miniz_oxide::inflate::core::inflate_flags::TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF;
335
336 let mut decomp = Box::<miniz_oxide::inflate::core::DecompressorOxide>::default();
337
338 let (status, in_consumed, out_consumed) =
339 miniz_oxide::inflate::core::decompress(&mut decomp, encoded, decoded, 0, flags);
340
341 match status {
342 miniz_oxide::inflate::TINFLStatus::Done => {
343 if in_consumed != encoded.len() {
344 Err(ZlibCodecError::DecodeExcessiveEncodedData)
345 } else if out_consumed == decoded.len() {
346 Ok(())
347 } else {
348 Err(ZlibCodecError::DecodeProducedLess)
349 }
350 }
351 status => Err(ZlibCodecError::ZlibDecodeFailed {
352 source: ZlibDecodeError(miniz_oxide::inflate::DecompressError {
353 status,
354 output: Vec::new(),
355 }),
356 }),
357 }
358}
359
360#[derive(Serialize, Deserialize)]
361struct CompressionHeader<'a> {
362 dtype: AnyArrayDType,
363 #[serde(borrow)]
364 shape: Cow<'a, [usize]>,
365}