1#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayView, Dimension, Zip};
25use numcodecs::{
26 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
27 Codec, StaticCodec, StaticCodecConfig,
28};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33mod ffi;
34
35#[derive(Clone, Serialize, Deserialize, JsonSchema)]
36#[serde(transparent)]
37pub struct ZfpCodec {
39 pub mode: ZfpCompressionMode,
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
44#[serde(tag = "mode")]
45#[serde(deny_unknown_fields)]
46pub enum ZfpCompressionMode {
48 #[serde(rename = "expert")]
49 Expert {
51 min_bits: u32,
53 max_bits: u32,
55 max_prec: u32,
57 min_exp: i32,
62 },
63 #[serde(rename = "fixed-rate")]
68 FixedRate {
69 rate: f64,
71 },
72 #[serde(rename = "fixed-precision")]
76 FixedPrecision {
77 precision: u32,
79 },
80 #[serde(rename = "fixed-accuracy")]
85 FixedAccuracy {
86 tolerance: f64,
88 },
89 #[serde(rename = "reversible")]
92 Reversible,
93}
94
95impl Codec for ZfpCodec {
96 type Error = ZfpCodecError;
97
98 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
99 if matches!(data.dtype(), AnyArrayDType::I32 | AnyArrayDType::I64)
100 && matches!(
101 self.mode,
102 ZfpCompressionMode::FixedAccuracy { tolerance: _ }
103 )
104 {
105 return Err(ZfpCodecError::FixedAccuracyModeIntegerData);
106 }
107
108 match data {
109 AnyCowArray::I32(data) => Ok(AnyArray::U8(
110 Array1::from(compress(data.view(), &self.mode)?).into_dyn(),
111 )),
112 AnyCowArray::I64(data) => Ok(AnyArray::U8(
113 Array1::from(compress(data.view(), &self.mode)?).into_dyn(),
114 )),
115 AnyCowArray::F32(data) => Ok(AnyArray::U8(
116 Array1::from(compress(data.view(), &self.mode)?).into_dyn(),
117 )),
118 AnyCowArray::F64(data) => Ok(AnyArray::U8(
119 Array1::from(compress(data.view(), &self.mode)?).into_dyn(),
120 )),
121 encoded => Err(ZfpCodecError::UnsupportedDtype(encoded.dtype())),
122 }
123 }
124
125 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
126 let AnyCowArray::U8(encoded) = encoded else {
127 return Err(ZfpCodecError::EncodedDataNotBytes {
128 dtype: encoded.dtype(),
129 });
130 };
131
132 if !matches!(encoded.shape(), [_]) {
133 return Err(ZfpCodecError::EncodedDataNotOneDimensional {
134 shape: encoded.shape().to_vec(),
135 });
136 }
137
138 decompress(&AnyCowArray::U8(encoded).as_bytes())
139 }
140
141 fn decode_into(
142 &self,
143 encoded: AnyArrayView,
144 decoded: AnyArrayViewMut,
145 ) -> Result<(), Self::Error> {
146 let AnyArrayView::U8(encoded) = encoded else {
147 return Err(ZfpCodecError::EncodedDataNotBytes {
148 dtype: encoded.dtype(),
149 });
150 };
151
152 if !matches!(encoded.shape(), [_]) {
153 return Err(ZfpCodecError::EncodedDataNotOneDimensional {
154 shape: encoded.shape().to_vec(),
155 });
156 }
157
158 decompress_into(&AnyArrayView::U8(encoded).as_bytes(), decoded)
159 }
160}
161
162impl StaticCodec for ZfpCodec {
163 const CODEC_ID: &'static str = "zfp";
164
165 type Config<'de> = Self;
166
167 fn from_config(config: Self::Config<'_>) -> Self {
168 config
169 }
170
171 fn get_config(&self) -> StaticCodecConfig<Self> {
172 StaticCodecConfig::from(self)
173 }
174}
175
176#[derive(Debug, Error)]
177pub enum ZfpCodecError {
179 #[error("Zfp does not support the dtype {0}")]
181 UnsupportedDtype(AnyArrayDType),
182 #[error("Zfp does not support the fixed accuracy mode for integer data")]
184 FixedAccuracyModeIntegerData,
185 #[error("Zfp only supports 1-4 dimensional data but found shape {shape:?}")]
187 ExcessiveDimensionality {
188 shape: Vec<usize>,
190 },
191 #[error("Zfp was configured with an invalid expert mode {mode:?}")]
193 InvalidExpertMode {
194 mode: ZfpCompressionMode,
196 },
197 #[error("Zfp does not support non-finite (infinite or NaN) floating point data in non-reversible lossy compression")]
200 NonFiniteData,
201 #[error("Zfp failed to encode the header")]
203 HeaderEncodeFailed,
204 #[error("Zfp failed to encode the array metadata header")]
206 MetaHeaderEncodeFailed {
207 source: ZfpHeaderError,
209 },
210 #[error("Zfp failed to encode the data")]
212 ZfpEncodeFailed,
213 #[error(
216 "Zfp can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
217 )]
218 EncodedDataNotBytes {
219 dtype: AnyArrayDType,
221 },
222 #[error("Zfp can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
225 EncodedDataNotOneDimensional {
226 shape: Vec<usize>,
228 },
229 #[error("Zfp failed to decode the header")]
231 HeaderDecodeFailed,
232 #[error("Zfp failed to decode the array metadata header")]
234 MetaHeaderDecodeFailed {
235 source: ZfpHeaderError,
237 },
238 #[error("ZfpCodec cannot decode into the provided array")]
240 MismatchedDecodeIntoArray {
241 #[from]
243 source: AnyArrayAssignError,
244 },
245 #[error("Zfp failed to decode the data")]
247 ZfpDecodeFailed,
248}
249
250#[derive(Debug, Error)]
251#[error(transparent)]
252pub struct ZfpHeaderError(postcard::Error);
254
255pub fn compress<T: ffi::ZfpCompressible, D: Dimension>(
271 data: ArrayView<T, D>,
272 mode: &ZfpCompressionMode,
273) -> Result<Vec<u8>, ZfpCodecError> {
274 if !matches!(mode, ZfpCompressionMode::Reversible) && !Zip::from(&data).all(|x| x.is_finite()) {
275 return Err(ZfpCodecError::NonFiniteData);
276 }
277
278 let mut encoded = postcard::to_extend(
279 &CompressionHeader {
280 dtype: <T as ffi::ZfpCompressible>::D_TYPE,
281 shape: Cow::Borrowed(data.shape()),
282 },
283 Vec::new(),
284 )
285 .map_err(|err| ZfpCodecError::MetaHeaderEncodeFailed {
286 source: ZfpHeaderError(err),
287 })?;
288
289 if data.is_empty() {
291 return Ok(encoded);
292 }
293
294 let field = ffi::ZfpField::new(data.into_dyn().squeeze())?;
297 let stream = ffi::ZfpCompressionStream::new(&field, mode)?;
298
299 let stream = stream.with_bitstream(field, &mut encoded);
302
303 let stream = stream.write_header()?;
305
306 stream.compress()?;
308
309 Ok(encoded)
310}
311
312pub fn decompress(encoded: &[u8]) -> Result<AnyArray, ZfpCodecError> {
322 let (header, encoded) =
323 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
324 ZfpCodecError::MetaHeaderDecodeFailed {
325 source: ZfpHeaderError(err),
326 }
327 })?;
328
329 if header.shape.iter().copied().product::<usize>() == 0 {
331 let decoded = match header.dtype {
332 ZfpDType::I32 => AnyArray::I32(Array::zeros(&*header.shape)),
333 ZfpDType::I64 => AnyArray::I64(Array::zeros(&*header.shape)),
334 ZfpDType::F32 => AnyArray::F32(Array::zeros(&*header.shape)),
335 ZfpDType::F64 => AnyArray::F64(Array::zeros(&*header.shape)),
336 };
337 return Ok(decoded);
338 }
339
340 let stream = ffi::ZfpDecompressionStream::new(encoded);
342
343 let stream = stream.read_header()?;
345
346 match header.dtype {
348 ZfpDType::I32 => {
349 let mut decompressed = Array::zeros(&*header.shape);
350 stream.decompress_into(decompressed.view_mut().squeeze())?;
351 Ok(AnyArray::I32(decompressed))
352 }
353 ZfpDType::I64 => {
354 let mut decompressed = Array::zeros(&*header.shape);
355 stream.decompress_into(decompressed.view_mut().squeeze())?;
356 Ok(AnyArray::I64(decompressed))
357 }
358 ZfpDType::F32 => {
359 let mut decompressed = Array::zeros(&*header.shape);
360 stream.decompress_into(decompressed.view_mut().squeeze())?;
361 Ok(AnyArray::F32(decompressed))
362 }
363 ZfpDType::F64 => {
364 let mut decompressed = Array::zeros(&*header.shape);
365 stream.decompress_into(decompressed.view_mut().squeeze())?;
366 Ok(AnyArray::F64(decompressed))
367 }
368 }
369}
370
371pub fn decompress_into(encoded: &[u8], decoded: AnyArrayViewMut) -> Result<(), ZfpCodecError> {
383 let (header, encoded) =
384 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
385 ZfpCodecError::MetaHeaderDecodeFailed {
386 source: ZfpHeaderError(err),
387 }
388 })?;
389
390 if decoded.shape() != &*header.shape {
391 return Err(ZfpCodecError::MismatchedDecodeIntoArray {
392 source: AnyArrayAssignError::ShapeMismatch {
393 src: header.shape.into_owned(),
394 dst: decoded.shape().to_vec(),
395 },
396 });
397 }
398
399 if decoded.is_empty() {
401 return Ok(());
402 }
403
404 let stream = ffi::ZfpDecompressionStream::new(encoded);
406
407 let stream = stream.read_header()?;
409
410 match (decoded, header.dtype) {
412 (AnyArrayViewMut::I32(decoded), ZfpDType::I32) => stream.decompress_into(decoded.squeeze()),
413 (AnyArrayViewMut::I64(decoded), ZfpDType::I64) => stream.decompress_into(decoded.squeeze()),
414 (AnyArrayViewMut::F32(decoded), ZfpDType::F32) => stream.decompress_into(decoded.squeeze()),
415 (AnyArrayViewMut::F64(decoded), ZfpDType::F64) => stream.decompress_into(decoded.squeeze()),
416 (decoded, dtype) => Err(ZfpCodecError::MismatchedDecodeIntoArray {
417 source: AnyArrayAssignError::DTypeMismatch {
418 src: dtype.into_dtype(),
419 dst: decoded.dtype(),
420 },
421 }),
422 }
423}
424
425#[derive(Serialize, Deserialize)]
426struct CompressionHeader<'a> {
427 dtype: ZfpDType,
428 #[serde(borrow)]
429 shape: Cow<'a, [usize]>,
430}
431
432#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
434#[expect(missing_docs)]
435pub enum ZfpDType {
436 #[serde(rename = "i32", alias = "int32")]
437 I32,
438 #[serde(rename = "i64", alias = "int64")]
439 I64,
440 #[serde(rename = "f32", alias = "float32")]
441 F32,
442 #[serde(rename = "f64", alias = "float64")]
443 F64,
444}
445
446impl ZfpDType {
447 #[must_use]
449 pub const fn into_dtype(self) -> AnyArrayDType {
450 match self {
451 Self::I32 => AnyArrayDType::I32,
452 Self::I64 => AnyArrayDType::I64,
453 Self::F32 => AnyArrayDType::F32,
454 Self::F64 => AnyArrayDType::F64,
455 }
456 }
457}
458
459impl fmt::Display for ZfpDType {
460 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
461 fmt.write_str(match self {
462 Self::I32 => "i32",
463 Self::I64 => "i64",
464 Self::F32 => "f32",
465 Self::F64 => "f64",
466 })
467 }
468}
469
470#[cfg(test)]
471#[allow(clippy::unwrap_used)]
472mod tests {
473 use ndarray::ArrayView1;
474
475 use super::*;
476
477 #[test]
478 fn zero_length() {
479 let encoded = compress(
480 Array::<f32, _>::from_shape_vec([1, 27, 0].as_slice(), vec![])
481 .unwrap()
482 .view(),
483 &ZfpCompressionMode::FixedPrecision { precision: 7 },
484 )
485 .unwrap();
486 let decoded = decompress(&encoded).unwrap();
487
488 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
489 assert!(decoded.is_empty());
490 assert_eq!(decoded.shape(), &[1, 27, 0]);
491 }
492
493 #[test]
494 fn one_dimension() {
495 let data = Array::from_shape_vec(
496 [2_usize, 1, 2, 1, 1, 1].as_slice(),
497 vec![1.0, 2.0, 3.0, 4.0],
498 )
499 .unwrap();
500
501 let encoded = compress(
502 data.view(),
503 &ZfpCompressionMode::FixedAccuracy { tolerance: 0.1 },
504 )
505 .unwrap();
506 let decoded = decompress(&encoded).unwrap();
507
508 assert_eq!(decoded, AnyArray::F32(data));
509 }
510
511 #[test]
512 fn small_state() {
513 for data in [
514 &[][..],
515 &[0.0],
516 &[0.0, 1.0],
517 &[0.0, 1.0, 0.0],
518 &[0.0, 1.0, 0.0, 1.0],
519 ] {
520 let encoded = compress(
521 ArrayView1::from(data),
522 &ZfpCompressionMode::FixedAccuracy { tolerance: 0.1 },
523 )
524 .unwrap();
525 let decoded = decompress(&encoded).unwrap();
526
527 assert_eq!(
528 decoded,
529 AnyArray::F64(Array1::from_vec(data.to_vec()).into_dyn())
530 );
531 }
532 }
533}