numcodecs_bit_round/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-bit-round
10//! [crates.io]: https://crates.io/crates/numcodecs-bit-round
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-bit-round
13//! [docs.rs]: https://docs.rs/numcodecs-bit-round/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_bit_round
17//!
18//! Bit rounding codec implementation for the [`numcodecs`] API.
19
20use ndarray::{Array, ArrayBase, Data, Dimension};
21use numcodecs::{
22    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use thiserror::Error;
28
29#[derive(Clone, Serialize, Deserialize, JsonSchema)]
30#[serde(deny_unknown_fields)]
31/// Codec providing floating-point bit rounding.
32///
33/// Drops the specified number of bits from the floating point mantissa,
34/// leaving an array that is more amenable to compression. The number of
35/// bits to keep should be determined by information analysis of the data
36/// to be compressed.
37///
38/// The approach is based on the paper by Klöwer et al. 2021
39/// (<https://www.nature.com/articles/s43588-021-00156-2>).
40pub struct BitRoundCodec {
41    /// The number of bits of the mantissa to keep.
42    ///
43    /// The valid range depends on the dtype of the input data.
44    ///
45    /// If keepbits is equal to the bitlength of the dtype's mantissa, no
46    /// transformation is performed.
47    pub keepbits: u8,
48    /// The codec's encoding format version. Do not provide this parameter explicitly.
49    #[serde(default, rename = "_version")]
50    pub version: StaticCodecVersion<1, 0, 0>,
51}
52
53impl Codec for BitRoundCodec {
54    type Error = BitRoundCodecError;
55
56    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
57        match data {
58            AnyCowArray::F32(data) => Ok(AnyArray::F32(bit_round(data, self.keepbits)?)),
59            AnyCowArray::F64(data) => Ok(AnyArray::F64(bit_round(data, self.keepbits)?)),
60            encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
61        }
62    }
63
64    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
65        match encoded {
66            AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
67            AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
68            encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
69        }
70    }
71
72    fn decode_into(
73        &self,
74        encoded: AnyArrayView,
75        mut decoded: AnyArrayViewMut,
76    ) -> Result<(), Self::Error> {
77        if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
78            return Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype()));
79        }
80
81        Ok(decoded.assign(&encoded)?)
82    }
83}
84
85impl StaticCodec for BitRoundCodec {
86    const CODEC_ID: &'static str = "bit-round.rs";
87
88    type Config<'de> = Self;
89
90    fn from_config(config: Self::Config<'_>) -> Self {
91        config
92    }
93
94    fn get_config(&self) -> StaticCodecConfig<Self> {
95        StaticCodecConfig::from(self)
96    }
97}
98
99#[derive(Debug, Error)]
100/// Errors that may occur when applying the [`BitRoundCodec`].
101pub enum BitRoundCodecError {
102    /// [`BitRoundCodec`] does not support the dtype
103    #[error("BitRound does not support the dtype {0}")]
104    UnsupportedDtype(AnyArrayDType),
105    /// [`BitRoundCodec`] encode `keepbits` exceed the mantissa size for `dtype`
106    #[error("BitRound encode {keepbits} bits exceed the mantissa size for {dtype}")]
107    ExcessiveKeepBits {
108        /// The number of bits of the mantissa to keep
109        keepbits: u8,
110        /// The `dtype` of the data to encode
111        dtype: AnyArrayDType,
112    },
113    /// [`BitRoundCodec`] cannot decode into the provided array
114    #[error("BitRound cannot decode into the provided array")]
115    MismatchedDecodeIntoArray {
116        /// The source of the error
117        #[from]
118        source: AnyArrayAssignError,
119    },
120}
121
122/// Floating-point bit rounding, which drops the specified number of bits from
123/// the floating point mantissa.
124///
125/// See <https://github.com/milankl/BitInformation.jl> for the the original
126/// implementation in Julia.
127///
128/// # Errors
129///
130/// Errors with [`BitRoundCodecError::ExcessiveKeepBits`] if `keepbits` exceeds
131/// [`T::MANITSSA_BITS`][`Float::MANITSSA_BITS`].
132pub fn bit_round<T: Float, S: Data<Elem = T>, D: Dimension>(
133    data: ArrayBase<S, D>,
134    keepbits: u8,
135) -> Result<Array<T, D>, BitRoundCodecError> {
136    if u32::from(keepbits) > T::MANITSSA_BITS {
137        return Err(BitRoundCodecError::ExcessiveKeepBits {
138            keepbits,
139            dtype: T::TY,
140        });
141    }
142
143    let mut encoded = data.into_owned();
144
145    // Early return if no bit rounding needs to happen
146    // - required since the ties to even impl does not work in this case
147    if u32::from(keepbits) == T::MANITSSA_BITS {
148        return Ok(encoded);
149    }
150
151    // half of unit in last place (ulp)
152    let ulp_half = T::MANTISSA_MASK >> (u32::from(keepbits) + 1);
153    // mask to zero out trailing mantissa bits
154    let keep_mask = !(T::MANTISSA_MASK >> u32::from(keepbits));
155    // shift to extract the least significant bit of the exponent
156    let shift = T::MANITSSA_BITS - u32::from(keepbits);
157
158    encoded.mapv_inplace(|x| {
159        let mut bits = T::to_binary(x);
160
161        // add ulp/2 with ties to even
162        bits += ulp_half + ((bits >> shift) & T::BINARY_ONE);
163
164        // set the trailing bits to zero
165        bits &= keep_mask;
166
167        T::from_binary(bits)
168    });
169
170    Ok(encoded)
171}
172
173/// Floating point types.
174pub trait Float: Sized + Copy {
175    /// Number of significant digits in base 2
176    const MANITSSA_BITS: u32;
177    /// Binary mask to extract only the mantissa bits
178    const MANTISSA_MASK: Self::Binary;
179    /// Binary `0x1`
180    const BINARY_ONE: Self::Binary;
181
182    /// Dtype of this type
183    const TY: AnyArrayDType;
184
185    /// Binary representation of this type
186    type Binary: Copy
187        + std::ops::Not<Output = Self::Binary>
188        + std::ops::Shr<u32, Output = Self::Binary>
189        + std::ops::Add<Self::Binary, Output = Self::Binary>
190        + std::ops::AddAssign<Self::Binary>
191        + std::ops::BitAnd<Self::Binary, Output = Self::Binary>
192        + std::ops::BitAndAssign<Self::Binary>;
193
194    /// Bit-cast the floating point value to its binary representation
195    fn to_binary(self) -> Self::Binary;
196    /// Bit-cast the binary representation into a floating point value
197    fn from_binary(u: Self::Binary) -> Self;
198}
199
200impl Float for f32 {
201    type Binary = u32;
202
203    const BINARY_ONE: Self::Binary = 1;
204    const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
205    const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
206    const TY: AnyArrayDType = AnyArrayDType::F32;
207
208    fn to_binary(self) -> Self::Binary {
209        self.to_bits()
210    }
211
212    fn from_binary(u: Self::Binary) -> Self {
213        Self::from_bits(u)
214    }
215}
216
217impl Float for f64 {
218    type Binary = u64;
219
220    const BINARY_ONE: Self::Binary = 1;
221    const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
222    const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
223    const TY: AnyArrayDType = AnyArrayDType::F64;
224
225    fn to_binary(self) -> Self::Binary {
226        self.to_bits()
227    }
228
229    fn from_binary(u: Self::Binary) -> Self {
230        Self::from_bits(u)
231    }
232}
233
234#[cfg(test)]
235#[expect(clippy::unwrap_used)]
236mod tests {
237    use ndarray::{Array1, ArrayView1};
238
239    use super::*;
240
241    #[test]
242    fn no_mantissa() {
243        assert_eq!(
244            bit_round(ArrayView1::from(&[0.0_f32]), 0).unwrap(),
245            Array1::from_vec(vec![0.0_f32])
246        );
247        assert_eq!(
248            bit_round(ArrayView1::from(&[1.0_f32]), 0).unwrap(),
249            Array1::from_vec(vec![1.0_f32])
250        );
251        // tie to even rounds up as the offset exponent is odd
252        assert_eq!(
253            bit_round(ArrayView1::from(&[1.5_f32]), 0).unwrap(),
254            Array1::from_vec(vec![2.0_f32])
255        );
256        assert_eq!(
257            bit_round(ArrayView1::from(&[2.0_f32]), 0).unwrap(),
258            Array1::from_vec(vec![2.0_f32])
259        );
260        assert_eq!(
261            bit_round(ArrayView1::from(&[2.5_f32]), 0).unwrap(),
262            Array1::from_vec(vec![2.0_f32])
263        );
264        // tie to even rounds down as the offset exponent is even
265        assert_eq!(
266            bit_round(ArrayView1::from(&[3.0_f32]), 0).unwrap(),
267            Array1::from_vec(vec![2.0_f32])
268        );
269        assert_eq!(
270            bit_round(ArrayView1::from(&[3.5_f32]), 0).unwrap(),
271            Array1::from_vec(vec![4.0_f32])
272        );
273        assert_eq!(
274            bit_round(ArrayView1::from(&[4.0_f32]), 0).unwrap(),
275            Array1::from_vec(vec![4.0_f32])
276        );
277        assert_eq!(
278            bit_round(ArrayView1::from(&[5.0_f32]), 0).unwrap(),
279            Array1::from_vec(vec![4.0_f32])
280        );
281        // tie to even rounds up as the offset exponent is odd
282        assert_eq!(
283            bit_round(ArrayView1::from(&[6.0_f32]), 0).unwrap(),
284            Array1::from_vec(vec![8.0_f32])
285        );
286        assert_eq!(
287            bit_round(ArrayView1::from(&[7.0_f32]), 0).unwrap(),
288            Array1::from_vec(vec![8.0_f32])
289        );
290        assert_eq!(
291            bit_round(ArrayView1::from(&[8.0_f32]), 0).unwrap(),
292            Array1::from_vec(vec![8.0_f32])
293        );
294
295        assert_eq!(
296            bit_round(ArrayView1::from(&[0.0_f64]), 0).unwrap(),
297            Array1::from_vec(vec![0.0_f64])
298        );
299        assert_eq!(
300            bit_round(ArrayView1::from(&[1.0_f64]), 0).unwrap(),
301            Array1::from_vec(vec![1.0_f64])
302        );
303        // tie to even rounds up as the offset exponent is odd
304        assert_eq!(
305            bit_round(ArrayView1::from(&[1.5_f64]), 0).unwrap(),
306            Array1::from_vec(vec![2.0_f64])
307        );
308        assert_eq!(
309            bit_round(ArrayView1::from(&[2.0_f64]), 0).unwrap(),
310            Array1::from_vec(vec![2.0_f64])
311        );
312        assert_eq!(
313            bit_round(ArrayView1::from(&[2.5_f64]), 0).unwrap(),
314            Array1::from_vec(vec![2.0_f64])
315        );
316        // tie to even rounds down as the offset exponent is even
317        assert_eq!(
318            bit_round(ArrayView1::from(&[3.0_f64]), 0).unwrap(),
319            Array1::from_vec(vec![2.0_f64])
320        );
321        assert_eq!(
322            bit_round(ArrayView1::from(&[3.5_f64]), 0).unwrap(),
323            Array1::from_vec(vec![4.0_f64])
324        );
325        assert_eq!(
326            bit_round(ArrayView1::from(&[4.0_f64]), 0).unwrap(),
327            Array1::from_vec(vec![4.0_f64])
328        );
329        assert_eq!(
330            bit_round(ArrayView1::from(&[5.0_f64]), 0).unwrap(),
331            Array1::from_vec(vec![4.0_f64])
332        );
333        // tie to even rounds up as the offset exponent is odd
334        assert_eq!(
335            bit_round(ArrayView1::from(&[6.0_f64]), 0).unwrap(),
336            Array1::from_vec(vec![8.0_f64])
337        );
338        assert_eq!(
339            bit_round(ArrayView1::from(&[7.0_f64]), 0).unwrap(),
340            Array1::from_vec(vec![8.0_f64])
341        );
342        assert_eq!(
343            bit_round(ArrayView1::from(&[8.0_f64]), 0).unwrap(),
344            Array1::from_vec(vec![8.0_f64])
345        );
346    }
347
348    #[test]
349    #[expect(clippy::cast_possible_truncation)]
350    fn full_mantissa() {
351        fn full<T: Float>(x: T) -> T {
352            T::from_binary(T::to_binary(x) + T::MANTISSA_MASK)
353        }
354
355        for v in [0.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32] {
356            assert_eq!(
357                bit_round(ArrayView1::from(&[full(v)]), f32::MANITSSA_BITS as u8).unwrap(),
358                Array1::from_vec(vec![full(v)])
359            );
360        }
361
362        for v in [0.0_f64, 1.0_f64, 2.0_f64, 3.0_f64, 4.0_f64] {
363            assert_eq!(
364                bit_round(ArrayView1::from(&[full(v)]), f64::MANITSSA_BITS as u8).unwrap(),
365                Array1::from_vec(vec![full(v)])
366            );
367        }
368    }
369}