1#![allow(clippy::multiple_crate_versions)] #[cfg(test)]
23use ::serde_json as _;
24
25use std::borrow::Cow;
26use std::fmt;
27
28use ndarray::{Array, Array1, ArrayBase, Axis, Data, Dimension, IxDyn, ShapeError};
29use num_traits::{Float, identities::Zero};
30use numcodecs::{
31 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
32 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
33};
34use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
35use serde::{Deserialize, Deserializer, Serialize, Serializer};
36use thiserror::Error;
37
38type SperrCodecVersion = StaticCodecVersion<0, 1, 0>;
39
40#[derive(Clone, Serialize, Deserialize, JsonSchema)]
41#[schemars(deny_unknown_fields)]
43pub struct SperrCodec {
51 #[serde(flatten)]
53 pub mode: SperrCompressionMode,
54 #[serde(default, rename = "_version")]
56 pub version: SperrCodecVersion,
57}
58
59#[derive(Clone, Serialize, Deserialize, JsonSchema)]
60#[serde(tag = "mode")]
62pub enum SperrCompressionMode {
63 #[serde(rename = "bpp")]
65 BitsPerPixel {
66 bpp: Positive<f64>,
68 },
69 #[serde(rename = "psnr")]
71 PeakSignalToNoiseRatio {
72 psnr: Positive<f64>,
74 },
75 #[serde(rename = "pwe")]
77 PointwiseError {
78 pwe: Positive<f64>,
80 },
81}
82
83impl Codec for SperrCodec {
84 type Error = SperrCodecError;
85
86 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
87 match data {
88 AnyCowArray::F32(data) => Ok(AnyArray::U8(
89 Array1::from(compress(data, &self.mode)?).into_dyn(),
90 )),
91 AnyCowArray::F64(data) => Ok(AnyArray::U8(
92 Array1::from(compress(data, &self.mode)?).into_dyn(),
93 )),
94 encoded => Err(SperrCodecError::UnsupportedDtype(encoded.dtype())),
95 }
96 }
97
98 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
99 let AnyCowArray::U8(encoded) = encoded else {
100 return Err(SperrCodecError::EncodedDataNotBytes {
101 dtype: encoded.dtype(),
102 });
103 };
104
105 if !matches!(encoded.shape(), [_]) {
106 return Err(SperrCodecError::EncodedDataNotOneDimensional {
107 shape: encoded.shape().to_vec(),
108 });
109 }
110
111 decompress(&AnyCowArray::U8(encoded).as_bytes())
112 }
113
114 fn decode_into(
115 &self,
116 encoded: AnyArrayView,
117 mut decoded: AnyArrayViewMut,
118 ) -> Result<(), Self::Error> {
119 let decoded_in = self.decode(encoded.cow())?;
120
121 Ok(decoded.assign(&decoded_in)?)
122 }
123}
124
125impl StaticCodec for SperrCodec {
126 const CODEC_ID: &'static str = "sperr.rs";
127
128 type Config<'de> = Self;
129
130 fn from_config(config: Self::Config<'_>) -> Self {
131 config
132 }
133
134 fn get_config(&self) -> StaticCodecConfig<Self> {
135 StaticCodecConfig::from(self)
136 }
137}
138
139#[derive(Debug, Error)]
140pub enum SperrCodecError {
142 #[error("Sperr does not support the dtype {0}")]
144 UnsupportedDtype(AnyArrayDType),
145 #[error("Sperr failed to encode the header")]
147 HeaderEncodeFailed {
148 source: SperrHeaderError,
150 },
151 #[error("Sperr failed to encode the data")]
153 SperrEncodeFailed {
154 source: SperrCodingError,
156 },
157 #[error("Sperr failed to encode a slice")]
159 SliceEncodeFailed {
160 source: SperrSliceError,
162 },
163 #[error(
166 "Sperr can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
167 )]
168 EncodedDataNotBytes {
169 dtype: AnyArrayDType,
171 },
172 #[error(
175 "Sperr can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
176 )]
177 EncodedDataNotOneDimensional {
178 shape: Vec<usize>,
180 },
181 #[error("Sperr failed to decode the header")]
183 HeaderDecodeFailed {
184 source: SperrHeaderError,
186 },
187 #[error("Sperr failed to decode a slice")]
189 SliceDecodeFailed {
190 source: SperrSliceError,
192 },
193 #[error("Sperr failed to decode from an excessive number of slices")]
195 DecodeTooManySlices,
196 #[error("Sperr failed to decode the data")]
198 SperrDecodeFailed {
199 source: SperrCodingError,
201 },
202 #[error("Sperr decoded into an invalid shape not matching the data size")]
204 DecodeInvalidShape {
205 source: ShapeError,
207 },
208 #[error("Sperr cannot decode into the provided array")]
210 MismatchedDecodeIntoArray {
211 #[from]
213 source: AnyArrayAssignError,
214 },
215}
216
217#[derive(Debug, Error)]
218#[error(transparent)]
219pub struct SperrHeaderError(postcard::Error);
221
222#[derive(Debug, Error)]
223#[error(transparent)]
224pub struct SperrSliceError(postcard::Error);
226
227#[derive(Debug, Error)]
228#[error(transparent)]
229pub struct SperrCodingError(sperr::Error);
231
232#[allow(clippy::missing_panics_doc)]
241pub fn compress<T: SperrElement, S: Data<Elem = T>, D: Dimension>(
242 data: ArrayBase<S, D>,
243 mode: &SperrCompressionMode,
244) -> Result<Vec<u8>, SperrCodecError> {
245 let mut encoded = postcard::to_extend(
246 &CompressionHeader {
247 dtype: T::DTYPE,
248 shape: Cow::Borrowed(data.shape()),
249 version: StaticCodecVersion,
250 },
251 Vec::new(),
252 )
253 .map_err(|err| SperrCodecError::HeaderEncodeFailed {
254 source: SperrHeaderError(err),
255 })?;
256
257 if data.is_empty() {
259 return Ok(encoded);
260 }
261
262 let mut chunk_size = Vec::from(data.shape());
263 let (width, height, depth) = match *chunk_size.as_mut_slice() {
264 [ref mut rest @ .., depth, height, width] => {
265 for r in rest {
266 *r = 1;
267 }
268 (width, height, depth)
269 }
270 [height, width] => (width, height, 1),
271 [width] => (width, 1, 1),
272 [] => (1, 1, 1),
273 };
274
275 for mut slice in data.into_dyn().exact_chunks(chunk_size.as_slice()) {
276 while slice.ndim() < 3 {
277 slice = slice.insert_axis(Axis(0));
278 }
279 #[allow(clippy::unwrap_used)]
280 let slice = slice.into_shape_with_order((depth, height, width)).unwrap();
283
284 let encoded_slice = sperr::compress_3d(
285 slice,
286 match mode {
287 SperrCompressionMode::BitsPerPixel { bpp } => {
288 sperr::CompressionMode::BitsPerPixel { bpp: bpp.0 }
289 }
290 SperrCompressionMode::PeakSignalToNoiseRatio { psnr } => {
291 sperr::CompressionMode::PeakSignalToNoiseRatio { psnr: psnr.0 }
292 }
293 SperrCompressionMode::PointwiseError { pwe } => {
294 sperr::CompressionMode::PointwiseError { pwe: pwe.0 }
295 }
296 },
297 (256, 256, 256),
298 )
299 .map_err(|err| SperrCodecError::SperrEncodeFailed {
300 source: SperrCodingError(err),
301 })?;
302
303 encoded = postcard::to_extend(encoded_slice.as_slice(), encoded).map_err(|err| {
304 SperrCodecError::SliceEncodeFailed {
305 source: SperrSliceError(err),
306 }
307 })?;
308 }
309
310 Ok(encoded)
311}
312
313pub fn decompress(encoded: &[u8]) -> Result<AnyArray, SperrCodecError> {
326 fn decompress_typed<T: SperrElement>(
327 mut encoded: &[u8],
328 shape: &[usize],
329 ) -> Result<Array<T, IxDyn>, SperrCodecError> {
330 let mut decoded = Array::<T, _>::zeros(shape);
331
332 let mut chunk_size = Vec::from(shape);
333 let (width, height, depth) = match *chunk_size.as_mut_slice() {
334 [ref mut rest @ .., depth, height, width] => {
335 for r in rest {
336 *r = 1;
337 }
338 (width, height, depth)
339 }
340 [height, width] => (width, height, 1),
341 [width] => (width, 1, 1),
342 [] => (1, 1, 1),
343 };
344
345 for mut slice in decoded.exact_chunks_mut(chunk_size.as_slice()) {
346 let (encoded_slice, rest) =
347 postcard::take_from_bytes::<Cow<[u8]>>(encoded).map_err(|err| {
348 SperrCodecError::SliceDecodeFailed {
349 source: SperrSliceError(err),
350 }
351 })?;
352 encoded = rest;
353
354 while slice.ndim() < 3 {
355 slice = slice.insert_axis(Axis(0));
356 }
357 #[allow(clippy::unwrap_used)]
358 let slice = slice.into_shape_with_order((depth, height, width)).unwrap();
361
362 sperr::decompress_into_3d(&encoded_slice, slice).map_err(|err| {
363 SperrCodecError::SperrDecodeFailed {
364 source: SperrCodingError(err),
365 }
366 })?;
367 }
368
369 if !encoded.is_empty() {
370 return Err(SperrCodecError::DecodeTooManySlices);
371 }
372
373 Ok(decoded)
374 }
375
376 let (header, encoded) =
377 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
378 SperrCodecError::HeaderDecodeFailed {
379 source: SperrHeaderError(err),
380 }
381 })?;
382
383 if header.shape.iter().copied().product::<usize>() == 0 {
385 return match header.dtype {
386 SperrDType::F32 => Ok(AnyArray::F32(Array::zeros(&*header.shape))),
387 SperrDType::F64 => Ok(AnyArray::F64(Array::zeros(&*header.shape))),
388 };
389 }
390
391 match header.dtype {
392 SperrDType::F32 => Ok(AnyArray::F32(decompress_typed(encoded, &header.shape)?)),
393 SperrDType::F64 => Ok(AnyArray::F64(decompress_typed(encoded, &header.shape)?)),
394 }
395}
396
397pub trait SperrElement: sperr::Element + Zero {
399 const DTYPE: SperrDType;
401}
402
403impl SperrElement for f32 {
404 const DTYPE: SperrDType = SperrDType::F32;
405}
406impl SperrElement for f64 {
407 const DTYPE: SperrDType = SperrDType::F64;
408}
409
410#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
412pub struct Positive<T: Float>(T);
414
415impl<T: Float> Positive<T> {
416 #[must_use]
417 pub const fn get(self) -> T {
419 self.0
420 }
421}
422
423impl Serialize for Positive<f64> {
424 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
425 serializer.serialize_f64(self.0)
426 }
427}
428
429impl<'de> Deserialize<'de> for Positive<f64> {
430 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
431 let x = f64::deserialize(deserializer)?;
432
433 if x > 0.0 {
434 Ok(Self(x))
435 } else {
436 Err(serde::de::Error::invalid_value(
437 serde::de::Unexpected::Float(x),
438 &"a positive value",
439 ))
440 }
441 }
442}
443
444impl JsonSchema for Positive<f64> {
445 fn schema_name() -> Cow<'static, str> {
446 Cow::Borrowed("PositiveF64")
447 }
448
449 fn schema_id() -> Cow<'static, str> {
450 Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
451 }
452
453 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
454 json_schema!({
455 "type": "number",
456 "exclusiveMinimum": 0.0
457 })
458 }
459}
460
461#[derive(Serialize, Deserialize)]
462struct CompressionHeader<'a> {
463 dtype: SperrDType,
464 #[serde(borrow)]
465 shape: Cow<'a, [usize]>,
466 version: SperrCodecVersion,
467}
468
469#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
471#[expect(missing_docs)]
472pub enum SperrDType {
473 #[serde(rename = "f32", alias = "float32")]
474 F32,
475 #[serde(rename = "f64", alias = "float64")]
476 F64,
477}
478
479impl fmt::Display for SperrDType {
480 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
481 fmt.write_str(match self {
482 Self::F32 => "f32",
483 Self::F64 => "f64",
484 })
485 }
486}
487
488#[cfg(test)]
489#[allow(clippy::unwrap_used)]
490mod tests {
491 use ndarray::{Ix0, Ix1, Ix2, Ix3, Ix4};
492
493 use super::*;
494
495 #[test]
496 fn zero_length() {
497 let encoded = compress(
498 Array::<f32, _>::from_shape_vec([3, 0], vec![]).unwrap(),
499 &SperrCompressionMode::PeakSignalToNoiseRatio {
500 psnr: Positive(42.0),
501 },
502 )
503 .unwrap();
504 let decoded = decompress(&encoded).unwrap();
505
506 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
507 assert!(decoded.is_empty());
508 assert_eq!(decoded.shape(), &[3, 0]);
509 }
510
511 #[test]
512 fn small_2d() {
513 let encoded = compress(
514 Array::<f32, _>::from_shape_vec([1, 1], vec![42.0]).unwrap(),
515 &SperrCompressionMode::PeakSignalToNoiseRatio {
516 psnr: Positive(42.0),
517 },
518 )
519 .unwrap();
520 let decoded = decompress(&encoded).unwrap();
521
522 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
523 assert_eq!(decoded.len(), 1);
524 assert_eq!(decoded.shape(), &[1, 1]);
525 }
526
527 #[test]
528 fn large_3d() {
529 let encoded = compress(
530 Array::<f64, _>::zeros((64, 64, 64)),
531 &SperrCompressionMode::PeakSignalToNoiseRatio {
532 psnr: Positive(42.0),
533 },
534 )
535 .unwrap();
536 let decoded = decompress(&encoded).unwrap();
537
538 assert_eq!(decoded.dtype(), AnyArrayDType::F64);
539 assert_eq!(decoded.len(), 64 * 64 * 64);
540 assert_eq!(decoded.shape(), &[64, 64, 64]);
541 }
542
543 #[test]
544 fn all_modes() {
545 for mode in [
546 SperrCompressionMode::BitsPerPixel { bpp: Positive(1.0) },
547 SperrCompressionMode::PeakSignalToNoiseRatio {
548 psnr: Positive(42.0),
549 },
550 SperrCompressionMode::PointwiseError { pwe: Positive(0.1) },
551 ] {
552 let encoded = compress(Array::<f64, _>::zeros((64, 64, 64)), &mode).unwrap();
553 let decoded = decompress(&encoded).unwrap();
554
555 assert_eq!(decoded.dtype(), AnyArrayDType::F64);
556 assert_eq!(decoded.len(), 64 * 64 * 64);
557 assert_eq!(decoded.shape(), &[64, 64, 64]);
558 }
559 }
560
561 #[test]
562 fn many_dimensions() {
563 for data in [
564 Array::<f32, Ix0>::from_shape_vec([], vec![42.0])
565 .unwrap()
566 .into_dyn(),
567 Array::<f32, Ix1>::from_shape_vec([2], vec![1.0, 2.0])
568 .unwrap()
569 .into_dyn(),
570 Array::<f32, Ix2>::from_shape_vec([2, 2], vec![1.0, 2.0, 3.0, 4.0])
571 .unwrap()
572 .into_dyn(),
573 Array::<f32, Ix3>::from_shape_vec(
574 [2, 2, 2],
575 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
576 )
577 .unwrap()
578 .into_dyn(),
579 Array::<f32, Ix4>::from_shape_vec(
580 [2, 2, 2, 2],
581 vec![
582 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
583 15.0, 16.0,
584 ],
585 )
586 .unwrap()
587 .into_dyn(),
588 ] {
589 let encoded = compress(
590 data.view(),
591 &SperrCompressionMode::PointwiseError {
592 pwe: Positive(f64::EPSILON),
593 },
594 )
595 .unwrap();
596 let decoded = decompress(&encoded).unwrap();
597
598 assert_eq!(decoded, AnyArray::F32(data));
599 }
600 }
601}