numcodecs_bit_round/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-bit-round
10//! [crates.io]: https://crates.io/crates/numcodecs-bit-round
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-bit-round
13//! [docs.rs]: https://docs.rs/numcodecs-bit-round/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_bit_round
17//!
18//! Bit rounding codec implementation for the [`numcodecs`] API.
19
20use ndarray::{Array, ArrayBase, Data, Dimension};
21use numcodecs::{
22    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23    Codec, StaticCodec, StaticCodecConfig,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use thiserror::Error;
28
29#[derive(Clone, Serialize, Deserialize, JsonSchema)]
30#[serde(deny_unknown_fields)]
31/// Codec providing floating-point bit rounding.
32///
33/// Drops the specified number of bits from the floating point mantissa,
34/// leaving an array that is more amenable to compression. The number of
35/// bits to keep should be determined by information analysis of the data
36/// to be compressed.
37///
38/// The approach is based on the paper by Klöwer et al. 2021
39/// (<https://www.nature.com/articles/s43588-021-00156-2>).
40pub struct BitRoundCodec {
41    /// The number of bits of the mantissa to keep.
42    ///
43    /// The valid range depends on the dtype of the input data.
44    ///
45    /// If keepbits is equal to the bitlength of the dtype's mantissa, no
46    /// transformation is performed.
47    pub keepbits: u8,
48}
49
50impl Codec for BitRoundCodec {
51    type Error = BitRoundCodecError;
52
53    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
54        match data {
55            AnyCowArray::F32(data) => Ok(AnyArray::F32(bit_round(data, self.keepbits)?)),
56            AnyCowArray::F64(data) => Ok(AnyArray::F64(bit_round(data, self.keepbits)?)),
57            encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
58        }
59    }
60
61    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
62        match encoded {
63            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
64            AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
65            encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
66        }
67    }
68
69    fn decode_into(
70        &self,
71        encoded: AnyArrayView,
72        mut decoded: AnyArrayViewMut,
73    ) -> Result<(), Self::Error> {
74        if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
75            return Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype()));
76        }
77
78        Ok(decoded.assign(&encoded)?)
79    }
80}
81
82impl StaticCodec for BitRoundCodec {
83    const CODEC_ID: &'static str = "bit-round";
84
85    type Config<'de> = Self;
86
87    fn from_config(config: Self::Config<'_>) -> Self {
88        config
89    }
90
91    fn get_config(&self) -> StaticCodecConfig<Self> {
92        StaticCodecConfig::from(self)
93    }
94}
95
96#[derive(Debug, Error)]
97/// Errors that may occur when applying the [`BitRoundCodec`].
98pub enum BitRoundCodecError {
99    /// [`BitRoundCodec`] does not support the dtype
100    #[error("BitRound does not support the dtype {0}")]
101    UnsupportedDtype(AnyArrayDType),
102    /// [`BitRoundCodec`] encode `keepbits` exceed the mantissa size for `dtype`
103    #[error("BitRound encode {keepbits} bits exceed the mantissa size for {dtype}")]
104    ExcessiveKeepBits {
105        /// The number of bits of the mantissa to keep
106        keepbits: u8,
107        /// The `dtype` of the data to encode
108        dtype: AnyArrayDType,
109    },
110    /// [`BitRoundCodec`] cannot decode into the provided array
111    #[error("BitRound cannot decode into the provided array")]
112    MismatchedDecodeIntoArray {
113        /// The source of the error
114        #[from]
115        source: AnyArrayAssignError,
116    },
117}
118
119/// Floating-point bit rounding, which drops the specified number of bits from
120/// the floating point mantissa.
121///
122/// See <https://github.com/milankl/BitInformation.jl> for the the original
123/// implementation in Julia.
124///
125/// # Errors
126///
127/// Errors with [`BitRoundCodecError::ExcessiveKeepBits`] if `keepbits` exceeds
128/// [`T::MANITSSA_BITS`][`Float::MANITSSA_BITS`].
129pub fn bit_round<T: Float, S: Data<Elem = T>, D: Dimension>(
130    data: ArrayBase<S, D>,
131    keepbits: u8,
132) -> Result<Array<T, D>, BitRoundCodecError> {
133    if u32::from(keepbits) > T::MANITSSA_BITS {
134        return Err(BitRoundCodecError::ExcessiveKeepBits {
135            keepbits,
136            dtype: T::TY,
137        });
138    }
139
140    let mut encoded = data.into_owned();
141
142    // Early return if no bit rounding needs to happen
143    // - required since the ties to even impl does not work in this case
144    if u32::from(keepbits) == T::MANITSSA_BITS {
145        return Ok(encoded);
146    }
147
148    // half of unit in last place (ulp)
149    let ulp_half = T::MANTISSA_MASK >> (u32::from(keepbits) + 1);
150    // mask to zero out trailing mantissa bits
151    let keep_mask = !(T::MANTISSA_MASK >> u32::from(keepbits));
152    // shift to extract the least significant bit of the exponent
153    let shift = T::MANITSSA_BITS - u32::from(keepbits);
154
155    encoded.mapv_inplace(|x| {
156        let mut bits = T::to_binary(x);
157
158        // add ulp/2 with ties to even
159        bits += ulp_half + ((bits >> shift) & T::BINARY_ONE);
160
161        // set the trailing bits to zero
162        bits &= keep_mask;
163
164        T::from_binary(bits)
165    });
166
167    Ok(encoded)
168}
169
170/// Floating point types.
171pub trait Float: Sized + Copy {
172    /// Number of significant digits in base 2
173    const MANITSSA_BITS: u32;
174    /// Binary mask to extract only the mantissa bits
175    const MANTISSA_MASK: Self::Binary;
176    /// Binary `0x1`
177    const BINARY_ONE: Self::Binary;
178
179    /// Dtype of this type
180    const TY: AnyArrayDType;
181
182    /// Binary representation of this type
183    type Binary: Copy
184        + std::ops::Not<Output = Self::Binary>
185        + std::ops::Shr<u32, Output = Self::Binary>
186        + std::ops::Add<Self::Binary, Output = Self::Binary>
187        + std::ops::AddAssign<Self::Binary>
188        + std::ops::BitAnd<Self::Binary, Output = Self::Binary>
189        + std::ops::BitAndAssign<Self::Binary>;
190
191    /// Bit-cast the floating point value to its binary representation
192    fn to_binary(self) -> Self::Binary;
193    /// Bit-cast the binary representation into a floating point value
194    fn from_binary(u: Self::Binary) -> Self;
195}
196
197impl Float for f32 {
198    type Binary = u32;
199
200    const BINARY_ONE: Self::Binary = 1;
201    const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
202    const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
203    const TY: AnyArrayDType = AnyArrayDType::F32;
204
205    fn to_binary(self) -> Self::Binary {
206        self.to_bits()
207    }
208
209    fn from_binary(u: Self::Binary) -> Self {
210        Self::from_bits(u)
211    }
212}
213
214impl Float for f64 {
215    type Binary = u64;
216
217    const BINARY_ONE: Self::Binary = 1;
218    const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
219    const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
220    const TY: AnyArrayDType = AnyArrayDType::F64;
221
222    fn to_binary(self) -> Self::Binary {
223        self.to_bits()
224    }
225
226    fn from_binary(u: Self::Binary) -> Self {
227        Self::from_bits(u)
228    }
229}
230
231#[cfg(test)]
232#[expect(clippy::unwrap_used)]
233mod tests {
234    use ndarray::{Array1, ArrayView1};
235
236    use super::*;
237
238    #[test]
239    fn no_mantissa() {
240        assert_eq!(
241            bit_round(ArrayView1::from(&[0.0_f32]), 0).unwrap(),
242            Array1::from_vec(vec![0.0_f32])
243        );
244        assert_eq!(
245            bit_round(ArrayView1::from(&[1.0_f32]), 0).unwrap(),
246            Array1::from_vec(vec![1.0_f32])
247        );
248        // tie to even rounds up as the offset exponent is odd
249        assert_eq!(
250            bit_round(ArrayView1::from(&[1.5_f32]), 0).unwrap(),
251            Array1::from_vec(vec![2.0_f32])
252        );
253        assert_eq!(
254            bit_round(ArrayView1::from(&[2.0_f32]), 0).unwrap(),
255            Array1::from_vec(vec![2.0_f32])
256        );
257        assert_eq!(
258            bit_round(ArrayView1::from(&[2.5_f32]), 0).unwrap(),
259            Array1::from_vec(vec![2.0_f32])
260        );
261        // tie to even rounds down as the offset exponent is even
262        assert_eq!(
263            bit_round(ArrayView1::from(&[3.0_f32]), 0).unwrap(),
264            Array1::from_vec(vec![2.0_f32])
265        );
266        assert_eq!(
267            bit_round(ArrayView1::from(&[3.5_f32]), 0).unwrap(),
268            Array1::from_vec(vec![4.0_f32])
269        );
270        assert_eq!(
271            bit_round(ArrayView1::from(&[4.0_f32]), 0).unwrap(),
272            Array1::from_vec(vec![4.0_f32])
273        );
274        assert_eq!(
275            bit_round(ArrayView1::from(&[5.0_f32]), 0).unwrap(),
276            Array1::from_vec(vec![4.0_f32])
277        );
278        // tie to even rounds up as the offset exponent is odd
279        assert_eq!(
280            bit_round(ArrayView1::from(&[6.0_f32]), 0).unwrap(),
281            Array1::from_vec(vec![8.0_f32])
282        );
283        assert_eq!(
284            bit_round(ArrayView1::from(&[7.0_f32]), 0).unwrap(),
285            Array1::from_vec(vec![8.0_f32])
286        );
287        assert_eq!(
288            bit_round(ArrayView1::from(&[8.0_f32]), 0).unwrap(),
289            Array1::from_vec(vec![8.0_f32])
290        );
291
292        assert_eq!(
293            bit_round(ArrayView1::from(&[0.0_f64]), 0).unwrap(),
294            Array1::from_vec(vec![0.0_f64])
295        );
296        assert_eq!(
297            bit_round(ArrayView1::from(&[1.0_f64]), 0).unwrap(),
298            Array1::from_vec(vec![1.0_f64])
299        );
300        // tie to even rounds up as the offset exponent is odd
301        assert_eq!(
302            bit_round(ArrayView1::from(&[1.5_f64]), 0).unwrap(),
303            Array1::from_vec(vec![2.0_f64])
304        );
305        assert_eq!(
306            bit_round(ArrayView1::from(&[2.0_f64]), 0).unwrap(),
307            Array1::from_vec(vec![2.0_f64])
308        );
309        assert_eq!(
310            bit_round(ArrayView1::from(&[2.5_f64]), 0).unwrap(),
311            Array1::from_vec(vec![2.0_f64])
312        );
313        // tie to even rounds down as the offset exponent is even
314        assert_eq!(
315            bit_round(ArrayView1::from(&[3.0_f64]), 0).unwrap(),
316            Array1::from_vec(vec![2.0_f64])
317        );
318        assert_eq!(
319            bit_round(ArrayView1::from(&[3.5_f64]), 0).unwrap(),
320            Array1::from_vec(vec![4.0_f64])
321        );
322        assert_eq!(
323            bit_round(ArrayView1::from(&[4.0_f64]), 0).unwrap(),
324            Array1::from_vec(vec![4.0_f64])
325        );
326        assert_eq!(
327            bit_round(ArrayView1::from(&[5.0_f64]), 0).unwrap(),
328            Array1::from_vec(vec![4.0_f64])
329        );
330        // tie to even rounds up as the offset exponent is odd
331        assert_eq!(
332            bit_round(ArrayView1::from(&[6.0_f64]), 0).unwrap(),
333            Array1::from_vec(vec![8.0_f64])
334        );
335        assert_eq!(
336            bit_round(ArrayView1::from(&[7.0_f64]), 0).unwrap(),
337            Array1::from_vec(vec![8.0_f64])
338        );
339        assert_eq!(
340            bit_round(ArrayView1::from(&[8.0_f64]), 0).unwrap(),
341            Array1::from_vec(vec![8.0_f64])
342        );
343    }
344
345    #[test]
346    #[expect(clippy::cast_possible_truncation)]
347    fn full_mantissa() {
348        fn full<T: Float>(x: T) -> T {
349            T::from_binary(T::to_binary(x) + T::MANTISSA_MASK)
350        }
351
352        for v in [0.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32] {
353            assert_eq!(
354                bit_round(ArrayView1::from(&[full(v)]), f32::MANITSSA_BITS as u8).unwrap(),
355                Array1::from_vec(vec![full(v)])
356            );
357        }
358
359        for v in [0.0_f64, 1.0_f64, 2.0_f64, 3.0_f64, 4.0_f64] {
360            assert_eq!(
361                bit_round(ArrayView1::from(&[full(v)]), f64::MANITSSA_BITS as u8).unwrap(),
362                Array1::from_vec(vec![full(v)])
363            );
364        }
365    }
366}