numcodecs_tthresh/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.85.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-tthresh
10//! [crates.io]: https://crates.io/crates/numcodecs-tthresh
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-tthresh
13//! [docs.rs]: https://docs.rs/numcodecs-tthresh/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_tthresh
17//!
18//! tthresh codec implementation for the [`numcodecs`] API.
19
20use std::borrow::Cow;
21
22use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
23use num_traits::Float;
24use numcodecs::{
25    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
27};
28use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[cfg(test)]
33use ::serde_json as _;
34
35#[derive(Clone, Serialize, Deserialize, JsonSchema)]
36// serde cannot deny unknown fields because of the flatten
37#[schemars(deny_unknown_fields)]
38/// Codec providing compression using tthresh
39pub struct TthreshCodec {
40    /// tthresh error bound
41    #[serde(flatten)]
42    pub error_bound: TthreshErrorBound,
43    /// The codec's encoding format version. Do not provide this parameter explicitly.
44    #[serde(default, rename = "_version")]
45    pub version: StaticCodecVersion<0, 1, 0>,
46}
47
48/// tthresh error bound
49#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
50#[serde(tag = "eb_mode")]
51#[serde(deny_unknown_fields)]
52pub enum TthreshErrorBound {
53    /// Relative error bound
54    #[serde(rename = "eps")]
55    Eps {
56        /// Relative error bound
57        #[serde(rename = "eb_eps")]
58        eps: NonNegative<f64>,
59    },
60    /// Root mean square error bound
61    #[serde(rename = "rmse")]
62    RMSE {
63        /// Peak signal to noise ratio error bound
64        #[serde(rename = "eb_rmse")]
65        rmse: NonNegative<f64>,
66    },
67    /// Peak signal-to-noise ratio error bound
68    #[serde(rename = "psnr")]
69    PSNR {
70        /// Peak signal to noise ratio error bound
71        #[serde(rename = "eb_psnr")]
72        psnr: NonNegative<f64>,
73    },
74}
75
76impl Codec for TthreshCodec {
77    type Error = TthreshCodecError;
78
79    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
80        match data {
81            AnyCowArray::U8(data) => Ok(AnyArray::U8(
82                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
83            )),
84            AnyCowArray::U16(data) => Ok(AnyArray::U8(
85                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
86            )),
87            AnyCowArray::I32(data) => Ok(AnyArray::U8(
88                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
89            )),
90            AnyCowArray::F32(data) => Ok(AnyArray::U8(
91                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
92            )),
93            AnyCowArray::F64(data) => Ok(AnyArray::U8(
94                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
95            )),
96            encoded => Err(TthreshCodecError::UnsupportedDtype(encoded.dtype())),
97        }
98    }
99
100    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
101        let AnyCowArray::U8(encoded) = encoded else {
102            return Err(TthreshCodecError::EncodedDataNotBytes {
103                dtype: encoded.dtype(),
104            });
105        };
106
107        if !matches!(encoded.shape(), [_]) {
108            return Err(TthreshCodecError::EncodedDataNotOneDimensional {
109                shape: encoded.shape().to_vec(),
110            });
111        }
112
113        decompress(&AnyCowArray::U8(encoded).as_bytes())
114    }
115
116    fn decode_into(
117        &self,
118        encoded: AnyArrayView,
119        mut decoded: AnyArrayViewMut,
120    ) -> Result<(), Self::Error> {
121        let decoded_in = self.decode(encoded.cow())?;
122
123        Ok(decoded.assign(&decoded_in)?)
124    }
125}
126
127impl StaticCodec for TthreshCodec {
128    const CODEC_ID: &'static str = "tthresh.rs";
129
130    type Config<'de> = Self;
131
132    fn from_config(config: Self::Config<'_>) -> Self {
133        config
134    }
135
136    fn get_config(&self) -> StaticCodecConfig<Self> {
137        StaticCodecConfig::from(self)
138    }
139}
140
141#[derive(Debug, Error)]
142/// Errors that may occur when applying the [`TthreshCodec`].
143pub enum TthreshCodecError {
144    /// [`TthreshCodec`] does not support the dtype
145    #[error("Tthresh does not support the dtype {0}")]
146    UnsupportedDtype(AnyArrayDType),
147    /// [`TthreshCodec`] failed to encode the data
148    #[error("Tthresh failed to encode the data")]
149    TthreshEncodeFailed {
150        /// Opaque source error
151        source: TthreshCodingError,
152    },
153    /// [`TthreshCodec`] can only decode one-dimensional byte arrays but received
154    /// an array of a different dtype
155    #[error(
156        "Tthresh can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
157    )]
158    EncodedDataNotBytes {
159        /// The unexpected dtype of the encoded array
160        dtype: AnyArrayDType,
161    },
162    /// [`TthreshCodec`] can only decode one-dimensional byte arrays but received
163    /// an array of a different shape
164    #[error(
165        "Tthresh can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
166    )]
167    EncodedDataNotOneDimensional {
168        /// The unexpected shape of the encoded array
169        shape: Vec<usize>,
170    },
171    /// [`TthreshCodec`] failed to decode the data
172    #[error("Tthresh failed to decode the data")]
173    TthreshDecodeFailed {
174        /// Opaque source error
175        source: TthreshCodingError,
176    },
177    /// [`TthreshCodec`] decoded an invalid array shape header which does not fit
178    /// the decoded data
179    #[error("Tthresh decoded an invalid array shape header which does not fit the decoded data")]
180    DecodeInvalidShapeHeader {
181        /// Source error
182        #[from]
183        source: ShapeError,
184    },
185    /// [`TthreshCodec`] cannot decode into the provided array
186    #[error("Tthresh cannot decode into the provided array")]
187    MismatchedDecodeIntoArray {
188        /// The source of the error
189        #[from]
190        source: AnyArrayAssignError,
191    },
192}
193
194#[derive(Debug, Error)]
195#[error(transparent)]
196/// Opaque error for when encoding or decoding with tthresh fails
197pub struct TthreshCodingError(tthresh::Error);
198
199#[expect(clippy::needless_pass_by_value)]
200/// Compresses the input `data` array using tthresh with the provided
201/// `error_bound`.
202///
203/// # Errors
204///
205/// Errors with
206/// - [`TthreshCodecError::TthreshEncodeFailed`] if encoding failed with an opaque error
207pub fn compress<T: TthreshElement, S: Data<Elem = T>, D: Dimension>(
208    data: ArrayBase<S, D>,
209    error_bound: &TthreshErrorBound,
210) -> Result<Vec<u8>, TthreshCodecError> {
211    #[expect(clippy::option_if_let_else)]
212    let data_cow = match data.as_slice() {
213        Some(data) => Cow::Borrowed(data),
214        None => Cow::Owned(data.iter().copied().collect()),
215    };
216
217    let compressed = tthresh::compress(
218        &data_cow,
219        data.shape(),
220        match error_bound {
221            TthreshErrorBound::Eps { eps } => tthresh::ErrorBound::Eps(eps.0),
222            TthreshErrorBound::RMSE { rmse } => tthresh::ErrorBound::RMSE(rmse.0),
223            TthreshErrorBound::PSNR { psnr } => tthresh::ErrorBound::PSNR(psnr.0),
224        },
225        false,
226        false,
227    )
228    .map_err(|err| TthreshCodecError::TthreshEncodeFailed {
229        source: TthreshCodingError(err),
230    })?;
231
232    Ok(compressed)
233}
234
235/// Decompresses the `encoded` data into an array.
236///
237/// # Errors
238///
239/// Errors with
240/// - [`TthreshCodecError::TthreshDecodeFailed`] if decoding failed with an opaque error
241pub fn decompress(encoded: &[u8]) -> Result<AnyArray, TthreshCodecError> {
242    let (decompressed, shape) = tthresh::decompress(encoded, false, false).map_err(|err| {
243        TthreshCodecError::TthreshDecodeFailed {
244            source: TthreshCodingError(err),
245        }
246    })?;
247
248    let decoded = match decompressed {
249        tthresh::Buffer::U8(decompressed) => {
250            AnyArray::U8(Array::from_shape_vec(shape, decompressed)?.into_dyn())
251        }
252        tthresh::Buffer::U16(decompressed) => {
253            AnyArray::U16(Array::from_shape_vec(shape, decompressed)?.into_dyn())
254        }
255        tthresh::Buffer::I32(decompressed) => {
256            AnyArray::I32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
257        }
258        tthresh::Buffer::F32(decompressed) => {
259            AnyArray::F32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
260        }
261        tthresh::Buffer::F64(decompressed) => {
262            AnyArray::F64(Array::from_shape_vec(shape, decompressed)?.into_dyn())
263        }
264    };
265
266    Ok(decoded)
267}
268
269/// Array element types which can be compressed with tthresh.
270pub trait TthreshElement: Copy + tthresh::Element {}
271
272impl TthreshElement for u8 {}
273impl TthreshElement for u16 {}
274impl TthreshElement for i32 {}
275impl TthreshElement for f32 {}
276impl TthreshElement for f64 {}
277
278#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
279#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Hash)]
280/// Non-negative floating point number
281pub struct NonNegative<T: Float>(T);
282
283impl Serialize for NonNegative<f64> {
284    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
285        serializer.serialize_f64(self.0)
286    }
287}
288
289impl<'de> Deserialize<'de> for NonNegative<f64> {
290    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
291        let x = f64::deserialize(deserializer)?;
292
293        if x >= 0.0 {
294            Ok(Self(x))
295        } else {
296            Err(serde::de::Error::invalid_value(
297                serde::de::Unexpected::Float(x),
298                &"a non-negative value",
299            ))
300        }
301    }
302}
303
304impl JsonSchema for NonNegative<f64> {
305    fn schema_name() -> Cow<'static, str> {
306        Cow::Borrowed("NonNegativeF64")
307    }
308
309    fn schema_id() -> Cow<'static, str> {
310        Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
311    }
312
313    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
314        json_schema!({
315            "type": "number",
316            "minimum": 0.0
317        })
318    }
319}