use ndarray::{Array, ArrayBase, Data, Dimension};
use numcodecs::{
AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
Codec, StaticCodec, StaticCodecConfig,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct BitRoundCodec {
pub keepbits: u8,
}
impl Codec for BitRoundCodec {
type Error = BitRoundCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
match data {
AnyCowArray::F32(data) => Ok(AnyArray::F32(bit_round(data, self.keepbits)?)),
AnyCowArray::F64(data) => Ok(AnyArray::F64(bit_round(data, self.keepbits)?)),
encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
match encoded {
AnyCowArray::F32(encoded) => Ok(AnyArray::F32(encoded.into_owned())),
AnyCowArray::F64(encoded) => Ok(AnyArray::F64(encoded.into_owned())),
encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())),
}
}
fn decode_into(
&self,
encoded: AnyArrayView,
mut decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
if !matches!(encoded.dtype(), AnyArrayDType::F32 | AnyArrayDType::F64) {
return Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype()));
}
Ok(decoded.assign(&encoded)?)
}
}
impl StaticCodec for BitRoundCodec {
const CODEC_ID: &'static str = "bit-round";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
config
}
fn get_config(&self) -> StaticCodecConfig<Self> {
StaticCodecConfig::from(self)
}
}
#[derive(Debug, Error)]
pub enum BitRoundCodecError {
#[error("BitRound does not support the dtype {0}")]
UnsupportedDtype(AnyArrayDType),
#[error("BitRound encode {keepbits} bits exceed the mantissa size for {dtype}")]
ExcessiveKeepBits {
keepbits: u8,
dtype: AnyArrayDType,
},
#[error("BitRound cannot decode into the provided array")]
MismatchedDecodeIntoArray {
#[from]
source: AnyArrayAssignError,
},
}
pub fn bit_round<T: Float, S: Data<Elem = T>, D: Dimension>(
data: ArrayBase<S, D>,
keepbits: u8,
) -> Result<Array<T, D>, BitRoundCodecError> {
if u32::from(keepbits) > T::MANITSSA_BITS {
return Err(BitRoundCodecError::ExcessiveKeepBits {
keepbits,
dtype: T::TY,
});
}
let mut encoded = data.into_owned();
if u32::from(keepbits) == T::MANITSSA_BITS {
return Ok(encoded);
}
let ulp_half = T::MANTISSA_MASK >> (u32::from(keepbits) + 1);
let keep_mask = !(T::MANTISSA_MASK >> u32::from(keepbits));
let shift = T::MANITSSA_BITS - u32::from(keepbits);
encoded.mapv_inplace(|x| {
let mut bits = T::to_binary(x);
bits += ulp_half + ((bits >> shift) & T::BINARY_ONE);
bits &= keep_mask;
T::from_binary(bits)
});
Ok(encoded)
}
pub trait Float: Sized + Copy {
const MANITSSA_BITS: u32;
const MANTISSA_MASK: Self::Binary;
const BINARY_ONE: Self::Binary;
const TY: AnyArrayDType;
type Binary: Copy
+ std::ops::Not<Output = Self::Binary>
+ std::ops::Shr<u32, Output = Self::Binary>
+ std::ops::Add<Self::Binary, Output = Self::Binary>
+ std::ops::AddAssign<Self::Binary>
+ std::ops::BitAnd<Self::Binary, Output = Self::Binary>
+ std::ops::BitAndAssign<Self::Binary>;
fn to_binary(self) -> Self::Binary;
fn from_binary(u: Self::Binary) -> Self;
}
impl Float for f32 {
type Binary = u32;
const BINARY_ONE: Self::Binary = 1;
const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
const TY: AnyArrayDType = AnyArrayDType::F32;
fn to_binary(self) -> Self::Binary {
self.to_bits()
}
fn from_binary(u: Self::Binary) -> Self {
Self::from_bits(u)
}
}
impl Float for f64 {
type Binary = u64;
const BINARY_ONE: Self::Binary = 1;
const MANITSSA_BITS: u32 = Self::MANTISSA_DIGITS - 1;
const MANTISSA_MASK: Self::Binary = (1 << Self::MANITSSA_BITS) - 1;
const TY: AnyArrayDType = AnyArrayDType::F64;
fn to_binary(self) -> Self::Binary {
self.to_bits()
}
fn from_binary(u: Self::Binary) -> Self {
Self::from_bits(u)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use ndarray::{Array1, ArrayView1};
use super::*;
#[test]
fn no_mantissa() {
assert_eq!(
bit_round(ArrayView1::from(&[0.0_f32]), 0).unwrap(),
Array1::from_vec(vec![0.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[1.0_f32]), 0).unwrap(),
Array1::from_vec(vec![1.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[1.5_f32]), 0).unwrap(),
Array1::from_vec(vec![2.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[2.0_f32]), 0).unwrap(),
Array1::from_vec(vec![2.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[2.5_f32]), 0).unwrap(),
Array1::from_vec(vec![2.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[3.0_f32]), 0).unwrap(),
Array1::from_vec(vec![2.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[3.5_f32]), 0).unwrap(),
Array1::from_vec(vec![4.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[4.0_f32]), 0).unwrap(),
Array1::from_vec(vec![4.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[5.0_f32]), 0).unwrap(),
Array1::from_vec(vec![4.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[6.0_f32]), 0).unwrap(),
Array1::from_vec(vec![8.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[7.0_f32]), 0).unwrap(),
Array1::from_vec(vec![8.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[8.0_f32]), 0).unwrap(),
Array1::from_vec(vec![8.0_f32])
);
assert_eq!(
bit_round(ArrayView1::from(&[0.0_f64]), 0).unwrap(),
Array1::from_vec(vec![0.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[1.0_f64]), 0).unwrap(),
Array1::from_vec(vec![1.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[1.5_f64]), 0).unwrap(),
Array1::from_vec(vec![2.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[2.0_f64]), 0).unwrap(),
Array1::from_vec(vec![2.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[2.5_f64]), 0).unwrap(),
Array1::from_vec(vec![2.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[3.0_f64]), 0).unwrap(),
Array1::from_vec(vec![2.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[3.5_f64]), 0).unwrap(),
Array1::from_vec(vec![4.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[4.0_f64]), 0).unwrap(),
Array1::from_vec(vec![4.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[5.0_f64]), 0).unwrap(),
Array1::from_vec(vec![4.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[6.0_f64]), 0).unwrap(),
Array1::from_vec(vec![8.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[7.0_f64]), 0).unwrap(),
Array1::from_vec(vec![8.0_f64])
);
assert_eq!(
bit_round(ArrayView1::from(&[8.0_f64]), 0).unwrap(),
Array1::from_vec(vec![8.0_f64])
);
}
#[test]
#[allow(clippy::cast_possible_truncation)]
fn full_mantissa() {
fn full<T: Float>(x: T) -> T {
T::from_binary(T::to_binary(x) + T::MANTISSA_MASK)
}
for v in [0.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32] {
assert_eq!(
bit_round(ArrayView1::from(&[full(v)]), f32::MANITSSA_BITS as u8).unwrap(),
Array1::from_vec(vec![full(v)])
);
}
for v in [0.0_f64, 1.0_f64, 2.0_f64, 3.0_f64, 4.0_f64] {
assert_eq!(
bit_round(ArrayView1::from(&[full(v)]), f64::MANITSSA_BITS as u8).unwrap(),
Array1::from_vec(vec![full(v)])
);
}
}
}