numcodecs_tthresh/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-tthresh
10//! [crates.io]: https://crates.io/crates/numcodecs-tthresh
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-tthresh
13//! [docs.rs]: https://docs.rs/numcodecs-tthresh/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_tthresh
17//!
18//! tthresh codec implementation for the [`numcodecs`] API.
19
20use std::borrow::Cow;
21
22use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
23use num_traits::Float;
24use numcodecs::{
25    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
27};
28use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[cfg(test)]
33use ::serde_json as _;
34
35#[derive(Clone, Serialize, Deserialize, JsonSchema)]
36// serde cannot deny unknown fields because of the flatten
37#[schemars(deny_unknown_fields)]
38/// Codec providing compression using tthresh
39pub struct TthreshCodec {
40    /// tthresh error bound
41    #[serde(flatten)]
42    pub error_bound: TthreshErrorBound,
43    /// The codec's encoding format version. Do not provide this parameter explicitly.
44    #[serde(default, rename = "_version")]
45    pub version: StaticCodecVersion<0, 1, 0>,
46}
47
48/// tthresh error bound
49#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
50#[serde(tag = "eb_mode")]
51#[serde(deny_unknown_fields)]
52pub enum TthreshErrorBound {
53    /// Relative error bound
54    #[serde(rename = "eps")]
55    Eps {
56        /// Relative error bound
57        #[serde(rename = "eb_eps")]
58        eps: NonNegative<f64>,
59    },
60    /// Root mean square error bound
61    #[serde(rename = "rmse")]
62    RMSE {
63        /// Peak signal to noise ratio error bound
64        #[serde(rename = "eb_rmse")]
65        rmse: NonNegative<f64>,
66    },
67    /// Peak signal-to-noise ratio error bound
68    #[serde(rename = "psnr")]
69    PSNR {
70        /// Peak signal to noise ratio error bound
71        #[serde(rename = "eb_psnr")]
72        psnr: NonNegative<f64>,
73    },
74}
75
76impl Codec for TthreshCodec {
77    type Error = TthreshCodecError;
78
79    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
80        match data {
81            AnyCowArray::U8(data) => Ok(AnyArray::U8(
82                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
83            )),
84            AnyCowArray::U16(data) => Ok(AnyArray::U8(
85                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
86            )),
87            AnyCowArray::I32(data) => Ok(AnyArray::U8(
88                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
89            )),
90            AnyCowArray::F32(data) => Ok(AnyArray::U8(
91                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
92            )),
93            AnyCowArray::F64(data) => Ok(AnyArray::U8(
94                Array1::from(compress(data, &self.error_bound)?).into_dyn(),
95            )),
96            encoded => Err(TthreshCodecError::UnsupportedDtype(encoded.dtype())),
97        }
98    }
99
100    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
101        let AnyCowArray::U8(encoded) = encoded else {
102            return Err(TthreshCodecError::EncodedDataNotBytes {
103                dtype: encoded.dtype(),
104            });
105        };
106
107        if !matches!(encoded.shape(), [_]) {
108            return Err(TthreshCodecError::EncodedDataNotOneDimensional {
109                shape: encoded.shape().to_vec(),
110            });
111        }
112
113        decompress(&AnyCowArray::U8(encoded).as_bytes())
114    }
115
116    fn decode_into(
117        &self,
118        encoded: AnyArrayView,
119        mut decoded: AnyArrayViewMut,
120    ) -> Result<(), Self::Error> {
121        let decoded_in = self.decode(encoded.cow())?;
122
123        Ok(decoded.assign(&decoded_in)?)
124    }
125}
126
127impl StaticCodec for TthreshCodec {
128    const CODEC_ID: &'static str = "tthresh.rs";
129
130    type Config<'de> = Self;
131
132    fn from_config(config: Self::Config<'_>) -> Self {
133        config
134    }
135
136    fn get_config(&self) -> StaticCodecConfig<Self> {
137        StaticCodecConfig::from(self)
138    }
139}
140
141#[derive(Debug, Error)]
142/// Errors that may occur when applying the [`TthreshCodec`].
143pub enum TthreshCodecError {
144    /// [`TthreshCodec`] does not support the dtype
145    #[error("Tthresh does not support the dtype {0}")]
146    UnsupportedDtype(AnyArrayDType),
147    /// [`TthreshCodec`] failed to encode the data
148    #[error("Tthresh failed to encode the data")]
149    TthreshEncodeFailed {
150        /// Opaque source error
151        source: TthreshCodingError,
152    },
153    /// [`TthreshCodec`] can only decode one-dimensional byte arrays but received
154    /// an array of a different dtype
155    #[error(
156        "Tthresh can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
157    )]
158    EncodedDataNotBytes {
159        /// The unexpected dtype of the encoded array
160        dtype: AnyArrayDType,
161    },
162    /// [`TthreshCodec`] can only decode one-dimensional byte arrays but received
163    /// an array of a different shape
164    #[error("Tthresh can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
165    EncodedDataNotOneDimensional {
166        /// The unexpected shape of the encoded array
167        shape: Vec<usize>,
168    },
169    /// [`TthreshCodec`] failed to decode the data
170    #[error("Tthresh failed to decode the data")]
171    TthreshDecodeFailed {
172        /// Opaque source error
173        source: TthreshCodingError,
174    },
175    /// [`TthreshCodec`] decoded an invalid array shape header which does not fit
176    /// the decoded data
177    #[error("Tthresh decoded an invalid array shape header which does not fit the decoded data")]
178    DecodeInvalidShapeHeader {
179        /// Source error
180        #[from]
181        source: ShapeError,
182    },
183    /// [`TthreshCodec`] cannot decode into the provided array
184    #[error("Tthresh cannot decode into the provided array")]
185    MismatchedDecodeIntoArray {
186        /// The source of the error
187        #[from]
188        source: AnyArrayAssignError,
189    },
190}
191
192#[derive(Debug, Error)]
193#[error(transparent)]
194/// Opaque error for when encoding or decoding with tthresh fails
195pub struct TthreshCodingError(tthresh::Error);
196
197#[expect(clippy::needless_pass_by_value)]
198/// Compresses the input `data` array using tthresh with the provided
199/// `error_bound`.
200///
201/// # Errors
202///
203/// Errors with
204/// - [`TthreshCodecError::TthreshEncodeFailed`] if encoding failed with an opaque error
205pub fn compress<T: TthreshElement, S: Data<Elem = T>, D: Dimension>(
206    data: ArrayBase<S, D>,
207    error_bound: &TthreshErrorBound,
208) -> Result<Vec<u8>, TthreshCodecError> {
209    #[expect(clippy::option_if_let_else)]
210    let data_cow = if let Some(data) = data.as_slice() {
211        Cow::Borrowed(data)
212    } else {
213        Cow::Owned(data.iter().copied().collect())
214    };
215
216    let compressed = tthresh::compress(
217        &data_cow,
218        data.shape(),
219        match error_bound {
220            TthreshErrorBound::Eps { eps } => tthresh::ErrorBound::Eps(eps.0),
221            TthreshErrorBound::RMSE { rmse } => tthresh::ErrorBound::RMSE(rmse.0),
222            TthreshErrorBound::PSNR { psnr } => tthresh::ErrorBound::PSNR(psnr.0),
223        },
224        false,
225        false,
226    )
227    .map_err(|err| TthreshCodecError::TthreshEncodeFailed {
228        source: TthreshCodingError(err),
229    })?;
230
231    Ok(compressed)
232}
233
234/// Decompresses the `encoded` data into an array.
235///
236/// # Errors
237///
238/// Errors with
239/// - [`TthreshCodecError::TthreshDecodeFailed`] if decoding failed with an opaque error
240pub fn decompress(encoded: &[u8]) -> Result<AnyArray, TthreshCodecError> {
241    let (decompressed, shape) = tthresh::decompress(encoded, false, false).map_err(|err| {
242        TthreshCodecError::TthreshDecodeFailed {
243            source: TthreshCodingError(err),
244        }
245    })?;
246
247    let decoded = match decompressed {
248        tthresh::Buffer::U8(decompressed) => {
249            AnyArray::U8(Array::from_shape_vec(shape, decompressed)?.into_dyn())
250        }
251        tthresh::Buffer::U16(decompressed) => {
252            AnyArray::U16(Array::from_shape_vec(shape, decompressed)?.into_dyn())
253        }
254        tthresh::Buffer::I32(decompressed) => {
255            AnyArray::I32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
256        }
257        tthresh::Buffer::F32(decompressed) => {
258            AnyArray::F32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
259        }
260        tthresh::Buffer::F64(decompressed) => {
261            AnyArray::F64(Array::from_shape_vec(shape, decompressed)?.into_dyn())
262        }
263    };
264
265    Ok(decoded)
266}
267
268/// Array element types which can be compressed with tthresh.
269pub trait TthreshElement: Copy + tthresh::Element {}
270
271impl TthreshElement for u8 {}
272impl TthreshElement for u16 {}
273impl TthreshElement for i32 {}
274impl TthreshElement for f32 {}
275impl TthreshElement for f64 {}
276
277#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
278#[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Hash)]
279/// Non-negative floating point number
280pub struct NonNegative<T: Float>(T);
281
282impl Serialize for NonNegative<f64> {
283    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
284        serializer.serialize_f64(self.0)
285    }
286}
287
288impl<'de> Deserialize<'de> for NonNegative<f64> {
289    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
290        let x = f64::deserialize(deserializer)?;
291
292        if x >= 0.0 {
293            Ok(Self(x))
294        } else {
295            Err(serde::de::Error::invalid_value(
296                serde::de::Unexpected::Float(x),
297                &"a non-negative value",
298            ))
299        }
300    }
301}
302
303impl JsonSchema for NonNegative<f64> {
304    fn schema_name() -> Cow<'static, str> {
305        Cow::Borrowed("NonNegativeF64")
306    }
307
308    fn schema_id() -> Cow<'static, str> {
309        Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
310    }
311
312    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
313        json_schema!({
314            "type": "number",
315            "minimum": 0.0
316        })
317    }
318}