numcodecs_tthresh/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-tthresh
10//! [crates.io]: https://crates.io/crates/numcodecs-tthresh
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-tthresh
13//! [docs.rs]: https://docs.rs/numcodecs-tthresh/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_tthresh
17//!
18//! tthresh codec implementation for the [`numcodecs`] API.
19
20use std::borrow::Cow;
21
22use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
23use num_traits::Float;
24use numcodecs::{
25    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26    Codec, StaticCodec, StaticCodecConfig,
27};
28use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[cfg(test)]
33use ::serde_json as _;
34
35#[derive(Clone, Serialize, Deserialize, JsonSchema)]
36// serde cannot deny unknown fields because of the flatten
37#[schemars(deny_unknown_fields)]
38/// Codec providing compression using tthresh
39pub struct TthreshCodec {
40    /// tthresh error bound
41    #[serde(flatten)]
42    pub error_bound: TthreshErrorBound,
43}
44
45/// tthresh error bound
46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
47#[serde(tag = "eb_mode")]
48#[serde(deny_unknown_fields)]
49pub enum TthreshErrorBound {
50    /// Relative error bound
51    #[serde(rename = "eps")]
52    Eps {
53        /// Relative error bound
54        #[serde(rename = "eb_eps")]
55        eps: NonNegative<f64>,
56    },
57    /// Root mean square error bound
58    #[serde(rename = "rmse")]
59    RMSE {
60        /// Peak signal to noise ratio error bound
61        #[serde(rename = "eb_rmse")]
62        rmse: NonNegative<f64>,
63    },
64    /// Peak signal-to-noise ratio error bound
65    #[serde(rename = "psnr")]
66    PSNR {
67        /// Peak signal to noise ratio error bound
68        #[serde(rename = "eb_psnr")]
69        psnr: NonNegative<f64>,
70    },
71}
72
73impl Codec for TthreshCodec {
74    type Error = TthreshCodecError;
75
76    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
77        match data {
78            AnyCowArray::U8(data) => Ok(AnyArray::U8(
79                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
80            )),
81            AnyCowArray::U16(data) => Ok(AnyArray::U8(
82                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
83            )),
84            AnyCowArray::I32(data) => Ok(AnyArray::U8(
85                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
86            )),
87            AnyCowArray::F32(data) => Ok(AnyArray::U8(
88                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
89            )),
90            AnyCowArray::F64(data) => Ok(AnyArray::U8(
91                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
92            )),
93            encoded => Err(TthreshCodecError::UnsupportedDtype(encoded.dtype())),
94        }
95    }
96
97    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
98        let AnyCowArray::U8(encoded) = encoded else {
99            return Err(TthreshCodecError::EncodedDataNotBytes {
100                dtype: encoded.dtype(),
101            });
102        };
103
104        if !matches!(encoded.shape(), [_]) {
105            return Err(TthreshCodecError::EncodedDataNotOneDimensional {
106                shape: encoded.shape().to_vec(),
107            });
108        }
109
110        decompress(&AnyCowArray::U8(encoded).as_bytes())
111    }
112
113    fn decode_into(
114        &self,
115        encoded: AnyArrayView,
116        mut decoded: AnyArrayViewMut,
117    ) -> Result<(), Self::Error> {
118        let decoded_in = self.decode(encoded.cow())?;
119
120        Ok(decoded.assign(&decoded_in)?)
121    }
122}
123
124impl StaticCodec for TthreshCodec {
125    const CODEC_ID: &'static str = "tthresh";
126
127    type Config<'de> = Self;
128
129    fn from_config(config: Self::Config<'_>) -> Self {
130        config
131    }
132
133    fn get_config(&self) -> StaticCodecConfig<Self> {
134        StaticCodecConfig::from(self)
135    }
136}
137
138#[derive(Debug, Error)]
139/// Errors that may occur when applying the [`TthreshCodec`].
140pub enum TthreshCodecError {
141    /// [`TthreshCodec`] does not support the dtype
142    #[error("Tthresh does not support the dtype {0}")]
143    UnsupportedDtype(AnyArrayDType),
144    /// [`TthreshCodec`] failed to encode the data
145    #[error("Tthresh failed to encode the data")]
146    TthreshEncodeFailed {
147        /// Opaque source error
148        source: TthreshCodingError,
149    },
150    /// [`TthreshCodec`] can only decode one-dimensional byte arrays but received
151    /// an array of a different dtype
152    #[error(
153        "Tthresh can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
154    )]
155    EncodedDataNotBytes {
156        /// The unexpected dtype of the encoded array
157        dtype: AnyArrayDType,
158    },
159    /// [`TthreshCodec`] can only decode one-dimensional byte arrays but received
160    /// an array of a different shape
161    #[error("Tthresh can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
162    EncodedDataNotOneDimensional {
163        /// The unexpected shape of the encoded array
164        shape: Vec<usize>,
165    },
166    /// [`TthreshCodec`] failed to decode the data
167    #[error("Tthresh failed to decode the data")]
168    TthreshDecodeFailed {
169        /// Opaque source error
170        source: TthreshCodingError,
171    },
172    /// [`TthreshCodec`] decoded an invalid array shape header which does not fit
173    /// the decoded data
174    #[error("Tthresh decoded an invalid array shape header which does not fit the decoded data")]
175    DecodeInvalidShapeHeader {
176        /// Source error
177        #[from]
178        source: ShapeError,
179    },
180    /// [`TthreshCodec`] cannot decode into the provided array
181    #[error("Tthresh cannot decode into the provided array")]
182    MismatchedDecodeIntoArray {
183        /// The source of the error
184        #[from]
185        source: AnyArrayAssignError,
186    },
187}
188
189#[derive(Debug, Error)]
190#[error(transparent)]
191/// Opaque error for when encoding or decoding with tthresh fails
192pub struct TthreshCodingError(tthresh::Error);
193
194#[expect(clippy::needless_pass_by_value)]
195/// Compresses the input `data` array using tthresh with the provided
196/// `error_bound`.
197///
198/// # Errors
199///
200/// Errors with
201/// - [`TthreshCodecError::TthreshEncodeFailed`] if encoding failed with an opaque error
202pub fn compress<T: TthreshElement, S: Data<Elem = T>, D: Dimension>(
203    data: ArrayBase<S, D>,
204    error_bound: &TthreshErrorBound,
205) -> Result<Vec<u8>, TthreshCodecError> {
206    #[expect(clippy::option_if_let_else)]
207    let data_cow = if let Some(data) = data.as_slice() {
208        Cow::Borrowed(data)
209    } else {
210        Cow::Owned(data.iter().copied().collect())
211    };
212
213    let compressed = tthresh::compress(
214        &data_cow,
215        data.shape(),
216        match error_bound {
217            TthreshErrorBound::Eps { eps } => tthresh::ErrorBound::Eps(eps.0),
218            TthreshErrorBound::RMSE { rmse } => tthresh::ErrorBound::RMSE(rmse.0),
219            TthreshErrorBound::PSNR { psnr } => tthresh::ErrorBound::PSNR(psnr.0),
220        },
221        false,
222        false,
223    )
224    .map_err(|err| TthreshCodecError::TthreshEncodeFailed {
225        source: TthreshCodingError(err),
226    })?;
227
228    Ok(compressed)
229}
230
231/// Decompresses the `encoded` data into an array.
232///
233/// # Errors
234///
235/// Errors with
236/// - [`TthreshCodecError::TthreshDecodeFailed`] if decoding failed with an opaque error
237pub fn decompress(encoded: &[u8]) -> Result<AnyArray, TthreshCodecError> {
238    let (decompressed, shape) = tthresh::decompress(encoded, false, false).map_err(|err| {
239        TthreshCodecError::TthreshDecodeFailed {
240            source: TthreshCodingError(err),
241        }
242    })?;
243
244    let decoded = match decompressed {
245        tthresh::Buffer::U8(decompressed) => {
246            AnyArray::U8(Array::from_shape_vec(shape, decompressed)?.into_dyn())
247        }
248        tthresh::Buffer::U16(decompressed) => {
249            AnyArray::U16(Array::from_shape_vec(shape, decompressed)?.into_dyn())
250        }
251        tthresh::Buffer::I32(decompressed) => {
252            AnyArray::I32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
253        }
254        tthresh::Buffer::F32(decompressed) => {
255            AnyArray::F32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
256        }
257        tthresh::Buffer::F64(decompressed) => {
258            AnyArray::F64(Array::from_shape_vec(shape, decompressed)?.into_dyn())
259        }
260    };
261
262    Ok(decoded)
263}
264
265/// Array element types which can be compressed with tthresh.
266pub trait TthreshElement: Copy + tthresh::Element {}
267
268impl TthreshElement for u8 {}
269impl TthreshElement for u16 {}
270impl TthreshElement for i32 {}
271impl TthreshElement for f32 {}
272impl TthreshElement for f64 {}
273
274#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
275#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Hash)]
276/// Non-negative floating point number
277pub struct NonNegative<T: Float>(T);
278
279impl Serialize for NonNegative<f64> {
280    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
281        serializer.serialize_f64(self.0)
282    }
283}
284
285impl<'de> Deserialize<'de> for NonNegative<f64> {
286    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
287        let x = f64::deserialize(deserializer)?;
288
289        if x >= 0.0 {
290            Ok(Self(x))
291        } else {
292            Err(serde::de::Error::invalid_value(
293                serde::de::Unexpected::Float(x),
294                &"a non-negative value",
295            ))
296        }
297    }
298}
299
300impl JsonSchema for NonNegative<f64> {
301    fn schema_name() -> Cow<'static, str> {
302        Cow::Borrowed("NonNegativeF64")
303    }
304
305    fn schema_id() -> Cow<'static, str> {
306        Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
307    }
308
309    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
310        json_schema!({
311            "type": "number",
312            "minimum": 0.0
313        })
314    }
315}