1#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
37
38use ndarray::{Array, Array1, ArrayView, Dimension, Zip};
39use numcodecs::{
40 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
41 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
42};
43use schemars::JsonSchema;
44use serde::{Deserialize, Serialize};
45use thiserror::Error;
46
47#[cfg(test)]
48use ::serde_json as _;
49
50mod ffi;
51
52type ZfpCodecVersion = StaticCodecVersion<0, 2, 0>;
53
54#[derive(Clone, Serialize, Deserialize, JsonSchema)]
55#[schemars(deny_unknown_fields)]
57pub struct ZfpCodec {
59 #[serde(flatten)]
61 pub mode: ZfpCompressionMode,
62 #[serde(default)]
64 pub non_finite: ZfpNonFiniteValuesMode,
65 #[serde(default, rename = "_version")]
67 pub version: ZfpCodecVersion,
68}
69
70#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
71#[serde(tag = "mode")]
72#[serde(deny_unknown_fields)]
73pub enum ZfpCompressionMode {
75 #[serde(rename = "expert")]
76 Expert {
78 min_bits: u32,
80 max_bits: u32,
82 max_prec: u32,
84 min_exp: i32,
89 },
90 #[serde(rename = "fixed-rate")]
95 FixedRate {
96 rate: f64,
98 },
99 #[serde(rename = "fixed-precision")]
103 FixedPrecision {
104 precision: u32,
106 },
107 #[serde(rename = "fixed-accuracy")]
112 FixedAccuracy {
113 tolerance: f64,
115 },
116 #[serde(rename = "reversible")]
119 Reversible,
120}
121
122#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
123pub enum ZfpNonFiniteValuesMode {
125 #[default]
127 #[serde(rename = "deny")]
128 Deny,
129 #[serde(rename = "allow-unsafe")]
133 AllowUnsafe,
134}
135
136impl Codec for ZfpCodec {
137 type Error = ZfpCodecError;
138
139 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
140 if matches!(data.dtype(), AnyArrayDType::I32 | AnyArrayDType::I64)
141 && matches!(
142 self.mode,
143 ZfpCompressionMode::FixedAccuracy { tolerance: _ }
144 )
145 {
146 return Err(ZfpCodecError::FixedAccuracyModeIntegerData);
147 }
148
149 match data {
150 AnyCowArray::I32(data) => Ok(AnyArray::U8(
151 Array1::from(compress(data.view(), &self.mode, self.non_finite)?).into_dyn(),
152 )),
153 AnyCowArray::I64(data) => Ok(AnyArray::U8(
154 Array1::from(compress(data.view(), &self.mode, self.non_finite)?).into_dyn(),
155 )),
156 AnyCowArray::F32(data) => Ok(AnyArray::U8(
157 Array1::from(compress(data.view(), &self.mode, self.non_finite)?).into_dyn(),
158 )),
159 AnyCowArray::F64(data) => Ok(AnyArray::U8(
160 Array1::from(compress(data.view(), &self.mode, self.non_finite)?).into_dyn(),
161 )),
162 encoded => Err(ZfpCodecError::UnsupportedDtype(encoded.dtype())),
163 }
164 }
165
166 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
167 let AnyCowArray::U8(encoded) = encoded else {
168 return Err(ZfpCodecError::EncodedDataNotBytes {
169 dtype: encoded.dtype(),
170 });
171 };
172
173 if !matches!(encoded.shape(), [_]) {
174 return Err(ZfpCodecError::EncodedDataNotOneDimensional {
175 shape: encoded.shape().to_vec(),
176 });
177 }
178
179 decompress(&AnyCowArray::U8(encoded).as_bytes())
180 }
181
182 fn decode_into(
183 &self,
184 encoded: AnyArrayView,
185 decoded: AnyArrayViewMut,
186 ) -> Result<(), Self::Error> {
187 let AnyArrayView::U8(encoded) = encoded else {
188 return Err(ZfpCodecError::EncodedDataNotBytes {
189 dtype: encoded.dtype(),
190 });
191 };
192
193 if !matches!(encoded.shape(), [_]) {
194 return Err(ZfpCodecError::EncodedDataNotOneDimensional {
195 shape: encoded.shape().to_vec(),
196 });
197 }
198
199 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
200 }
201}
202
203impl StaticCodec for ZfpCodec {
204 const CODEC_ID: &'static str = "zfp.rs";
205
206 type Config<'de> = Self;
207
208 fn from_config(config: Self::Config<'_>) -> Self {
209 config
210 }
211
212 fn get_config(&self) -> StaticCodecConfig<'_, Self> {
213 StaticCodecConfig::from(self)
214 }
215}
216
217#[derive(Debug, Error)]
218pub enum ZfpCodecError {
220 #[error("Zfp does not support the dtype {0}")]
222 UnsupportedDtype(AnyArrayDType),
223 #[error("Zfp does not support the fixed accuracy mode for integer data")]
225 FixedAccuracyModeIntegerData,
226 #[error("Zfp only supports 1-4 dimensional data but found shape {shape:?}")]
228 ExcessiveDimensionality {
229 shape: Vec<usize>,
231 },
232 #[error("Zfp was configured with an invalid expert mode {mode:?}")]
234 InvalidExpertMode {
235 mode: ZfpCompressionMode,
237 },
238 #[error(
241 "Zfp does not support non-finite (infinite or NaN) floating point data in non-reversible lossy compression"
242 )]
243 NonFiniteData,
244 #[error("Zfp failed to encode the header")]
246 HeaderEncodeFailed,
247 #[error("Zfp failed to encode the array metadata header")]
249 MetaHeaderEncodeFailed {
250 source: ZfpHeaderError,
252 },
253 #[error("Zfp failed to encode the data")]
255 ZfpEncodeFailed,
256 #[error(
259 "Zfp can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
260 )]
261 EncodedDataNotBytes {
262 dtype: AnyArrayDType,
264 },
265 #[error(
268 "Zfp can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
269 )]
270 EncodedDataNotOneDimensional {
271 shape: Vec<usize>,
273 },
274 #[error("Zfp failed to decode the header")]
276 HeaderDecodeFailed,
277 #[error("Zfp failed to decode the array metadata header")]
279 MetaHeaderDecodeFailed {
280 source: ZfpHeaderError,
282 },
283 #[error("ZfpCodec cannot decode into the provided array")]
285 MismatchedDecodeIntoArray {
286 #[from]
288 source: AnyArrayAssignError,
289 },
290 #[error("Zfp failed to decode the data")]
292 ZfpDecodeFailed,
293}
294
295#[derive(Debug, Error)]
296#[error(transparent)]
297pub struct ZfpHeaderError(postcard::Error);
299
300pub fn compress<T: ffi::ZfpCompressible, D: Dimension>(
317 data: ArrayView<T, D>,
318 mode: &ZfpCompressionMode,
319 non_finite: ZfpNonFiniteValuesMode,
320) -> Result<Vec<u8>, ZfpCodecError> {
321 if !matches!(mode, ZfpCompressionMode::Reversible)
322 && !matches!(non_finite, ZfpNonFiniteValuesMode::AllowUnsafe)
323 && !Zip::from(&data).all(|x| x.is_finite())
324 {
325 return Err(ZfpCodecError::NonFiniteData);
326 }
327
328 let mut encoded = postcard::to_extend(
329 &CompressionHeader {
330 dtype: <T as ffi::ZfpCompressible>::D_TYPE,
331 shape: Cow::Borrowed(data.shape()),
332 version: StaticCodecVersion,
333 },
334 Vec::new(),
335 )
336 .map_err(|err| ZfpCodecError::MetaHeaderEncodeFailed {
337 source: ZfpHeaderError(err),
338 })?;
339
340 if data.is_empty() {
342 return Ok(encoded);
343 }
344
345 let field = ffi::ZfpField::new(data.into_dyn().squeeze())?;
348 let stream = ffi::ZfpCompressionStream::new(&field, mode)?;
349
350 let stream = stream.with_bitstream(field, &mut encoded);
353
354 let stream = stream.write_header()?;
356
357 stream.compress()?;
359
360 Ok(encoded)
361}
362
363pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZfpCodecError> {
373 let (header, encoded) =
374 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
375 ZfpCodecError::MetaHeaderDecodeFailed {
376 source: ZfpHeaderError(err),
377 }
378 })?;
379
380 if header.shape.iter().copied().product::<usize>() == 0 {
382 let decoded = match header.dtype {
383 ZfpDType::I32 => AnyArray::I32(Array::zeros(&*header.shape)),
384 ZfpDType::I64 => AnyArray::I64(Array::zeros(&*header.shape)),
385 ZfpDType::F32 => AnyArray::F32(Array::zeros(&*header.shape)),
386 ZfpDType::F64 => AnyArray::F64(Array::zeros(&*header.shape)),
387 };
388 return Ok(decoded);
389 }
390
391 let stream = ffi::ZfpDecompressionStream::new(encoded);
393
394 let stream = stream.read_header()?;
396
397 match header.dtype {
399 ZfpDType::I32 => {
400 let mut decompressed = Array::zeros(&*header.shape);
401 stream.decompress_into(decompressed.view_mut().squeeze())?;
402 Ok(AnyArray::I32(decompressed))
403 }
404 ZfpDType::I64 => {
405 let mut decompressed = Array::zeros(&*header.shape);
406 stream.decompress_into(decompressed.view_mut().squeeze())?;
407 Ok(AnyArray::I64(decompressed))
408 }
409 ZfpDType::F32 => {
410 let mut decompressed = Array::zeros(&*header.shape);
411 stream.decompress_into(decompressed.view_mut().squeeze())?;
412 Ok(AnyArray::F32(decompressed))
413 }
414 ZfpDType::F64 => {
415 let mut decompressed = Array::zeros(&*header.shape);
416 stream.decompress_into(decompressed.view_mut().squeeze())?;
417 Ok(AnyArray::F64(decompressed))
418 }
419 }
420}
421
422pub fn decompress_into(encoded: &[u8], decoded: AnyArrayViewMut) -> Result<(), ZfpCodecError> {
434 let (header, encoded) =
435 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
436 ZfpCodecError::MetaHeaderDecodeFailed {
437 source: ZfpHeaderError(err),
438 }
439 })?;
440
441 if decoded.shape() != &*header.shape {
442 return Err(ZfpCodecError::MismatchedDecodeIntoArray {
443 source: AnyArrayAssignError::ShapeMismatch {
444 src: header.shape.into_owned(),
445 dst: decoded.shape().to_vec(),
446 },
447 });
448 }
449
450 if decoded.is_empty() {
452 return Ok(());
453 }
454
455 let stream = ffi::ZfpDecompressionStream::new(encoded);
457
458 let stream = stream.read_header()?;
460
461 match (decoded, header.dtype) {
463 (AnyArrayViewMut::I32(decoded), ZfpDType::I32) => stream.decompress_into(decoded.squeeze()),
464 (AnyArrayViewMut::I64(decoded), ZfpDType::I64) => stream.decompress_into(decoded.squeeze()),
465 (AnyArrayViewMut::F32(decoded), ZfpDType::F32) => stream.decompress_into(decoded.squeeze()),
466 (AnyArrayViewMut::F64(decoded), ZfpDType::F64) => stream.decompress_into(decoded.squeeze()),
467 (decoded, dtype) => Err(ZfpCodecError::MismatchedDecodeIntoArray {
468 source: AnyArrayAssignError::DTypeMismatch {
469 src: dtype.into_dtype(),
470 dst: decoded.dtype(),
471 },
472 }),
473 }
474}
475
476#[derive(Serialize, Deserialize)]
477struct CompressionHeader<'a> {
478 dtype: ZfpDType,
479 #[serde(borrow)]
480 shape: Cow<'a, [usize]>,
481 version: ZfpCodecVersion,
482}
483
484#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
486#[expect(missing_docs)]
487pub enum ZfpDType {
488 #[serde(rename = "i32", alias = "int32")]
489 I32,
490 #[serde(rename = "i64", alias = "int64")]
491 I64,
492 #[serde(rename = "f32", alias = "float32")]
493 F32,
494 #[serde(rename = "f64", alias = "float64")]
495 F64,
496}
497
498impl ZfpDType {
499 #[must_use]
501 pub const fn into_dtype(self) -> AnyArrayDType {
502 match self {
503 Self::I32 => AnyArrayDType::I32,
504 Self::I64 => AnyArrayDType::I64,
505 Self::F32 => AnyArrayDType::F32,
506 Self::F64 => AnyArrayDType::F64,
507 }
508 }
509}
510
511impl fmt::Display for ZfpDType {
512 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
513 fmt.write_str(match self {
514 Self::I32 => "i32",
515 Self::I64 => "i64",
516 Self::F32 => "f32",
517 Self::F64 => "f64",
518 })
519 }
520}
521
522#[cfg(test)]
523#[allow(clippy::unwrap_used)]
524mod tests {
525 use ndarray::ArrayView1;
526
527 use super::*;
528
529 #[test]
530 fn zero_length() {
531 let encoded = compress(
532 Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])
533 .unwrap()
534 .view(),
535 &ZfpCompressionMode::FixedPrecision { precision: 7 },
536 ZfpNonFiniteValuesMode::Deny,
537 )
538 .unwrap();
539 let decoded = decompress(&encoded).unwrap();
540
541 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
542 assert!(decoded.is_empty());
543 assert_eq!(decoded.shape(), &[1, 27, 0]);
544 }
545
546 #[test]
547 fn one_dimension() {
548 let data = Array::from_shape_vec(
549 [2_usize, 1, 2, 1, 1, 1].as_slice(),
550 vec![1.0, 2.0, 3.0, 4.0],
551 )
552 .unwrap();
553
554 let encoded = compress(
555 data.view(),
556 &ZfpCompressionMode::FixedAccuracy { tolerance: 0.1 },
557 ZfpNonFiniteValuesMode::Deny,
558 )
559 .unwrap();
560 let decoded = decompress(&encoded).unwrap();
561
562 assert_eq!(decoded, AnyArray::F32(data));
563 }
564
565 #[test]
566 fn small_state() {
567 for data in [
568 &[][..],
569 &[0.0],
570 &[0.0, 1.0],
571 &[0.0, 1.0, 0.0],
572 &[0.0, 1.0, 0.0, 1.0],
573 ] {
574 let encoded = compress(
575 ArrayView1::from(data),
576 &ZfpCompressionMode::FixedAccuracy { tolerance: 0.1 },
577 ZfpNonFiniteValuesMode::Deny,
578 )
579 .unwrap();
580 let decoded = decompress(&encoded).unwrap();
581
582 assert_eq!(
583 decoded,
584 AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
585 );
586 }
587 }
588}