1use std::borrow::Cow;
21
22use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
23use num_traits::Float;
24use numcodecs::{
25 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26 Codec, StaticCodec, StaticCodecConfig,
27};
28use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[cfg(test)]
33use ::serde_json as _;
34
35#[derive(Clone, Serialize, Deserialize, JsonSchema)]
36#[schemars(deny_unknown_fields)]
38pub struct TthreshCodec {
40 #[serde(flatten)]
42 pub error_bound: TthreshErrorBound,
43}
44
45#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
47#[serde(tag = "eb_mode")]
48#[serde(deny_unknown_fields)]
49pub enum TthreshErrorBound {
50 #[serde(rename = "eps")]
52 Eps {
53 #[serde(rename = "eb_eps")]
55 eps: NonNegative<f64>,
56 },
57 #[serde(rename = "rmse")]
59 RMSE {
60 #[serde(rename = "eb_rmse")]
62 rmse: NonNegative<f64>,
63 },
64 #[serde(rename = "psnr")]
66 PSNR {
67 #[serde(rename = "eb_psnr")]
69 psnr: NonNegative<f64>,
70 },
71}
72
73impl Codec for TthreshCodec {
74 type Error = TthreshCodecError;
75
76 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
77 match data {
78 AnyCowArray::U8(data) => Ok(AnyArray::U8(
79 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
80 )),
81 AnyCowArray::U16(data) => Ok(AnyArray::U8(
82 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
83 )),
84 AnyCowArray::I32(data) => Ok(AnyArray::U8(
85 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
86 )),
87 AnyCowArray::F32(data) => Ok(AnyArray::U8(
88 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
89 )),
90 AnyCowArray::F64(data) => Ok(AnyArray::U8(
91 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
92 )),
93 encoded => Err(TthreshCodecError::UnsupportedDtype(encoded.dtype())),
94 }
95 }
96
97 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
98 let AnyCowArray::U8(encoded) = encoded else {
99 return Err(TthreshCodecError::EncodedDataNotBytes {
100 dtype: encoded.dtype(),
101 });
102 };
103
104 if !matches!(encoded.shape(), [_]) {
105 return Err(TthreshCodecError::EncodedDataNotOneDimensional {
106 shape: encoded.shape().to_vec(),
107 });
108 }
109
110 decompress(&AnyCowArray::U8(encoded).as_bytes())
111 }
112
113 fn decode_into(
114 &self,
115 encoded: AnyArrayView,
116 mut decoded: AnyArrayViewMut,
117 ) -> Result<(), Self::Error> {
118 let decoded_in = self.decode(encoded.cow())?;
119
120 Ok(decoded.assign(&decoded_in)?)
121 }
122}
123
124impl StaticCodec for TthreshCodec {
125 const CODEC_ID: &'static str = "tthresh";
126
127 type Config<'de> = Self;
128
129 fn from_config(config: Self::Config<'_>) -> Self {
130 config
131 }
132
133 fn get_config(&self) -> StaticCodecConfig<Self> {
134 StaticCodecConfig::from(self)
135 }
136}
137
138#[derive(Debug, Error)]
139pub enum TthreshCodecError {
141 #[error("Tthresh does not support the dtype {0}")]
143 UnsupportedDtype(AnyArrayDType),
144 #[error("Tthresh failed to encode the data")]
146 TthreshEncodeFailed {
147 source: TthreshCodingError,
149 },
150 #[error(
153 "Tthresh can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
154 )]
155 EncodedDataNotBytes {
156 dtype: AnyArrayDType,
158 },
159 #[error("Tthresh can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
162 EncodedDataNotOneDimensional {
163 shape: Vec<usize>,
165 },
166 #[error("Tthresh failed to decode the data")]
168 TthreshDecodeFailed {
169 source: TthreshCodingError,
171 },
172 #[error("Tthresh decoded an invalid array shape header which does not fit the decoded data")]
175 DecodeInvalidShapeHeader {
176 #[from]
178 source: ShapeError,
179 },
180 #[error("Tthresh cannot decode into the provided array")]
182 MismatchedDecodeIntoArray {
183 #[from]
185 source: AnyArrayAssignError,
186 },
187}
188
189#[derive(Debug, Error)]
190#[error(transparent)]
191pub struct TthreshCodingError(tthresh::Error);
193
194#[expect(clippy::needless_pass_by_value)]
195pub fn compress<T: TthreshElement, S: Data<Elem = T>, D: Dimension>(
203 data: ArrayBase<S, D>,
204 error_bound: &TthreshErrorBound,
205) -> Result<Vec<u8>, TthreshCodecError> {
206 #[expect(clippy::option_if_let_else)]
207 let data_cow = if let Some(data) = data.as_slice() {
208 Cow::Borrowed(data)
209 } else {
210 Cow::Owned(data.iter().copied().collect())
211 };
212
213 let compressed = tthresh::compress(
214 &data_cow,
215 data.shape(),
216 match error_bound {
217 TthreshErrorBound::Eps { eps } => tthresh::ErrorBound::Eps(eps.0),
218 TthreshErrorBound::RMSE { rmse } => tthresh::ErrorBound::RMSE(rmse.0),
219 TthreshErrorBound::PSNR { psnr } => tthresh::ErrorBound::PSNR(psnr.0),
220 },
221 false,
222 false,
223 )
224 .map_err(|err| TthreshCodecError::TthreshEncodeFailed {
225 source: TthreshCodingError(err),
226 })?;
227
228 Ok(compressed)
229}
230
231pub fn decompress(encoded: &[u8]) -> Result<AnyArray, TthreshCodecError> {
238 let (decompressed, shape) = tthresh::decompress(encoded, false, false).map_err(|err| {
239 TthreshCodecError::TthreshDecodeFailed {
240 source: TthreshCodingError(err),
241 }
242 })?;
243
244 let decoded = match decompressed {
245 tthresh::Buffer::U8(decompressed) => {
246 AnyArray::U8(Array::from_shape_vec(shape, decompressed)?.into_dyn())
247 }
248 tthresh::Buffer::U16(decompressed) => {
249 AnyArray::U16(Array::from_shape_vec(shape, decompressed)?.into_dyn())
250 }
251 tthresh::Buffer::I32(decompressed) => {
252 AnyArray::I32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
253 }
254 tthresh::Buffer::F32(decompressed) => {
255 AnyArray::F32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
256 }
257 tthresh::Buffer::F64(decompressed) => {
258 AnyArray::F64(Array::from_shape_vec(shape, decompressed)?.into_dyn())
259 }
260 };
261
262 Ok(decoded)
263}
264
265pub trait TthreshElement: Copy + tthresh::Element {}
267
268impl TthreshElement for u8 {}
269impl TthreshElement for u16 {}
270impl TthreshElement for i32 {}
271impl TthreshElement for f32 {}
272impl TthreshElement for f64 {}
273
274#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Hash)]
276pub struct NonNegative<T: Float>(T);
278
279impl Serialize for NonNegative<f64> {
280 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
281 serializer.serialize_f64(self.0)
282 }
283}
284
285impl<'de> Deserialize<'de> for NonNegative<f64> {
286 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
287 let x = f64::deserialize(deserializer)?;
288
289 if x >= 0.0 {
290 Ok(Self(x))
291 } else {
292 Err(serde::de::Error::invalid_value(
293 serde::de::Unexpected::Float(x),
294 &"a non-negative value",
295 ))
296 }
297 }
298}
299
300impl JsonSchema for NonNegative<f64> {
301 fn schema_name() -> Cow<'static, str> {
302 Cow::Borrowed("NonNegativeF64")
303 }
304
305 fn schema_id() -> Cow<'static, str> {
306 Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
307 }
308
309 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
310 json_schema!({
311 "type": "number",
312 "minimum": 0.0
313 })
314 }
315}