1use std::borrow::Cow;
21
22use ndarray::{Array, Array1, ArrayBase, Data, Dimension, ShapeError};
23use num_traits::Float;
24use numcodecs::{
25 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
26 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
27};
28use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
29use serde::{Deserialize, Deserializer, Serialize, Serializer};
30use thiserror::Error;
31
32#[cfg(test)]
33use ::serde_json as _;
34
35#[derive(Clone, Serialize, Deserialize, JsonSchema)]
36#[schemars(deny_unknown_fields)]
38pub struct TthreshCodec {
40 #[serde(flatten)]
42 pub error_bound: TthreshErrorBound,
43 #[serde(default, rename = "_version")]
45 pub version: StaticCodecVersion<0, 1, 0>,
46}
47
48#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
50#[serde(tag = "eb_mode")]
51#[serde(deny_unknown_fields)]
52pub enum TthreshErrorBound {
53 #[serde(rename = "eps")]
55 Eps {
56 #[serde(rename = "eb_eps")]
58 eps: NonNegative<f64>,
59 },
60 #[serde(rename = "rmse")]
62 RMSE {
63 #[serde(rename = "eb_rmse")]
65 rmse: NonNegative<f64>,
66 },
67 #[serde(rename = "psnr")]
69 PSNR {
70 #[serde(rename = "eb_psnr")]
72 psnr: NonNegative<f64>,
73 },
74}
75
76impl Codec for TthreshCodec {
77 type Error = TthreshCodecError;
78
79 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
80 match data {
81 AnyCowArray::U8(data) => Ok(AnyArray::U8(
82 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
83 )),
84 AnyCowArray::U16(data) => Ok(AnyArray::U8(
85 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
86 )),
87 AnyCowArray::I32(data) => Ok(AnyArray::U8(
88 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
89 )),
90 AnyCowArray::F32(data) => Ok(AnyArray::U8(
91 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
92 )),
93 AnyCowArray::F64(data) => Ok(AnyArray::U8(
94 Array1::from(compress(data, &self.error_bound)?).into_dyn(),
95 )),
96 encoded => Err(TthreshCodecError::UnsupportedDtype(encoded.dtype())),
97 }
98 }
99
100 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
101 let AnyCowArray::U8(encoded) = encoded else {
102 return Err(TthreshCodecError::EncodedDataNotBytes {
103 dtype: encoded.dtype(),
104 });
105 };
106
107 if !matches!(encoded.shape(), [_]) {
108 return Err(TthreshCodecError::EncodedDataNotOneDimensional {
109 shape: encoded.shape().to_vec(),
110 });
111 }
112
113 decompress(&AnyCowArray::U8(encoded).as_bytes())
114 }
115
116 fn decode_into(
117 &self,
118 encoded: AnyArrayView,
119 mut decoded: AnyArrayViewMut,
120 ) -> Result<(), Self::Error> {
121 let decoded_in = self.decode(encoded.cow())?;
122
123 Ok(decoded.assign(&decoded_in)?)
124 }
125}
126
127impl StaticCodec for TthreshCodec {
128 const CODEC_ID: &'static str = "tthresh.rs";
129
130 type Config<'de> = Self;
131
132 fn from_config(config: Self::Config<'_>) -> Self {
133 config
134 }
135
136 fn get_config(&self) -> StaticCodecConfig<Self> {
137 StaticCodecConfig::from(self)
138 }
139}
140
141#[derive(Debug, Error)]
142pub enum TthreshCodecError {
144 #[error("Tthresh does not support the dtype {0}")]
146 UnsupportedDtype(AnyArrayDType),
147 #[error("Tthresh failed to encode the data")]
149 TthreshEncodeFailed {
150 source: TthreshCodingError,
152 },
153 #[error(
156 "Tthresh can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
157 )]
158 EncodedDataNotBytes {
159 dtype: AnyArrayDType,
161 },
162 #[error("Tthresh can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
165 EncodedDataNotOneDimensional {
166 shape: Vec<usize>,
168 },
169 #[error("Tthresh failed to decode the data")]
171 TthreshDecodeFailed {
172 source: TthreshCodingError,
174 },
175 #[error("Tthresh decoded an invalid array shape header which does not fit the decoded data")]
178 DecodeInvalidShapeHeader {
179 #[from]
181 source: ShapeError,
182 },
183 #[error("Tthresh cannot decode into the provided array")]
185 MismatchedDecodeIntoArray {
186 #[from]
188 source: AnyArrayAssignError,
189 },
190}
191
192#[derive(Debug, Error)]
193#[error(transparent)]
194pub struct TthreshCodingError(tthresh::Error);
196
197#[expect(clippy::needless_pass_by_value)]
198pub fn compress<T: TthreshElement, S: Data<Elem = T>, D: Dimension>(
206 data: ArrayBase<S, D>,
207 error_bound: &TthreshErrorBound,
208) -> Result<Vec<u8>, TthreshCodecError> {
209 #[expect(clippy::option_if_let_else)]
210 let data_cow = if let Some(data) = data.as_slice() {
211 Cow::Borrowed(data)
212 } else {
213 Cow::Owned(data.iter().copied().collect())
214 };
215
216 let compressed = tthresh::compress(
217 &data_cow,
218 data.shape(),
219 match error_bound {
220 TthreshErrorBound::Eps { eps } => tthresh::ErrorBound::Eps(eps.0),
221 TthreshErrorBound::RMSE { rmse } => tthresh::ErrorBound::RMSE(rmse.0),
222 TthreshErrorBound::PSNR { psnr } => tthresh::ErrorBound::PSNR(psnr.0),
223 },
224 false,
225 false,
226 )
227 .map_err(|err| TthreshCodecError::TthreshEncodeFailed {
228 source: TthreshCodingError(err),
229 })?;
230
231 Ok(compressed)
232}
233
234pub fn decompress(encoded: &[u8]) -> Result<AnyArray, TthreshCodecError> {
241 let (decompressed, shape) = tthresh::decompress(encoded, false, false).map_err(|err| {
242 TthreshCodecError::TthreshDecodeFailed {
243 source: TthreshCodingError(err),
244 }
245 })?;
246
247 let decoded = match decompressed {
248 tthresh::Buffer::U8(decompressed) => {
249 AnyArray::U8(Array::from_shape_vec(shape, decompressed)?.into_dyn())
250 }
251 tthresh::Buffer::U16(decompressed) => {
252 AnyArray::U16(Array::from_shape_vec(shape, decompressed)?.into_dyn())
253 }
254 tthresh::Buffer::I32(decompressed) => {
255 AnyArray::I32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
256 }
257 tthresh::Buffer::F32(decompressed) => {
258 AnyArray::F32(Array::from_shape_vec(shape, decompressed)?.into_dyn())
259 }
260 tthresh::Buffer::F64(decompressed) => {
261 AnyArray::F64(Array::from_shape_vec(shape, decompressed)?.into_dyn())
262 }
263 };
264
265 Ok(decoded)
266}
267
268pub trait TthreshElement: Copy + tthresh::Element {}
270
271impl TthreshElement for u8 {}
272impl TthreshElement for u16 {}
273impl TthreshElement for i32 {}
274impl TthreshElement for f32 {}
275impl TthreshElement for f64 {}
276
277#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, Hash)]
279pub struct NonNegative<T: Float>(T);
281
282impl Serialize for NonNegative<f64> {
283 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
284 serializer.serialize_f64(self.0)
285 }
286}
287
288impl<'de> Deserialize<'de> for NonNegative<f64> {
289 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
290 let x = f64::deserialize(deserializer)?;
291
292 if x >= 0.0 {
293 Ok(Self(x))
294 } else {
295 Err(serde::de::Error::invalid_value(
296 serde::de::Unexpected::Float(x),
297 &"a non-negative value",
298 ))
299 }
300 }
301}
302
303impl JsonSchema for NonNegative<f64> {
304 fn schema_name() -> Cow<'static, str> {
305 Cow::Borrowed("NonNegativeF64")
306 }
307
308 fn schema_id() -> Cow<'static, str> {
309 Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
310 }
311
312 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
313 json_schema!({
314 "type": "number",
315 "minimum": 0.0
316 })
317 }
318}