1#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayViewMut, Data, Dimension, IxDyn, ShapeError};
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 ArrayDType, ArrayDataMutExt, Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33use ::zstd_sys as _;
36
37#[cfg(test)]
38use ::serde_json as _;
39
40type Sz3CodecVersion = StaticCodecVersion<0, 2, 0>;
41
42#[derive(Clone, Serialize, Deserialize, JsonSchema)]
43#[schemars(deny_unknown_fields)]
45pub struct Sz3Codec {
47 #[serde(default = "default_predictor")]
49 pub predictor: Option<Sz3Predictor>,
50 #[serde(flatten)]
52 pub error_bound: Sz3ErrorBound,
53 #[serde(default, rename = "_version")]
55 pub version: Sz3CodecVersion,
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
60#[serde(tag = "eb_mode")]
61#[serde(deny_unknown_fields)]
62pub enum Sz3ErrorBound {
63 #[serde(rename = "abs-and-rel")]
66 AbsoluteAndRelative {
67 #[serde(rename = "eb_abs")]
69 abs: f64,
70 #[serde(rename = "eb_rel")]
72 rel: f64,
73 },
74 #[serde(rename = "abs-or-rel")]
77 AbsoluteOrRelative {
78 #[serde(rename = "eb_abs")]
80 abs: f64,
81 #[serde(rename = "eb_rel")]
83 rel: f64,
84 },
85 #[serde(rename = "abs")]
87 Absolute {
88 #[serde(rename = "eb_abs")]
90 abs: f64,
91 },
92 #[serde(rename = "rel")]
94 Relative {
95 #[serde(rename = "eb_rel")]
97 rel: f64,
98 },
99 #[serde(rename = "psnr")]
101 PS2NR {
102 #[serde(rename = "eb_psnr")]
104 psnr: f64,
105 },
106 #[serde(rename = "l2")]
108 L2Norm {
109 #[serde(rename = "eb_l2")]
111 l2: f64,
112 },
113}
114
115#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
117#[serde(deny_unknown_fields)]
118pub enum Sz3Predictor {
119 #[serde(rename = "interpolation")]
121 Interpolation,
122 #[serde(rename = "interpolation-lorenzo")]
124 InterpolationLorenzo,
125 #[serde(rename = "regression")]
127 Regression,
128 #[serde(rename = "lorenzo2")]
130 LorenzoSecondOrder,
131 #[serde(rename = "lorenzo2-regression")]
133 LorenzoSecondOrderRegression,
134 #[serde(rename = "lorenzo")]
136 Lorenzo,
137 #[serde(rename = "lorenzo-regression")]
139 LorenzoRegression,
140 #[serde(rename = "lorenzo-lorenzo2")]
142 LorenzoFirstSecondOrder,
143 #[serde(rename = "lorenzo-lorenzo2-regression")]
145 LorenzoFirstSecondOrderRegression,
146}
147
148#[expect(clippy::unnecessary_wraps)]
149const fn default_predictor() -> Option<Sz3Predictor> {
150 Some(Sz3Predictor::InterpolationLorenzo)
151}
152
153impl Codec for Sz3Codec {
154 type Error = Sz3CodecError;
155
156 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
157 match data {
158 AnyCowArray::I32(data) => Ok(AnyArray::U8(
159 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
160 .into_dyn(),
161 )),
162 AnyCowArray::I64(data) => Ok(AnyArray::U8(
163 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
164 .into_dyn(),
165 )),
166 AnyCowArray::F32(data) => Ok(AnyArray::U8(
167 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
168 .into_dyn(),
169 )),
170 AnyCowArray::F64(data) => Ok(AnyArray::U8(
171 Array1::from(compress(data, self.predictor.as_ref(), &self.error_bound)?)
172 .into_dyn(),
173 )),
174 encoded => Err(Sz3CodecError::UnsupportedDtype(encoded.dtype())),
175 }
176 }
177
178 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
179 let AnyCowArray::U8(encoded) = encoded else {
180 return Err(Sz3CodecError::EncodedDataNotBytes {
181 dtype: encoded.dtype(),
182 });
183 };
184
185 if !matches!(encoded.shape(), [_]) {
186 return Err(Sz3CodecError::EncodedDataNotOneDimensional {
187 shape: encoded.shape().to_vec(),
188 });
189 }
190
191 decompress(&AnyCowArray::U8(encoded).as_bytes())
192 }
193
194 fn decode_into(
195 &self,
196 encoded: AnyArrayView,
197 decoded: AnyArrayViewMut,
198 ) -> Result<(), Self::Error> {
199 let AnyArrayView::U8(encoded) = encoded else {
200 return Err(Sz3CodecError::EncodedDataNotBytes {
201 dtype: encoded.dtype(),
202 });
203 };
204
205 if !matches!(encoded.shape(), [_]) {
206 return Err(Sz3CodecError::EncodedDataNotOneDimensional {
207 shape: encoded.shape().to_vec(),
208 });
209 }
210
211 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
212 }
213}
214
215impl StaticCodec for Sz3Codec {
216 const CODEC_ID: &'static str = "sz3.rs";
217
218 type Config<'de> = Self;
219
220 fn from_config(config: Self::Config<'_>) -> Self {
221 config
222 }
223
224 fn get_config(&self) -> StaticCodecConfig<'_, Self> {
225 StaticCodecConfig::from(self)
226 }
227}
228
229#[derive(Debug, Error)]
230pub enum Sz3CodecError {
232 #[error("Sz3 does not support the dtype {0}")]
234 UnsupportedDtype(AnyArrayDType),
235 #[error("Sz3 failed to encode the header")]
237 HeaderEncodeFailed {
238 source: Sz3HeaderError,
240 },
241 #[error("Sz3 cannot encode an array of shape {shape:?}")]
243 InvalidEncodeShape {
244 source: Sz3CodingError,
246 shape: Vec<usize>,
248 },
249 #[error("Sz3 failed to encode the data")]
251 Sz3EncodeFailed {
252 source: Sz3CodingError,
254 },
255 #[error(
258 "Sz3 can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
259 )]
260 EncodedDataNotBytes {
261 dtype: AnyArrayDType,
263 },
264 #[error(
267 "Sz3 can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
268 )]
269 EncodedDataNotOneDimensional {
270 shape: Vec<usize>,
272 },
273 #[error("Sz3 failed to decode the header")]
275 HeaderDecodeFailed {
276 source: Sz3HeaderError,
278 },
279 #[error("Sz3 failed to decode the data")]
281 Sz3DecodeFailed {
282 source: Sz3CodingError,
284 },
285 #[error("Sz3 decoded an invalid array shape header which does not fit the decoded data")]
288 DecodeInvalidShapeHeader {
289 #[from]
291 source: ShapeError,
292 },
293 #[error("Sz3 cannot decode into the provided array")]
295 MismatchedDecodeIntoArray {
296 #[from]
298 source: AnyArrayAssignError,
299 },
300}
301
302#[derive(Debug, Error)]
303#[error(transparent)]
304pub struct Sz3HeaderError(postcard::Error);
306
307#[derive(Debug, Error)]
308#[error(transparent)]
309pub struct Sz3CodingError(sz3::SZ3Error);
311
312#[expect(clippy::needless_pass_by_value, clippy::too_many_lines)]
313pub fn compress<T: Sz3Element, S: Data<Elem = T>, D: Dimension>(
324 data: ArrayBase<S, D>,
325 predictor: Option<&Sz3Predictor>,
326 error_bound: &Sz3ErrorBound,
327) -> Result<Vec<u8>, Sz3CodecError> {
328 let mut encoded_bytes = postcard::to_extend(
329 &CompressionHeader {
330 dtype: <T as Sz3Element>::DTYPE,
331 shape: Cow::Borrowed(data.shape()),
332 version: StaticCodecVersion,
333 },
334 Vec::new(),
335 )
336 .map_err(|err| Sz3CodecError::HeaderEncodeFailed {
337 source: Sz3HeaderError(err),
338 })?;
339
340 if data.is_empty() {
342 return Ok(encoded_bytes);
343 }
344
345 #[expect(clippy::option_if_let_else)]
346 let data_cow = match data.as_slice() {
347 Some(data) => Cow::Borrowed(data),
348 None => Cow::Owned(data.iter().copied().collect()),
349 };
350 let mut builder = sz3::DimensionedData::build(&data_cow);
351
352 for length in data.shape() {
353 if *length > 1 {
357 builder = builder
358 .dim(*length)
359 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
360 source: Sz3CodingError(err),
361 shape: data.shape().to_vec(),
362 })?;
363 }
364 }
365
366 if data.len() == 1 {
367 builder = builder
370 .dim(1)
371 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
372 source: Sz3CodingError(err),
373 shape: data.shape().to_vec(),
374 })?;
375 }
376
377 let data = builder
378 .finish()
379 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
380 source: Sz3CodingError(err),
381 shape: data.shape().to_vec(),
382 })?;
383
384 let error_bound = match error_bound {
386 Sz3ErrorBound::AbsoluteAndRelative { abs, rel } => sz3::ErrorBound::AbsoluteAndRelative {
387 absolute_bound: *abs,
388 relative_bound: *rel,
389 },
390 Sz3ErrorBound::AbsoluteOrRelative { abs, rel } => sz3::ErrorBound::AbsoluteOrRelative {
391 absolute_bound: *abs,
392 relative_bound: *rel,
393 },
394 Sz3ErrorBound::Absolute { abs } => sz3::ErrorBound::Absolute(*abs),
395 Sz3ErrorBound::Relative { rel } => sz3::ErrorBound::Relative(*rel),
396 Sz3ErrorBound::PS2NR { psnr } => sz3::ErrorBound::PSNR(*psnr),
397 Sz3ErrorBound::L2Norm { l2 } => sz3::ErrorBound::L2Norm(*l2),
398 };
399 let mut config = sz3::Config::new(error_bound);
400
401 let predictor = match predictor {
403 Some(Sz3Predictor::Interpolation) => sz3::CompressionAlgorithm::Interpolation,
404 Some(Sz3Predictor::InterpolationLorenzo) => sz3::CompressionAlgorithm::InterpolationLorenzo,
405 Some(Sz3Predictor::Regression) => sz3::CompressionAlgorithm::LorenzoRegression {
406 lorenzo: false,
407 lorenzo_second_order: false,
408 regression: true,
409 },
410 Some(Sz3Predictor::LorenzoSecondOrder) => sz3::CompressionAlgorithm::LorenzoRegression {
411 lorenzo: false,
412 lorenzo_second_order: true,
413 regression: false,
414 },
415 Some(Sz3Predictor::LorenzoSecondOrderRegression) => {
416 sz3::CompressionAlgorithm::LorenzoRegression {
417 lorenzo: false,
418 lorenzo_second_order: true,
419 regression: true,
420 }
421 }
422 Some(Sz3Predictor::Lorenzo) => sz3::CompressionAlgorithm::LorenzoRegression {
423 lorenzo: true,
424 lorenzo_second_order: false,
425 regression: false,
426 },
427 Some(Sz3Predictor::LorenzoRegression) => sz3::CompressionAlgorithm::LorenzoRegression {
428 lorenzo: true,
429 lorenzo_second_order: false,
430 regression: true,
431 },
432 Some(Sz3Predictor::LorenzoFirstSecondOrder) => {
433 sz3::CompressionAlgorithm::LorenzoRegression {
434 lorenzo: true,
435 lorenzo_second_order: true,
436 regression: false,
437 }
438 }
439 Some(Sz3Predictor::LorenzoFirstSecondOrderRegression) => {
440 sz3::CompressionAlgorithm::LorenzoRegression {
441 lorenzo: true,
442 lorenzo_second_order: true,
443 regression: true,
444 }
445 }
446 None => sz3::CompressionAlgorithm::NoPrediction,
447 };
448 config = config.compression_algorithm(predictor);
449
450 sz3::compress_into_with_config(&data, &config, &mut encoded_bytes).map_err(|err| {
451 Sz3CodecError::Sz3EncodeFailed {
452 source: Sz3CodingError(err),
453 }
454 })?;
455
456 Ok(encoded_bytes)
457}
458
459pub fn decompress(encoded: &[u8]) -> Result<AnyArray, Sz3CodecError> {
467 fn decompress_typed<T: Sz3Element>(
468 encoded: &[u8],
469 shape: &[usize],
470 ) -> Result<Array<T, IxDyn>, Sz3CodecError> {
471 if shape.iter().copied().any(|s| s == 0) {
472 return Ok(Array::from_shape_vec(shape, Vec::new())?);
473 }
474
475 let (_config, decompressed) =
476 sz3::decompress(encoded).map_err(|err| Sz3CodecError::Sz3DecodeFailed {
477 source: Sz3CodingError(err),
478 })?;
479
480 Ok(Array::from_shape_vec(shape, decompressed.into_data())?)
481 }
482
483 let (header, data) =
484 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
485 Sz3CodecError::HeaderDecodeFailed {
486 source: Sz3HeaderError(err),
487 }
488 })?;
489
490 let decoded = match header.dtype {
491 Sz3DType::U8 => AnyArray::U8(decompress_typed(data, &header.shape)?),
492 Sz3DType::I8 => AnyArray::I8(decompress_typed(data, &header.shape)?),
493 Sz3DType::U16 => AnyArray::U16(decompress_typed(data, &header.shape)?),
494 Sz3DType::I16 => AnyArray::I16(decompress_typed(data, &header.shape)?),
495 Sz3DType::U32 => AnyArray::U32(decompress_typed(data, &header.shape)?),
496 Sz3DType::I32 => AnyArray::I32(decompress_typed(data, &header.shape)?),
497 Sz3DType::U64 => AnyArray::U64(decompress_typed(data, &header.shape)?),
498 Sz3DType::I64 => AnyArray::I64(decompress_typed(data, &header.shape)?),
499 Sz3DType::F32 => AnyArray::F32(decompress_typed(data, &header.shape)?),
500 Sz3DType::F64 => AnyArray::F64(decompress_typed(data, &header.shape)?),
501 };
502
503 Ok(decoded)
504}
505
506pub fn decompress_into(encoded: &[u8], decoded: AnyArrayViewMut) -> Result<(), Sz3CodecError> {
516 fn decompress_into_typed<T: Sz3Element>(
517 encoded: &[u8],
518 mut decoded: ArrayViewMut<T, IxDyn>,
519 ) -> Result<(), Sz3CodecError> {
520 if decoded.is_empty() {
521 return Ok(());
522 }
523
524 let decoded_shape = decoded.shape().to_vec();
525
526 decoded.with_slice_mut(|mut decoded| {
527 let decoded_len = decoded.len();
528
529 let mut builder = sz3::DimensionedData::build_mut(&mut decoded);
530
531 for length in &decoded_shape {
532 if *length > 1 {
536 builder = builder
537 .dim(*length)
538 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
540 source: Sz3CodingError(err),
541 shape: decoded_shape.clone(),
542 })?;
543 }
544 }
545
546 if decoded_len == 1 {
547 builder = builder
550 .dim(1)
551 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
553 source: Sz3CodingError(err),
554 shape: decoded_shape.clone(),
555 })?;
556 }
557
558 let mut decoded = builder
559 .finish()
560 .map_err(|err| Sz3CodecError::InvalidEncodeShape {
562 source: Sz3CodingError(err),
563 shape: decoded_shape,
564 })?;
565
566 sz3::decompress_into_dimensioned(encoded, &mut decoded).map_err(|err| {
567 Sz3CodecError::Sz3DecodeFailed {
568 source: Sz3CodingError(err),
569 }
570 })
571 })?;
572
573 Ok(())
574 }
575
576 let (header, data) =
577 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
578 Sz3CodecError::HeaderDecodeFailed {
579 source: Sz3HeaderError(err),
580 }
581 })?;
582
583 if decoded.shape() != &*header.shape {
584 return Err(Sz3CodecError::MismatchedDecodeIntoArray {
585 source: AnyArrayAssignError::ShapeMismatch {
586 src: header.shape.into_owned(),
587 dst: decoded.shape().to_vec(),
588 },
589 });
590 }
591
592 match (decoded, header.dtype) {
593 (AnyArrayViewMut::U8(decoded), Sz3DType::U8) => decompress_into_typed(data, decoded),
594 (AnyArrayViewMut::I8(decoded), Sz3DType::I8) => decompress_into_typed(data, decoded),
595 (AnyArrayViewMut::U16(decoded), Sz3DType::U16) => decompress_into_typed(data, decoded),
596 (AnyArrayViewMut::I16(decoded), Sz3DType::I16) => decompress_into_typed(data, decoded),
597 (AnyArrayViewMut::U32(decoded), Sz3DType::U32) => decompress_into_typed(data, decoded),
598 (AnyArrayViewMut::I32(decoded), Sz3DType::I32) => decompress_into_typed(data, decoded),
599 (AnyArrayViewMut::U64(decoded), Sz3DType::U64) => decompress_into_typed(data, decoded),
600 (AnyArrayViewMut::I64(decoded), Sz3DType::I64) => decompress_into_typed(data, decoded),
601 (AnyArrayViewMut::F32(decoded), Sz3DType::F32) => decompress_into_typed(data, decoded),
602 (AnyArrayViewMut::F64(decoded), Sz3DType::F64) => decompress_into_typed(data, decoded),
603 (decoded, dtype) => Err(Sz3CodecError::MismatchedDecodeIntoArray {
604 source: AnyArrayAssignError::DTypeMismatch {
605 src: dtype.into_dtype(),
606 dst: decoded.dtype(),
607 },
608 }),
609 }
610}
611
612pub trait Sz3Element: Copy + sz3::SZ3Compressible + ArrayDType {
614 const DTYPE: Sz3DType;
616}
617
618impl Sz3Element for u8 {
619 const DTYPE: Sz3DType = Sz3DType::U8;
620}
621
622impl Sz3Element for i8 {
623 const DTYPE: Sz3DType = Sz3DType::I8;
624}
625
626impl Sz3Element for u16 {
627 const DTYPE: Sz3DType = Sz3DType::U16;
628}
629
630impl Sz3Element for i16 {
631 const DTYPE: Sz3DType = Sz3DType::I16;
632}
633
634impl Sz3Element for u32 {
635 const DTYPE: Sz3DType = Sz3DType::U32;
636}
637
638impl Sz3Element for i32 {
639 const DTYPE: Sz3DType = Sz3DType::I32;
640}
641
642impl Sz3Element for u64 {
643 const DTYPE: Sz3DType = Sz3DType::U64;
644}
645
646impl Sz3Element for i64 {
647 const DTYPE: Sz3DType = Sz3DType::I64;
648}
649
650impl Sz3Element for f32 {
651 const DTYPE: Sz3DType = Sz3DType::F32;
652}
653
654impl Sz3Element for f64 {
655 const DTYPE: Sz3DType = Sz3DType::F64;
656}
657
658#[derive(Serialize, Deserialize)]
659struct CompressionHeader<'a> {
660 dtype: Sz3DType,
661 #[serde(borrow)]
662 shape: Cow<'a, [usize]>,
663 version: Sz3CodecVersion,
664}
665
666#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
668#[expect(missing_docs)]
669pub enum Sz3DType {
670 #[serde(rename = "u8", alias = "uint8")]
671 U8,
672 #[serde(rename = "u16", alias = "uint16")]
673 U16,
674 #[serde(rename = "u32", alias = "uint32")]
675 U32,
676 #[serde(rename = "u64", alias = "uint64")]
677 U64,
678 #[serde(rename = "i8", alias = "int8")]
679 I8,
680 #[serde(rename = "i16", alias = "int16")]
681 I16,
682 #[serde(rename = "i32", alias = "int32")]
683 I32,
684 #[serde(rename = "i64", alias = "int64")]
685 I64,
686 #[serde(rename = "f32", alias = "float32")]
687 F32,
688 #[serde(rename = "f64", alias = "float64")]
689 F64,
690}
691
692impl Sz3DType {
693 #[must_use]
695 pub const fn into_dtype(self) -> AnyArrayDType {
696 match self {
697 Self::U8 => AnyArrayDType::U8,
698 Self::U16 => AnyArrayDType::U16,
699 Self::U32 => AnyArrayDType::U32,
700 Self::U64 => AnyArrayDType::U64,
701 Self::I8 => AnyArrayDType::I8,
702 Self::I16 => AnyArrayDType::I16,
703 Self::I32 => AnyArrayDType::I32,
704 Self::I64 => AnyArrayDType::I64,
705 Self::F32 => AnyArrayDType::F32,
706 Self::F64 => AnyArrayDType::F64,
707 }
708 }
709}
710
711impl fmt::Display for Sz3DType {
712 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
713 fmt.write_str(match self {
714 Self::U8 => "u8",
715 Self::U16 => "u16",
716 Self::U32 => "u32",
717 Self::U64 => "u64",
718 Self::I8 => "i8",
719 Self::I16 => "i16",
720 Self::I32 => "i32",
721 Self::I64 => "i64",
722 Self::F32 => "f32",
723 Self::F64 => "f64",
724 })
725 }
726}
727
728#[cfg(test)]
729mod tests {
730 use ndarray::ArrayView1;
731
732 use super::*;
733
734 #[test]
735 fn zero_length() -> Result<(), Sz3CodecError> {
736 let encoded = compress(
737 Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])?,
738 default_predictor().as_ref(),
739 &Sz3ErrorBound::L2Norm { l2: 27.0 },
740 )?;
741 let decoded = decompress(&encoded)?;
742
743 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
744 assert!(decoded.is_empty());
745 assert_eq!(decoded.shape(), &[1, 27, 0]);
746
747 Ok(())
748 }
749
750 #[test]
751 fn one_dimension() -> Result<(), Sz3CodecError> {
752 let data = Array::from_shape_vec([2_usize, 1, 2, 1].as_slice(), vec![1, 2, 3, 4])?;
753
754 let encoded = compress(
755 data.view(),
756 default_predictor().as_ref(),
757 &Sz3ErrorBound::Absolute { abs: 0.1 },
758 )?;
759 let decoded = decompress(&encoded)?;
760
761 assert_eq!(decoded, AnyArray::I32(data.clone()));
762
763 let mut decoded = Array::zeros(data.dim());
764 decompress_into(&encoded, AnyArrayViewMut::I32(decoded.view_mut()))?;
765
766 assert_eq!(decoded, data);
767
768 Ok(())
769 }
770
771 #[test]
772 fn small_state() -> Result<(), Sz3CodecError> {
773 for data in [
774 &[][..],
775 &[0.0],
776 &[0.0, 1.0],
777 &[0.0, 1.0, 0.0],
778 &[0.0, 1.0, 0.0, 1.0],
779 ] {
780 let encoded = compress(
781 ArrayView1::from(data),
782 default_predictor().as_ref(),
783 &Sz3ErrorBound::Absolute { abs: 0.1 },
784 )?;
785 let decoded = decompress(&encoded)?;
786
787 assert_eq!(
788 decoded,
789 AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
790 );
791
792 let mut decoded = Array::zeros([data.len()]);
793 decompress_into(
794 &encoded,
795 AnyArrayViewMut::F64(decoded.view_mut().into_dyn()),
796 )?;
797
798 assert_eq!(decoded, Array1::from_vec(data.to_vec()));
799 }
800
801 Ok(())
802 }
803
804 #[test]
805 fn all_predictors() -> Result<(), Sz3CodecError> {
806 let data = Array::linspace(-42.0, 42.0, 85);
807
808 for predictor in [
809 None,
810 Some(Sz3Predictor::Interpolation),
811 Some(Sz3Predictor::InterpolationLorenzo),
812 Some(Sz3Predictor::Regression),
813 Some(Sz3Predictor::LorenzoSecondOrder),
814 Some(Sz3Predictor::LorenzoSecondOrderRegression),
815 Some(Sz3Predictor::Lorenzo),
816 Some(Sz3Predictor::LorenzoRegression),
817 Some(Sz3Predictor::LorenzoFirstSecondOrder),
818 Some(Sz3Predictor::LorenzoFirstSecondOrderRegression),
819 ] {
820 let encoded = compress(
821 data.view(),
822 predictor.as_ref(),
823 &Sz3ErrorBound::Absolute { abs: 0.1 },
824 )?;
825 let _decoded = decompress(&encoded)?;
826
827 let mut decoded = Array::zeros(data.dim());
828 decompress_into(
829 &encoded,
830 AnyArrayViewMut::F64(decoded.view_mut().into_dyn()),
831 )?;
832 }
833
834 Ok(())
835 }
836
837 #[test]
838 fn all_dtypes() -> Result<(), Sz3CodecError> {
839 fn compress_decompress<T: Sz3Element + num_traits::identities::Zero>(
840 iter: impl Clone + IntoIterator<Item = T>,
841 view_mut: impl for<'a> Fn(ArrayViewMut<'a, T, IxDyn>) -> AnyArrayViewMut<'a>,
842 ) -> Result<(), Sz3CodecError> {
843 let encoded = compress(
844 Array::from_iter(iter.clone()).view(),
845 None,
846 &Sz3ErrorBound::Absolute { abs: 2.0 },
847 )?;
848 let _decoded = decompress(&encoded)?;
849
850 let mut decoded = Array::<T, _>::zeros([iter.into_iter().count()]).into_dyn();
851 decompress_into(&encoded, view_mut(decoded.view_mut().into_dyn()))?;
852
853 Ok(())
854 }
855
856 compress_decompress(0_u8..42, |x| AnyArrayViewMut::U8(x))?;
857 compress_decompress(-42_i8..42, |x| AnyArrayViewMut::I8(x))?;
858 compress_decompress(0_u16..42, |x| AnyArrayViewMut::U16(x))?;
859 compress_decompress(-42_i16..42, |x| AnyArrayViewMut::I16(x))?;
860 compress_decompress(0_u32..42, |x| AnyArrayViewMut::U32(x))?;
861 compress_decompress(-42_i32..42, |x| AnyArrayViewMut::I32(x))?;
862 compress_decompress(0_u64..42, |x| AnyArrayViewMut::U64(x))?;
863 compress_decompress(-42_i64..42, |x| AnyArrayViewMut::I64(x))?;
864 compress_decompress((-42_i16..42).map(f32::from), |x| AnyArrayViewMut::F32(x))?;
865 compress_decompress((-42_i16..42).map(f64::from), |x| AnyArrayViewMut::F64(x))?;
866
867 Ok(())
868 }
869}