1use ndarray::{Array, ArrayBase, Data, Dimension};
21use numcodecs::{
22 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use thiserror::Error;
28
29#[derive(Clone, Serialize, Deserialize, JsonSchema)]
30#[serde(deny_unknown_fields)]
31pub struct BitRoundCodec {
41 pub keepbits: u8,
48 #[serde(default, rename = "_version")]
50 pub version: StaticCodecVersion<1, 0, 0>,
51}
52
53impl Codec for BitRoundCodec {
54 type Error = BitRoundCodecError;
55
56 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
57 match data {
58 AnyCowArray::F32(data) => Ok(AnyArray::F32(bit_round(data, self.keepbits)?)),
59 AnyCowArray::F64(data) => Ok(AnyArray::F64(bit_round(data, self.keepbits)?)),
60 encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
61 }
62 }
63
64 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
65 match encoded {
66 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
67 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
68 encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
69 }
70 }
71
72 fn decode_into(
73 &self,
74 encoded: AnyArrayView,
75 mut decoded: AnyArrayViewMut,
76 ) -> Result<(), Self::Error> {
77 if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
78 return Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype()));
79 }
80
81 Ok(decoded.assign(&encoded)?)
82 }
83}
84
85impl StaticCodec for BitRoundCodec {
86 const CODEC_ID: &'static str = "bit-round.rs";
87
88 type Config<'de> = Self;
89
90 fn from_config(config: Self::Config<'_>) -> Self {
91 config
92 }
93
94 fn get_config(&self) -> StaticCodecConfig<Self> {
95 StaticCodecConfig::from(self)
96 }
97}
98
99#[derive(Debug, Error)]
100pub enum BitRoundCodecError {
102 #[error("BitRound does not support the dtype {0}")]
104 UnsupportedDtype(AnyArrayDType),
105 #[error("BitRound encode {keepbits} bits exceed the mantissa size for {dtype}")]
107 ExcessiveKeepBits {
108 keepbits: u8,
110 dtype: AnyArrayDType,
112 },
113 #[error("BitRound cannot decode into the provided array")]
115 MismatchedDecodeIntoArray {
116 #[from]
118 source: AnyArrayAssignError,
119 },
120}
121
122pub fn bit_round<T: Float, S: Data<Elem = T>, D: Dimension>(
133 data: ArrayBase<S, D>,
134 keepbits: u8,
135) -> Result<Array<T, D>, BitRoundCodecError> {
136 if u32::from(keepbits) > T::MANITSSA_BITS {
137 return Err(BitRoundCodecError::ExcessiveKeepBits {
138 keepbits,
139 dtype: T::TY,
140 });
141 }
142
143 let mut encoded = data.into_owned();
144
145 if u32::from(keepbits) == T::MANITSSA_BITS {
148 return Ok(encoded);
149 }
150
151 let ulp_half = T::MANTISSA_MASK >> (u32::from(keepbits) + 1);
153 let keep_mask = !(T::MANTISSA_MASK >> u32::from(keepbits));
155 let shift = T::MANITSSA_BITS - u32::from(keepbits);
157
158 encoded.mapv_inplace(|x| {
159 let mut bits = T::to_binary(x);
160
161 bits += ulp_half + ((bits >> shift) & T::BINARY_ONE);
163
164 bits &= keep_mask;
166
167 T::from_binary(bits)
168 });
169
170 Ok(encoded)
171}
172
173pub trait Float: Sized + Copy {
175 const MANITSSA_BITS: u32;
177 const MANTISSA_MASK: Self::Binary;
179 const BINARY_ONE: Self::Binary;
181
182 const TY: AnyArrayDType;
184
185 type Binary: Copy
187 + std::ops::Not<Output = Self::Binary>
188 + std::ops::Shr<u32, Output = Self::Binary>
189 + std::ops::Add<Self::Binary, Output = Self::Binary>
190 + std::ops::AddAssign<Self::Binary>
191 + std::ops::BitAnd<Self::Binary, Output = Self::Binary>
192 + std::ops::BitAndAssign<Self::Binary>;
193
194 fn to_binary(self) -> Self::Binary;
196 fn from_binary(u: Self::Binary) -> Self;
198}
199
200impl Float for f32 {
201 type Binary = u32;
202
203 const BINARY_ONE: Self::Binary = 1;
204 const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
205 const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
206 const TY: AnyArrayDType = AnyArrayDType::F32;
207
208 fn to_binary(self) -> Self::Binary {
209 self.to_bits()
210 }
211
212 fn from_binary(u: Self::Binary) -> Self {
213 Self::from_bits(u)
214 }
215}
216
217impl Float for f64 {
218 type Binary = u64;
219
220 const BINARY_ONE: Self::Binary = 1;
221 const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
222 const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
223 const TY: AnyArrayDType = AnyArrayDType::F64;
224
225 fn to_binary(self) -> Self::Binary {
226 self.to_bits()
227 }
228
229 fn from_binary(u: Self::Binary) -> Self {
230 Self::from_bits(u)
231 }
232}
233
234#[cfg(test)]
235#[expect(clippy::unwrap_used)]
236mod tests {
237 use ndarray::{Array1, ArrayView1};
238
239 use super::*;
240
241 #[test]
242 fn no_mantissa() {
243 assert_eq!(
244 bit_round(ArrayView1::from(&[0.0_f32]), 0).unwrap(),
245 Array1::from_vec(vec![0.0_f32])
246 );
247 assert_eq!(
248 bit_round(ArrayView1::from(&[1.0_f32]), 0).unwrap(),
249 Array1::from_vec(vec![1.0_f32])
250 );
251 assert_eq!(
253 bit_round(ArrayView1::from(&[1.5_f32]), 0).unwrap(),
254 Array1::from_vec(vec![2.0_f32])
255 );
256 assert_eq!(
257 bit_round(ArrayView1::from(&[2.0_f32]), 0).unwrap(),
258 Array1::from_vec(vec![2.0_f32])
259 );
260 assert_eq!(
261 bit_round(ArrayView1::from(&[2.5_f32]), 0).unwrap(),
262 Array1::from_vec(vec![2.0_f32])
263 );
264 assert_eq!(
266 bit_round(ArrayView1::from(&[3.0_f32]), 0).unwrap(),
267 Array1::from_vec(vec![2.0_f32])
268 );
269 assert_eq!(
270 bit_round(ArrayView1::from(&[3.5_f32]), 0).unwrap(),
271 Array1::from_vec(vec![4.0_f32])
272 );
273 assert_eq!(
274 bit_round(ArrayView1::from(&[4.0_f32]), 0).unwrap(),
275 Array1::from_vec(vec![4.0_f32])
276 );
277 assert_eq!(
278 bit_round(ArrayView1::from(&[5.0_f32]), 0).unwrap(),
279 Array1::from_vec(vec![4.0_f32])
280 );
281 assert_eq!(
283 bit_round(ArrayView1::from(&[6.0_f32]), 0).unwrap(),
284 Array1::from_vec(vec![8.0_f32])
285 );
286 assert_eq!(
287 bit_round(ArrayView1::from(&[7.0_f32]), 0).unwrap(),
288 Array1::from_vec(vec![8.0_f32])
289 );
290 assert_eq!(
291 bit_round(ArrayView1::from(&[8.0_f32]), 0).unwrap(),
292 Array1::from_vec(vec![8.0_f32])
293 );
294
295 assert_eq!(
296 bit_round(ArrayView1::from(&[0.0_f64]), 0).unwrap(),
297 Array1::from_vec(vec![0.0_f64])
298 );
299 assert_eq!(
300 bit_round(ArrayView1::from(&[1.0_f64]), 0).unwrap(),
301 Array1::from_vec(vec![1.0_f64])
302 );
303 assert_eq!(
305 bit_round(ArrayView1::from(&[1.5_f64]), 0).unwrap(),
306 Array1::from_vec(vec![2.0_f64])
307 );
308 assert_eq!(
309 bit_round(ArrayView1::from(&[2.0_f64]), 0).unwrap(),
310 Array1::from_vec(vec![2.0_f64])
311 );
312 assert_eq!(
313 bit_round(ArrayView1::from(&[2.5_f64]), 0).unwrap(),
314 Array1::from_vec(vec![2.0_f64])
315 );
316 assert_eq!(
318 bit_round(ArrayView1::from(&[3.0_f64]), 0).unwrap(),
319 Array1::from_vec(vec![2.0_f64])
320 );
321 assert_eq!(
322 bit_round(ArrayView1::from(&[3.5_f64]), 0).unwrap(),
323 Array1::from_vec(vec![4.0_f64])
324 );
325 assert_eq!(
326 bit_round(ArrayView1::from(&[4.0_f64]), 0).unwrap(),
327 Array1::from_vec(vec![4.0_f64])
328 );
329 assert_eq!(
330 bit_round(ArrayView1::from(&[5.0_f64]), 0).unwrap(),
331 Array1::from_vec(vec![4.0_f64])
332 );
333 assert_eq!(
335 bit_round(ArrayView1::from(&[6.0_f64]), 0).unwrap(),
336 Array1::from_vec(vec![8.0_f64])
337 );
338 assert_eq!(
339 bit_round(ArrayView1::from(&[7.0_f64]), 0).unwrap(),
340 Array1::from_vec(vec![8.0_f64])
341 );
342 assert_eq!(
343 bit_round(ArrayView1::from(&[8.0_f64]), 0).unwrap(),
344 Array1::from_vec(vec![8.0_f64])
345 );
346 }
347
348 #[test]
349 #[expect(clippy::cast_possible_truncation)]
350 fn full_mantissa() {
351 fn full<T: Float>(x: T) -> T {
352 T::from_binary(T::to_binary(x) + T::MANTISSA_MASK)
353 }
354
355 for v in [0.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32] {
356 assert_eq!(
357 bit_round(ArrayView1::from(&[full(v)]), f32::MANITSSA_BITS as u8).unwrap(),
358 Array1::from_vec(vec![full(v)])
359 );
360 }
361
362 for v in [0.0_f64, 1.0_f64, 2.0_f64, 3.0_f64, 4.0_f64] {
363 assert_eq!(
364 bit_round(ArrayView1::from(&[full(v)]), f64::MANITSSA_BITS as u8).unwrap(),
365 Array1::from_vec(vec![full(v)])
366 );
367 }
368 }
369}