#![allow(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
use ndarray::{Array, Array1, ArrayBase, ArrayD, ArrayViewMutD, Data, Dimension, ShapeError, Zip};
use num_traits::{ConstOne, ConstZero, Float};
use numcodecs::{
AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec,
use schemars::{JsonSchema, JsonSchema_repr};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use thiserror::Error;
use twofloat::TwoFloat;
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
pub struct LinearQuantizeCodec {
pub dtype: LinearQuantizeDType,
pub bits: LinearQuantizeBins,
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("enum" = ["f32", "float32", "f64", "float64"]))]
pub enum LinearQuantizeDType {
#[serde(rename = "f32", alias = "float32")]
#[serde(rename = "f64", alias = "float64")]
impl fmt::Display for LinearQuantizeDType {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(match self {
Self::F32 => "f32",
Self::F64 => "f64",
#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
pub enum LinearQuantizeBins {
_1B1 = 1, _1B2, _1B3, _1B4, _1B5, _1B6, _1B7, _1B8,
_1B9, _1B10, _1B11, _1B12, _1B13, _1B14, _1B15, _1B16,
_1B17, _1B18, _1B19, _1B20, _1B21, _1B22, _1B23, _1B24,
_1B25, _1B26, _1B27, _1B28, _1B29, _1B30, _1B31, _1B32,
_1B33, _1B34, _1B35, _1B36, _1B37, _1B38, _1B39, _1B40,
_1B41, _1B42, _1B43, _1B44, _1B45, _1B46, _1B47, _1B48,
_1B49, _1B50, _1B51, _1B52, _1B53, _1B54, _1B55, _1B56,
_1B57, _1B58, _1B59, _1B60, _1B61, _1B62, _1B63, _1B64,
impl Codec for LinearQuantizeCodec {
type Error = LinearQuantizeCodecError;
fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
let encoded = match (&data, self.dtype) {
(AnyCowArray::F32(data), LinearQuantizeDType::F32) => match self.bits as u8 {
bits @ ..=8 => AnyArray::U8(
Array1::from_vec(quantize(data, |x| {
let max = f32::from(u8::MAX >> (8 - bits));
let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
unsafe {
bits @ 9..=16 => AnyArray::U16(
Array1::from_vec(quantize(data, |x| {
let max = f32::from(u16::MAX >> (16 - bits));
let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
unsafe {
bits @ 17..=32 => AnyArray::U32(
Array1::from_vec(quantize(data, |x| {
let max = f64::from(u32::MAX >> (32 - bits));
let x = f64::from(x)
.mul_add(scale_for_bits::<f64>(bits), 0.5)
.clamp(0.0, max);
unsafe {
bits @ 33.. => AnyArray::U64(
Array1::from_vec(quantize(data, |x| {
let max = TwoFloat::from(u64::MAX >> (64 - bits));
let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
+ TwoFloat::from(0.5))
unsafe {
(AnyCowArray::F64(data), LinearQuantizeDType::F64) => match self.bits as u8 {
bits @ ..=8 => AnyArray::U8(
Array1::from_vec(quantize(data, |x| {
let max = f64::from(u8::MAX >> (8 - bits));
let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
unsafe {
bits @ 9..=16 => AnyArray::U16(
Array1::from_vec(quantize(data, |x| {
let max = f64::from(u16::MAX >> (16 - bits));
let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
unsafe {
bits @ 17..=32 => AnyArray::U32(
Array1::from_vec(quantize(data, |x| {
let max = f64::from(u32::MAX >> (32 - bits));
let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
unsafe {
bits @ 33.. => AnyArray::U64(
Array1::from_vec(quantize(data, |x| {
let max = TwoFloat::from(u64::MAX >> (64 - bits));
let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
+ TwoFloat::from(0.5))
unsafe {
(data, dtype) => {
return Err(LinearQuantizeCodecError::MismatchedEncodeDType {
configured: dtype,
provided: data.dtype(),
fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
array: &ArrayBase<S, D>,
) -> Cow<[T]> {
if let Some(data) = array.as_slice() {
} else {
if !matches!(encoded.shape(), [_]) {
return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
shape: encoded.shape().to_vec(),
let decoded = match (&encoded, self.dtype) {
(AnyCowArray::U8(encoded), LinearQuantizeDType::F32) => {
AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
(AnyCowArray::U16(encoded), LinearQuantizeDType::F32) => {
AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
(AnyCowArray::U32(encoded), LinearQuantizeDType::F32) => {
AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
let x = x as f32;
(AnyCowArray::U64(encoded), LinearQuantizeDType::F32) => {
AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
(AnyCowArray::U8(encoded), LinearQuantizeDType::F64) => {
AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
(AnyCowArray::U16(encoded), LinearQuantizeDType::F64) => {
AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
(AnyCowArray::U32(encoded), LinearQuantizeDType::F64) => {
AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
(AnyCowArray::U64(encoded), LinearQuantizeDType::F64) => {
AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
(encoded, _dtype) => {
return Err(LinearQuantizeCodecError::InvalidEncodedDType {
dtype: encoded.dtype(),
fn decode_into(
encoded: AnyArrayView,
decoded: AnyArrayViewMut,
) -> Result<(), Self::Error> {
fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
array: &ArrayBase<S, D>,
) -> Cow<[T]> {
if let Some(data) = array.as_slice() {
} else {
if !matches!(encoded.shape(), [_]) {
return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
shape: encoded.shape().to_vec(),
match (decoded, self.dtype) {
(AnyArrayViewMut::F32(decoded), LinearQuantizeDType::F32) => {
match &encoded {
AnyArrayView::U8(encoded) => {
reconstruct_into(&as_standard_order(encoded), decoded, |x| {
f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
AnyArrayView::U16(encoded) => {
reconstruct_into(&as_standard_order(encoded), decoded, |x| {
f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
AnyArrayView::U32(encoded) => {
reconstruct_into(&as_standard_order(encoded), decoded, |x| {
let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
let x = x as f32;
AnyArrayView::U64(encoded) => {
reconstruct_into(&as_standard_order(encoded), decoded, |x| {
let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
encoded => {
return Err(LinearQuantizeCodecError::InvalidEncodedDType {
dtype: encoded.dtype(),
(AnyArrayViewMut::F64(decoded), LinearQuantizeDType::F64) => {
match &encoded {
AnyArrayView::U8(encoded) => {
reconstruct_into(&as_standard_order(encoded), decoded, |x| {
f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
AnyArrayView::U16(encoded) => {
reconstruct_into(&as_standard_order(encoded), decoded, |x| {
f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
AnyArrayView::U32(encoded) => {
reconstruct_into(&as_standard_order(encoded), decoded, |x| {
f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
AnyArrayView::U64(encoded) => {
reconstruct_into(&as_standard_order(encoded), decoded, |x| {
let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
encoded => {
return Err(LinearQuantizeCodecError::InvalidEncodedDType {
dtype: encoded.dtype(),
(decoded, dtype) => {
return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
configured: dtype,
provided: decoded.dtype(),
impl StaticCodec for LinearQuantizeCodec {
const CODEC_ID: &'static str = "linear-quantize";
type Config<'de> = Self;
fn from_config(config: Self::Config<'_>) -> Self {
fn get_config(&self) -> StaticCodecConfig<Self> {
#[derive(Debug, Error)]
pub enum LinearQuantizeCodecError {
#[error("LinearQuantize cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
MismatchedEncodeDType {
configured: LinearQuantizeDType,
provided: AnyArrayDType,
#[error("LinearQuantize does not support non-finite (infinite or NaN) floating point data")]
#[error("LinearQuantize failed to encode the header")]
HeaderEncodeFailed {
source: LinearQuantizeHeaderError,
#[error("LinearQuantize can only decode one-dimensional arrays but received an array of shape {shape:?}")]
EncodedDataNotOneDimensional {
shape: Vec<usize>,
#[error("LinearQuantize failed to decode the header")]
HeaderDecodeFailed {
source: LinearQuantizeHeaderError,
"LinearQuantize decoded an invalid array shape header which does not fit the decoded data"
DecodeInvalidShapeHeader {
source: ShapeError,
#[error("LinearQuantize cannot decode the provided dtype {dtype}")]
InvalidEncodedDType {
dtype: AnyArrayDType,
#[error("LinearQuantize cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
MismatchedDecodeIntoDtype {
configured: LinearQuantizeDType,
provided: AnyArrayDType,
#[error("LinearQuantize cannot decode the decoded array of shape {decoded:?} into the provided array of shape {provided:?}")]
MismatchedDecodeIntoShape {
decoded: Vec<usize>,
provided: Vec<usize>,
#[derive(Debug, Error)]
pub struct LinearQuantizeHeaderError(postcard::Error);
pub fn quantize<
T: Float + ConstZero + ConstOne + Serialize,
Q: Unsigned,
S: Data<Elem = T>,
D: Dimension,
data: &ArrayBase<S, D>,
quantize: impl Fn(T) -> Q,
) -> Result<Vec<Q>, LinearQuantizeCodecError> {
if !Zip::from(data).all(|x| x.is_finite()) {
return Err(LinearQuantizeCodecError::NonFiniteData);
let (minimum, maximum) = data.first().map_or((T::ZERO, T::ONE), |first| {
Zip::from(data).fold(*first, |a, b| a.min(*b)),
Zip::from(data).fold(*first, |a, b| a.max(*b)),
let header = postcard::to_extend(
&CompressionHeader {
shape: Cow::Borrowed(data.shape()),
.map_err(|err| LinearQuantizeCodecError::HeaderEncodeFailed {
source: LinearQuantizeHeaderError(err),
let mut encoded: Vec<Q> = vec![Q::ZERO; header.len().div_ceil(std::mem::size_of::<Q>())];
unsafe {
std::ptr::copy_nonoverlapping(header.as_ptr(), encoded.as_mut_ptr().cast(), header.len());
if maximum == minimum {
encoded.resize(encoded.len() + data.len(), quantize(T::ZERO));
} else {
.map(|x| quantize((*x - minimum) / (maximum - minimum))),
pub fn reconstruct<T: Float + DeserializeOwned, Q: Unsigned>(
encoded: &[Q],
floatify: impl Fn(Q) -> T,
) -> Result<ArrayD<T>, LinearQuantizeCodecError> {
let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
.map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
source: LinearQuantizeHeaderError(err),
let encoded = encoded
.get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
let decoded = encoded
.map(|x| header.minimum + (floatify(*x) * (header.maximum - header.minimum)))
.map(|x| x.clamp(header.minimum, header.maximum))
let decoded = Array::from_shape_vec(&*header.shape, decoded)?;
pub fn reconstruct_into<T: Float + DeserializeOwned, Q: Unsigned>(
encoded: &[Q],
mut decoded: ArrayViewMutD<T>,
floatify: impl Fn(Q) -> T,
) -> Result<(), LinearQuantizeCodecError> {
let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
.map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
source: LinearQuantizeHeaderError(err),
let encoded = encoded
.get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
if decoded.shape() != &*header.shape {
return Err(LinearQuantizeCodecError::MismatchedDecodeIntoShape {
decoded: header.shape.into_owned(),
provided: decoded.shape().to_vec(),
for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
*d = (header.minimum + (floatify(*e) * (header.maximum - header.minimum)))
.clamp(header.minimum, header.maximum);
fn scale_for_bits<T: Float + From<u8> + ConstOne>(bits: u8) -> T {
<T as From<u8>>::from(bits).exp2() - T::ONE
pub trait Unsigned: Copy {
const ZERO: Self;
impl Unsigned for u8 {
const ZERO: Self = 0;
impl Unsigned for u16 {
const ZERO: Self = 0;
impl Unsigned for u32 {
const ZERO: Self = 0;
impl Unsigned for u64 {
const ZERO: Self = 0;
#[derive(Serialize, Deserialize)]
struct CompressionHeader<'a, T> {
shape: Cow<'a, [usize]>,
minimum: T,
maximum: T,
mod tests {
use ndarray::CowArray;
use super::*;
fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> {
for bits in 1..=16 {
let codec = LinearQuantizeCodec {
dtype: LinearQuantizeDType::F32,
bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
let mut data: Vec<f32> = (0..(u16::MAX >> (16 - bits)))
.step_by(1 << (bits.max(8) - 8))
data.push(f32::from(u16::MAX >> (16 - bits)));
let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
let decoded = codec.decode(encoded.cow())?;
let AnyArray::F32(decoded) = decoded else {
return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
configured: LinearQuantizeDType::F32,
provided: decoded.dtype(),
for (o, d) in data.iter().zip(decoded.iter()) {
assert_eq!(o.to_bits(), d.to_bits());
fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> {
for bits in 1..=64 {
let codec = LinearQuantizeCodec {
dtype: LinearQuantizeDType::F32,
bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
let mut data: Vec<f32> = (0..(u64::MAX >> (64 - bits)))
.step_by(1 << (bits.max(8) - 8))
.map(|x| x as f32)
data.push((u64::MAX >> (64 - bits)) as f32);
let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
let decoded = codec.decode(encoded.cow())?;
let AnyArray::F32(decoded) = decoded else {
return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
configured: LinearQuantizeDType::F32,
provided: decoded.dtype(),
for (o, d) in data.iter().zip(decoded.iter()) {
assert_eq!(o.to_bits(), d.to_bits());
fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> {
for bits in 1..=32 {
let codec = LinearQuantizeCodec {
dtype: LinearQuantizeDType::F64,
bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
let mut data: Vec<f64> = (0..(u32::MAX >> (32 - bits)))
.step_by(1 << (bits.max(8) - 8))
data.push(f64::from(u32::MAX >> (32 - bits)));
let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
let decoded = codec.decode(encoded.cow())?;
let AnyArray::F64(decoded) = decoded else {
return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
configured: LinearQuantizeDType::F64,
provided: decoded.dtype(),
for (o, d) in data.iter().zip(decoded.iter()) {
assert_eq!(o.to_bits(), d.to_bits());
fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> {
for bits in 1..=64 {
let codec = LinearQuantizeCodec {
dtype: LinearQuantizeDType::F64,
bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
let mut data: Vec<f64> = (0..(u64::MAX >> (64 - bits)))
.step_by(1 << (bits.max(8) - 8))
.map(|x| x as f64)
data.push((u64::MAX >> (64 - bits)) as f64);
let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
let decoded = codec.decode(encoded.cow())?;
let AnyArray::F64(decoded) = decoded else {
return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
configured: LinearQuantizeDType::F64,
provided: decoded.dtype(),
for (o, d) in data.iter().zip(decoded.iter()) {
assert_eq!(o.to_bits(), d.to_bits());
fn const_data_roundtrip() -> Result<(), LinearQuantizeCodecError> {
for bits in 1..=64 {
let data = [42.0, 42.0, 42.0, 42.0];
let codec = LinearQuantizeCodec {
dtype: LinearQuantizeDType::F64,
bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
let decoded = codec.decode(encoded.cow())?;
let AnyArray::F64(decoded) = decoded else {
return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
configured: LinearQuantizeDType::F64,
provided: decoded.dtype(),
for (o, d) in data.iter().zip(decoded.iter()) {
assert_eq!(o.to_bits(), d.to_bits());