1#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayViewMut, Data, Dimension, IxDyn, ShapeError};
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 ArrayDType, ArrayDataMutExt, Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40type Sz3CodecVersion = StaticCodecVersion<0, 2, 0>;
41
42#[derive(Clone, Serialize, Deserialize, JsonSchema)]
43#[schemars(deny_unknown_fields)]
45pub struct Sz3Codec {
47 #[serde(default = "default_predictor")]
49 pub predictor: Option<Sz3Predictor>,
50 #[serde(flatten)]
52 pub error_bound: Sz3ErrorBound,
53 #[serde(default, rename = "_version")]
55 pub version: Sz3CodecVersion,
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
60#[serde(tag = "eb_mode")]
61#[serde(deny_unknown_fields)]
62pub enum Sz3ErrorBound {
63 #[serde(rename = "abs-and-rel")]
66 AbsoluteAndRelative {
67 #[serde(rename = "eb_abs")]
69 abs: f64,
70 #[serde(rename = "eb_rel")]
72 rel: f64,
73 },
74 #[serde(rename = "abs-or-rel")]
77 AbsoluteOrRelative {
78 #[serde(rename = "eb_abs")]
80 abs: f64,
81 #[serde(rename = "eb_rel")]
83 rel: f64,
84 },
85 #[serde(rename = "abs")]
87 Absolute {
88 #[serde(rename = "eb_abs")]
90 abs: f64,
91 },
92 #[serde(rename = "rel")]
94 Relative {
95 #[serde(rename = "eb_rel")]
97 rel: f64,
98 },
99 #[serde(rename = "psnr")]
101 PS2NR {
102 #[serde(rename = "eb_psnr")]
104 psnr: f64,
105 },
106 #[serde(rename = "l2")]
108 L2Norm {
109 #[serde(rename = "eb_l2")]
111 l2: f64,
112 },
113}
114
115#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
117#[serde(deny_unknown_fields)]
118pub enum Sz3Predictor {
119 #[serde(rename = "interpolation")]
121 Interpolation,
122 #[serde(rename = "interpolation-lorenzo")]
124 InterpolationLorenzo,
125 #[serde(rename = "regression")]
127 Regression,
128 #[serde(rename = "lorenzo2")]
130 LorenzoSecondOrder,
131 #[serde(rename = "lorenzo2-regression")]
133 LorenzoSecondOrderRegression,
134 #[serde(rename = "lorenzo")]
136 Lorenzo,
137 #[serde(rename = "lorenzo-regression")]
139 LorenzoRegression,
140 #[serde(rename = "lorenzo-lorenzo2")]
142 LorenzoFirstSecondOrder,
143 #[serde(rename = "lorenzo-lorenzo2-regression")]
145 LorenzoFirstSecondOrderRegression,
146}
147
148#[expect(clippy::unnecessary_wraps)]
149const fn default_predictor() -> Option<Sz3Predictor> {
150 Some(Sz3Predictor::InterpolationLorenzo)
151}
152
153impl Codec for Sz3Codec {
154 type Error = Sz3CodecError;
155
156 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
157 match data {
158 AnyCowArray::U8(data) => Ok(AnyArray::U8(
159 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
160 .into_dyn(),
161 )),
162 AnyCowArray::I8(data) => Ok(AnyArray::U8(
163 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
164 .into_dyn(),
165 )),
166 AnyCowArray::U16(data) => Ok(AnyArray::U8(
167 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
168 .into_dyn(),
169 )),
170 AnyCowArray::I16(data) => Ok(AnyArray::U8(
171 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
172 .into_dyn(),
173 )),
174 AnyCowArray::U32(data) => Ok(AnyArray::U8(
175 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
176 .into_dyn(),
177 )),
178 AnyCowArray::I32(data) => Ok(AnyArray::U8(
179 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
180 .into_dyn(),
181 )),
182 AnyCowArray::U64(data) => Ok(AnyArray::U8(
183 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
184 .into_dyn(),
185 )),
186 AnyCowArray::I64(data) => Ok(AnyArray::U8(
187 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
188 .into_dyn(),
189 )),
190 AnyCowArray::F32(data) => Ok(AnyArray::U8(
191 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
192 .into_dyn(),
193 )),
194 AnyCowArray::F64(data) => Ok(AnyArray::U8(
195 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
196 .into_dyn(),
197 )),
198 encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
199 }
200 }
201
202 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
203 let AnyCowArray::U8(encoded) = encoded else {
204 return Err(Sz3CodecError::EncodedDataNotBytes {
205 dtype: encoded.dtype(),
206 });
207 };
208
209 if !matches!(encoded.shape(), [_]) {
210 return Err(Sz3CodecError::EncodedDataNotOneDimensional {
211 shape: encoded.shape().to_vec(),
212 });
213 }
214
215 decompress(&AnyCowArray::U8(encoded).as_bytes())
216 }
217
218 fn decode_into(
219 &self,
220 encoded: AnyArrayView,
221 decoded: AnyArrayViewMut,
222 ) -> Result<(), Self::Error> {
223 let AnyArrayView::U8(encoded) = encoded else {
224 return Err(Sz3CodecError::EncodedDataNotBytes {
225 dtype: encoded.dtype(),
226 });
227 };
228
229 if !matches!(encoded.shape(), [_]) {
230 return Err(Sz3CodecError::EncodedDataNotOneDimensional {
231 shape: encoded.shape().to_vec(),
232 });
233 }
234
235 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
236 }
237}
238
239impl StaticCodec for Sz3Codec {
240 const CODEC_ID: &'static str = "sz3.rs";
241
242 type Config<'de> = Self;
243
244 fn from_config(config: Self::Config<'_>) -> Self {
245 config
246 }
247
248 fn get_config(&self) -> StaticCodecConfig<'_, Self> {
249 StaticCodecConfig::from(self)
250 }
251}
252
253#[derive(Debug, Error)]
254pub enum Sz3CodecError {
256 #[error("Sz3 does not support the dtype {0}")]
258 UnsupportedDtype(AnyArrayDType),
259 #[error("Sz3 failed to encode the header")]
261 HeaderEncodeFailed {
262 source: Sz3HeaderError,
264 },
265 #[error("Sz3 cannot encode an array of shape {shape:?}")]
267 InvalidEncodeShape {
268 source: Sz3CodingError,
270 shape: Vec<usize>,
272 },
273 #[error("Sz3 failed to encode the data")]
275 Sz3EncodeFailed {
276 source: Sz3CodingError,
278 },
279 #[error(
282 "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
283 )]
284 EncodedDataNotBytes {
285 dtype: AnyArrayDType,
287 },
288 #[error(
291 "Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
292 )]
293 EncodedDataNotOneDimensional {
294 shape: Vec<usize>,
296 },
297 #[error("Sz3 failed to decode the header")]
299 HeaderDecodeFailed {
300 source: Sz3HeaderError,
302 },
303 #[error("Sz3 failed to decode the data")]
305 Sz3DecodeFailed {
306 source: Sz3CodingError,
308 },
309 #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
312 DecodeInvalidShapeHeader {
313 #[from]
315 source: ShapeError,
316 },
317 #[error("Sz3 cannot decode into the provided array")]
319 MismatchedDecodeIntoArray {
320 #[from]
322 source: AnyArrayAssignError,
323 },
324}
325
326#[derive(Debug, Error)]
327#[error(transparent)]
328pub struct Sz3HeaderError(postcard::Error);
330
331#[derive(Debug, Error)]
332#[error(transparent)]
333pub struct Sz3CodingError(sz3::SZ3Error);
335
336#[expect(clippy::needless_pass_by_value, clippy::too_many_lines)]
337pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
348 data: ArrayBase<S, D>,
349 predictor: Option<&Sz3Predictor>,
350 error_bound: &Sz3ErrorBound,
351) -> Result<Vec<u8>, Sz3CodecError> {
352 let mut encoded_bytes = postcard::to_extend(
353 &CompressionHeader {
354 dtype: <T as Sz3Element>::DTYPE,
355 shape: Cow::Borrowed(data.shape()),
356 version: StaticCodecVersion,
357 },
358 Vec::new(),
359 )
360 .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
361 source: Sz3HeaderError(err),
362 })?;
363
364 if data.is_empty() {
366 return Ok(encoded_bytes);
367 }
368
369 #[expect(clippy::option_if_let_else)]
370 let data_cow = match data.as_slice() {
371 Some(data) => Cow::Borrowed(data),
372 None => Cow::Owned(data.iter().copied().collect()),
373 };
374 let mut builder = sz3::DimensionedData::build(&data_cow);
375
376 for length in data.shape() {
377 if *length > 1 {
381 builder = builder
382 .dim(*length)
383 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
384 source: Sz3CodingError(err),
385 shape: data.shape().to_vec(),
386 })?;
387 }
388 }
389
390 if data.len() == 1 {
391 builder = builder
394 .dim(1)
395 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
396 source: Sz3CodingError(err),
397 shape: data.shape().to_vec(),
398 })?;
399 }
400
401 let data = builder
402 .finish()
403 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
404 source: Sz3CodingError(err),
405 shape: data.shape().to_vec(),
406 })?;
407
408 let error_bound = match error_bound {
410 Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
411 absolute_bound: *abs,
412 relative_bound: *rel,
413 },
414 Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
415 absolute_bound: *abs,
416 relative_bound: *rel,
417 },
418 Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
419 Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
420 Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
421 Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
422 };
423 let mut config = sz3::Config::new(error_bound);
424
425 let predictor = match predictor {
427 Some(Sz3Predictor::Interpolation) => sz3::CompressionAlgorithm::Interpolation,
428 Some(Sz3Predictor::InterpolationLorenzo) => sz3::CompressionAlgorithm::InterpolationLorenzo,
429 Some(Sz3Predictor::Regression) => sz3::CompressionAlgorithm::LorenzoRegression {
430 lorenzo: false,
431 lorenzo_second_order: false,
432 regression: true,
433 },
434 Some(Sz3Predictor::LorenzoSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
435 lorenzo: false,
436 lorenzo_second_order: true,
437 regression: false,
438 },
439 Some(Sz3Predictor::LorenzoSecondOrderRegression) => {
440 sz3::CompressionAlgorithm::LorenzoRegression {
441 lorenzo: false,
442 lorenzo_second_order: true,
443 regression: true,
444 }
445 }
446 Some(Sz3Predictor::Lorenzo) => sz3::CompressionAlgorithm::LorenzoRegression {
447 lorenzo: true,
448 lorenzo_second_order: false,
449 regression: false,
450 },
451 Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::LorenzoRegression {
452 lorenzo: true,
453 lorenzo_second_order: false,
454 regression: true,
455 },
456 Some(Sz3Predictor::LorenzoFirstSecondOrder) => {
457 sz3::CompressionAlgorithm::LorenzoRegression {
458 lorenzo: true,
459 lorenzo_second_order: true,
460 regression: false,
461 }
462 }
463 Some(Sz3Predictor::LorenzoFirstSecondOrderRegression) => {
464 sz3::CompressionAlgorithm::LorenzoRegression {
465 lorenzo: true,
466 lorenzo_second_order: true,
467 regression: true,
468 }
469 }
470 None => sz3::CompressionAlgorithm::NoPrediction,
471 };
472 config = config.compression_algorithm(predictor);
473
474 sz3::compress_into_with_config(&data, &config, &mut encoded_bytes).map_err(|err| {
475 Sz3CodecError::Sz3EncodeFailed {
476 source: Sz3CodingError(err),
477 }
478 })?;
479
480 Ok(encoded_bytes)
481}
482
483pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
491 fn decompress_typed<T: Sz3Element>(
492 encoded: &[u8],
493 shape: &[usize],
494 ) -> Result<Array<T, IxDyn>, Sz3CodecError> {
495 if shape.iter().copied().any(|s| s == 0) {
496 return Ok(Array::from_shape_vec(shape, Vec::new())?);
497 }
498
499 let (_config, decompressed) =
500 sz3::decompress(encoded).map_err(|err| Sz3CodecError::Sz3DecodeFailed {
501 source: Sz3CodingError(err),
502 })?;
503
504 Ok(Array::from_shape_vec(shape, decompressed.into_data())?)
505 }
506
507 let (header, data) =
508 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
509 Sz3CodecError::HeaderDecodeFailed {
510 source: Sz3HeaderError(err),
511 }
512 })?;
513
514 let decoded = match header.dtype {
515 Sz3DType::U8 => AnyArray::U8(decompress_typed(data, &header.shape)?),
516 Sz3DType::I8 => AnyArray::I8(decompress_typed(data, &header.shape)?),
517 Sz3DType::U16 => AnyArray::U16(decompress_typed(data, &header.shape)?),
518 Sz3DType::I16 => AnyArray::I16(decompress_typed(data, &header.shape)?),
519 Sz3DType::U32 => AnyArray::U32(decompress_typed(data, &header.shape)?),
520 Sz3DType::I32 => AnyArray::I32(decompress_typed(data, &header.shape)?),
521 Sz3DType::U64 => AnyArray::U64(decompress_typed(data, &header.shape)?),
522 Sz3DType::I64 => AnyArray::I64(decompress_typed(data, &header.shape)?),
523 Sz3DType::F32 => AnyArray::F32(decompress_typed(data, &header.shape)?),
524 Sz3DType::F64 => AnyArray::F64(decompress_typed(data, &header.shape)?),
525 };
526
527 Ok(decoded)
528}
529
530pub fn decompress_into(encoded: &[u8], decoded: AnyArrayViewMut) -> Result<(), Sz3CodecError> {
540 fn decompress_into_typed<T: Sz3Element>(
541 encoded: &[u8],
542 mut decoded: ArrayViewMut<T, IxDyn>,
543 ) -> Result<(), Sz3CodecError> {
544 if decoded.is_empty() {
545 return Ok(());
546 }
547
548 let decoded_shape = decoded.shape().to_vec();
549
550 decoded.with_slice_mut(|mut decoded| {
551 let decoded_len = decoded.len();
552
553 let mut builder = sz3::DimensionedData::build_mut(&mut decoded);
554
555 for length in &decoded_shape {
556 if *length > 1 {
560 builder = builder
561 .dim(*length)
562 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
564 source: Sz3CodingError(err),
565 shape: decoded_shape.clone(),
566 })?;
567 }
568 }
569
570 if decoded_len == 1 {
571 builder = builder
574 .dim(1)
575 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
577 source: Sz3CodingError(err),
578 shape: decoded_shape.clone(),
579 })?;
580 }
581
582 let mut decoded = builder
583 .finish()
584 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
586 source: Sz3CodingError(err),
587 shape: decoded_shape,
588 })?;
589
590 sz3::decompress_into_dimensioned(encoded, &mut decoded).map_err(|err| {
591 Sz3CodecError::Sz3DecodeFailed {
592 source: Sz3CodingError(err),
593 }
594 })
595 })?;
596
597 Ok(())
598 }
599
600 let (header, data) =
601 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
602 Sz3CodecError::HeaderDecodeFailed {
603 source: Sz3HeaderError(err),
604 }
605 })?;
606
607 if decoded.shape() != &*header.shape {
608 return Err(Sz3CodecError::MismatchedDecodeIntoArray {
609 source: AnyArrayAssignError::ShapeMismatch {
610 src: header.shape.into_owned(),
611 dst: decoded.shape().to_vec(),
612 },
613 });
614 }
615
616 match (decoded, header.dtype) {
617 (AnyArrayViewMut::U8(decoded), Sz3DType::U8) => decompress_into_typed(data, decoded),
618 (AnyArrayViewMut::I8(decoded), Sz3DType::I8) => decompress_into_typed(data, decoded),
619 (AnyArrayViewMut::U16(decoded), Sz3DType::U16) => decompress_into_typed(data, decoded),
620 (AnyArrayViewMut::I16(decoded), Sz3DType::I16) => decompress_into_typed(data, decoded),
621 (AnyArrayViewMut::U32(decoded), Sz3DType::U32) => decompress_into_typed(data, decoded),
622 (AnyArrayViewMut::I32(decoded), Sz3DType::I32) => decompress_into_typed(data, decoded),
623 (AnyArrayViewMut::U64(decoded), Sz3DType::U64) => decompress_into_typed(data, decoded),
624 (AnyArrayViewMut::I64(decoded), Sz3DType::I64) => decompress_into_typed(data, decoded),
625 (AnyArrayViewMut::F32(decoded), Sz3DType::F32) => decompress_into_typed(data, decoded),
626 (AnyArrayViewMut::F64(decoded), Sz3DType::F64) => decompress_into_typed(data, decoded),
627 (decoded, dtype) => Err(Sz3CodecError::MismatchedDecodeIntoArray {
628 source: AnyArrayAssignError::DTypeMismatch {
629 src: dtype.into_dtype(),
630 dst: decoded.dtype(),
631 },
632 }),
633 }
634}
635
636pub trait Sz3Element: Copy + sz3::SZ3Compressible + ArrayDType {
638 const DTYPE: Sz3DType;
640}
641
642impl Sz3Element for u8 {
643 const DTYPE: Sz3DType = Sz3DType::U8;
644}
645
646impl Sz3Element for i8 {
647 const DTYPE: Sz3DType = Sz3DType::I8;
648}
649
650impl Sz3Element for u16 {
651 const DTYPE: Sz3DType = Sz3DType::U16;
652}
653
654impl Sz3Element for i16 {
655 const DTYPE: Sz3DType = Sz3DType::I16;
656}
657
658impl Sz3Element for u32 {
659 const DTYPE: Sz3DType = Sz3DType::U32;
660}
661
662impl Sz3Element for i32 {
663 const DTYPE: Sz3DType = Sz3DType::I32;
664}
665
666impl Sz3Element for u64 {
667 const DTYPE: Sz3DType = Sz3DType::U64;
668}
669
670impl Sz3Element for i64 {
671 const DTYPE: Sz3DType = Sz3DType::I64;
672}
673
674impl Sz3Element for f32 {
675 const DTYPE: Sz3DType = Sz3DType::F32;
676}
677
678impl Sz3Element for f64 {
679 const DTYPE: Sz3DType = Sz3DType::F64;
680}
681
682#[derive(Serialize, Deserialize)]
683struct CompressionHeader<'a> {
684 dtype: Sz3DType,
685 #[serde(borrow)]
686 shape: Cow<'a, [usize]>,
687 version: Sz3CodecVersion,
688}
689
690#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
692#[expect(missing_docs)]
693pub enum Sz3DType {
694 #[serde(rename = "u8", alias = "uint8")]
695 U8,
696 #[serde(rename = "u16", alias = "uint16")]
697 U16,
698 #[serde(rename = "u32", alias = "uint32")]
699 U32,
700 #[serde(rename = "u64", alias = "uint64")]
701 U64,
702 #[serde(rename = "i8", alias = "int8")]
703 I8,
704 #[serde(rename = "i16", alias = "int16")]
705 I16,
706 #[serde(rename = "i32", alias = "int32")]
707 I32,
708 #[serde(rename = "i64", alias = "int64")]
709 I64,
710 #[serde(rename = "f32", alias = "float32")]
711 F32,
712 #[serde(rename = "f64", alias = "float64")]
713 F64,
714}
715
716impl Sz3DType {
717 #[must_use]
719 pub const fn into_dtype(self) -> AnyArrayDType {
720 match self {
721 Self::U8 => AnyArrayDType::U8,
722 Self::U16 => AnyArrayDType::U16,
723 Self::U32 => AnyArrayDType::U32,
724 Self::U64 => AnyArrayDType::U64,
725 Self::I8 => AnyArrayDType::I8,
726 Self::I16 => AnyArrayDType::I16,
727 Self::I32 => AnyArrayDType::I32,
728 Self::I64 => AnyArrayDType::I64,
729 Self::F32 => AnyArrayDType::F32,
730 Self::F64 => AnyArrayDType::F64,
731 }
732 }
733}
734
735impl fmt::Display for Sz3DType {
736 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
737 fmt.write_str(match self {
738 Self::U8 => "u8",
739 Self::U16 => "u16",
740 Self::U32 => "u32",
741 Self::U64 => "u64",
742 Self::I8 => "i8",
743 Self::I16 => "i16",
744 Self::I32 => "i32",
745 Self::I64 => "i64",
746 Self::F32 => "f32",
747 Self::F64 => "f64",
748 })
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use ndarray::ArrayView1;
755
756 use super::*;
757
758 #[test]
759 fn zero_length() -> Result<(), Sz3CodecError> {
760 let encoded = compress(
761 Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
762 default_predictor().as_ref(),
763 &Sz3ErrorBound::L2Norm { l2: 27.0 },
764 )?;
765 let decoded = decompress(&encoded)?;
766
767 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
768 assert!(decoded.is_empty());
769 assert_eq!(decoded.shape(), &[1, 27, 0]);
770
771 Ok(())
772 }
773
774 #[test]
775 fn one_dimension() -> Result<(), Sz3CodecError> {
776 let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
777
778 let encoded = compress(
779 data.view(),
780 default_predictor().as_ref(),
781 &Sz3ErrorBound::Absolute { abs: 0.1 },
782 )?;
783 let decoded = decompress(&encoded)?;
784
785 assert_eq!(decoded, AnyArray::I32(data.clone()));
786
787 let mut decoded = Array::zeros(data.dim());
788 decompress_into(&encoded, AnyArrayViewMut::I32(decoded.view_mut()))?;
789
790 assert_eq!(decoded, data);
791
792 Ok(())
793 }
794
795 #[test]
796 fn small_state() -> Result<(), Sz3CodecError> {
797 for data in [
798 &[][..],
799 &[0.0],
800 &[0.0, 1.0],
801 &[0.0, 1.0, 0.0],
802 &[0.0, 1.0, 0.0, 1.0],
803 ] {
804 let encoded = compress(
805 ArrayView1::from(data),
806 default_predictor().as_ref(),
807 &Sz3ErrorBound::Absolute { abs: 0.1 },
808 )?;
809 let decoded = decompress(&encoded)?;
810
811 assert_eq!(
812 decoded,
813 AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
814 );
815
816 let mut decoded = Array::zeros([data.len()]);
817 decompress_into(
818 &encoded,
819 AnyArrayViewMut::F64(decoded.view_mut().into_dyn()),
820 )?;
821
822 assert_eq!(decoded, Array1::from_vec(data.to_vec()));
823 }
824
825 Ok(())
826 }
827
828 #[test]
829 fn all_predictors() -> Result<(), Sz3CodecError> {
830 let data = Array::linspace(-42.0, 42.0, 85);
831
832 for predictor in [
833 None,
834 Some(Sz3Predictor::Interpolation),
835 Some(Sz3Predictor::InterpolationLorenzo),
836 Some(Sz3Predictor::Regression),
837 Some(Sz3Predictor::LorenzoSecondOrder),
838 Some(Sz3Predictor::LorenzoSecondOrderRegression),
839 Some(Sz3Predictor::Lorenzo),
840 Some(Sz3Predictor::LorenzoRegression),
841 Some(Sz3Predictor::LorenzoFirstSecondOrder),
842 Some(Sz3Predictor::LorenzoFirstSecondOrderRegression),
843 ] {
844 let encoded = compress(
845 data.view(),
846 predictor.as_ref(),
847 &Sz3ErrorBound::Absolute { abs: 0.1 },
848 )?;
849 let _decoded = decompress(&encoded)?;
850
851 let mut decoded = Array::zeros(data.dim());
852 decompress_into(
853 &encoded,
854 AnyArrayViewMut::F64(decoded.view_mut().into_dyn()),
855 )?;
856 }
857
858 Ok(())
859 }
860
861 #[test]
862 fn all_dtypes() -> Result<(), Sz3CodecError> {
863 fn compress_decompress<T: Sz3Element + num_traits::identities::Zero>(
864 iter: impl Clone + IntoIterator<Item = T>,
865 view_mut: impl for<'a> Fn(ArrayViewMut<'a, T, IxDyn>) -> AnyArrayViewMut<'a>,
866 ) -> Result<(), Sz3CodecError> {
867 let encoded = compress(
868 Array::from_iter(iter.clone()).view(),
869 default_predictor().as_ref(),
870 &Sz3ErrorBound::Absolute { abs: 2.0 },
871 )?;
872 let _decoded = decompress(&encoded)?;
873
874 let mut decoded = Array::<T, _>::zeros([iter.into_iter().count()]).into_dyn();
875 decompress_into(&encoded, view_mut(decoded.view_mut().into_dyn()))?;
876
877 Ok(())
878 }
879
880 compress_decompress(0_u8..42, |x| AnyArrayViewMut::U8(x))?;
881 compress_decompress(-42_i8..42, |x| AnyArrayViewMut::I8(x))?;
882 compress_decompress(0_u16..42, |x| AnyArrayViewMut::U16(x))?;
883 compress_decompress(-42_i16..42, |x| AnyArrayViewMut::I16(x))?;
884 compress_decompress(0_u32..42, |x| AnyArrayViewMut::U32(x))?;
885 compress_decompress(-42_i32..42, |x| AnyArrayViewMut::I32(x))?;
886 compress_decompress(0_u64..42, |x| AnyArrayViewMut::U64(x))?;
887 compress_decompress(-42_i64..42, |x| AnyArrayViewMut::I64(x))?;
888 compress_decompress((-42_i16..42).map(f32::from), |x| AnyArrayViewMut::F32(x))?;
889 compress_decompress((-42_i16..42).map(f64::from), |x| AnyArrayViewMut::F64(x))?;
890
891 Ok(())
892 }
893}