1#![expect(clippy::multiple_crate_versions)] use std::borrow::Cow;
23
24use ndarray::Array1;
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::{JsonSchema, JsonSchema_repr};
30use serde::{Deserialize, Serialize};
31use serde_repr::{Deserialize_repr, Serialize_repr};
32use thiserror::Error;
33
34type ZlibCodecVersion = StaticCodecVersion<0, 1, 0>;
35
36#[derive(Clone, Serialize, Deserialize, JsonSchema)]
37#[serde(deny_unknown_fields)]
38pub struct ZlibCodec {
40 pub level: ZlibLevel,
44 #[serde(default, rename = "_version")]
46 pub version: ZlibCodecVersion,
47}
48
49#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
50#[repr(u8)]
51#[expect(missing_docs)]
55pub enum ZlibLevel {
56 ZNoCompression = 0,
57 ZBestSpeed = 1,
58 ZLevel2 = 2,
59 ZLevel3 = 3,
60 ZLevel4 = 4,
61 ZLevel5 = 5,
62 ZLevel6 = 6,
63 ZLevel7 = 7,
64 ZLevel8 = 8,
65 ZBestCompression = 9,
66}
67
68impl Codec for ZlibCodec {
69 type Error = ZlibCodecError;
70
71 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
72 compress(data.view(), self.level)
73 .map(|bytes| AnyArray::U8(Array1::from_vec(bytes).into_dyn()))
74 }
75
76 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
77 let AnyCowArray::U8(encoded) = encoded else {
78 return Err(ZlibCodecError::EncodedDataNotBytes {
79 dtype: encoded.dtype(),
80 });
81 };
82
83 if !matches!(encoded.shape(), [_]) {
84 return Err(ZlibCodecError::EncodedDataNotOneDimensional {
85 shape: encoded.shape().to_vec(),
86 });
87 }
88
89 decompress(&AnyCowArray::U8(encoded).as_bytes())
90 }
91
92 fn decode_into(
93 &self,
94 encoded: AnyArrayView,
95 decoded: AnyArrayViewMut,
96 ) -> Result<(), Self::Error> {
97 let AnyArrayView::U8(encoded) = encoded else {
98 return Err(ZlibCodecError::EncodedDataNotBytes {
99 dtype: encoded.dtype(),
100 });
101 };
102
103 if !matches!(encoded.shape(), [_]) {
104 return Err(ZlibCodecError::EncodedDataNotOneDimensional {
105 shape: encoded.shape().to_vec(),
106 });
107 }
108
109 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
110 }
111}
112
113impl StaticCodec for ZlibCodec {
114 const CODEC_ID: &'static str = "zlib.rs";
115
116 type Config<'de> = Self;
117
118 fn from_config(config: Self::Config<'_>) -> Self {
119 config
120 }
121
122 fn get_config(&self) -> StaticCodecConfig<Self> {
123 StaticCodecConfig::from(self)
124 }
125}
126
127#[derive(Debug, Error)]
128pub enum ZlibCodecError {
130 #[error("Zlib failed to encode the header")]
132 HeaderEncodeFailed {
133 source: ZlibHeaderError,
135 },
136 #[error(
139 "Zlib can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
140 )]
141 EncodedDataNotBytes {
142 dtype: AnyArrayDType,
144 },
145 #[error("Zlib can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
148 EncodedDataNotOneDimensional {
149 shape: Vec<usize>,
151 },
152 #[error("Zlib failed to decode the header")]
154 HeaderDecodeFailed {
155 source: ZlibHeaderError,
157 },
158 #[error("Zlib decode consumed less encoded data, which contains trailing junk")]
161 DecodeExcessiveEncodedData,
162 #[error("Zlib produced less decoded data than expected")]
164 DecodeProducedLess,
165 #[error("Zlib failed to decode the encoded data")]
167 ZlibDecodeFailed {
168 source: ZlibDecodeError,
170 },
171 #[error("Zlib cannot decode into the provided array")]
173 MismatchedDecodeIntoArray {
174 #[from]
176 source: AnyArrayAssignError,
177 },
178}
179
180#[derive(Debug, Error)]
181#[error(transparent)]
182pub struct ZlibHeaderError(postcard::Error);
184
185#[derive(Debug, Error)]
186#[error(transparent)]
187pub struct ZlibDecodeError(miniz_oxide::inflate::DecompressError);
189
190#[expect(clippy::needless_pass_by_value)]
191pub fn compress(array: AnyArrayView, level: ZlibLevel) -> Result<Vec<u8>, ZlibCodecError> {
202 let data = array.as_bytes();
203
204 let mut encoded = postcard::to_extend(
205 &CompressionHeader {
206 dtype: array.dtype(),
207 shape: Cow::Borrowed(array.shape()),
208 version: StaticCodecVersion,
209 },
210 Vec::new(),
211 )
212 .map_err(|err| ZlibCodecError::HeaderEncodeFailed {
213 source: ZlibHeaderError(err),
214 })?;
215
216 let mut in_pos = 0;
217 let mut out_pos = encoded.len();
218
219 let flags =
222 miniz_oxide::deflate::core::create_comp_flags_from_zip_params((level as u8).into(), 1, 0);
223 let mut compressor = miniz_oxide::deflate::core::CompressorOxide::new(flags);
224 encoded.resize(encoded.len() + (data.len() / 2).max(2), 0);
225
226 loop {
227 let (Some(data_left), Some(encoded_left)) =
228 (data.get(in_pos..), encoded.get_mut(out_pos..))
229 else {
230 #[expect(clippy::panic)] {
232 panic!("Zlib encode bug: input or output is out of bounds")
233 }
234 };
235
236 let (status, bytes_in, bytes_out) = miniz_oxide::deflate::core::compress(
237 &mut compressor,
238 data_left,
239 encoded_left,
240 miniz_oxide::deflate::core::TDEFLFlush::Finish,
241 );
242
243 out_pos += bytes_out;
244 in_pos += bytes_in;
245
246 match status {
247 miniz_oxide::deflate::core::TDEFLStatus::Okay => {
248 if encoded.len().saturating_sub(out_pos) < 30 {
250 encoded.resize(encoded.len() * 2, 0);
251 }
252 }
253 miniz_oxide::deflate::core::TDEFLStatus::Done => {
254 encoded.truncate(out_pos);
255
256 assert!(
257 in_pos == data.len(),
258 "Zlib encode bug: consumed less input than expected"
259 );
260
261 return Ok(encoded);
262 }
263 #[expect(clippy::panic)] err => panic!("Zlib encode bug: {err:?}"),
265 }
266 }
267}
268
269pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZlibCodecError> {
281 let (header, encoded) =
282 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
283 ZlibCodecError::HeaderDecodeFailed {
284 source: ZlibHeaderError(err),
285 }
286 })?;
287
288 let (decoded, result) = AnyArray::with_zeros_bytes(header.dtype, &header.shape, |decoded| {
289 decompress_into_bytes(encoded, decoded)
290 });
291
292 result.map(|()| decoded)
293}
294
295pub fn decompress_into(encoded: &[u8], mut decoded: AnyArrayViewMut) -> Result<(), ZlibCodecError> {
310 let (header, encoded) =
311 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
312 ZlibCodecError::HeaderDecodeFailed {
313 source: ZlibHeaderError(err),
314 }
315 })?;
316
317 if header.dtype != decoded.dtype() {
318 return Err(ZlibCodecError::MismatchedDecodeIntoArray {
319 source: AnyArrayAssignError::DTypeMismatch {
320 src: header.dtype,
321 dst: decoded.dtype(),
322 },
323 });
324 }
325
326 if header.shape != decoded.shape() {
327 return Err(ZlibCodecError::MismatchedDecodeIntoArray {
328 source: AnyArrayAssignError::ShapeMismatch {
329 src: header.shape.into_owned(),
330 dst: decoded.shape().to_vec(),
331 },
332 });
333 }
334
335 decoded.with_bytes_mut(|decoded| decompress_into_bytes(encoded, decoded))
336}
337
338fn decompress_into_bytes(encoded: &[u8], decoded: &mut [u8]) -> Result<(), ZlibCodecError> {
339 let flags = miniz_oxide::inflate::core::inflate_flags::TINFL_FLAG_PARSE_ZLIB_HEADER
340 | miniz_oxide::inflate::core::inflate_flags::TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF;
341
342 let mut decomp = Box::<miniz_oxide::inflate::core::DecompressorOxide>::default();
343
344 let (status, in_consumed, out_consumed) =
345 miniz_oxide::inflate::core::decompress(&mut decomp, encoded, decoded, 0, flags);
346
347 match status {
348 miniz_oxide::inflate::TINFLStatus::Done => {
349 if in_consumed != encoded.len() {
350 Err(ZlibCodecError::DecodeExcessiveEncodedData)
351 } else if out_consumed == decoded.len() {
352 Ok(())
353 } else {
354 Err(ZlibCodecError::DecodeProducedLess)
355 }
356 }
357 status => Err(ZlibCodecError::ZlibDecodeFailed {
358 source: ZlibDecodeError(miniz_oxide::inflate::DecompressError {
359 status,
360 output: Vec::new(),
361 }),
362 }),
363 }
364}
365
366#[derive(Serialize, Deserialize)]
367struct CompressionHeader<'a> {
368 dtype: AnyArrayDType,
369 #[serde(borrow)]
370 shape: Cow<'a, [usize]>,
371 version: ZlibCodecVersion,
372}