1use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Zip};
21use num_traits::{Float, Signed};
22use numcodecs::{
23 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
24 Codec, StaticCodec, StaticCodecConfig,
25};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29
30#[derive(Clone, Serialize, Deserialize, JsonSchema)]
31#[serde(deny_unknown_fields)]
32pub struct LogCodec {
37 }
39
40impl Codec for LogCodec {
41 type Error = LogCodecError;
42
43 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
44 match data {
45 AnyCowArray::F32(data) => Ok(AnyArray::F32(ln(data)?)),
46 AnyCowArray::F64(data) => Ok(AnyArray::F64(ln(data)?)),
47 encoded => Err(LogCodecError::UnsupportedDtype(encoded.dtype())),
48 }
49 }
50
51 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
52 match encoded {
53 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(exp(encoded)?)),
54 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(exp(encoded)?)),
55 encoded => Err(LogCodecError::UnsupportedDtype(encoded.dtype())),
56 }
57 }
58
59 fn decode_into(
60 &self,
61 encoded: AnyArrayView,
62 decoded: AnyArrayViewMut,
63 ) -> Result<(), Self::Error> {
64 match (encoded, decoded) {
65 (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
66 exp_into(encoded, decoded)
67 }
68 (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
69 exp_into(encoded, decoded)
70 }
71 (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => {
72 Err(LogCodecError::MismatchedDecodeIntoArray {
73 source: AnyArrayAssignError::DTypeMismatch {
74 src: encoded.dtype(),
75 dst: decoded.dtype(),
76 },
77 })
78 }
79 (encoded, _decoded) => Err(LogCodecError::UnsupportedDtype(encoded.dtype())),
80 }
81 }
82}
83
84impl StaticCodec for LogCodec {
85 const CODEC_ID: &'static str = "log";
86
87 type Config<'de> = Self;
88
89 fn from_config(config: Self::Config<'_>) -> Self {
90 config
91 }
92
93 fn get_config(&self) -> StaticCodecConfig<Self> {
94 StaticCodecConfig::from(self)
95 }
96}
97
98#[derive(Debug, Error)]
99pub enum LogCodecError {
101 #[error("Log does not support the dtype {0}")]
103 UnsupportedDtype(AnyArrayDType),
104 #[error("Log does not support non-positive (negative or zero) floating point data")]
107 NonPositiveData,
108 #[error("Log does not support non-finite (infinite or NaN) floating point data")]
111 NonFiniteData,
112 #[error("Log cannot decode into the provided array")]
114 MismatchedDecodeIntoArray {
115 #[from]
117 source: AnyArrayAssignError,
118 },
119}
120
121pub fn ln<T: Float + Signed, S: Data<Elem = T>, D: Dimension>(
131 data: ArrayBase<S, D>,
132) -> Result<Array<T, D>, LogCodecError> {
133 if !Zip::from(&data).all(T::is_positive) {
134 return Err(LogCodecError::NonPositiveData);
135 }
136
137 if !Zip::from(&data).all(|x| x.is_finite()) {
138 return Err(LogCodecError::NonFiniteData);
139 }
140
141 let mut data = data.into_owned();
142 data.mapv_inplace(T::ln);
143
144 Ok(data)
145}
146
147pub fn exp<T: Float, S: Data<Elem = T>, D: Dimension>(
155 data: ArrayBase<S, D>,
156) -> Result<Array<T, D>, LogCodecError> {
157 if !Zip::from(&data).all(|x| x.is_finite()) {
158 return Err(LogCodecError::NonFiniteData);
159 }
160
161 let mut data = data.into_owned();
162 data.mapv_inplace(T::exp);
163
164 Ok(data)
165}
166
167#[expect(clippy::needless_pass_by_value)]
168pub fn exp_into<T: Float, D: Dimension>(
179 data: ArrayView<T, D>,
180 mut out: ArrayViewMut<T, D>,
181) -> Result<(), LogCodecError> {
182 if data.shape() != out.shape() {
183 return Err(LogCodecError::MismatchedDecodeIntoArray {
184 source: AnyArrayAssignError::ShapeMismatch {
185 src: data.shape().to_vec(),
186 dst: out.shape().to_vec(),
187 },
188 });
189 }
190
191 if !Zip::from(&data).all(|x| x.is_finite()) {
192 return Err(LogCodecError::NonFiniteData);
193 }
194
195 for (d, o) in data.iter().zip(out.iter_mut()) {
197 *o = T::exp(*d);
198 }
199
200 Ok(())
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn roundtrip() -> Result<(), LogCodecError> {
209 let data = (1..1000).map(f64::from).collect::<Vec<_>>();
210 let data = Array::from_vec(data);
211
212 let encoded = ln(data.view())?;
213
214 for (r, e) in data.iter().zip(encoded.iter()) {
215 assert_eq!((*r).ln().to_bits(), (*e).to_bits());
216 }
217
218 let decoded = exp(encoded)?;
219
220 for (r, d) in data.iter().zip(decoded.iter()) {
221 assert!(((*r) - (*d)).abs() < 1e-12);
222 }
223
224 Ok(())
225 }
226}