1use ndarray::{Array, ArrayBase, Data, Dimension};
21use numcodecs::{
22 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
23 Codec, StaticCodec, StaticCodecConfig,
24};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use thiserror::Error;
28
29#[derive(Clone, Serialize, Deserialize, JsonSchema)]
30#[serde(deny_unknown_fields)]
31pub struct BitRoundCodec {
41 pub keepbits: u8,
48}
49
50impl Codec for BitRoundCodec {
51 type Error = BitRoundCodecError;
52
53 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
54 match data {
55 AnyCowArray::F32(data) => Ok(AnyArray::F32(bit_round(data, self.keepbits)?)),
56 AnyCowArray::F64(data) => Ok(AnyArray::F64(bit_round(data, self.keepbits)?)),
57 encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
58 }
59 }
60
61 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
62 match encoded {
63 AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
64 AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
65 encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
66 }
67 }
68
69 fn decode_into(
70 &self,
71 encoded: AnyArrayView,
72 mut decoded: AnyArrayViewMut,
73 ) -> Result<(), Self::Error> {
74 if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
75 return Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype()));
76 }
77
78 Ok(decoded.assign(&encoded)?)
79 }
80}
81
82impl StaticCodec for BitRoundCodec {
83 const CODEC_ID: &'static str = "bit-round";
84
85 type Config<'de> = Self;
86
87 fn from_config(config: Self::Config<'_>) -> Self {
88 config
89 }
90
91 fn get_config(&self) -> StaticCodecConfig<Self> {
92 StaticCodecConfig::from(self)
93 }
94}
95
96#[derive(Debug, Error)]
97pub enum BitRoundCodecError {
99 #[error("BitRound does not support the dtype {0}")]
101 UnsupportedDtype(AnyArrayDType),
102 #[error("BitRound encode {keepbits} bits exceed the mantissa size for {dtype}")]
104 ExcessiveKeepBits {
105 keepbits: u8,
107 dtype: AnyArrayDType,
109 },
110 #[error("BitRound cannot decode into the provided array")]
112 MismatchedDecodeIntoArray {
113 #[from]
115 source: AnyArrayAssignError,
116 },
117}
118
119pub fn bit_round<T: Float, S: Data<Elem = T>, D: Dimension>(
130 data: ArrayBase<S, D>,
131 keepbits: u8,
132) -> Result<Array<T, D>, BitRoundCodecError> {
133 if u32::from(keepbits) > T::MANITSSA_BITS {
134 return Err(BitRoundCodecError::ExcessiveKeepBits {
135 keepbits,
136 dtype: T::TY,
137 });
138 }
139
140 let mut encoded = data.into_owned();
141
142 if u32::from(keepbits) == T::MANITSSA_BITS {
145 return Ok(encoded);
146 }
147
148 let ulp_half = T::MANTISSA_MASK >> (u32::from(keepbits) + 1);
150 let keep_mask = !(T::MANTISSA_MASK >> u32::from(keepbits));
152 let shift = T::MANITSSA_BITS - u32::from(keepbits);
154
155 encoded.mapv_inplace(|x| {
156 let mut bits = T::to_binary(x);
157
158 bits += ulp_half + ((bits >> shift) & T::BINARY_ONE);
160
161 bits &= keep_mask;
163
164 T::from_binary(bits)
165 });
166
167 Ok(encoded)
168}
169
170pub trait Float: Sized + Copy {
172 const MANITSSA_BITS: u32;
174 const MANTISSA_MASK: Self::Binary;
176 const BINARY_ONE: Self::Binary;
178
179 const TY: AnyArrayDType;
181
182 type Binary: Copy
184 + std::ops::Not<Output = Self::Binary>
185 + std::ops::Shr<u32, Output = Self::Binary>
186 + std::ops::Add<Self::Binary, Output = Self::Binary>
187 + std::ops::AddAssign<Self::Binary>
188 + std::ops::BitAnd<Self::Binary, Output = Self::Binary>
189 + std::ops::BitAndAssign<Self::Binary>;
190
191 fn to_binary(self) -> Self::Binary;
193 fn from_binary(u: Self::Binary) -> Self;
195}
196
197impl Float for f32 {
198 type Binary = u32;
199
200 const BINARY_ONE: Self::Binary = 1;
201 const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
202 const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
203 const TY: AnyArrayDType = AnyArrayDType::F32;
204
205 fn to_binary(self) -> Self::Binary {
206 self.to_bits()
207 }
208
209 fn from_binary(u: Self::Binary) -> Self {
210 Self::from_bits(u)
211 }
212}
213
214impl Float for f64 {
215 type Binary = u64;
216
217 const BINARY_ONE: Self::Binary = 1;
218 const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
219 const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
220 const TY: AnyArrayDType = AnyArrayDType::F64;
221
222 fn to_binary(self) -> Self::Binary {
223 self.to_bits()
224 }
225
226 fn from_binary(u: Self::Binary) -> Self {
227 Self::from_bits(u)
228 }
229}
230
231#[cfg(test)]
232#[expect(clippy::unwrap_used)]
233mod tests {
234 use ndarray::{Array1, ArrayView1};
235
236 use super::*;
237
238 #[test]
239 fn no_mantissa() {
240 assert_eq!(
241 bit_round(ArrayView1::from(&[0.0_f32]), 0).unwrap(),
242 Array1::from_vec(vec![0.0_f32])
243 );
244 assert_eq!(
245 bit_round(ArrayView1::from(&[1.0_f32]), 0).unwrap(),
246 Array1::from_vec(vec![1.0_f32])
247 );
248 assert_eq!(
250 bit_round(ArrayView1::from(&[1.5_f32]), 0).unwrap(),
251 Array1::from_vec(vec![2.0_f32])
252 );
253 assert_eq!(
254 bit_round(ArrayView1::from(&[2.0_f32]), 0).unwrap(),
255 Array1::from_vec(vec![2.0_f32])
256 );
257 assert_eq!(
258 bit_round(ArrayView1::from(&[2.5_f32]), 0).unwrap(),
259 Array1::from_vec(vec![2.0_f32])
260 );
261 assert_eq!(
263 bit_round(ArrayView1::from(&[3.0_f32]), 0).unwrap(),
264 Array1::from_vec(vec![2.0_f32])
265 );
266 assert_eq!(
267 bit_round(ArrayView1::from(&[3.5_f32]), 0).unwrap(),
268 Array1::from_vec(vec![4.0_f32])
269 );
270 assert_eq!(
271 bit_round(ArrayView1::from(&[4.0_f32]), 0).unwrap(),
272 Array1::from_vec(vec![4.0_f32])
273 );
274 assert_eq!(
275 bit_round(ArrayView1::from(&[5.0_f32]), 0).unwrap(),
276 Array1::from_vec(vec![4.0_f32])
277 );
278 assert_eq!(
280 bit_round(ArrayView1::from(&[6.0_f32]), 0).unwrap(),
281 Array1::from_vec(vec![8.0_f32])
282 );
283 assert_eq!(
284 bit_round(ArrayView1::from(&[7.0_f32]), 0).unwrap(),
285 Array1::from_vec(vec![8.0_f32])
286 );
287 assert_eq!(
288 bit_round(ArrayView1::from(&[8.0_f32]), 0).unwrap(),
289 Array1::from_vec(vec![8.0_f32])
290 );
291
292 assert_eq!(
293 bit_round(ArrayView1::from(&[0.0_f64]), 0).unwrap(),
294 Array1::from_vec(vec![0.0_f64])
295 );
296 assert_eq!(
297 bit_round(ArrayView1::from(&[1.0_f64]), 0).unwrap(),
298 Array1::from_vec(vec![1.0_f64])
299 );
300 assert_eq!(
302 bit_round(ArrayView1::from(&[1.5_f64]), 0).unwrap(),
303 Array1::from_vec(vec![2.0_f64])
304 );
305 assert_eq!(
306 bit_round(ArrayView1::from(&[2.0_f64]), 0).unwrap(),
307 Array1::from_vec(vec![2.0_f64])
308 );
309 assert_eq!(
310 bit_round(ArrayView1::from(&[2.5_f64]), 0).unwrap(),
311 Array1::from_vec(vec![2.0_f64])
312 );
313 assert_eq!(
315 bit_round(ArrayView1::from(&[3.0_f64]), 0).unwrap(),
316 Array1::from_vec(vec![2.0_f64])
317 );
318 assert_eq!(
319 bit_round(ArrayView1::from(&[3.5_f64]), 0).unwrap(),
320 Array1::from_vec(vec![4.0_f64])
321 );
322 assert_eq!(
323 bit_round(ArrayView1::from(&[4.0_f64]), 0).unwrap(),
324 Array1::from_vec(vec![4.0_f64])
325 );
326 assert_eq!(
327 bit_round(ArrayView1::from(&[5.0_f64]), 0).unwrap(),
328 Array1::from_vec(vec![4.0_f64])
329 );
330 assert_eq!(
332 bit_round(ArrayView1::from(&[6.0_f64]), 0).unwrap(),
333 Array1::from_vec(vec![8.0_f64])
334 );
335 assert_eq!(
336 bit_round(ArrayView1::from(&[7.0_f64]), 0).unwrap(),
337 Array1::from_vec(vec![8.0_f64])
338 );
339 assert_eq!(
340 bit_round(ArrayView1::from(&[8.0_f64]), 0).unwrap(),
341 Array1::from_vec(vec![8.0_f64])
342 );
343 }
344
345 #[test]
346 #[expect(clippy::cast_possible_truncation)]
347 fn full_mantissa() {
348 fn full<T: Float>(x: T) -> T {
349 T::from_binary(T::to_binary(x) + T::MANTISSA_MASK)
350 }
351
352 for v in [0.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32] {
353 assert_eq!(
354 bit_round(ArrayView1::from(&[full(v)]), f32::MANITSSA_BITS as u8).unwrap(),
355 Array1::from_vec(vec![full(v)])
356 );
357 }
358
359 for v in [0.0_f64, 1.0_f64, 2.0_f64, 3.0_f64, 4.0_f64] {
360 assert_eq!(
361 bit_round(ArrayView1::from(&[full(v)]), f64::MANITSSA_BITS as u8).unwrap(),
362 Array1::from_vec(vec![full(v)])
363 );
364 }
365 }
366}