1use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Zip};
21use num_traits::{Float, Signed};
22use numcodecs::{
23 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
24 Codec, StaticCodec, StaticCodecConfig,
25};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29
30#[derive(Clone, Serialize, Deserialize, JsonSchema)]
31#[serde(deny_unknown_fields)]
32pub struct AsinhCodec {
47 linear_width: f64,
50}
51
52impl Codec for AsinhCodec {
53 type Error = AsinhCodecError;
54
55 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
56 match data {
57 #[expect(clippy::cast_possible_truncation)]
58 AnyCowArray::F32(data) => Ok(AnyArray::F32(asinh(data, self.linear_width as f32)?)),
59 AnyCowArray::F64(data) => Ok(AnyArray::F64(asinh(data, self.linear_width)?)),
60 encoded => Err(AsinhCodecError::UnsupportedDtype(encoded.dtype())),
61 }
62 }
63
64 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
65 match encoded {
66 #[expect(clippy::cast_possible_truncation)]
67 AnyCowArray::F32(encoded) => {
68 Ok(AnyArray::F32(sinh(encoded, self.linear_width as f32)?))
69 }
70 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(sinh(encoded, self.linear_width)?)),
71 encoded => Err(AsinhCodecError::UnsupportedDtype(encoded.dtype())),
72 }
73 }
74
75 fn decode_into(
76 &self,
77 encoded: AnyArrayView,
78 decoded: AnyArrayViewMut,
79 ) -> Result<(), Self::Error> {
80 match (encoded, decoded) {
81 #[expect(clippy::cast_possible_truncation)]
82 (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
83 sinh_into(encoded, decoded, self.linear_width as f32)
84 }
85 (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
86 sinh_into(encoded, decoded, self.linear_width)
87 }
88 (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => {
89 Err(AsinhCodecError::MismatchedDecodeIntoArray {
90 source: AnyArrayAssignError::DTypeMismatch {
91 src: encoded.dtype(),
92 dst: decoded.dtype(),
93 },
94 })
95 }
96 (encoded, _decoded) => Err(AsinhCodecError::UnsupportedDtype(encoded.dtype())),
97 }
98 }
99}
100
101impl StaticCodec for AsinhCodec {
102 const CODEC_ID: &'static str = "asinh";
103
104 type Config<'de> = Self;
105
106 fn from_config(config: Self::Config<'_>) -> Self {
107 config
108 }
109
110 fn get_config(&self) -> StaticCodecConfig<Self> {
111 StaticCodecConfig::from(self)
112 }
113}
114
115#[derive(Debug, Error)]
116pub enum AsinhCodecError {
118 #[error("Asinh does not support the dtype {0}")]
120 UnsupportedDtype(AnyArrayDType),
121 #[error("Asinh does not support non-finite (infinite or NaN) floating point data")]
124 NonFiniteData,
125 #[error("Asinh cannot decode into the provided array")]
127 MismatchedDecodeIntoArray {
128 #[from]
130 source: AnyArrayAssignError,
131 },
132}
133
134pub fn asinh<T: Float + Signed, S: Data<Elem = T>, D: Dimension>(
143 data: ArrayBase<S, D>,
144 linear_width: T,
145) -> Result<Array<T, D>, AsinhCodecError> {
146 if !Zip::from(&data).all(|x| x.is_finite()) {
147 return Err(AsinhCodecError::NonFiniteData);
148 }
149
150 let mut data = data.into_owned();
151 data.mapv_inplace(|x| (x / linear_width).asinh() * linear_width);
152
153 Ok(data)
154}
155
156pub fn sinh<T: Float, S: Data<Elem = T>, D: Dimension>(
165 data: ArrayBase<S, D>,
166 linear_width: T,
167) -> Result<Array<T, D>, AsinhCodecError> {
168 if !Zip::from(&data).all(|x| x.is_finite()) {
169 return Err(AsinhCodecError::NonFiniteData);
170 }
171
172 let mut data = data.into_owned();
173 data.mapv_inplace(|x| (x / linear_width).sinh() * linear_width);
174
175 Ok(data)
176}
177
178#[expect(clippy::needless_pass_by_value)]
179pub fn sinh_into<T: Float, D: Dimension>(
190 data: ArrayView<T, D>,
191 mut out: ArrayViewMut<T, D>,
192 linear_width: T,
193) -> Result<(), AsinhCodecError> {
194 if data.shape() != out.shape() {
195 return Err(AsinhCodecError::MismatchedDecodeIntoArray {
196 source: AnyArrayAssignError::ShapeMismatch {
197 src: data.shape().to_vec(),
198 dst: out.shape().to_vec(),
199 },
200 });
201 }
202
203 if !Zip::from(&data).all(|x| x.is_finite()) {
204 return Err(AsinhCodecError::NonFiniteData);
205 }
206
207 for (d, o) in data.iter().zip(out.iter_mut()) {
209 *o = ((*d) / linear_width).sinh() * linear_width;
210 }
211
212 Ok(())
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn roundtrip() -> Result<(), AsinhCodecError> {
221 let data = (-1000..1000).map(f64::from).collect::<Vec<_>>();
222 let data = Array::from_vec(data);
223
224 let encoded = asinh(data.view(), 1.0)?;
225
226 for (r, e) in data.iter().zip(encoded.iter()) {
227 assert_eq!((*r).asinh().to_bits(), (*e).to_bits());
228 }
229
230 let decoded = sinh(encoded, 1.0)?;
231
232 for (r, d) in data.iter().zip(decoded.iter()) {
233 assert!(((*r) - (*d)).abs() < 1e-12);
234 }
235
236 Ok(())
237 }
238
239 #[test]
240 fn roundtrip_widths() -> Result<(), AsinhCodecError> {
241 let data = (-1000..1000).map(f64::from).collect::<Vec<_>>();
242 let data = Array::from_vec(data);
243
244 for linear_width in [-100.0, -10.0, -1.0, -0.1, 0.1, 1.0, 10.0, 100.0] {
245 let encoded = asinh(data.view(), linear_width)?;
246 let decoded = sinh(encoded, linear_width)?;
247
248 for (r, d) in data.iter().zip(decoded.iter()) {
249 assert!(((*r) - (*d)).abs() < 1e-12);
250 }
251 }
252
253 Ok(())
254 }
255}