1use std::borrow::Cow;
21
22use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
23use num_traits::Float;
24use numcodecs::{
25 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
27};
28use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[cfg(test)]
33use ::serde_json as _;
34
35#[derive(Clone, Serialize, Deserialize, JsonSchema)]
36#[schemars(deny_unknown_fields)]
38pub struct TthreshCodec {
40 #[serde(flatten)]
42 pub error_bound: TthreshErrorBound,
43 #[serde(default, rename = "_version")]
45 pub version: StaticCodecVersion<0, 1, 0>,
46}
47
48#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
50#[serde(tag = "eb_mode")]
51#[serde(deny_unknown_fields)]
52pub enum TthreshErrorBound {
53 #[serde(rename = "eps")]
55 Eps {
56 #[serde(rename = "eb_eps")]
58 eps: NonNegative<f64>,
59 },
60 #[serde(rename = "rmse")]
62 RMSE {
63 #[serde(rename = "eb_rmse")]
65 rmse: NonNegative<f64>,
66 },
67 #[serde(rename = "psnr")]
69 PSNR {
70 #[serde(rename = "eb_psnr")]
72 psnr: NonNegative<f64>,
73 },
74}
75
76impl Codec for TthreshCodec {
77 type Error = TthreshCodecError;
78
79 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
80 match data {
81 AnyCowArray::U8(data) => Ok(AnyArray::U8(
82 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
83 )),
84 AnyCowArray::U16(data) => Ok(AnyArray::U8(
85 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
86 )),
87 AnyCowArray::I32(data) => Ok(AnyArray::U8(
88 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
89 )),
90 AnyCowArray::F32(data) => Ok(AnyArray::U8(
91 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
92 )),
93 AnyCowArray::F64(data) => Ok(AnyArray::U8(
94 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
95 )),
96 encoded => Err(TthreshCodecError::UnsupportedDtype(encoded.dtype())),
97 }
98 }
99
100 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
101 let AnyCowArray::U8(encoded) = encoded else {
102 return Err(TthreshCodecError::EncodedDataNotBytes {
103 dtype: encoded.dtype(),
104 });
105 };
106
107 if !matches!(encoded.shape(), [_]) {
108 return Err(TthreshCodecError::EncodedDataNotOneDimensional {
109 shape: encoded.shape().to_vec(),
110 });
111 }
112
113 decompress(&AnyCowArray::U8(encoded).as_bytes())
114 }
115
116 fn decode_into(
117 &self,
118 encoded: AnyArrayView,
119 mut decoded: AnyArrayViewMut,
120 ) -> Result<(), Self::Error> {
121 let decoded_in = self.decode(encoded.cow())?;
122
123 Ok(decoded.assign(&decoded_in)?)
124 }
125}
126
127impl StaticCodec for TthreshCodec {
128 const CODEC_ID: &'static str = "tthresh.rs";
129
130 type Config<'de> = Self;
131
132 fn from_config(config: Self::Config<'_>) -> Self {
133 config
134 }
135
136 fn get_config(&self) -> StaticCodecConfig<Self> {
137 StaticCodecConfig::from(self)
138 }
139}
140
141#[derive(Debug, Error)]
142pub enum TthreshCodecError {
144 #[error("Tthresh does not support the dtype {0}")]
146 UnsupportedDtype(AnyArrayDType),
147 #[error("Tthresh failed to encode the data")]
149 TthreshEncodeFailed {
150 source: TthreshCodingError,
152 },
153 #[error(
156 "Tthresh can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
157 )]
158 EncodedDataNotBytes {
159 dtype: AnyArrayDType,
161 },
162 #[error(
165 "Tthresh can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
166 )]
167 EncodedDataNotOneDimensional {
168 shape: Vec<usize>,
170 },
171 #[error("Tthresh failed to decode the data")]
173 TthreshDecodeFailed {
174 source: TthreshCodingError,
176 },
177 #[error("Tthresh decoded an invalid array shape header which does not fit the decoded data")]
180 DecodeInvalidShapeHeader {
181 #[from]
183 source: ShapeError,
184 },
185 #[error("Tthresh cannot decode into the provided array")]
187 MismatchedDecodeIntoArray {
188 #[from]
190 source: AnyArrayAssignError,
191 },
192}
193
194#[derive(Debug, Error)]
195#[error(transparent)]
196pub struct TthreshCodingError(tthresh::Error);
198
199#[expect(clippy::needless_pass_by_value)]
200pub fn compress<T: TthreshElement, S: Data<Elem = T>, D: Dimension>(
208 data: ArrayBase<S, D>,
209 error_bound: &TthreshErrorBound,
210) -> Result<Vec<u8>, TthreshCodecError> {
211 #[expect(clippy::option_if_let_else)]
212 let data_cow = match data.as_slice() {
213 Some(data) => Cow::Borrowed(data),
214 None => Cow::Owned(data.iter().copied().collect()),
215 };
216
217 let compressed = tthresh::compress(
218 &data_cow,
219 data.shape(),
220 match error_bound {
221 TthreshErrorBound::Eps { eps } => tthresh::ErrorBound::Eps(eps.0),
222 TthreshErrorBound::RMSE { rmse } => tthresh::ErrorBound::RMSE(rmse.0),
223 TthreshErrorBound::PSNR { psnr } => tthresh::ErrorBound::PSNR(psnr.0),
224 },
225 false,
226 false,
227 )
228 .map_err(|err| TthreshCodecError::TthreshEncodeFailed {
229 source: TthreshCodingError(err),
230 })?;
231
232 Ok(compressed)
233}
234
235pub fn decompress(encoded: &[u8]) -> Result<AnyArray, TthreshCodecError> {
242 let (decompressed, shape) = tthresh::decompress(encoded, false, false).map_err(|err| {
243 TthreshCodecError::TthreshDecodeFailed {
244 source: TthreshCodingError(err),
245 }
246 })?;
247
248 let decoded = match decompressed {
249 tthresh::Buffer::U8(decompressed) => {
250 AnyArray::U8(Array::from_shape_vec(shape, decompressed)?.into_dyn())
251 }
252 tthresh::Buffer::U16(decompressed) => {
253 AnyArray::U16(Array::from_shape_vec(shape, decompressed)?.into_dyn())
254 }
255 tthresh::Buffer::I32(decompressed) => {
256 AnyArray::I32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
257 }
258 tthresh::Buffer::F32(decompressed) => {
259 AnyArray::F32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
260 }
261 tthresh::Buffer::F64(decompressed) => {
262 AnyArray::F64(Array::from_shape_vec(shape, decompressed)?.into_dyn())
263 }
264 };
265
266 Ok(decoded)
267}
268
269pub trait TthreshElement: Copy + tthresh::Element {}
271
272impl TthreshElement for u8 {}
273impl TthreshElement for u16 {}
274impl TthreshElement for i32 {}
275impl TthreshElement for f32 {}
276impl TthreshElement for f64 {}
277
278#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Hash)]
280pub struct NonNegative<T: Float>(T);
282
283impl Serialize for NonNegative<f64> {
284 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
285 serializer.serialize_f64(self.0)
286 }
287}
288
289impl<'de> Deserialize<'de> for NonNegative<f64> {
290 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
291 let x = f64::deserialize(deserializer)?;
292
293 if x >= 0.0 {
294 Ok(Self(x))
295 } else {
296 Err(serde::de::Error::invalid_value(
297 serde::de::Unexpected::Float(x),
298 &"a non-negative value",
299 ))
300 }
301 }
302}
303
304impl JsonSchema for NonNegative<f64> {
305 fn schema_name() -> Cow<'static, str> {
306 Cow::Borrowed("NonNegativeF64")
307 }
308
309 fn schema_id() -> Cow<'static, str> {
310 Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
311 }
312
313 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
314 json_schema!({
315 "type": "number",
316 "minimum": 0.0
317 })
318 }
319}