numcodecs_linear_quantize/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-linear-quantize
10//! [crates.io]: https://crates.io/crates/numcodecs-linear-quantize
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-linear-quantize
13//! [docs.rs]: https://docs.rs/numcodecs-linear-quantize/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_linear_quantize
17//!
18//! Linear Quantization codec implementation for the [`numcodecs`] API.
19
20#![expect(clippy::multiple_crate_versions)] // FIXME: twofloat -> hexf -> syn 1.0
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayD, ArrayViewMutD, Data, Dimension, ShapeError, Zip};
25use num_traits::{ConstOne, ConstZero, Float};
26use numcodecs::{
27    AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec,
28    StaticCodecConfig,
29};
30use schemars::{JsonSchema, JsonSchema_repr};
31use serde::{de::DeserializeOwned, Deserialize, Serialize};
32use serde_repr::{Deserialize_repr, Serialize_repr};
33use thiserror::Error;
34use twofloat::TwoFloat;
35
36#[derive(Clone, Serialize, Deserialize, JsonSchema)]
37#[serde(deny_unknown_fields)]
38/// Lossy codec to reduce the precision of floating point data.
39///
40/// The data is quantized to unsigned integers of the best-fitting type.
41/// The range and shape of the input data is stored in-band.
42pub struct LinearQuantizeCodec {
43    /// Dtype of the decoded data
44    pub dtype: LinearQuantizeDType,
45    /// Binary precision of the encoded data where `$bits = \log_{2}(bins)$`
46    pub bits: LinearQuantizeBins,
47}
48
49/// Data types which the [`LinearQuantizeCodec`] can quantize
50#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema)]
51#[schemars(extend("enum" = ["f32", "float32", "f64", "float64"]))]
52#[expect(missing_docs)]
53pub enum LinearQuantizeDType {
54    #[serde(rename = "f32", alias = "float32")]
55    F32,
56    #[serde(rename = "f64", alias = "float64")]
57    F64,
58}
59
60impl fmt::Display for LinearQuantizeDType {
61    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
62        fmt.write_str(match self {
63            Self::F32 => "f32",
64            Self::F64 => "f64",
65        })
66    }
67}
68
69/// Number of bins for quantization, written in base-2 scientific notation.
70///
71/// The binary `#[repr(u8)]` value of each variant is equivalent to the binary
72/// logarithm of the number of bins, i.e. the binary precision or the number of
73/// bits used.
74#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
75#[repr(u8)]
76#[rustfmt::skip]
77#[expect(missing_docs)]
78pub enum LinearQuantizeBins {
79    _1B1 = 1, _1B2, _1B3, _1B4, _1B5, _1B6, _1B7, _1B8,
80    _1B9, _1B10, _1B11, _1B12, _1B13, _1B14, _1B15, _1B16,
81    _1B17, _1B18, _1B19, _1B20, _1B21, _1B22, _1B23, _1B24,
82    _1B25, _1B26, _1B27, _1B28, _1B29, _1B30, _1B31, _1B32,
83    _1B33, _1B34, _1B35, _1B36, _1B37, _1B38, _1B39, _1B40,
84    _1B41, _1B42, _1B43, _1B44, _1B45, _1B46, _1B47, _1B48,
85    _1B49, _1B50, _1B51, _1B52, _1B53, _1B54, _1B55, _1B56,
86    _1B57, _1B58, _1B59, _1B60, _1B61, _1B62, _1B63, _1B64,
87}
88
89impl Codec for LinearQuantizeCodec {
90    type Error = LinearQuantizeCodecError;
91
92    #[expect(clippy::too_many_lines)]
93    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
94        let encoded = match (&data, self.dtype) {
95            (AnyCowArray::F32(data), LinearQuantizeDType::F32) => match self.bits as u8 {
96                bits @ ..=8 => AnyArray::U8(
97                    Array1::from_vec(quantize(data, |x| {
98                        let max = f32::from(u8::MAX >> (8 - bits));
99                        let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
100                        #[expect(unsafe_code)]
101                        // Safety: x is clamped beforehand
102                        unsafe {
103                            x.to_int_unchecked::<u8>()
104                        }
105                    })?)
106                    .into_dyn(),
107                ),
108                bits @ 9..=16 => AnyArray::U16(
109                    Array1::from_vec(quantize(data, |x| {
110                        let max = f32::from(u16::MAX >> (16 - bits));
111                        let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
112                        #[expect(unsafe_code)]
113                        // Safety: x is clamped beforehand
114                        unsafe {
115                            x.to_int_unchecked::<u16>()
116                        }
117                    })?)
118                    .into_dyn(),
119                ),
120                bits @ 17..=32 => AnyArray::U32(
121                    Array1::from_vec(quantize(data, |x| {
122                        // we need to use f64 here to have sufficient precision
123                        let max = f64::from(u32::MAX >> (32 - bits));
124                        let x = f64::from(x)
125                            .mul_add(scale_for_bits::<f64>(bits), 0.5)
126                            .clamp(0.0, max);
127                        #[expect(unsafe_code)]
128                        // Safety: x is clamped beforehand
129                        unsafe {
130                            x.to_int_unchecked::<u32>()
131                        }
132                    })?)
133                    .into_dyn(),
134                ),
135                bits @ 33.. => AnyArray::U64(
136                    Array1::from_vec(quantize(data, |x| {
137                        // we need to use TwoFloat here to have sufficient precision
138                        let max = TwoFloat::from(u64::MAX >> (64 - bits));
139                        let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
140                            + TwoFloat::from(0.5))
141                        .max(TwoFloat::from(0.0))
142                        .min(max);
143                        #[expect(unsafe_code)]
144                        // Safety: x is clamped beforehand
145                        unsafe {
146                            u64::try_from(x).unwrap_unchecked()
147                        }
148                    })?)
149                    .into_dyn(),
150                ),
151            },
152            (AnyCowArray::F64(data), LinearQuantizeDType::F64) => match self.bits as u8 {
153                bits @ ..=8 => AnyArray::U8(
154                    Array1::from_vec(quantize(data, |x| {
155                        let max = f64::from(u8::MAX >> (8 - bits));
156                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
157                        #[expect(unsafe_code)]
158                        // Safety: x is clamped beforehand
159                        unsafe {
160                            x.to_int_unchecked::<u8>()
161                        }
162                    })?)
163                    .into_dyn(),
164                ),
165                bits @ 9..=16 => AnyArray::U16(
166                    Array1::from_vec(quantize(data, |x| {
167                        let max = f64::from(u16::MAX >> (16 - bits));
168                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
169                        #[expect(unsafe_code)]
170                        // Safety: x is clamped beforehand
171                        unsafe {
172                            x.to_int_unchecked::<u16>()
173                        }
174                    })?)
175                    .into_dyn(),
176                ),
177                bits @ 17..=32 => AnyArray::U32(
178                    Array1::from_vec(quantize(data, |x| {
179                        let max = f64::from(u32::MAX >> (32 - bits));
180                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
181                        #[expect(unsafe_code)]
182                        // Safety: x is clamped beforehand
183                        unsafe {
184                            x.to_int_unchecked::<u32>()
185                        }
186                    })?)
187                    .into_dyn(),
188                ),
189                bits @ 33.. => AnyArray::U64(
190                    Array1::from_vec(quantize(data, |x| {
191                        // we need to use TwoFloat here to have sufficient precision
192                        let max = TwoFloat::from(u64::MAX >> (64 - bits));
193                        let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
194                            + TwoFloat::from(0.5))
195                        .max(TwoFloat::from(0.0))
196                        .min(max);
197                        #[expect(unsafe_code)]
198                        // Safety: x is clamped beforehand
199                        unsafe {
200                            u64::try_from(x).unwrap_unchecked()
201                        }
202                    })?)
203                    .into_dyn(),
204                ),
205            },
206            (data, dtype) => {
207                return Err(LinearQuantizeCodecError::MismatchedEncodeDType {
208                    configured: dtype,
209                    provided: data.dtype(),
210                });
211            }
212        };
213
214        Ok(encoded)
215    }
216
217    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
218        #[expect(clippy::option_if_let_else)]
219        fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
220            array: &ArrayBase<S, D>,
221        ) -> Cow<[T]> {
222            if let Some(data) = array.as_slice() {
223                Cow::Borrowed(data)
224            } else {
225                Cow::Owned(array.iter().copied().collect())
226            }
227        }
228
229        if !matches!(encoded.shape(), [_]) {
230            return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
231                shape: encoded.shape().to_vec(),
232            });
233        }
234
235        let decoded = match (&encoded, self.dtype) {
236            (AnyCowArray::U8(encoded), LinearQuantizeDType::F32) => {
237                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
238                    f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
239                })?)
240            }
241            (AnyCowArray::U16(encoded), LinearQuantizeDType::F32) => {
242                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
243                    f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
244                })?)
245            }
246            (AnyCowArray::U32(encoded), LinearQuantizeDType::F32) => {
247                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
248                    // we need to use f64 here to have sufficient precision
249                    let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
250                    #[expect(clippy::cast_possible_truncation)]
251                    let x = x as f32;
252                    x
253                })?)
254            }
255            (AnyCowArray::U64(encoded), LinearQuantizeDType::F32) => {
256                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
257                    // we need to use TwoFloat here to have sufficient precision
258                    let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
259                    f32::from(x)
260                })?)
261            }
262            (AnyCowArray::U8(encoded), LinearQuantizeDType::F64) => {
263                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
264                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
265                })?)
266            }
267            (AnyCowArray::U16(encoded), LinearQuantizeDType::F64) => {
268                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
269                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
270                })?)
271            }
272            (AnyCowArray::U32(encoded), LinearQuantizeDType::F64) => {
273                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
274                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
275                })?)
276            }
277            (AnyCowArray::U64(encoded), LinearQuantizeDType::F64) => {
278                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
279                    // we need to use TwoFloat here to have sufficient precision
280                    let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
281                    f64::from(x)
282                })?)
283            }
284            (encoded, _dtype) => {
285                return Err(LinearQuantizeCodecError::InvalidEncodedDType {
286                    dtype: encoded.dtype(),
287                })
288            }
289        };
290
291        Ok(decoded)
292    }
293
294    fn decode_into(
295        &self,
296        encoded: AnyArrayView,
297        decoded: AnyArrayViewMut,
298    ) -> Result<(), Self::Error> {
299        fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
300            array: &ArrayBase<S, D>,
301        ) -> Cow<[T]> {
302            #[expect(clippy::option_if_let_else)]
303            if let Some(data) = array.as_slice() {
304                Cow::Borrowed(data)
305            } else {
306                Cow::Owned(array.iter().copied().collect())
307            }
308        }
309
310        if !matches!(encoded.shape(), [_]) {
311            return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
312                shape: encoded.shape().to_vec(),
313            });
314        }
315
316        match (decoded, self.dtype) {
317            (AnyArrayViewMut::F32(decoded), LinearQuantizeDType::F32) => {
318                match &encoded {
319                    AnyArrayView::U8(encoded) => {
320                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
321                            f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
322                        })
323                    }
324                    AnyArrayView::U16(encoded) => {
325                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
326                            f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
327                        })
328                    }
329                    AnyArrayView::U32(encoded) => {
330                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
331                            // we need to use f64 here to have sufficient precision
332                            let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
333                            #[expect(clippy::cast_possible_truncation)]
334                            let x = x as f32;
335                            x
336                        })
337                    }
338                    AnyArrayView::U64(encoded) => {
339                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
340                            // we need to use TwoFloat here to have sufficient precision
341                            let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
342                            f32::from(x)
343                        })
344                    }
345                    encoded => {
346                        return Err(LinearQuantizeCodecError::InvalidEncodedDType {
347                            dtype: encoded.dtype(),
348                        })
349                    }
350                }
351            }
352            (AnyArrayViewMut::F64(decoded), LinearQuantizeDType::F64) => {
353                match &encoded {
354                    AnyArrayView::U8(encoded) => {
355                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
356                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
357                        })
358                    }
359                    AnyArrayView::U16(encoded) => {
360                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
361                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
362                        })
363                    }
364                    AnyArrayView::U32(encoded) => {
365                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
366                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
367                        })
368                    }
369                    AnyArrayView::U64(encoded) => {
370                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
371                            // we need to use TwoFloat here to have sufficient precision
372                            let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
373                            f64::from(x)
374                        })
375                    }
376                    encoded => {
377                        return Err(LinearQuantizeCodecError::InvalidEncodedDType {
378                            dtype: encoded.dtype(),
379                        })
380                    }
381                }
382            }
383            (decoded, dtype) => {
384                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
385                    configured: dtype,
386                    provided: decoded.dtype(),
387                })
388            }
389        }?;
390
391        Ok(())
392    }
393}
394
395impl StaticCodec for LinearQuantizeCodec {
396    const CODEC_ID: &'static str = "linear-quantize";
397
398    type Config<'de> = Self;
399
400    fn from_config(config: Self::Config<'_>) -> Self {
401        config
402    }
403
404    fn get_config(&self) -> StaticCodecConfig<Self> {
405        StaticCodecConfig::from(self)
406    }
407}
408
409#[derive(Debug, Error)]
410/// Errors that may occur when applying the [`LinearQuantizeCodec`].
411pub enum LinearQuantizeCodecError {
412    /// [`LinearQuantizeCodec`] cannot encode the provided dtype which differs
413    /// from the configured dtype
414    #[error("LinearQuantize cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
415    MismatchedEncodeDType {
416        /// Dtype of the `configured` `dtype`
417        configured: LinearQuantizeDType,
418        /// Dtype of the `provided` array from which the data is to be encoded
419        provided: AnyArrayDType,
420    },
421    /// [`LinearQuantizeCodec`] does not support non-finite (infinite or NaN) floating
422    /// point data
423    #[error("LinearQuantize does not support non-finite (infinite or NaN) floating point data")]
424    NonFiniteData,
425    /// [`LinearQuantizeCodec`] failed to encode the header
426    #[error("LinearQuantize failed to encode the header")]
427    HeaderEncodeFailed {
428        /// Opaque source error
429        source: LinearQuantizeHeaderError,
430    },
431    /// [`LinearQuantizeCodec`] can only decode one-dimensional arrays but
432    /// received an array of a different shape
433    #[error("LinearQuantize can only decode one-dimensional arrays but received an array of shape {shape:?}")]
434    EncodedDataNotOneDimensional {
435        /// The unexpected shape of the encoded array
436        shape: Vec<usize>,
437    },
438    /// [`LinearQuantizeCodec`] failed to decode the header
439    #[error("LinearQuantize failed to decode the header")]
440    HeaderDecodeFailed {
441        /// Opaque source error
442        source: LinearQuantizeHeaderError,
443    },
444    /// [`LinearQuantizeCodec`] decoded an invalid array shape header which does
445    /// not fit the decoded data
446    #[error(
447        "LinearQuantize decoded an invalid array shape header which does not fit the decoded data"
448    )]
449    DecodeInvalidShapeHeader {
450        /// Source error
451        #[from]
452        source: ShapeError,
453    },
454    /// [`LinearQuantizeCodec`] cannot decode the provided dtype
455    #[error("LinearQuantize cannot decode the provided dtype {dtype}")]
456    InvalidEncodedDType {
457        /// Dtype of the provided array from which the data is to be decoded
458        dtype: AnyArrayDType,
459    },
460    /// [`LinearQuantizeCodec`] cannot decode the provided dtype which differs
461    /// from the configured dtype
462    #[error("LinearQuantize cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
463    MismatchedDecodeIntoDtype {
464        /// Dtype of the `configured` `dtype`
465        configured: LinearQuantizeDType,
466        /// Dtype of the `provided` array into which the data is to be decoded
467        provided: AnyArrayDType,
468    },
469    /// [`LinearQuantizeCodec`] cannot decode the decoded array into the provided
470    /// array of a different shape
471    #[error("LinearQuantize cannot decode the decoded array of shape {decoded:?} into the provided array of shape {provided:?}")]
472    MismatchedDecodeIntoShape {
473        /// Shape of the `decoded` data
474        decoded: Vec<usize>,
475        /// Shape of the `provided` array into which the data is to be decoded
476        provided: Vec<usize>,
477    },
478}
479
480#[derive(Debug, Error)]
481#[error(transparent)]
482/// Opaque error for when encoding or decoding the header fails
483pub struct LinearQuantizeHeaderError(postcard::Error);
484
485/// Linear-quantize the elements in the `data` array using the `quantize`
486/// closure.
487///
488/// # Errors
489///
490/// Errors with
491/// - [`LinearQuantizeCodecError::NonFiniteData`] if any data element is non-
492///   finite (infinite or NaN)
493/// - [`LinearQuantizeCodecError::HeaderEncodeFailed`] if encoding the header
494///   failed
495pub fn quantize<
496    T: Float + ConstZero + ConstOne + Serialize,
497    Q: Unsigned,
498    S: Data<Elem = T>,
499    D: Dimension,
500>(
501    data: &ArrayBase<S, D>,
502    quantize: impl Fn(T) -> Q,
503) -> Result<Vec<Q>, LinearQuantizeCodecError> {
504    if !Zip::from(data).all(|x| x.is_finite()) {
505        return Err(LinearQuantizeCodecError::NonFiniteData);
506    }
507
508    let (minimum, maximum) = data.first().map_or((T::ZERO, T::ONE), |first| {
509        (
510            Zip::from(data).fold(*first, |a, b| a.min(*b)),
511            Zip::from(data).fold(*first, |a, b| a.max(*b)),
512        )
513    });
514
515    let header = postcard::to_extend(
516        &CompressionHeader {
517            shape: Cow::Borrowed(data.shape()),
518            minimum,
519            maximum,
520        },
521        Vec::new(),
522    )
523    .map_err(|err| LinearQuantizeCodecError::HeaderEncodeFailed {
524        source: LinearQuantizeHeaderError(err),
525    })?;
526
527    let mut encoded: Vec<Q> = vec![Q::ZERO; header.len().div_ceil(std::mem::size_of::<Q>())];
528    #[expect(unsafe_code)]
529    // Safety: encoded is at least header.len() bytes long and properly aligned for Q
530    unsafe {
531        std::ptr::copy_nonoverlapping(header.as_ptr(), encoded.as_mut_ptr().cast(), header.len());
532    }
533    encoded.reserve(data.len());
534
535    if maximum == minimum {
536        encoded.resize(encoded.len() + data.len(), quantize(T::ZERO));
537    } else {
538        encoded.extend(
539            data.iter()
540                .map(|x| quantize((*x - minimum) / (maximum - minimum))),
541        );
542    }
543
544    Ok(encoded)
545}
546
547/// Reconstruct the linear-quantized `encoded` array using the `floatify`
548/// closure.
549///
550/// # Errors
551///
552/// Errors with
553/// - [`LinearQuantizeCodecError::HeaderDecodeFailed`] if decoding the header
554///   failed
555pub fn reconstruct<T: Float + DeserializeOwned, Q: Unsigned>(
556    encoded: &[Q],
557    floatify: impl Fn(Q) -> T,
558) -> Result<ArrayD<T>, LinearQuantizeCodecError> {
559    #[expect(unsafe_code)]
560    // Safety: data is data.len()*size_of::<Q> bytes long and properly aligned for Q
561    let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
562        std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
563    })
564    .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
565        source: LinearQuantizeHeaderError(err),
566    })?;
567
568    let encoded = encoded
569        .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
570        .unwrap_or(&[]);
571
572    let decoded = encoded
573        .iter()
574        .map(|x| header.minimum + (floatify(*x) * (header.maximum - header.minimum)))
575        .map(|x| x.clamp(header.minimum, header.maximum))
576        .collect();
577
578    let decoded = Array::from_shape_vec(&*header.shape, decoded)?;
579
580    Ok(decoded)
581}
582
583/// Reconstruct the linear-quantized `encoded` array using the `floatify`
584/// closure into the `decoded` array.
585///
586/// # Errors
587///
588/// Errors with
589/// - [`LinearQuantizeCodecError::HeaderDecodeFailed`] if decoding the header
590///   failed
591/// - [`LinearQuantizeCodecError::MismatchedDecodeIntoShape`] if the `decoded`
592///   array is of the wrong shape
593pub fn reconstruct_into<T: Float + DeserializeOwned, Q: Unsigned>(
594    encoded: &[Q],
595    mut decoded: ArrayViewMutD<T>,
596    floatify: impl Fn(Q) -> T,
597) -> Result<(), LinearQuantizeCodecError> {
598    #[expect(unsafe_code)]
599    // Safety: data is data.len()*size_of::<Q> bytes long and properly aligned for Q
600    let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
601        std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
602    })
603    .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
604        source: LinearQuantizeHeaderError(err),
605    })?;
606
607    let encoded = encoded
608        .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
609        .unwrap_or(&[]);
610
611    if decoded.shape() != &*header.shape {
612        return Err(LinearQuantizeCodecError::MismatchedDecodeIntoShape {
613            decoded: header.shape.into_owned(),
614            provided: decoded.shape().to_vec(),
615        });
616    }
617
618    // iteration must occur in synchronised (standard) order
619    for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
620        *d = (header.minimum + (floatify(*e) * (header.maximum - header.minimum)))
621            .clamp(header.minimum, header.maximum);
622    }
623
624    Ok(())
625}
626
627/// Returns `${2.0}^{bits} - 1.0$`
628fn scale_for_bits<T: Float + From<u8> + ConstOne>(bits: u8) -> T {
629    <T as From<u8>>::from(bits).exp2() - T::ONE
630}
631
632/// Unsigned binary types.
633pub trait Unsigned: Copy {
634    /// `0x0`
635    const ZERO: Self;
636}
637
638impl Unsigned for u8 {
639    const ZERO: Self = 0;
640}
641
642impl Unsigned for u16 {
643    const ZERO: Self = 0;
644}
645
646impl Unsigned for u32 {
647    const ZERO: Self = 0;
648}
649
650impl Unsigned for u64 {
651    const ZERO: Self = 0;
652}
653
654#[derive(Serialize, Deserialize)]
655struct CompressionHeader<'a, T> {
656    #[serde(borrow)]
657    shape: Cow<'a, [usize]>,
658    minimum: T,
659    maximum: T,
660}
661
662#[cfg(test)]
663mod tests {
664    use ndarray::CowArray;
665
666    use super::*;
667
668    #[test]
669    fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> {
670        for bits in 1..=16 {
671            let codec = LinearQuantizeCodec {
672                dtype: LinearQuantizeDType::F32,
673                #[expect(unsafe_code)]
674                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
675            };
676
677            let mut data: Vec<f32> = (0..(u16::MAX >> (16 - bits)))
678                .step_by(1 << (bits.max(8) - 8))
679                .map(f32::from)
680                .collect();
681            data.push(f32::from(u16::MAX >> (16 - bits)));
682
683            let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
684            let decoded = codec.decode(encoded.cow())?;
685
686            let AnyArray::F32(decoded) = decoded else {
687                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
688                    configured: LinearQuantizeDType::F32,
689                    provided: decoded.dtype(),
690                });
691            };
692
693            for (o, d) in data.iter().zip(decoded.iter()) {
694                assert_eq!(o.to_bits(), d.to_bits());
695            }
696        }
697
698        Ok(())
699    }
700
701    #[test]
702    fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> {
703        for bits in 1..=64 {
704            let codec = LinearQuantizeCodec {
705                dtype: LinearQuantizeDType::F32,
706                #[expect(unsafe_code)]
707                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
708            };
709
710            #[expect(clippy::cast_precision_loss)]
711            let mut data: Vec<f32> = (0..(u64::MAX >> (64 - bits)))
712                .step_by(1 << (bits.max(8) - 8))
713                .map(|x| x as f32)
714                .collect();
715            #[expect(clippy::cast_precision_loss)]
716            data.push((u64::MAX >> (64 - bits)) as f32);
717
718            let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
719            let decoded = codec.decode(encoded.cow())?;
720
721            let AnyArray::F32(decoded) = decoded else {
722                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
723                    configured: LinearQuantizeDType::F32,
724                    provided: decoded.dtype(),
725                });
726            };
727
728            for (o, d) in data.iter().zip(decoded.iter()) {
729                assert_eq!(o.to_bits(), d.to_bits());
730            }
731        }
732
733        Ok(())
734    }
735
736    #[test]
737    fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> {
738        for bits in 1..=32 {
739            let codec = LinearQuantizeCodec {
740                dtype: LinearQuantizeDType::F64,
741                #[expect(unsafe_code)]
742                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
743            };
744
745            let mut data: Vec<f64> = (0..(u32::MAX >> (32 - bits)))
746                .step_by(1 << (bits.max(8) - 8))
747                .map(f64::from)
748                .collect();
749            data.push(f64::from(u32::MAX >> (32 - bits)));
750
751            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
752            let decoded = codec.decode(encoded.cow())?;
753
754            let AnyArray::F64(decoded) = decoded else {
755                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
756                    configured: LinearQuantizeDType::F64,
757                    provided: decoded.dtype(),
758                });
759            };
760
761            for (o, d) in data.iter().zip(decoded.iter()) {
762                assert_eq!(o.to_bits(), d.to_bits());
763            }
764        }
765
766        Ok(())
767    }
768
769    #[test]
770    fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> {
771        for bits in 1..=64 {
772            let codec = LinearQuantizeCodec {
773                dtype: LinearQuantizeDType::F64,
774                #[expect(unsafe_code)]
775                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
776            };
777
778            #[expect(clippy::cast_precision_loss)]
779            let mut data: Vec<f64> = (0..(u64::MAX >> (64 - bits)))
780                .step_by(1 << (bits.max(8) - 8))
781                .map(|x| x as f64)
782                .collect();
783            #[expect(clippy::cast_precision_loss)]
784            data.push((u64::MAX >> (64 - bits)) as f64);
785
786            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
787            let decoded = codec.decode(encoded.cow())?;
788
789            let AnyArray::F64(decoded) = decoded else {
790                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
791                    configured: LinearQuantizeDType::F64,
792                    provided: decoded.dtype(),
793                });
794            };
795
796            for (o, d) in data.iter().zip(decoded.iter()) {
797                assert_eq!(o.to_bits(), d.to_bits());
798            }
799        }
800
801        Ok(())
802    }
803
804    #[test]
805    fn const_data_roundtrip() -> Result<(), LinearQuantizeCodecError> {
806        for bits in 1..=64 {
807            let data = [42.0, 42.0, 42.0, 42.0];
808
809            let codec = LinearQuantizeCodec {
810                dtype: LinearQuantizeDType::F64,
811                #[expect(unsafe_code)]
812                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
813            };
814
815            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
816            let decoded = codec.decode(encoded.cow())?;
817
818            let AnyArray::F64(decoded) = decoded else {
819                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
820                    configured: LinearQuantizeDType::F64,
821                    provided: decoded.dtype(),
822                });
823            };
824
825            for (o, d) in data.iter().zip(decoded.iter()) {
826                assert_eq!(o.to_bits(), d.to_bits());
827            }
828        }
829
830        Ok(())
831    }
832}