1#![allow(clippy::multiple_crate_versions)] #[cfg(test)]
23use ::serde_json as _;
24
25use std::borrow::Cow;
26use std::fmt;
27
28use ndarray::{Array, Array1, ArrayBase, Axis, Data, Dimension, IxDyn, ShapeError};
29use num_traits::{Float, identities::Zero};
30use numcodecs::{
31 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
32 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
33};
34use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
35use serde::{Deserialize, Deserializer, Serialize, Serializer};
36use thiserror::Error;
37
38type SperrCodecVersion = StaticCodecVersion<0, 2, 0>;
39
40#[derive(Clone, Serialize, Deserialize, JsonSchema)]
41#[schemars(deny_unknown_fields)]
43pub struct SperrCodec {
51 #[serde(flatten)]
53 pub mode: SperrCompressionMode,
54 #[serde(default, rename = "_version")]
56 pub version: SperrCodecVersion,
57}
58
59#[derive(Clone, Serialize, Deserialize, JsonSchema)]
60#[serde(tag = "mode")]
62pub enum SperrCompressionMode {
63 #[serde(rename = "bpp")]
65 BitsPerPixel {
66 bpp: Positive<f64>,
68 },
69 #[serde(rename = "psnr")]
71 PeakSignalToNoiseRatio {
72 psnr: Positive<f64>,
74 },
75 #[serde(rename = "pwe")]
77 PointwiseError {
78 pwe: Positive<f64>,
80 },
81 #[serde(rename = "q")]
83 QuantisationStep {
84 q: Positive<f64>,
86 },
87}
88
89impl Codec for SperrCodec {
90 type Error = SperrCodecError;
91
92 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
93 match data {
94 AnyCowArray::F32(data) => Ok(AnyArray::U8(
95 Array1::from(compress(data, &self.mode)?).into_dyn(),
96 )),
97 AnyCowArray::F64(data) => Ok(AnyArray::U8(
98 Array1::from(compress(data, &self.mode)?).into_dyn(),
99 )),
100 encoded => Err(SperrCodecError::UnsupportedDtype(encoded.dtype())),
101 }
102 }
103
104 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
105 let AnyCowArray::U8(encoded) = encoded else {
106 return Err(SperrCodecError::EncodedDataNotBytes {
107 dtype: encoded.dtype(),
108 });
109 };
110
111 if !matches!(encoded.shape(), [_]) {
112 return Err(SperrCodecError::EncodedDataNotOneDimensional {
113 shape: encoded.shape().to_vec(),
114 });
115 }
116
117 decompress(&AnyCowArray::U8(encoded).as_bytes())
118 }
119
120 fn decode_into(
121 &self,
122 encoded: AnyArrayView,
123 mut decoded: AnyArrayViewMut,
124 ) -> Result<(), Self::Error> {
125 let decoded_in = self.decode(encoded.cow())?;
126
127 Ok(decoded.assign(&decoded_in)?)
128 }
129}
130
131impl StaticCodec for SperrCodec {
132 const CODEC_ID: &'static str = "sperr.rs";
133
134 type Config<'de> = Self;
135
136 fn from_config(config: Self::Config<'_>) -> Self {
137 config
138 }
139
140 fn get_config(&self) -> StaticCodecConfig<'_, Self> {
141 StaticCodecConfig::from(self)
142 }
143}
144
145#[derive(Debug, Error)]
146pub enum SperrCodecError {
148 #[error("Sperr does not support the dtype {0}")]
150 UnsupportedDtype(AnyArrayDType),
151 #[error("Sperr failed to encode the header")]
153 HeaderEncodeFailed {
154 source: SperrHeaderError,
156 },
157 #[error("Sperr failed to encode the data")]
159 SperrEncodeFailed {
160 source: SperrCodingError,
162 },
163 #[error("Sperr failed to encode a slice")]
165 SliceEncodeFailed {
166 source: SperrSliceError,
168 },
169 #[error(
172 "Sperr can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
173 )]
174 EncodedDataNotBytes {
175 dtype: AnyArrayDType,
177 },
178 #[error(
181 "Sperr can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
182 )]
183 EncodedDataNotOneDimensional {
184 shape: Vec<usize>,
186 },
187 #[error("Sperr failed to decode the header")]
189 HeaderDecodeFailed {
190 source: SperrHeaderError,
192 },
193 #[error("Sperr failed to decode a slice")]
195 SliceDecodeFailed {
196 source: SperrSliceError,
198 },
199 #[error("Sperr failed to decode from an excessive number of slices")]
201 DecodeTooManySlices,
202 #[error("Sperr failed to decode the data")]
204 SperrDecodeFailed {
205 source: SperrCodingError,
207 },
208 #[error("Sperr decoded into an invalid shape not matching the data size")]
210 DecodeInvalidShape {
211 source: ShapeError,
213 },
214 #[error("Sperr cannot decode into the provided array")]
216 MismatchedDecodeIntoArray {
217 #[from]
219 source: AnyArrayAssignError,
220 },
221}
222
223#[derive(Debug, Error)]
224#[error(transparent)]
225pub struct SperrHeaderError(postcard::Error);
227
228#[derive(Debug, Error)]
229#[error(transparent)]
230pub struct SperrSliceError(postcard::Error);
232
233#[derive(Debug, Error)]
234#[error(transparent)]
235pub struct SperrCodingError(sperr::Error);
237
238#[allow(clippy::missing_panics_doc)]
247pub fn compress<T: SperrElement, S: Data<Elem = T>, D: Dimension>(
248 data: ArrayBase<S, D>,
249 mode: &SperrCompressionMode,
250) -> Result<Vec<u8>, SperrCodecError> {
251 let mut encoded = postcard::to_extend(
252 &CompressionHeader {
253 dtype: T::DTYPE,
254 shape: Cow::Borrowed(data.shape()),
255 version: StaticCodecVersion,
256 },
257 Vec::new(),
258 )
259 .map_err(|err| SperrCodecError::HeaderEncodeFailed {
260 source: SperrHeaderError(err),
261 })?;
262
263 if data.is_empty() {
265 return Ok(encoded);
266 }
267
268 let mut chunk_size = Vec::from(data.shape());
269 let (width, height, depth) = match *chunk_size.as_mut_slice() {
270 [ref mut rest @ .., depth, height, width] => {
271 for r in rest {
272 *r = 1;
273 }
274 (width, height, depth)
275 }
276 [height, width] => (width, height, 1),
277 [width] => (width, 1, 1),
278 [] => (1, 1, 1),
279 };
280
281 for mut slice in data.into_dyn().exact_chunks(chunk_size.as_slice()) {
282 while slice.ndim() < 3 {
283 slice = slice.insert_axis(Axis(0));
284 }
285 #[allow(clippy::unwrap_used)]
286 let slice = slice.into_shape_with_order((depth, height, width)).unwrap();
289
290 let encoded_slice = sperr::compress_3d(
291 slice,
292 match mode {
293 SperrCompressionMode::BitsPerPixel { bpp } => {
294 sperr::CompressionMode::BitsPerPixel { bpp: bpp.0 }
295 }
296 SperrCompressionMode::PeakSignalToNoiseRatio { psnr } => {
297 sperr::CompressionMode::PeakSignalToNoiseRatio { psnr: psnr.0 }
298 }
299 SperrCompressionMode::PointwiseError { pwe } => {
300 sperr::CompressionMode::PointwiseError { pwe: pwe.0 }
301 }
302 SperrCompressionMode::QuantisationStep { q } => {
303 sperr::CompressionMode::QuantisationStep { q: q.0 }
304 }
305 },
306 (256, 256, 256),
307 )
308 .map_err(|err| SperrCodecError::SperrEncodeFailed {
309 source: SperrCodingError(err),
310 })?;
311
312 encoded = postcard::to_extend(encoded_slice.as_slice(), encoded).map_err(|err| {
313 SperrCodecError::SliceEncodeFailed {
314 source: SperrSliceError(err),
315 }
316 })?;
317 }
318
319 Ok(encoded)
320}
321
322pub fn decompress(encoded: &[u8]) -> Result<AnyArray, SperrCodecError> {
335 fn decompress_typed<T: SperrElement>(
336 mut encoded: &[u8],
337 shape: &[usize],
338 ) -> Result<Array<T, IxDyn>, SperrCodecError> {
339 let mut decoded = Array::<T, _>::zeros(shape);
340
341 let mut chunk_size = Vec::from(shape);
342 let (width, height, depth) = match *chunk_size.as_mut_slice() {
343 [ref mut rest @ .., depth, height, width] => {
344 for r in rest {
345 *r = 1;
346 }
347 (width, height, depth)
348 }
349 [height, width] => (width, height, 1),
350 [width] => (width, 1, 1),
351 [] => (1, 1, 1),
352 };
353
354 for mut slice in decoded.exact_chunks_mut(chunk_size.as_slice()) {
355 let (encoded_slice, rest) =
356 postcard::take_from_bytes::<Cow<[u8]>>(encoded).map_err(|err| {
357 SperrCodecError::SliceDecodeFailed {
358 source: SperrSliceError(err),
359 }
360 })?;
361 encoded = rest;
362
363 while slice.ndim() < 3 {
364 slice = slice.insert_axis(Axis(0));
365 }
366 #[allow(clippy::unwrap_used)]
367 let slice = slice.into_shape_with_order((depth, height, width)).unwrap();
370
371 sperr::decompress_into_3d(&encoded_slice, slice).map_err(|err| {
372 SperrCodecError::SperrDecodeFailed {
373 source: SperrCodingError(err),
374 }
375 })?;
376 }
377
378 if !encoded.is_empty() {
379 return Err(SperrCodecError::DecodeTooManySlices);
380 }
381
382 Ok(decoded)
383 }
384
385 let (header, encoded) =
386 postcard::take_from_bytes::<CompressionHeader>(encoded).map_err(|err| {
387 SperrCodecError::HeaderDecodeFailed {
388 source: SperrHeaderError(err),
389 }
390 })?;
391
392 if header.shape.iter().copied().product::<usize>() == 0 {
394 return match header.dtype {
395 SperrDType::F32 => Ok(AnyArray::F32(Array::zeros(&*header.shape))),
396 SperrDType::F64 => Ok(AnyArray::F64(Array::zeros(&*header.shape))),
397 };
398 }
399
400 match header.dtype {
401 SperrDType::F32 => Ok(AnyArray::F32(decompress_typed(encoded, &header.shape)?)),
402 SperrDType::F64 => Ok(AnyArray::F64(decompress_typed(encoded, &header.shape)?)),
403 }
404}
405
406pub trait SperrElement: sperr::Element + Zero {
408 const DTYPE: SperrDType;
410}
411
412impl SperrElement for f32 {
413 const DTYPE: SperrDType = SperrDType::F32;
414}
415impl SperrElement for f64 {
416 const DTYPE: SperrDType = SperrDType::F64;
417}
418
419#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
421pub struct Positive<T: Float>(T);
423
424impl<T: Float> Positive<T> {
425 #[must_use]
426 pub const fn get(self) -> T {
428 self.0
429 }
430}
431
432impl Serialize for Positive<f64> {
433 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
434 serializer.serialize_f64(self.0)
435 }
436}
437
438impl<'de> Deserialize<'de> for Positive<f64> {
439 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
440 let x = f64::deserialize(deserializer)?;
441
442 if x > 0.0 {
443 Ok(Self(x))
444 } else {
445 Err(serde::de::Error::invalid_value(
446 serde::de::Unexpected::Float(x),
447 &"a positive value",
448 ))
449 }
450 }
451}
452
453impl JsonSchema for Positive<f64> {
454 fn schema_name() -> Cow<'static, str> {
455 Cow::Borrowed("PositiveF64")
456 }
457
458 fn schema_id() -> Cow<'static, str> {
459 Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
460 }
461
462 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
463 json_schema!({
464 "type": "number",
465 "exclusiveMinimum": 0.0
466 })
467 }
468}
469
470#[derive(Serialize, Deserialize)]
471struct CompressionHeader<'a> {
472 dtype: SperrDType,
473 #[serde(borrow)]
474 shape: Cow<'a, [usize]>,
475 version: SperrCodecVersion,
476}
477
478#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
480#[expect(missing_docs)]
481pub enum SperrDType {
482 #[serde(rename = "f32", alias = "float32")]
483 F32,
484 #[serde(rename = "f64", alias = "float64")]
485 F64,
486}
487
488impl fmt::Display for SperrDType {
489 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
490 fmt.write_str(match self {
491 Self::F32 => "f32",
492 Self::F64 => "f64",
493 })
494 }
495}
496
497#[cfg(test)]
498#[allow(clippy::unwrap_used)]
499mod tests {
500 use ndarray::{Ix0, Ix1, Ix2, Ix3, Ix4};
501
502 use super::*;
503
504 #[test]
505 fn zero_length() {
506 let encoded = compress(
507 Array::<f32, _>::from_shape_vec([3, 0], vec![]).unwrap(),
508 &SperrCompressionMode::PeakSignalToNoiseRatio {
509 psnr: Positive(42.0),
510 },
511 )
512 .unwrap();
513 let decoded = decompress(&encoded).unwrap();
514
515 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
516 assert!(decoded.is_empty());
517 assert_eq!(decoded.shape(), &[3, 0]);
518 }
519
520 #[test]
521 fn small_2d() {
522 let encoded = compress(
523 Array::<f32, _>::from_shape_vec([1, 1], vec![42.0]).unwrap(),
524 &SperrCompressionMode::PeakSignalToNoiseRatio {
525 psnr: Positive(42.0),
526 },
527 )
528 .unwrap();
529 let decoded = decompress(&encoded).unwrap();
530
531 assert_eq!(decoded.dtype(), AnyArrayDType::F32);
532 assert_eq!(decoded.len(), 1);
533 assert_eq!(decoded.shape(), &[1, 1]);
534 }
535
536 #[test]
537 fn large_3d() {
538 let encoded = compress(
539 Array::<f64, _>::zeros((64, 64, 64)),
540 &SperrCompressionMode::PeakSignalToNoiseRatio {
541 psnr: Positive(42.0),
542 },
543 )
544 .unwrap();
545 let decoded = decompress(&encoded).unwrap();
546
547 assert_eq!(decoded.dtype(), AnyArrayDType::F64);
548 assert_eq!(decoded.len(), 64 * 64 * 64);
549 assert_eq!(decoded.shape(), &[64, 64, 64]);
550 }
551
552 #[test]
553 fn all_modes() {
554 for mode in [
555 SperrCompressionMode::BitsPerPixel { bpp: Positive(1.0) },
556 SperrCompressionMode::PeakSignalToNoiseRatio {
557 psnr: Positive(42.0),
558 },
559 SperrCompressionMode::PointwiseError { pwe: Positive(0.1) },
560 SperrCompressionMode::QuantisationStep { q: Positive(1.5) },
561 ] {
562 let encoded = compress(Array::<f64, _>::zeros((64, 64, 64)), &mode).unwrap();
563 let decoded = decompress(&encoded).unwrap();
564
565 assert_eq!(decoded.dtype(), AnyArrayDType::F64);
566 assert_eq!(decoded.len(), 64 * 64 * 64);
567 assert_eq!(decoded.shape(), &[64, 64, 64]);
568 }
569 }
570
571 #[test]
572 fn many_dimensions() {
573 for data in [
574 Array::<f32, Ix0>::from_shape_vec([], vec![42.0])
575 .unwrap()
576 .into_dyn(),
577 Array::<f32, Ix1>::from_shape_vec([2], vec![1.0, 2.0])
578 .unwrap()
579 .into_dyn(),
580 Array::<f32, Ix2>::from_shape_vec([2, 2], vec![1.0, 2.0, 3.0, 4.0])
581 .unwrap()
582 .into_dyn(),
583 Array::<f32, Ix3>::from_shape_vec(
584 [2, 2, 2],
585 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
586 )
587 .unwrap()
588 .into_dyn(),
589 Array::<f32, Ix4>::from_shape_vec(
590 [2, 2, 2, 2],
591 vec![
592 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
593 15.0, 16.0,
594 ],
595 )
596 .unwrap()
597 .into_dyn(),
598 ] {
599 let encoded = compress(
600 data.view(),
601 &SperrCompressionMode::PointwiseError {
602 pwe: Positive(f64::EPSILON),
603 },
604 )
605 .unwrap();
606 let decoded = decompress(&encoded).unwrap();
607
608 assert_eq!(decoded, AnyArray::F32(data));
609 }
610 }
611}