1#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40#[derive(Clone, Serialize, Deserialize, JsonSchema)]
41#[schemars(deny_unknown_fields)]
43pub struct Sz3Codec {
45 #[serde(default = "default_predictor")]
47 pub predictor: Option<Sz3Predictor>,
48 #[serde(flatten)]
50 pub error_bound: Sz3ErrorBound,
51 #[serde(default = "default_encoder")]
53 pub encoder: Option<Sz3Encoder>,
54 #[serde(default = "default_lossless_compressor")]
56 pub lossless: Option<Sz3LosslessCompressor>,
57}
58
59#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
61#[serde(tag = "eb_mode")]
62#[serde(deny_unknown_fields)]
63pub enum Sz3ErrorBound {
64 #[serde(rename = "abs-and-rel")]
67 AbsoluteAndRelative {
68 #[serde(rename = "eb_abs")]
70 abs: f64,
71 #[serde(rename = "eb_rel")]
73 rel: f64,
74 },
75 #[serde(rename = "abs-or-rel")]
78 AbsoluteOrRelative {
79 #[serde(rename = "eb_abs")]
81 abs: f64,
82 #[serde(rename = "eb_rel")]
84 rel: f64,
85 },
86 #[serde(rename = "abs")]
88 Absolute {
89 #[serde(rename = "eb_abs")]
91 abs: f64,
92 },
93 #[serde(rename = "rel")]
95 Relative {
96 #[serde(rename = "eb_rel")]
98 rel: f64,
99 },
100 #[serde(rename = "psnr")]
102 PS2NR {
103 #[serde(rename = "eb_psnr")]
105 psnr: f64,
106 },
107 #[serde(rename = "l2")]
109 L2Norm {
110 #[serde(rename = "eb_l2")]
112 l2: f64,
113 },
114}
115
116#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
118#[serde(deny_unknown_fields)]
119pub enum Sz3Predictor {
120 #[serde(rename = "linear-interpolation")]
122 LinearInterpolation,
123 #[serde(rename = "cubic-interpolation")]
125 CubicInterpolation,
126 #[serde(rename = "linear-interpolation-lorenzo")]
128 LinearInterpolationLorenzo,
129 #[serde(rename = "cubic-interpolation-lorenzo")]
131 CubicInterpolationLorenzo,
132 #[serde(rename = "lorenzo-regression")]
134 LorenzoRegression,
135}
136
137#[expect(clippy::unnecessary_wraps)]
138const fn default_predictor() -> Option<Sz3Predictor> {
139 Some(Sz3Predictor::CubicInterpolationLorenzo)
140}
141
142#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
144#[serde(deny_unknown_fields)]
145pub enum Sz3Encoder {
146 #[serde(rename = "huffman")]
148 Huffman,
149 #[serde(rename = "arithmetic")]
151 Arithmetic,
152}
153
154#[expect(clippy::unnecessary_wraps)]
155const fn default_encoder() -> Option<Sz3Encoder> {
156 Some(Sz3Encoder::Huffman)
157}
158
159#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
161#[serde(deny_unknown_fields)]
162pub enum Sz3LosslessCompressor {
163 #[serde(rename = "zstd")]
165 Zstd,
166}
167
168#[expect(clippy::unnecessary_wraps)]
169const fn default_lossless_compressor() -> Option<Sz3LosslessCompressor> {
170 Some(Sz3LosslessCompressor::Zstd)
171}
172
173impl Codec for Sz3Codec {
174 type Error = Sz3CodecError;
175
176 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
177 match data {
178 AnyCowArray::I32(data) => Ok(AnyArray::U8(
179 Array1::from(compress(
180 data,
181 self.predictor.as_ref(),
182 &self.error_bound,
183 self.encoder.as_ref(),
184 self.lossless.as_ref(),
185 )?)
186 .into_dyn(),
187 )),
188 AnyCowArray::I64(data) => Ok(AnyArray::U8(
189 Array1::from(compress(
190 data,
191 self.predictor.as_ref(),
192 &self.error_bound,
193 self.encoder.as_ref(),
194 self.lossless.as_ref(),
195 )?)
196 .into_dyn(),
197 )),
198 AnyCowArray::F32(data) => Ok(AnyArray::U8(
199 Array1::from(compress(
200 data,
201 self.predictor.as_ref(),
202 &self.error_bound,
203 self.encoder.as_ref(),
204 self.lossless.as_ref(),
205 )?)
206 .into_dyn(),
207 )),
208 AnyCowArray::F64(data) => Ok(AnyArray::U8(
209 Array1::from(compress(
210 data,
211 self.predictor.as_ref(),
212 &self.error_bound,
213 self.encoder.as_ref(),
214 self.lossless.as_ref(),
215 )?)
216 .into_dyn(),
217 )),
218 encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
219 }
220 }
221
222 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
223 let AnyCowArray::U8(encoded) = encoded else {
224 return Err(Sz3CodecError::EncodedDataNotBytes {
225 dtype: encoded.dtype(),
226 });
227 };
228
229 if !matches!(encoded.shape(), [_]) {
230 return Err(Sz3CodecError::EncodedDataNotOneDimensional {
231 shape: encoded.shape().to_vec(),
232 });
233 }
234
235 decompress(&AnyCowArray::U8(encoded).as_bytes())
236 }
237
238 fn decode_into(
239 &self,
240 encoded: AnyArrayView,
241 mut decoded: AnyArrayViewMut,
242 ) -> Result<(), Self::Error> {
243 let decoded_in = self.decode(encoded.cow())?;
244
245 Ok(decoded.assign(&decoded_in)?)
246 }
247}
248
249impl StaticCodec for Sz3Codec {
250 const CODEC_ID: &'static str = "sz3";
251
252 type Config<'de> = Self;
253
254 fn from_config(config: Self::Config<'_>) -> Self {
255 config
256 }
257
258 fn get_config(&self) -> StaticCodecConfig<Self> {
259 StaticCodecConfig::from(self)
260 }
261}
262
263#[derive(Debug, Error)]
264pub enum Sz3CodecError {
266 #[error("Sz3 does not support the dtype {0}")]
268 UnsupportedDtype(AnyArrayDType),
269 #[error("Sz3 failed to encode the header")]
271 HeaderEncodeFailed {
272 source: Sz3HeaderError,
274 },
275 #[error("Sz3 cannot encode an array of shape {shape:?}")]
277 InvalidEncodeShape {
278 source: Sz3CodingError,
280 shape: Vec<usize>,
282 },
283 #[error("Sz3 failed to encode the data")]
285 Sz3EncodeFailed {
286 source: Sz3CodingError,
288 },
289 #[error(
292 "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
293 )]
294 EncodedDataNotBytes {
295 dtype: AnyArrayDType,
297 },
298 #[error("Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
301 EncodedDataNotOneDimensional {
302 shape: Vec<usize>,
304 },
305 #[error("Sz3 failed to decode the header")]
307 HeaderDecodeFailed {
308 source: Sz3HeaderError,
310 },
311 #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
314 DecodeInvalidShapeHeader {
315 #[from]
317 source: ShapeError,
318 },
319 #[error("Sz3 cannot decode into the provided array")]
321 MismatchedDecodeIntoArray {
322 #[from]
324 source: AnyArrayAssignError,
325 },
326}
327
328#[derive(Debug, Error)]
329#[error(transparent)]
330pub struct Sz3HeaderError(postcard::Error);
332
333#[derive(Debug, Error)]
334#[error(transparent)]
335pub struct Sz3CodingError(sz3::SZ3Error);
337
338#[expect(clippy::needless_pass_by_value)]
339pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
350 data: ArrayBase<S, D>,
351 predictor: Option<&Sz3Predictor>,
352 error_bound: &Sz3ErrorBound,
353 encoder: Option<&Sz3Encoder>,
354 lossless: Option<&Sz3LosslessCompressor>,
355) -> Result<Vec<u8>, Sz3CodecError> {
356 let mut encoded_bytes = postcard::to_extend(
357 &CompressionHeader {
358 dtype: <T as Sz3Element>::DTYPE,
359 shape: Cow::Borrowed(data.shape()),
360 },
361 Vec::new(),
362 )
363 .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
364 source: Sz3HeaderError(err),
365 })?;
366
367 if data.is_empty() {
369 return Ok(encoded_bytes);
370 }
371
372 #[expect(clippy::option_if_let_else)]
373 let data_cow = if let Some(data) = data.as_slice() {
374 Cow::Borrowed(data)
375 } else {
376 Cow::Owned(data.iter().copied().collect())
377 };
378 let mut builder = sz3::DimensionedData::build(&data_cow);
379
380 for length in data.shape() {
381 if *length > 1 {
385 builder = builder
386 .dim(*length)
387 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
388 source: Sz3CodingError(err),
389 shape: data.shape().to_vec(),
390 })?;
391 }
392 }
393
394 if data.len() == 1 {
395 builder = builder
398 .dim(1)
399 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
400 source: Sz3CodingError(err),
401 shape: data.shape().to_vec(),
402 })?;
403 }
404
405 let data = builder
406 .finish()
407 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
408 source: Sz3CodingError(err),
409 shape: data.shape().to_vec(),
410 })?;
411
412 let error_bound = match error_bound {
414 Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
415 absolute_bound: *abs,
416 relative_bound: *rel,
417 },
418 Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
419 absolute_bound: *abs,
420 relative_bound: *rel,
421 },
422 Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
423 Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
424 Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
425 Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
426 };
427 let mut config = sz3::Config::new(error_bound);
428
429 let interpolation = match predictor {
431 Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::LinearInterpolationLorenzo) => {
432 Some(sz3::InterpolationAlgorithm::Linear)
433 }
434 Some(Sz3Predictor::CubicInterpolation | Sz3Predictor::CubicInterpolationLorenzo) => {
435 Some(sz3::InterpolationAlgorithm::Cubic)
436 }
437 Some(Sz3Predictor::LorenzoRegression) | None => None,
438 };
439 if let Some(interpolation) = interpolation {
440 config = config.interpolation_algorithm(interpolation);
441 }
442
443 let predictor = match predictor {
445 Some(Sz3Predictor::LinearInterpolation | Sz3Predictor::CubicInterpolation) => {
446 sz3::CompressionAlgorithm::Interpolation
447 }
448 Some(
449 Sz3Predictor::LinearInterpolationLorenzo | Sz3Predictor::CubicInterpolationLorenzo,
450 ) => sz3::CompressionAlgorithm::InterpolationLorenzo,
451 Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::lorenzo_regression(),
452 None => sz3::CompressionAlgorithm::NoPrediction,
453 };
454 config = config.compression_algorithm(predictor);
455
456 let encoder = match encoder {
458 None => sz3::Encoder::SkipEncoder,
459 Some(Sz3Encoder::Huffman) => sz3::Encoder::HuffmanEncoder,
460 Some(Sz3Encoder::Arithmetic) => sz3::Encoder::ArithmeticEncoder,
461 };
462 config = config.encoder(encoder);
463
464 let lossless = match lossless {
466 None => sz3::LossLess::LossLessBypass,
467 Some(Sz3LosslessCompressor::Zstd) => sz3::LossLess::ZSTD,
468 };
469 config = config.lossless(lossless);
470
471 let compressed = sz3::compress_with_config(&data, &config).map_err(|err| {
473 Sz3CodecError::Sz3EncodeFailed {
474 source: Sz3CodingError(err),
475 }
476 })?;
477 encoded_bytes.extend_from_slice(&compressed);
478
479 Ok(encoded_bytes)
480}
481
482pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
489 let (header, data) =
490 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
491 Sz3CodecError::HeaderDecodeFailed {
492 source: Sz3HeaderError(err),
493 }
494 })?;
495
496 let decoded = if header.shape.iter().copied().product::<usize>() == 0 {
497 match header.dtype {
498 Sz3DType::I32 => {
499 AnyArray::I32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
500 }
501 Sz3DType::I64 => {
502 AnyArray::I64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
503 }
504 Sz3DType::F32 => {
505 AnyArray::F32(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
506 }
507 Sz3DType::F64 => {
508 AnyArray::F64(Array::from_shape_vec(&*header.shape, Vec::new())?.into_dyn())
509 }
510 }
511 } else {
512 match header.dtype {
514 Sz3DType::I32 => AnyArray::I32(Array::from_shape_vec(
515 &*header.shape,
516 Vec::from(sz3::decompress(data).1.data()),
517 )?),
518 Sz3DType::I64 => AnyArray::I64(Array::from_shape_vec(
519 &*header.shape,
520 Vec::from(sz3::decompress(data).1.data()),
521 )?),
522 Sz3DType::F32 => AnyArray::F32(Array::from_shape_vec(
523 &*header.shape,
524 Vec::from(sz3::decompress(data).1.data()),
525 )?),
526 Sz3DType::F64 => AnyArray::F64(Array::from_shape_vec(
527 &*header.shape,
528 Vec::from(sz3::decompress(data).1.data()),
529 )?),
530 }
531 };
532
533 Ok(decoded)
534}
535
536pub trait Sz3Element: Copy + sz3::SZ3Compressible {
538 const DTYPE: Sz3DType;
540}
541
542impl Sz3Element for i32 {
543 const DTYPE: Sz3DType = Sz3DType::I32;
544}
545
546impl Sz3Element for i64 {
547 const DTYPE: Sz3DType = Sz3DType::I64;
548}
549
550impl Sz3Element for f32 {
551 const DTYPE: Sz3DType = Sz3DType::F32;
552}
553
554impl Sz3Element for f64 {
555 const DTYPE: Sz3DType = Sz3DType::F64;
556}
557
558#[derive(Serialize, Deserialize)]
559struct CompressionHeader<'a> {
560 dtype: Sz3DType,
561 #[serde(borrow)]
562 shape: Cow<'a, [usize]>,
563}
564
565#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
567#[expect(missing_docs)]
568pub enum Sz3DType {
569 #[serde(rename = "i32", alias = "int32")]
570 I32,
571 #[serde(rename = "i64", alias = "int64")]
572 I64,
573 #[serde(rename = "f32", alias = "float32")]
574 F32,
575 #[serde(rename = "f64", alias = "float64")]
576 F64,
577}
578
579impl fmt::Display for Sz3DType {
580 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
581 fmt.write_str(match self {
582 Self::I32 => "i32",
583 Self::I64 => "i64",
584 Self::F32 => "f32",
585 Self::F64 => "f64",
586 })
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use ndarray::ArrayView1;
593
594 use super::*;
595
596 #[test]
597 fn zero_length() -> Result<(), Sz3CodecError> {
598 let encoded = compress(
599 Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
600 default_predictor().as_ref(),
601 &Sz3ErrorBound::L2Norm { l2: 27.0 },
602 default_encoder().as_ref(),
603 default_lossless_compressor().as_ref(),
604 )?;
605 let decoded = decompress(&encoded)?;
606
607 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
608 assert!(decoded.is_empty());
609 assert_eq!(decoded.shape(), &[1, 27, 0]);
610
611 Ok(())
612 }
613
614 #[test]
615 fn one_dimension() -> Result<(), Sz3CodecError> {
616 let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
617
618 let encoded = compress(
619 data.view(),
620 default_predictor().as_ref(),
621 &Sz3ErrorBound::Absolute { abs: 0.1 },
622 default_encoder().as_ref(),
623 default_lossless_compressor().as_ref(),
624 )?;
625 let decoded = decompress(&encoded)?;
626
627 assert_eq!(decoded, AnyArray::I32(data));
628
629 Ok(())
630 }
631
632 #[test]
633 fn small_state() -> Result<(), Sz3CodecError> {
634 for data in [
635 &[][..],
636 &[0.0],
637 &[0.0, 1.0],
638 &[0.0, 1.0, 0.0],
639 &[0.0, 1.0, 0.0, 1.0],
640 ] {
641 let encoded = compress(
642 ArrayView1::from(data),
643 default_predictor().as_ref(),
644 &Sz3ErrorBound::Absolute { abs: 0.1 },
645 default_encoder().as_ref(),
646 default_lossless_compressor().as_ref(),
647 )?;
648 let decoded = decompress(&encoded)?;
649
650 assert_eq!(
651 decoded,
652 AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
653 );
654 }
655
656 Ok(())
657 }
658}