numcodecs_linear_quantize/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.85.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-linear-quantize
10//! [crates.io]: https://crates.io/crates/numcodecs-linear-quantize
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-linear-quantize
13//! [docs.rs]: https://docs.rs/numcodecs-linear-quantize/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_linear_quantize
17//!
18//! Linear Quantization codec implementation for the [`numcodecs`] API.
19
20#![expect(clippy::multiple_crate_versions)] // FIXME: twofloat -> hexf -> syn 1.0
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayD, ArrayViewMutD, Data, Dimension, ShapeError, Zip};
25use num_traits::{ConstOne, ConstZero, Float};
26use numcodecs::{
27    AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec,
28    StaticCodecConfig, StaticCodecVersion,
29};
30use schemars::{JsonSchema, JsonSchema_repr};
31use serde::{Deserialize, Serialize, de::DeserializeOwned};
32use serde_repr::{Deserialize_repr, Serialize_repr};
33use thiserror::Error;
34use twofloat::TwoFloat;
35
36type LinearQuantizeCodecVersion = StaticCodecVersion<0, 1, 0>;
37
38#[derive(Clone, Serialize, Deserialize, JsonSchema)]
39#[serde(deny_unknown_fields)]
40/// Lossy codec to reduce the precision of floating point data.
41///
42/// The data is quantized to unsigned integers of the best-fitting type.
43/// The range and shape of the input data is stored in-band.
44pub struct LinearQuantizeCodec {
45    /// Dtype of the decoded data
46    pub dtype: LinearQuantizeDType,
47    /// Binary precision of the encoded data where `$bits = \log_{2}(bins)$`
48    pub bits: LinearQuantizeBins,
49    /// The codec's encoding format version. Do not provide this parameter explicitly.
50    #[serde(default, rename = "_version")]
51    pub version: LinearQuantizeCodecVersion,
52}
53
54/// Data types which the [`LinearQuantizeCodec`] can quantize
55#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema)]
56#[schemars(extend("enum" = ["f32", "float32", "f64", "float64"]))]
57#[expect(missing_docs)]
58pub enum LinearQuantizeDType {
59    #[serde(rename = "f32", alias = "float32")]
60    F32,
61    #[serde(rename = "f64", alias = "float64")]
62    F64,
63}
64
65impl fmt::Display for LinearQuantizeDType {
66    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
67        fmt.write_str(match self {
68            Self::F32 => "f32",
69            Self::F64 => "f64",
70        })
71    }
72}
73
74/// Number of bins for quantization, written in base-2 scientific notation.
75///
76/// The binary `#[repr(u8)]` value of each variant is equivalent to the binary
77/// logarithm of the number of bins, i.e. the binary precision or the number of
78/// bits used.
79#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
80#[repr(u8)]
81#[rustfmt::skip]
82#[expect(missing_docs)]
83pub enum LinearQuantizeBins {
84    _1B1 = 1, _1B2, _1B3, _1B4, _1B5, _1B6, _1B7, _1B8,
85    _1B9, _1B10, _1B11, _1B12, _1B13, _1B14, _1B15, _1B16,
86    _1B17, _1B18, _1B19, _1B20, _1B21, _1B22, _1B23, _1B24,
87    _1B25, _1B26, _1B27, _1B28, _1B29, _1B30, _1B31, _1B32,
88    _1B33, _1B34, _1B35, _1B36, _1B37, _1B38, _1B39, _1B40,
89    _1B41, _1B42, _1B43, _1B44, _1B45, _1B46, _1B47, _1B48,
90    _1B49, _1B50, _1B51, _1B52, _1B53, _1B54, _1B55, _1B56,
91    _1B57, _1B58, _1B59, _1B60, _1B61, _1B62, _1B63, _1B64,
92}
93
94impl Codec for LinearQuantizeCodec {
95    type Error = LinearQuantizeCodecError;
96
97    #[expect(clippy::too_many_lines)]
98    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
99        let encoded = match (&data, self.dtype) {
100            (AnyCowArray::F32(data), LinearQuantizeDType::F32) => match self.bits as u8 {
101                bits @ ..=8 => AnyArray::U8(
102                    Array1::from_vec(quantize(data, |x| {
103                        let max = f32::from(u8::MAX >> (8 - bits));
104                        let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
105                        #[expect(unsafe_code)]
106                        // Safety: x is clamped beforehand
107                        unsafe {
108                            x.to_int_unchecked::<u8>()
109                        }
110                    })?)
111                    .into_dyn(),
112                ),
113                bits @ 9..=16 => AnyArray::U16(
114                    Array1::from_vec(quantize(data, |x| {
115                        let max = f32::from(u16::MAX >> (16 - bits));
116                        let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
117                        #[expect(unsafe_code)]
118                        // Safety: x is clamped beforehand
119                        unsafe {
120                            x.to_int_unchecked::<u16>()
121                        }
122                    })?)
123                    .into_dyn(),
124                ),
125                bits @ 17..=32 => AnyArray::U32(
126                    Array1::from_vec(quantize(data, |x| {
127                        // we need to use f64 here to have sufficient precision
128                        let max = f64::from(u32::MAX >> (32 - bits));
129                        let x = f64::from(x)
130                            .mul_add(scale_for_bits::<f64>(bits), 0.5)
131                            .clamp(0.0, max);
132                        #[expect(unsafe_code)]
133                        // Safety: x is clamped beforehand
134                        unsafe {
135                            x.to_int_unchecked::<u32>()
136                        }
137                    })?)
138                    .into_dyn(),
139                ),
140                bits @ 33.. => AnyArray::U64(
141                    Array1::from_vec(quantize(data, |x| {
142                        // we need to use TwoFloat here to have sufficient precision
143                        let max = TwoFloat::from(u64::MAX >> (64 - bits));
144                        let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
145                            + TwoFloat::from(0.5))
146                        .max(TwoFloat::from(0.0))
147                        .min(max);
148                        #[expect(unsafe_code)]
149                        // Safety: x is clamped beforehand
150                        unsafe {
151                            u64::try_from(x).unwrap_unchecked()
152                        }
153                    })?)
154                    .into_dyn(),
155                ),
156            },
157            (AnyCowArray::F64(data), LinearQuantizeDType::F64) => match self.bits as u8 {
158                bits @ ..=8 => AnyArray::U8(
159                    Array1::from_vec(quantize(data, |x| {
160                        let max = f64::from(u8::MAX >> (8 - bits));
161                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
162                        #[expect(unsafe_code)]
163                        // Safety: x is clamped beforehand
164                        unsafe {
165                            x.to_int_unchecked::<u8>()
166                        }
167                    })?)
168                    .into_dyn(),
169                ),
170                bits @ 9..=16 => AnyArray::U16(
171                    Array1::from_vec(quantize(data, |x| {
172                        let max = f64::from(u16::MAX >> (16 - bits));
173                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
174                        #[expect(unsafe_code)]
175                        // Safety: x is clamped beforehand
176                        unsafe {
177                            x.to_int_unchecked::<u16>()
178                        }
179                    })?)
180                    .into_dyn(),
181                ),
182                bits @ 17..=32 => AnyArray::U32(
183                    Array1::from_vec(quantize(data, |x| {
184                        let max = f64::from(u32::MAX >> (32 - bits));
185                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
186                        #[expect(unsafe_code)]
187                        // Safety: x is clamped beforehand
188                        unsafe {
189                            x.to_int_unchecked::<u32>()
190                        }
191                    })?)
192                    .into_dyn(),
193                ),
194                bits @ 33.. => AnyArray::U64(
195                    Array1::from_vec(quantize(data, |x| {
196                        // we need to use TwoFloat here to have sufficient precision
197                        let max = TwoFloat::from(u64::MAX >> (64 - bits));
198                        let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
199                            + TwoFloat::from(0.5))
200                        .max(TwoFloat::from(0.0))
201                        .min(max);
202                        #[expect(unsafe_code)]
203                        // Safety: x is clamped beforehand
204                        unsafe {
205                            u64::try_from(x).unwrap_unchecked()
206                        }
207                    })?)
208                    .into_dyn(),
209                ),
210            },
211            (data, dtype) => {
212                return Err(LinearQuantizeCodecError::MismatchedEncodeDType {
213                    configured: dtype,
214                    provided: data.dtype(),
215                });
216            }
217        };
218
219        Ok(encoded)
220    }
221
222    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
223        #[expect(clippy::option_if_let_else)]
224        fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
225            array: &ArrayBase<S, D>,
226        ) -> Cow<[T]> {
227            if let Some(data) = array.as_slice() {
228                Cow::Borrowed(data)
229            } else {
230                Cow::Owned(array.iter().copied().collect())
231            }
232        }
233
234        if !matches!(encoded.shape(), [_]) {
235            return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
236                shape: encoded.shape().to_vec(),
237            });
238        }
239
240        let decoded = match (&encoded, self.dtype) {
241            (AnyCowArray::U8(encoded), LinearQuantizeDType::F32) => {
242                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
243                    f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
244                })?)
245            }
246            (AnyCowArray::U16(encoded), LinearQuantizeDType::F32) => {
247                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
248                    f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
249                })?)
250            }
251            (AnyCowArray::U32(encoded), LinearQuantizeDType::F32) => {
252                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
253                    // we need to use f64 here to have sufficient precision
254                    let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
255                    #[expect(clippy::cast_possible_truncation)]
256                    let x = x as f32;
257                    x
258                })?)
259            }
260            (AnyCowArray::U64(encoded), LinearQuantizeDType::F32) => {
261                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
262                    // we need to use TwoFloat here to have sufficient precision
263                    let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
264                    f32::from(x)
265                })?)
266            }
267            (AnyCowArray::U8(encoded), LinearQuantizeDType::F64) => {
268                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
269                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
270                })?)
271            }
272            (AnyCowArray::U16(encoded), LinearQuantizeDType::F64) => {
273                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
274                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
275                })?)
276            }
277            (AnyCowArray::U32(encoded), LinearQuantizeDType::F64) => {
278                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
279                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
280                })?)
281            }
282            (AnyCowArray::U64(encoded), LinearQuantizeDType::F64) => {
283                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
284                    // we need to use TwoFloat here to have sufficient precision
285                    let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
286                    f64::from(x)
287                })?)
288            }
289            (encoded, _dtype) => {
290                return Err(LinearQuantizeCodecError::InvalidEncodedDType {
291                    dtype: encoded.dtype(),
292                });
293            }
294        };
295
296        Ok(decoded)
297    }
298
299    fn decode_into(
300        &self,
301        encoded: AnyArrayView,
302        decoded: AnyArrayViewMut,
303    ) -> Result<(), Self::Error> {
304        fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
305            array: &ArrayBase<S, D>,
306        ) -> Cow<[T]> {
307            #[expect(clippy::option_if_let_else)]
308            if let Some(data) = array.as_slice() {
309                Cow::Borrowed(data)
310            } else {
311                Cow::Owned(array.iter().copied().collect())
312            }
313        }
314
315        if !matches!(encoded.shape(), [_]) {
316            return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
317                shape: encoded.shape().to_vec(),
318            });
319        }
320
321        match (decoded, self.dtype) {
322            (AnyArrayViewMut::F32(decoded), LinearQuantizeDType::F32) => {
323                match &encoded {
324                    AnyArrayView::U8(encoded) => {
325                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
326                            f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
327                        })
328                    }
329                    AnyArrayView::U16(encoded) => {
330                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
331                            f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
332                        })
333                    }
334                    AnyArrayView::U32(encoded) => {
335                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
336                            // we need to use f64 here to have sufficient precision
337                            let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
338                            #[expect(clippy::cast_possible_truncation)]
339                            let x = x as f32;
340                            x
341                        })
342                    }
343                    AnyArrayView::U64(encoded) => {
344                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
345                            // we need to use TwoFloat here to have sufficient precision
346                            let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
347                            f32::from(x)
348                        })
349                    }
350                    encoded => {
351                        return Err(LinearQuantizeCodecError::InvalidEncodedDType {
352                            dtype: encoded.dtype(),
353                        });
354                    }
355                }
356            }
357            (AnyArrayViewMut::F64(decoded), LinearQuantizeDType::F64) => {
358                match &encoded {
359                    AnyArrayView::U8(encoded) => {
360                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
361                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
362                        })
363                    }
364                    AnyArrayView::U16(encoded) => {
365                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
366                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
367                        })
368                    }
369                    AnyArrayView::U32(encoded) => {
370                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
371                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
372                        })
373                    }
374                    AnyArrayView::U64(encoded) => {
375                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
376                            // we need to use TwoFloat here to have sufficient precision
377                            let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
378                            f64::from(x)
379                        })
380                    }
381                    encoded => {
382                        return Err(LinearQuantizeCodecError::InvalidEncodedDType {
383                            dtype: encoded.dtype(),
384                        });
385                    }
386                }
387            }
388            (decoded, dtype) => {
389                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
390                    configured: dtype,
391                    provided: decoded.dtype(),
392                });
393            }
394        }?;
395
396        Ok(())
397    }
398}
399
400impl StaticCodec for LinearQuantizeCodec {
401    const CODEC_ID: &'static str = "linear-quantize.rs";
402
403    type Config<'de> = Self;
404
405    fn from_config(config: Self::Config<'_>) -> Self {
406        config
407    }
408
409    fn get_config(&self) -> StaticCodecConfig<Self> {
410        StaticCodecConfig::from(self)
411    }
412}
413
414#[derive(Debug, Error)]
415/// Errors that may occur when applying the [`LinearQuantizeCodec`].
416pub enum LinearQuantizeCodecError {
417    /// [`LinearQuantizeCodec`] cannot encode the provided dtype which differs
418    /// from the configured dtype
419    #[error(
420        "LinearQuantize cannot encode the provided dtype {provided} which differs from the configured dtype {configured}"
421    )]
422    MismatchedEncodeDType {
423        /// Dtype of the `configured` `dtype`
424        configured: LinearQuantizeDType,
425        /// Dtype of the `provided` array from which the data is to be encoded
426        provided: AnyArrayDType,
427    },
428    /// [`LinearQuantizeCodec`] does not support non-finite (infinite or NaN) floating
429    /// point data
430    #[error("LinearQuantize does not support non-finite (infinite or NaN) floating point data")]
431    NonFiniteData,
432    /// [`LinearQuantizeCodec`] failed to encode the header
433    #[error("LinearQuantize failed to encode the header")]
434    HeaderEncodeFailed {
435        /// Opaque source error
436        source: LinearQuantizeHeaderError,
437    },
438    /// [`LinearQuantizeCodec`] can only decode one-dimensional arrays but
439    /// received an array of a different shape
440    #[error(
441        "LinearQuantize can only decode one-dimensional arrays but received an array of shape {shape:?}"
442    )]
443    EncodedDataNotOneDimensional {
444        /// The unexpected shape of the encoded array
445        shape: Vec<usize>,
446    },
447    /// [`LinearQuantizeCodec`] failed to decode the header
448    #[error("LinearQuantize failed to decode the header")]
449    HeaderDecodeFailed {
450        /// Opaque source error
451        source: LinearQuantizeHeaderError,
452    },
453    /// [`LinearQuantizeCodec`] decoded an invalid array shape header which does
454    /// not fit the decoded data
455    #[error(
456        "LinearQuantize decoded an invalid array shape header which does not fit the decoded data"
457    )]
458    DecodeInvalidShapeHeader {
459        /// Source error
460        #[from]
461        source: ShapeError,
462    },
463    /// [`LinearQuantizeCodec`] cannot decode the provided dtype
464    #[error("LinearQuantize cannot decode the provided dtype {dtype}")]
465    InvalidEncodedDType {
466        /// Dtype of the provided array from which the data is to be decoded
467        dtype: AnyArrayDType,
468    },
469    /// [`LinearQuantizeCodec`] cannot decode the provided dtype which differs
470    /// from the configured dtype
471    #[error(
472        "LinearQuantize cannot decode the provided dtype {provided} which differs from the configured dtype {configured}"
473    )]
474    MismatchedDecodeIntoDtype {
475        /// Dtype of the `configured` `dtype`
476        configured: LinearQuantizeDType,
477        /// Dtype of the `provided` array into which the data is to be decoded
478        provided: AnyArrayDType,
479    },
480    /// [`LinearQuantizeCodec`] cannot decode the decoded array into the provided
481    /// array of a different shape
482    #[error(
483        "LinearQuantize cannot decode the decoded array of shape {decoded:?} into the provided array of shape {provided:?}"
484    )]
485    MismatchedDecodeIntoShape {
486        /// Shape of the `decoded` data
487        decoded: Vec<usize>,
488        /// Shape of the `provided` array into which the data is to be decoded
489        provided: Vec<usize>,
490    },
491}
492
493#[derive(Debug, Error)]
494#[error(transparent)]
495/// Opaque error for when encoding or decoding the header fails
496pub struct LinearQuantizeHeaderError(postcard::Error);
497
498/// Linear-quantize the elements in the `data` array using the `quantize`
499/// closure.
500///
501/// # Errors
502///
503/// Errors with
504/// - [`LinearQuantizeCodecError::NonFiniteData`] if any data element is non-
505///   finite (infinite or NaN)
506/// - [`LinearQuantizeCodecError::HeaderEncodeFailed`] if encoding the header
507///   failed
508pub fn quantize<
509    T: Float + ConstZero + ConstOne + Serialize,
510    Q: Unsigned,
511    S: Data<Elem = T>,
512    D: Dimension,
513>(
514    data: &ArrayBase<S, D>,
515    quantize: impl Fn(T) -> Q,
516) -> Result<Vec<Q>, LinearQuantizeCodecError> {
517    if !Zip::from(data).all(|x| x.is_finite()) {
518        return Err(LinearQuantizeCodecError::NonFiniteData);
519    }
520
521    let (minimum, maximum) = data.first().map_or((T::ZERO, T::ONE), |first| {
522        (
523            Zip::from(data).fold(*first, |a, b| a.min(*b)),
524            Zip::from(data).fold(*first, |a, b| a.max(*b)),
525        )
526    });
527
528    let header = postcard::to_extend(
529        &CompressionHeader {
530            shape: Cow::Borrowed(data.shape()),
531            minimum,
532            maximum,
533            version: StaticCodecVersion,
534        },
535        Vec::new(),
536    )
537    .map_err(|err| LinearQuantizeCodecError::HeaderEncodeFailed {
538        source: LinearQuantizeHeaderError(err),
539    })?;
540
541    let mut encoded: Vec<Q> = vec![Q::ZERO; header.len().div_ceil(std::mem::size_of::<Q>())];
542    #[expect(unsafe_code)]
543    // Safety: encoded is at least header.len() bytes long and properly aligned for Q
544    unsafe {
545        std::ptr::copy_nonoverlapping(header.as_ptr(), encoded.as_mut_ptr().cast(), header.len());
546    }
547    encoded.reserve(data.len());
548
549    if maximum == minimum {
550        encoded.resize(encoded.len() + data.len(), quantize(T::ZERO));
551    } else {
552        encoded.extend(
553            data.iter()
554                .map(|x| quantize((*x - minimum) / (maximum - minimum))),
555        );
556    }
557
558    Ok(encoded)
559}
560
561/// Reconstruct the linear-quantized `encoded` array using the `floatify`
562/// closure.
563///
564/// # Errors
565///
566/// Errors with
567/// - [`LinearQuantizeCodecError::HeaderDecodeFailed`] if decoding the header
568///   failed
569pub fn reconstruct<T: Float + DeserializeOwned, Q: Unsigned>(
570    encoded: &[Q],
571    floatify: impl Fn(Q) -> T,
572) -> Result<ArrayD<T>, LinearQuantizeCodecError> {
573    #[expect(unsafe_code)]
574    // Safety: data is data.len()*size_of::<Q> bytes long and properly aligned for Q
575    let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
576        std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
577    })
578    .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
579        source: LinearQuantizeHeaderError(err),
580    })?;
581
582    let encoded = encoded
583        .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
584        .unwrap_or(&[]);
585
586    let decoded = encoded
587        .iter()
588        .map(|x| header.minimum + (floatify(*x) * (header.maximum - header.minimum)))
589        .map(|x| x.clamp(header.minimum, header.maximum))
590        .collect();
591
592    let decoded = Array::from_shape_vec(&*header.shape, decoded)?;
593
594    Ok(decoded)
595}
596
597/// Reconstruct the linear-quantized `encoded` array using the `floatify`
598/// closure into the `decoded` array.
599///
600/// # Errors
601///
602/// Errors with
603/// - [`LinearQuantizeCodecError::HeaderDecodeFailed`] if decoding the header
604///   failed
605/// - [`LinearQuantizeCodecError::MismatchedDecodeIntoShape`] if the `decoded`
606///   array is of the wrong shape
607pub fn reconstruct_into<T: Float + DeserializeOwned, Q: Unsigned>(
608    encoded: &[Q],
609    mut decoded: ArrayViewMutD<T>,
610    floatify: impl Fn(Q) -> T,
611) -> Result<(), LinearQuantizeCodecError> {
612    #[expect(unsafe_code)]
613    // Safety: data is data.len()*size_of::<Q> bytes long and properly aligned for Q
614    let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
615        std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
616    })
617    .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
618        source: LinearQuantizeHeaderError(err),
619    })?;
620
621    let encoded = encoded
622        .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
623        .unwrap_or(&[]);
624
625    if decoded.shape() != &*header.shape {
626        return Err(LinearQuantizeCodecError::MismatchedDecodeIntoShape {
627            decoded: header.shape.into_owned(),
628            provided: decoded.shape().to_vec(),
629        });
630    }
631
632    // iteration must occur in synchronised (standard) order
633    for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
634        *d = (header.minimum + (floatify(*e) * (header.maximum - header.minimum)))
635            .clamp(header.minimum, header.maximum);
636    }
637
638    Ok(())
639}
640
641/// Returns `${2.0}^{bits} - 1.0$`
642fn scale_for_bits<T: Float + From<u8> + ConstOne>(bits: u8) -> T {
643    <T as From<u8>>::from(bits).exp2() - T::ONE
644}
645
646/// Unsigned binary types.
647pub trait Unsigned: Copy {
648    /// `0x0`
649    const ZERO: Self;
650}
651
652impl Unsigned for u8 {
653    const ZERO: Self = 0;
654}
655
656impl Unsigned for u16 {
657    const ZERO: Self = 0;
658}
659
660impl Unsigned for u32 {
661    const ZERO: Self = 0;
662}
663
664impl Unsigned for u64 {
665    const ZERO: Self = 0;
666}
667
668#[derive(Serialize, Deserialize)]
669struct CompressionHeader<'a, T> {
670    #[serde(borrow)]
671    shape: Cow<'a, [usize]>,
672    minimum: T,
673    maximum: T,
674    version: LinearQuantizeCodecVersion,
675}
676
677#[cfg(test)]
678mod tests {
679    use ndarray::CowArray;
680
681    use super::*;
682
683    #[test]
684    fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> {
685        for bits in 1..=16 {
686            let codec = LinearQuantizeCodec {
687                dtype: LinearQuantizeDType::F32,
688                #[expect(unsafe_code)]
689                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
690                version: StaticCodecVersion,
691            };
692
693            let mut data: Vec<f32> = (0..(u16::MAX >> (16 - bits)))
694                .step_by(1 << (bits.max(8) - 8))
695                .map(f32::from)
696                .collect();
697            data.push(f32::from(u16::MAX >> (16 - bits)));
698
699            let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
700            let decoded = codec.decode(encoded.cow())?;
701
702            let AnyArray::F32(decoded) = decoded else {
703                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
704                    configured: LinearQuantizeDType::F32,
705                    provided: decoded.dtype(),
706                });
707            };
708
709            for (o, d) in data.iter().zip(decoded.iter()) {
710                assert_eq!(o.to_bits(), d.to_bits());
711            }
712        }
713
714        Ok(())
715    }
716
717    #[test]
718    fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> {
719        for bits in 1..=64 {
720            let codec = LinearQuantizeCodec {
721                dtype: LinearQuantizeDType::F32,
722                #[expect(unsafe_code)]
723                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
724                version: StaticCodecVersion,
725            };
726
727            #[expect(clippy::cast_precision_loss)]
728            let mut data: Vec<f32> = (0..(u64::MAX >> (64 - bits)))
729                .step_by(1 << (bits.max(8) - 8))
730                .map(|x| x as f32)
731                .collect();
732            #[expect(clippy::cast_precision_loss)]
733            data.push((u64::MAX >> (64 - bits)) as f32);
734
735            let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
736            let decoded = codec.decode(encoded.cow())?;
737
738            let AnyArray::F32(decoded) = decoded else {
739                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
740                    configured: LinearQuantizeDType::F32,
741                    provided: decoded.dtype(),
742                });
743            };
744
745            for (o, d) in data.iter().zip(decoded.iter()) {
746                assert_eq!(o.to_bits(), d.to_bits());
747            }
748        }
749
750        Ok(())
751    }
752
753    #[test]
754    fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> {
755        for bits in 1..=32 {
756            let codec = LinearQuantizeCodec {
757                dtype: LinearQuantizeDType::F64,
758                #[expect(unsafe_code)]
759                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
760                version: StaticCodecVersion,
761            };
762
763            let mut data: Vec<f64> = (0..(u32::MAX >> (32 - bits)))
764                .step_by(1 << (bits.max(8) - 8))
765                .map(f64::from)
766                .collect();
767            data.push(f64::from(u32::MAX >> (32 - bits)));
768
769            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
770            let decoded = codec.decode(encoded.cow())?;
771
772            let AnyArray::F64(decoded) = decoded else {
773                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
774                    configured: LinearQuantizeDType::F64,
775                    provided: decoded.dtype(),
776                });
777            };
778
779            for (o, d) in data.iter().zip(decoded.iter()) {
780                assert_eq!(o.to_bits(), d.to_bits());
781            }
782        }
783
784        Ok(())
785    }
786
787    #[test]
788    fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> {
789        for bits in 1..=64 {
790            let codec = LinearQuantizeCodec {
791                dtype: LinearQuantizeDType::F64,
792                #[expect(unsafe_code)]
793                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
794                version: StaticCodecVersion,
795            };
796
797            #[expect(clippy::cast_precision_loss)]
798            let mut data: Vec<f64> = (0..(u64::MAX >> (64 - bits)))
799                .step_by(1 << (bits.max(8) - 8))
800                .map(|x| x as f64)
801                .collect();
802            #[expect(clippy::cast_precision_loss)]
803            data.push((u64::MAX >> (64 - bits)) as f64);
804
805            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
806            let decoded = codec.decode(encoded.cow())?;
807
808            let AnyArray::F64(decoded) = decoded else {
809                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
810                    configured: LinearQuantizeDType::F64,
811                    provided: decoded.dtype(),
812                });
813            };
814
815            for (o, d) in data.iter().zip(decoded.iter()) {
816                assert_eq!(o.to_bits(), d.to_bits());
817            }
818        }
819
820        Ok(())
821    }
822
823    #[test]
824    fn const_data_roundtrip() -> Result<(), LinearQuantizeCodecError> {
825        for bits in 1..=64 {
826            let data = [42.0, 42.0, 42.0, 42.0];
827
828            let codec = LinearQuantizeCodec {
829                dtype: LinearQuantizeDType::F64,
830                #[expect(unsafe_code)]
831                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
832                version: StaticCodecVersion,
833            };
834
835            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
836            let decoded = codec.decode(encoded.cow())?;
837
838            let AnyArray::F64(decoded) = decoded else {
839                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
840                    configured: LinearQuantizeDType::F64,
841                    provided: decoded.dtype(),
842                });
843            };
844
845            for (o, d) in data.iter().zip(decoded.iter()) {
846                assert_eq!(o.to_bits(), d.to_bits());
847            }
848        }
849
850        Ok(())
851    }
852}