1use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Zip};
21use num_traits::{Float, Signed};
22use numcodecs::{
23 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
24 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
25};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29
30#[derive(Clone, Serialize, Deserialize, JsonSchema)]
31#[serde(deny_unknown_fields)]
32pub struct AsinhCodec {
47 pub linear_width: f64,
50 #[serde(default, rename = "_version")]
52 pub version: StaticCodecVersion<1, 0, 0>,
53}
54
55impl Codec for AsinhCodec {
56 type Error = AsinhCodecError;
57
58 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
59 match data {
60 #[expect(clippy::cast_possible_truncation)]
61 AnyCowArray::F32(data) => Ok(AnyArray::F32(asinh(data, self.linear_width as f32)?)),
62 AnyCowArray::F64(data) => Ok(AnyArray::F64(asinh(data, self.linear_width)?)),
63 encoded => Err(AsinhCodecError::UnsupportedDtype(encoded.dtype())),
64 }
65 }
66
67 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
68 match encoded {
69 #[expect(clippy::cast_possible_truncation)]
70 AnyCowArray::F32(encoded) => {
71 Ok(AnyArray::F32(sinh(encoded, self.linear_width as f32)?))
72 }
73 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(sinh(encoded, self.linear_width)?)),
74 encoded => Err(AsinhCodecError::UnsupportedDtype(encoded.dtype())),
75 }
76 }
77
78 fn decode_into(
79 &self,
80 encoded: AnyArrayView,
81 decoded: AnyArrayViewMut,
82 ) -> Result<(), Self::Error> {
83 match (encoded, decoded) {
84 #[expect(clippy::cast_possible_truncation)]
85 (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
86 sinh_into(encoded, decoded, self.linear_width as f32)
87 }
88 (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
89 sinh_into(encoded, decoded, self.linear_width)
90 }
91 (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => {
92 Err(AsinhCodecError::MismatchedDecodeIntoArray {
93 source: AnyArrayAssignError::DTypeMismatch {
94 src: encoded.dtype(),
95 dst: decoded.dtype(),
96 },
97 })
98 }
99 (encoded, _decoded) => Err(AsinhCodecError::UnsupportedDtype(encoded.dtype())),
100 }
101 }
102}
103
104impl StaticCodec for AsinhCodec {
105 const CODEC_ID: &'static str = "asinh.rs";
106
107 type Config<'de> = Self;
108
109 fn from_config(config: Self::Config<'_>) -> Self {
110 config
111 }
112
113 fn get_config(&self) -> StaticCodecConfig<Self> {
114 StaticCodecConfig::from(self)
115 }
116}
117
118#[derive(Debug, Error)]
119pub enum AsinhCodecError {
121 #[error("Asinh does not support the dtype {0}")]
123 UnsupportedDtype(AnyArrayDType),
124 #[error("Asinh does not support non-finite (infinite or NaN) floating point data")]
127 NonFiniteData,
128 #[error("Asinh cannot decode into the provided array")]
130 MismatchedDecodeIntoArray {
131 #[from]
133 source: AnyArrayAssignError,
134 },
135}
136
137pub fn asinh<T: Float + Signed, S: Data<Elem = T>, D: Dimension>(
146 data: ArrayBase<S, D>,
147 linear_width: T,
148) -> Result<Array<T, D>, AsinhCodecError> {
149 if !Zip::from(&data).all(|x| x.is_finite()) {
150 return Err(AsinhCodecError::NonFiniteData);
151 }
152
153 let mut data = data.into_owned();
154 data.mapv_inplace(|x| (x / linear_width).asinh() * linear_width);
155
156 Ok(data)
157}
158
159pub fn sinh<T: Float, S: Data<Elem = T>, D: Dimension>(
168 data: ArrayBase<S, D>,
169 linear_width: T,
170) -> Result<Array<T, D>, AsinhCodecError> {
171 if !Zip::from(&data).all(|x| x.is_finite()) {
172 return Err(AsinhCodecError::NonFiniteData);
173 }
174
175 let mut data = data.into_owned();
176 data.mapv_inplace(|x| (x / linear_width).sinh() * linear_width);
177
178 Ok(data)
179}
180
181#[expect(clippy::needless_pass_by_value)]
182pub fn sinh_into<T: Float, D: Dimension>(
193 data: ArrayView<T, D>,
194 mut out: ArrayViewMut<T, D>,
195 linear_width: T,
196) -> Result<(), AsinhCodecError> {
197 if data.shape() != out.shape() {
198 return Err(AsinhCodecError::MismatchedDecodeIntoArray {
199 source: AnyArrayAssignError::ShapeMismatch {
200 src: data.shape().to_vec(),
201 dst: out.shape().to_vec(),
202 },
203 });
204 }
205
206 if !Zip::from(&data).all(|x| x.is_finite()) {
207 return Err(AsinhCodecError::NonFiniteData);
208 }
209
210 for (d, o) in data.iter().zip(out.iter_mut()) {
212 *o = ((*d) / linear_width).sinh() * linear_width;
213 }
214
215 Ok(())
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn roundtrip() -> Result<(), AsinhCodecError> {
224 let data = (-1000..1000).map(f64::from).collect::<Vec<_>>();
225 let data = Array::from_vec(data);
226
227 let encoded = asinh(data.view(), 1.0)?;
228
229 for (r, e) in data.iter().zip(encoded.iter()) {
230 assert_eq!((*r).asinh().to_bits(), (*e).to_bits());
231 }
232
233 let decoded = sinh(encoded, 1.0)?;
234
235 for (r, d) in data.iter().zip(decoded.iter()) {
236 assert!(((*r) - (*d)).abs() < 1e-12);
237 }
238
239 Ok(())
240 }
241
242 #[test]
243 fn roundtrip_widths() -> Result<(), AsinhCodecError> {
244 let data = (-1000..1000).map(f64::from).collect::<Vec<_>>();
245 let data = Array::from_vec(data);
246
247 for linear_width in [-100.0, -10.0, -1.0, -0.1, 0.1, 1.0, 10.0, 100.0] {
248 let encoded = asinh(data.view(), linear_width)?;
249 let decoded = sinh(encoded, linear_width)?;
250
251 for (r, d) in data.iter().zip(decoded.iter()) {
252 assert!(((*r) - (*d)).abs() < 1e-12);
253 }
254 }
255
256 Ok(())
257 }
258}