1use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Zip};
21use num_traits::{Float, Signed};
22use numcodecs::{
23 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
24 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
25};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29
30#[derive(Clone, Serialize, Deserialize, JsonSchema)]
31#[serde(deny_unknown_fields)]
32pub struct LogCodec {
37 #[serde(default, rename = "_version")]
39 pub version: StaticCodecVersion<1, 0, 0>,
40}
41
42impl Codec for LogCodec {
43 type Error = LogCodecError;
44
45 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
46 match data {
47 AnyCowArray::F32(data) => Ok(AnyArray::F32(ln(data)?)),
48 AnyCowArray::F64(data) => Ok(AnyArray::F64(ln(data)?)),
49 encoded => Err(LogCodecError::UnsupportedDtype(encoded.dtype())),
50 }
51 }
52
53 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
54 match encoded {
55 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(exp(encoded)?)),
56 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(exp(encoded)?)),
57 encoded => Err(LogCodecError::UnsupportedDtype(encoded.dtype())),
58 }
59 }
60
61 fn decode_into(
62 &self,
63 encoded: AnyArrayView,
64 decoded: AnyArrayViewMut,
65 ) -> Result<(), Self::Error> {
66 match (encoded, decoded) {
67 (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
68 exp_into(encoded, decoded)
69 }
70 (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
71 exp_into(encoded, decoded)
72 }
73 (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => {
74 Err(LogCodecError::MismatchedDecodeIntoArray {
75 source: AnyArrayAssignError::DTypeMismatch {
76 src: encoded.dtype(),
77 dst: decoded.dtype(),
78 },
79 })
80 }
81 (encoded, _decoded) => Err(LogCodecError::UnsupportedDtype(encoded.dtype())),
82 }
83 }
84}
85
86impl StaticCodec for LogCodec {
87 const CODEC_ID: &'static str = "log.rs";
88
89 type Config<'de> = Self;
90
91 fn from_config(config: Self::Config<'_>) -> Self {
92 config
93 }
94
95 fn get_config(&self) -> StaticCodecConfig<Self> {
96 StaticCodecConfig::from(self)
97 }
98}
99
100#[derive(Debug, Error)]
101pub enum LogCodecError {
103 #[error("Log does not support the dtype {0}")]
105 UnsupportedDtype(AnyArrayDType),
106 #[error("Log does not support non-positive (negative or zero) floating point data")]
109 NonPositiveData,
110 #[error("Log does not support non-finite (infinite or NaN) floating point data")]
113 NonFiniteData,
114 #[error("Log cannot decode into the provided array")]
116 MismatchedDecodeIntoArray {
117 #[from]
119 source: AnyArrayAssignError,
120 },
121}
122
123pub fn ln<T: Float + Signed, S: Data<Elem = T>, D: Dimension>(
133 data: ArrayBase<S, D>,
134) -> Result<Array<T, D>, LogCodecError> {
135 if !Zip::from(&data).all(T::is_positive) {
136 return Err(LogCodecError::NonPositiveData);
137 }
138
139 if !Zip::from(&data).all(|x| x.is_finite()) {
140 return Err(LogCodecError::NonFiniteData);
141 }
142
143 let mut data = data.into_owned();
144 data.mapv_inplace(T::ln);
145
146 Ok(data)
147}
148
149pub fn exp<T: Float, S: Data<Elem = T>, D: Dimension>(
157 data: ArrayBase<S, D>,
158) -> Result<Array<T, D>, LogCodecError> {
159 if !Zip::from(&data).all(|x| x.is_finite()) {
160 return Err(LogCodecError::NonFiniteData);
161 }
162
163 let mut data = data.into_owned();
164 data.mapv_inplace(T::exp);
165
166 Ok(data)
167}
168
169#[expect(clippy::needless_pass_by_value)]
170pub fn exp_into<T: Float, D: Dimension>(
181 data: ArrayView<T, D>,
182 mut out: ArrayViewMut<T, D>,
183) -> Result<(), LogCodecError> {
184 if data.shape() != out.shape() {
185 return Err(LogCodecError::MismatchedDecodeIntoArray {
186 source: AnyArrayAssignError::ShapeMismatch {
187 src: data.shape().to_vec(),
188 dst: out.shape().to_vec(),
189 },
190 });
191 }
192
193 if !Zip::from(&data).all(|x| x.is_finite()) {
194 return Err(LogCodecError::NonFiniteData);
195 }
196
197 for (d, o) in data.iter().zip(out.iter_mut()) {
199 *o = T::exp(*d);
200 }
201
202 Ok(())
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn roundtrip() -> Result<(), LogCodecError> {
211 let data = (1..1000).map(f64::from).collect::<Vec<_>>();
212 let data = Array::from_vec(data);
213
214 let encoded = ln(data.view())?;
215
216 for (r, e) in data.iter().zip(encoded.iter()) {
217 assert_eq!((*r).ln().to_bits(), (*e).to_bits());
218 }
219
220 let decoded = exp(encoded)?;
221
222 for (r, d) in data.iter().zip(decoded.iter()) {
223 assert!(((*r) - (*d)).abs() < 1e-12);
224 }
225
226 Ok(())
227 }
228}