numcodecs_log/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-log
10//! [crates.io]: https://crates.io/crates/numcodecs-log
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-log
13//! [docs.rs]: https://docs.rs/numcodecs-log/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_log
17//!
18//! `$\ln(x)$` codec implementation for the [`numcodecs`] API.
19
20use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Zip};
21use num_traits::{Float, Signed};
22use numcodecs::{
23    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
24    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
25};
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29
30#[derive(Clone, Serialize, Deserialize, JsonSchema)]
31#[serde(deny_unknown_fields)]
32/// Log codec which calculates `$c = \ln(x)$` on encoding and `$d = {e}^{c}$`
33/// on decoding.
34///
35/// The codec only supports finite positive floating point numbers.
36pub struct LogCodec {
37    /// The codec's encoding format version. Do not provide this parameter explicitly.
38    #[serde(default, rename = "_version")]
39    pub version: StaticCodecVersion<1, 0, 0>,
40}
41
42impl Codec for LogCodec {
43    type Error = LogCodecError;
44
45    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
46        match data {
47            AnyCowArray::F32(data) => Ok(AnyArray::F32(ln(data)?)),
48            AnyCowArray::F64(data) => Ok(AnyArray::F64(ln(data)?)),
49            encoded => Err(LogCodecError::UnsupportedDtype(encoded.dtype())),
50        }
51    }
52
53    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
54        match encoded {
55            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(exp(encoded)?)),
56            AnyCowArray::F64(encoded) => Ok(AnyArray::F64(exp(encoded)?)),
57            encoded => Err(LogCodecError::UnsupportedDtype(encoded.dtype())),
58        }
59    }
60
61    fn decode_into(
62        &self,
63        encoded: AnyArrayView,
64        decoded: AnyArrayViewMut,
65    ) -> Result<(), Self::Error> {
66        match (encoded, decoded) {
67            (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
68                exp_into(encoded, decoded)
69            }
70            (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
71                exp_into(encoded, decoded)
72            }
73            (encoded @ (AnyArrayView::F32(_) | AnyArrayView::F64(_)), decoded) => {
74                Err(LogCodecError::MismatchedDecodeIntoArray {
75                    source: AnyArrayAssignError::DTypeMismatch {
76                        src: encoded.dtype(),
77                        dst: decoded.dtype(),
78                    },
79                })
80            }
81            (encoded, _decoded) => Err(LogCodecError::UnsupportedDtype(encoded.dtype())),
82        }
83    }
84}
85
86impl StaticCodec for LogCodec {
87    const CODEC_ID: &'static str = "log.rs";
88
89    type Config<'de> = Self;
90
91    fn from_config(config: Self::Config<'_>) -> Self {
92        config
93    }
94
95    fn get_config(&self) -> StaticCodecConfig<Self> {
96        StaticCodecConfig::from(self)
97    }
98}
99
100#[derive(Debug, Error)]
101/// Errors that may occur when applying the [`LogCodec`].
102pub enum LogCodecError {
103    /// [`LogCodec`] does not support the dtype
104    #[error("Log does not support the dtype {0}")]
105    UnsupportedDtype(AnyArrayDType),
106    /// [`LogCodec`] does not support non-positive (negative or zero) floating
107    /// point data
108    #[error("Log does not support non-positive (negative or zero) floating point data")]
109    NonPositiveData,
110    /// [`LogCodec`] does not support non-finite (infinite or NaN) floating
111    /// point data
112    #[error("Log does not support non-finite (infinite or NaN) floating point data")]
113    NonFiniteData,
114    /// [`LogCodec`] cannot decode into the provided array
115    #[error("Log cannot decode into the provided array")]
116    MismatchedDecodeIntoArray {
117        /// The source of the error
118        #[from]
119        source: AnyArrayAssignError,
120    },
121}
122
123/// Compute `$\ln(x)$` over the elements of the input `data` array.
124///
125/// # Errors
126///
127/// Errors with
128/// - [`LogCodecError::NonPositiveData`] if any data element is non-positive
129///   (negative or zero)
130/// - [`LogCodecError::NonFiniteData`] if any data element is non-finite
131///   (infinite or NaN)
132pub fn ln<T: Float + Signed, S: Data<Elem = T>, D: Dimension>(
133    data: ArrayBase<S, D>,
134) -> Result<Array<T, D>, LogCodecError> {
135    if !Zip::from(&data).all(T::is_positive) {
136        return Err(LogCodecError::NonPositiveData);
137    }
138
139    if !Zip::from(&data).all(|x| x.is_finite()) {
140        return Err(LogCodecError::NonFiniteData);
141    }
142
143    let mut data = data.into_owned();
144    data.mapv_inplace(T::ln);
145
146    Ok(data)
147}
148
149/// Compute `${e}^{x}$` over the elements of the input `data` array.
150///
151/// # Errors
152///
153/// Errors with
154/// - [`LogCodecError::NonFiniteData`] if any data element is non-finite
155///   (infinite or NaN)
156pub fn exp<T: Float, S: Data<Elem = T>, D: Dimension>(
157    data: ArrayBase<S, D>,
158) -> Result<Array<T, D>, LogCodecError> {
159    if !Zip::from(&data).all(|x| x.is_finite()) {
160        return Err(LogCodecError::NonFiniteData);
161    }
162
163    let mut data = data.into_owned();
164    data.mapv_inplace(T::exp);
165
166    Ok(data)
167}
168
169#[expect(clippy::needless_pass_by_value)]
170/// Compute `${e}^{x}$` over the elements of the input `data` array and write
171/// them into the `out`put array.
172///
173/// # Errors
174///
175/// Errors with
176/// - [`LogCodecError::NonFiniteData`] if any data element is non-finite
177///   (infinite or NaN)
178/// - [`LogCodecError::MismatchedDecodeIntoArray`] if the `data` array's shape
179///   does not match the `out`put array's shape
180pub fn exp_into<T: Float, D: Dimension>(
181    data: ArrayView<T, D>,
182    mut out: ArrayViewMut<T, D>,
183) -> Result<(), LogCodecError> {
184    if data.shape() != out.shape() {
185        return Err(LogCodecError::MismatchedDecodeIntoArray {
186            source: AnyArrayAssignError::ShapeMismatch {
187                src: data.shape().to_vec(),
188                dst: out.shape().to_vec(),
189            },
190        });
191    }
192
193    if !Zip::from(&data).all(|x| x.is_finite()) {
194        return Err(LogCodecError::NonFiniteData);
195    }
196
197    // iteration must occur in synchronised (standard) order
198    for (d, o) in data.iter().zip(out.iter_mut()) {
199        *o = T::exp(*d);
200    }
201
202    Ok(())
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn roundtrip() -> Result<(), LogCodecError> {
211        let data = (1..1000).map(f64::from).collect::<Vec<_>>();
212        let data = Array::from_vec(data);
213
214        let encoded = ln(data.view())?;
215
216        for (r, e) in data.iter().zip(encoded.iter()) {
217            assert_eq!((*r).ln().to_bits(), (*e).to_bits());
218        }
219
220        let decoded = exp(encoded)?;
221
222        for (r, d) in data.iter().zip(decoded.iter()) {
223            assert!(((*r) - (*d)).abs() < 1e-12);
224        }
225
226        Ok(())
227    }
228}