numcodecs_linear_quantize/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-linear-quantize
10//! [crates.io]: https://crates.io/crates/numcodecs-linear-quantize
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-linear-quantize
13//! [docs.rs]: https://docs.rs/numcodecs-linear-quantize/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_linear_quantize
17//!
18//! Linear Quantization codec implementation for the [`numcodecs`] API.
19
20#![expect(clippy::multiple_crate_versions)] // FIXME: twofloat -> hexf -> syn 1.0
21
22use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayD, ArrayViewMutD, Data, Dimension, ShapeError, Zip};
25use num_traits::{ConstOne, ConstZero, Float};
26use numcodecs::{
27    AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec,
28    StaticCodecConfig, StaticCodecVersion,
29};
30use schemars::{JsonSchema, JsonSchema_repr};
31use serde::{de::DeserializeOwned, Deserialize, Serialize};
32use serde_repr::{Deserialize_repr, Serialize_repr};
33use thiserror::Error;
34use twofloat::TwoFloat;
35
36type LinearQuantizeCodecVersion = StaticCodecVersion<0, 1, 0>;
37
38#[derive(Clone, Serialize, Deserialize, JsonSchema)]
39#[serde(deny_unknown_fields)]
40/// Lossy codec to reduce the precision of floating point data.
41///
42/// The data is quantized to unsigned integers of the best-fitting type.
43/// The range and shape of the input data is stored in-band.
44pub struct LinearQuantizeCodec {
45    /// Dtype of the decoded data
46    pub dtype: LinearQuantizeDType,
47    /// Binary precision of the encoded data where `$bits = \log_{2}(bins)$`
48    pub bits: LinearQuantizeBins,
49    /// The codec's encoding format version. Do not provide this parameter explicitly.
50    #[serde(default, rename = "_version")]
51    pub version: LinearQuantizeCodecVersion,
52}
53
54/// Data types which the [`LinearQuantizeCodec`] can quantize
55#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema)]
56#[schemars(extend("enum" = ["f32", "float32", "f64", "float64"]))]
57#[expect(missing_docs)]
58pub enum LinearQuantizeDType {
59    #[serde(rename = "f32", alias = "float32")]
60    F32,
61    #[serde(rename = "f64", alias = "float64")]
62    F64,
63}
64
65impl fmt::Display for LinearQuantizeDType {
66    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
67        fmt.write_str(match self {
68            Self::F32 => "f32",
69            Self::F64 => "f64",
70        })
71    }
72}
73
74/// Number of bins for quantization, written in base-2 scientific notation.
75///
76/// The binary `#[repr(u8)]` value of each variant is equivalent to the binary
77/// logarithm of the number of bins, i.e. the binary precision or the number of
78/// bits used.
79#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
80#[repr(u8)]
81#[rustfmt::skip]
82#[expect(missing_docs)]
83pub enum LinearQuantizeBins {
84    _1B1 = 1, _1B2, _1B3, _1B4, _1B5, _1B6, _1B7, _1B8,
85    _1B9, _1B10, _1B11, _1B12, _1B13, _1B14, _1B15, _1B16,
86    _1B17, _1B18, _1B19, _1B20, _1B21, _1B22, _1B23, _1B24,
87    _1B25, _1B26, _1B27, _1B28, _1B29, _1B30, _1B31, _1B32,
88    _1B33, _1B34, _1B35, _1B36, _1B37, _1B38, _1B39, _1B40,
89    _1B41, _1B42, _1B43, _1B44, _1B45, _1B46, _1B47, _1B48,
90    _1B49, _1B50, _1B51, _1B52, _1B53, _1B54, _1B55, _1B56,
91    _1B57, _1B58, _1B59, _1B60, _1B61, _1B62, _1B63, _1B64,
92}
93
94impl Codec for LinearQuantizeCodec {
95    type Error = LinearQuantizeCodecError;
96
97    #[expect(clippy::too_many_lines)]
98    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
99        let encoded = match (&data, self.dtype) {
100            (AnyCowArray::F32(data), LinearQuantizeDType::F32) => match self.bits as u8 {
101                bits @ ..=8 => AnyArray::U8(
102                    Array1::from_vec(quantize(data, |x| {
103                        let max = f32::from(u8::MAX >> (8 - bits));
104                        let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
105                        #[expect(unsafe_code)]
106                        // Safety: x is clamped beforehand
107                        unsafe {
108                            x.to_int_unchecked::<u8>()
109                        }
110                    })?)
111                    .into_dyn(),
112                ),
113                bits @ 9..=16 => AnyArray::U16(
114                    Array1::from_vec(quantize(data, |x| {
115                        let max = f32::from(u16::MAX >> (16 - bits));
116                        let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
117                        #[expect(unsafe_code)]
118                        // Safety: x is clamped beforehand
119                        unsafe {
120                            x.to_int_unchecked::<u16>()
121                        }
122                    })?)
123                    .into_dyn(),
124                ),
125                bits @ 17..=32 => AnyArray::U32(
126                    Array1::from_vec(quantize(data, |x| {
127                        // we need to use f64 here to have sufficient precision
128                        let max = f64::from(u32::MAX >> (32 - bits));
129                        let x = f64::from(x)
130                            .mul_add(scale_for_bits::<f64>(bits), 0.5)
131                            .clamp(0.0, max);
132                        #[expect(unsafe_code)]
133                        // Safety: x is clamped beforehand
134                        unsafe {
135                            x.to_int_unchecked::<u32>()
136                        }
137                    })?)
138                    .into_dyn(),
139                ),
140                bits @ 33.. => AnyArray::U64(
141                    Array1::from_vec(quantize(data, |x| {
142                        // we need to use TwoFloat here to have sufficient precision
143                        let max = TwoFloat::from(u64::MAX >> (64 - bits));
144                        let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
145                            + TwoFloat::from(0.5))
146                        .max(TwoFloat::from(0.0))
147                        .min(max);
148                        #[expect(unsafe_code)]
149                        // Safety: x is clamped beforehand
150                        unsafe {
151                            u64::try_from(x).unwrap_unchecked()
152                        }
153                    })?)
154                    .into_dyn(),
155                ),
156            },
157            (AnyCowArray::F64(data), LinearQuantizeDType::F64) => match self.bits as u8 {
158                bits @ ..=8 => AnyArray::U8(
159                    Array1::from_vec(quantize(data, |x| {
160                        let max = f64::from(u8::MAX >> (8 - bits));
161                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
162                        #[expect(unsafe_code)]
163                        // Safety: x is clamped beforehand
164                        unsafe {
165                            x.to_int_unchecked::<u8>()
166                        }
167                    })?)
168                    .into_dyn(),
169                ),
170                bits @ 9..=16 => AnyArray::U16(
171                    Array1::from_vec(quantize(data, |x| {
172                        let max = f64::from(u16::MAX >> (16 - bits));
173                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
174                        #[expect(unsafe_code)]
175                        // Safety: x is clamped beforehand
176                        unsafe {
177                            x.to_int_unchecked::<u16>()
178                        }
179                    })?)
180                    .into_dyn(),
181                ),
182                bits @ 17..=32 => AnyArray::U32(
183                    Array1::from_vec(quantize(data, |x| {
184                        let max = f64::from(u32::MAX >> (32 - bits));
185                        let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
186                        #[expect(unsafe_code)]
187                        // Safety: x is clamped beforehand
188                        unsafe {
189                            x.to_int_unchecked::<u32>()
190                        }
191                    })?)
192                    .into_dyn(),
193                ),
194                bits @ 33.. => AnyArray::U64(
195                    Array1::from_vec(quantize(data, |x| {
196                        // we need to use TwoFloat here to have sufficient precision
197                        let max = TwoFloat::from(u64::MAX >> (64 - bits));
198                        let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
199                            + TwoFloat::from(0.5))
200                        .max(TwoFloat::from(0.0))
201                        .min(max);
202                        #[expect(unsafe_code)]
203                        // Safety: x is clamped beforehand
204                        unsafe {
205                            u64::try_from(x).unwrap_unchecked()
206                        }
207                    })?)
208                    .into_dyn(),
209                ),
210            },
211            (data, dtype) => {
212                return Err(LinearQuantizeCodecError::MismatchedEncodeDType {
213                    configured: dtype,
214                    provided: data.dtype(),
215                });
216            }
217        };
218
219        Ok(encoded)
220    }
221
222    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
223        #[expect(clippy::option_if_let_else)]
224        fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
225            array: &ArrayBase<S, D>,
226        ) -> Cow<[T]> {
227            if let Some(data) = array.as_slice() {
228                Cow::Borrowed(data)
229            } else {
230                Cow::Owned(array.iter().copied().collect())
231            }
232        }
233
234        if !matches!(encoded.shape(), [_]) {
235            return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
236                shape: encoded.shape().to_vec(),
237            });
238        }
239
240        let decoded = match (&encoded, self.dtype) {
241            (AnyCowArray::U8(encoded), LinearQuantizeDType::F32) => {
242                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
243                    f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
244                })?)
245            }
246            (AnyCowArray::U16(encoded), LinearQuantizeDType::F32) => {
247                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
248                    f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
249                })?)
250            }
251            (AnyCowArray::U32(encoded), LinearQuantizeDType::F32) => {
252                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
253                    // we need to use f64 here to have sufficient precision
254                    let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
255                    #[expect(clippy::cast_possible_truncation)]
256                    let x = x as f32;
257                    x
258                })?)
259            }
260            (AnyCowArray::U64(encoded), LinearQuantizeDType::F32) => {
261                AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
262                    // we need to use TwoFloat here to have sufficient precision
263                    let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
264                    f32::from(x)
265                })?)
266            }
267            (AnyCowArray::U8(encoded), LinearQuantizeDType::F64) => {
268                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
269                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
270                })?)
271            }
272            (AnyCowArray::U16(encoded), LinearQuantizeDType::F64) => {
273                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
274                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
275                })?)
276            }
277            (AnyCowArray::U32(encoded), LinearQuantizeDType::F64) => {
278                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
279                    f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
280                })?)
281            }
282            (AnyCowArray::U64(encoded), LinearQuantizeDType::F64) => {
283                AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
284                    // we need to use TwoFloat here to have sufficient precision
285                    let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
286                    f64::from(x)
287                })?)
288            }
289            (encoded, _dtype) => {
290                return Err(LinearQuantizeCodecError::InvalidEncodedDType {
291                    dtype: encoded.dtype(),
292                })
293            }
294        };
295
296        Ok(decoded)
297    }
298
299    fn decode_into(
300        &self,
301        encoded: AnyArrayView,
302        decoded: AnyArrayViewMut,
303    ) -> Result<(), Self::Error> {
304        fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
305            array: &ArrayBase<S, D>,
306        ) -> Cow<[T]> {
307            #[expect(clippy::option_if_let_else)]
308            if let Some(data) = array.as_slice() {
309                Cow::Borrowed(data)
310            } else {
311                Cow::Owned(array.iter().copied().collect())
312            }
313        }
314
315        if !matches!(encoded.shape(), [_]) {
316            return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
317                shape: encoded.shape().to_vec(),
318            });
319        }
320
321        match (decoded, self.dtype) {
322            (AnyArrayViewMut::F32(decoded), LinearQuantizeDType::F32) => {
323                match &encoded {
324                    AnyArrayView::U8(encoded) => {
325                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
326                            f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
327                        })
328                    }
329                    AnyArrayView::U16(encoded) => {
330                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
331                            f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
332                        })
333                    }
334                    AnyArrayView::U32(encoded) => {
335                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
336                            // we need to use f64 here to have sufficient precision
337                            let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
338                            #[expect(clippy::cast_possible_truncation)]
339                            let x = x as f32;
340                            x
341                        })
342                    }
343                    AnyArrayView::U64(encoded) => {
344                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
345                            // we need to use TwoFloat here to have sufficient precision
346                            let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
347                            f32::from(x)
348                        })
349                    }
350                    encoded => {
351                        return Err(LinearQuantizeCodecError::InvalidEncodedDType {
352                            dtype: encoded.dtype(),
353                        })
354                    }
355                }
356            }
357            (AnyArrayViewMut::F64(decoded), LinearQuantizeDType::F64) => {
358                match &encoded {
359                    AnyArrayView::U8(encoded) => {
360                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
361                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
362                        })
363                    }
364                    AnyArrayView::U16(encoded) => {
365                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
366                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
367                        })
368                    }
369                    AnyArrayView::U32(encoded) => {
370                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
371                            f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
372                        })
373                    }
374                    AnyArrayView::U64(encoded) => {
375                        reconstruct_into(&as_standard_order(encoded), decoded, |x| {
376                            // we need to use TwoFloat here to have sufficient precision
377                            let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
378                            f64::from(x)
379                        })
380                    }
381                    encoded => {
382                        return Err(LinearQuantizeCodecError::InvalidEncodedDType {
383                            dtype: encoded.dtype(),
384                        })
385                    }
386                }
387            }
388            (decoded, dtype) => {
389                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
390                    configured: dtype,
391                    provided: decoded.dtype(),
392                })
393            }
394        }?;
395
396        Ok(())
397    }
398}
399
400impl StaticCodec for LinearQuantizeCodec {
401    const CODEC_ID: &'static str = "linear-quantize.rs";
402
403    type Config<'de> = Self;
404
405    fn from_config(config: Self::Config<'_>) -> Self {
406        config
407    }
408
409    fn get_config(&self) -> StaticCodecConfig<Self> {
410        StaticCodecConfig::from(self)
411    }
412}
413
414#[derive(Debug, Error)]
415/// Errors that may occur when applying the [`LinearQuantizeCodec`].
416pub enum LinearQuantizeCodecError {
417    /// [`LinearQuantizeCodec`] cannot encode the provided dtype which differs
418    /// from the configured dtype
419    #[error("LinearQuantize cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
420    MismatchedEncodeDType {
421        /// Dtype of the `configured` `dtype`
422        configured: LinearQuantizeDType,
423        /// Dtype of the `provided` array from which the data is to be encoded
424        provided: AnyArrayDType,
425    },
426    /// [`LinearQuantizeCodec`] does not support non-finite (infinite or NaN) floating
427    /// point data
428    #[error("LinearQuantize does not support non-finite (infinite or NaN) floating point data")]
429    NonFiniteData,
430    /// [`LinearQuantizeCodec`] failed to encode the header
431    #[error("LinearQuantize failed to encode the header")]
432    HeaderEncodeFailed {
433        /// Opaque source error
434        source: LinearQuantizeHeaderError,
435    },
436    /// [`LinearQuantizeCodec`] can only decode one-dimensional arrays but
437    /// received an array of a different shape
438    #[error("LinearQuantize can only decode one-dimensional arrays but received an array of shape {shape:?}")]
439    EncodedDataNotOneDimensional {
440        /// The unexpected shape of the encoded array
441        shape: Vec<usize>,
442    },
443    /// [`LinearQuantizeCodec`] failed to decode the header
444    #[error("LinearQuantize failed to decode the header")]
445    HeaderDecodeFailed {
446        /// Opaque source error
447        source: LinearQuantizeHeaderError,
448    },
449    /// [`LinearQuantizeCodec`] decoded an invalid array shape header which does
450    /// not fit the decoded data
451    #[error(
452        "LinearQuantize decoded an invalid array shape header which does not fit the decoded data"
453    )]
454    DecodeInvalidShapeHeader {
455        /// Source error
456        #[from]
457        source: ShapeError,
458    },
459    /// [`LinearQuantizeCodec`] cannot decode the provided dtype
460    #[error("LinearQuantize cannot decode the provided dtype {dtype}")]
461    InvalidEncodedDType {
462        /// Dtype of the provided array from which the data is to be decoded
463        dtype: AnyArrayDType,
464    },
465    /// [`LinearQuantizeCodec`] cannot decode the provided dtype which differs
466    /// from the configured dtype
467    #[error("LinearQuantize cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
468    MismatchedDecodeIntoDtype {
469        /// Dtype of the `configured` `dtype`
470        configured: LinearQuantizeDType,
471        /// Dtype of the `provided` array into which the data is to be decoded
472        provided: AnyArrayDType,
473    },
474    /// [`LinearQuantizeCodec`] cannot decode the decoded array into the provided
475    /// array of a different shape
476    #[error("LinearQuantize cannot decode the decoded array of shape {decoded:?} into the provided array of shape {provided:?}")]
477    MismatchedDecodeIntoShape {
478        /// Shape of the `decoded` data
479        decoded: Vec<usize>,
480        /// Shape of the `provided` array into which the data is to be decoded
481        provided: Vec<usize>,
482    },
483}
484
485#[derive(Debug, Error)]
486#[error(transparent)]
487/// Opaque error for when encoding or decoding the header fails
488pub struct LinearQuantizeHeaderError(postcard::Error);
489
490/// Linear-quantize the elements in the `data` array using the `quantize`
491/// closure.
492///
493/// # Errors
494///
495/// Errors with
496/// - [`LinearQuantizeCodecError::NonFiniteData`] if any data element is non-
497///   finite (infinite or NaN)
498/// - [`LinearQuantizeCodecError::HeaderEncodeFailed`] if encoding the header
499///   failed
500pub fn quantize<
501    T: Float + ConstZero + ConstOne + Serialize,
502    Q: Unsigned,
503    S: Data<Elem = T>,
504    D: Dimension,
505>(
506    data: &ArrayBase<S, D>,
507    quantize: impl Fn(T) -> Q,
508) -> Result<Vec<Q>, LinearQuantizeCodecError> {
509    if !Zip::from(data).all(|x| x.is_finite()) {
510        return Err(LinearQuantizeCodecError::NonFiniteData);
511    }
512
513    let (minimum, maximum) = data.first().map_or((T::ZERO, T::ONE), |first| {
514        (
515            Zip::from(data).fold(*first, |a, b| a.min(*b)),
516            Zip::from(data).fold(*first, |a, b| a.max(*b)),
517        )
518    });
519
520    let header = postcard::to_extend(
521        &CompressionHeader {
522            shape: Cow::Borrowed(data.shape()),
523            minimum,
524            maximum,
525            version: StaticCodecVersion,
526        },
527        Vec::new(),
528    )
529    .map_err(|err| LinearQuantizeCodecError::HeaderEncodeFailed {
530        source: LinearQuantizeHeaderError(err),
531    })?;
532
533    let mut encoded: Vec<Q> = vec![Q::ZERO; header.len().div_ceil(std::mem::size_of::<Q>())];
534    #[expect(unsafe_code)]
535    // Safety: encoded is at least header.len() bytes long and properly aligned for Q
536    unsafe {
537        std::ptr::copy_nonoverlapping(header.as_ptr(), encoded.as_mut_ptr().cast(), header.len());
538    }
539    encoded.reserve(data.len());
540
541    if maximum == minimum {
542        encoded.resize(encoded.len() + data.len(), quantize(T::ZERO));
543    } else {
544        encoded.extend(
545            data.iter()
546                .map(|x| quantize((*x - minimum) / (maximum - minimum))),
547        );
548    }
549
550    Ok(encoded)
551}
552
553/// Reconstruct the linear-quantized `encoded` array using the `floatify`
554/// closure.
555///
556/// # Errors
557///
558/// Errors with
559/// - [`LinearQuantizeCodecError::HeaderDecodeFailed`] if decoding the header
560///   failed
561pub fn reconstruct<T: Float + DeserializeOwned, Q: Unsigned>(
562    encoded: &[Q],
563    floatify: impl Fn(Q) -> T,
564) -> Result<ArrayD<T>, LinearQuantizeCodecError> {
565    #[expect(unsafe_code)]
566    // Safety: data is data.len()*size_of::<Q> bytes long and properly aligned for Q
567    let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
568        std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
569    })
570    .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
571        source: LinearQuantizeHeaderError(err),
572    })?;
573
574    let encoded = encoded
575        .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
576        .unwrap_or(&[]);
577
578    let decoded = encoded
579        .iter()
580        .map(|x| header.minimum + (floatify(*x) * (header.maximum - header.minimum)))
581        .map(|x| x.clamp(header.minimum, header.maximum))
582        .collect();
583
584    let decoded = Array::from_shape_vec(&*header.shape, decoded)?;
585
586    Ok(decoded)
587}
588
589/// Reconstruct the linear-quantized `encoded` array using the `floatify`
590/// closure into the `decoded` array.
591///
592/// # Errors
593///
594/// Errors with
595/// - [`LinearQuantizeCodecError::HeaderDecodeFailed`] if decoding the header
596///   failed
597/// - [`LinearQuantizeCodecError::MismatchedDecodeIntoShape`] if the `decoded`
598///   array is of the wrong shape
599pub fn reconstruct_into<T: Float + DeserializeOwned, Q: Unsigned>(
600    encoded: &[Q],
601    mut decoded: ArrayViewMutD<T>,
602    floatify: impl Fn(Q) -> T,
603) -> Result<(), LinearQuantizeCodecError> {
604    #[expect(unsafe_code)]
605    // Safety: data is data.len()*size_of::<Q> bytes long and properly aligned for Q
606    let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
607        std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
608    })
609    .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
610        source: LinearQuantizeHeaderError(err),
611    })?;
612
613    let encoded = encoded
614        .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
615        .unwrap_or(&[]);
616
617    if decoded.shape() != &*header.shape {
618        return Err(LinearQuantizeCodecError::MismatchedDecodeIntoShape {
619            decoded: header.shape.into_owned(),
620            provided: decoded.shape().to_vec(),
621        });
622    }
623
624    // iteration must occur in synchronised (standard) order
625    for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
626        *d = (header.minimum + (floatify(*e) * (header.maximum - header.minimum)))
627            .clamp(header.minimum, header.maximum);
628    }
629
630    Ok(())
631}
632
633/// Returns `${2.0}^{bits} - 1.0$`
634fn scale_for_bits<T: Float + From<u8> + ConstOne>(bits: u8) -> T {
635    <T as From<u8>>::from(bits).exp2() - T::ONE
636}
637
638/// Unsigned binary types.
639pub trait Unsigned: Copy {
640    /// `0x0`
641    const ZERO: Self;
642}
643
644impl Unsigned for u8 {
645    const ZERO: Self = 0;
646}
647
648impl Unsigned for u16 {
649    const ZERO: Self = 0;
650}
651
652impl Unsigned for u32 {
653    const ZERO: Self = 0;
654}
655
656impl Unsigned for u64 {
657    const ZERO: Self = 0;
658}
659
660#[derive(Serialize, Deserialize)]
661struct CompressionHeader<'a, T> {
662    #[serde(borrow)]
663    shape: Cow<'a, [usize]>,
664    minimum: T,
665    maximum: T,
666    version: LinearQuantizeCodecVersion,
667}
668
669#[cfg(test)]
670mod tests {
671    use ndarray::CowArray;
672
673    use super::*;
674
675    #[test]
676    fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> {
677        for bits in 1..=16 {
678            let codec = LinearQuantizeCodec {
679                dtype: LinearQuantizeDType::F32,
680                #[expect(unsafe_code)]
681                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
682                version: StaticCodecVersion,
683            };
684
685            let mut data: Vec<f32> = (0..(u16::MAX >> (16 - bits)))
686                .step_by(1 << (bits.max(8) - 8))
687                .map(f32::from)
688                .collect();
689            data.push(f32::from(u16::MAX >> (16 - bits)));
690
691            let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
692            let decoded = codec.decode(encoded.cow())?;
693
694            let AnyArray::F32(decoded) = decoded else {
695                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
696                    configured: LinearQuantizeDType::F32,
697                    provided: decoded.dtype(),
698                });
699            };
700
701            for (o, d) in data.iter().zip(decoded.iter()) {
702                assert_eq!(o.to_bits(), d.to_bits());
703            }
704        }
705
706        Ok(())
707    }
708
709    #[test]
710    fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> {
711        for bits in 1..=64 {
712            let codec = LinearQuantizeCodec {
713                dtype: LinearQuantizeDType::F32,
714                #[expect(unsafe_code)]
715                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
716                version: StaticCodecVersion,
717            };
718
719            #[expect(clippy::cast_precision_loss)]
720            let mut data: Vec<f32> = (0..(u64::MAX >> (64 - bits)))
721                .step_by(1 << (bits.max(8) - 8))
722                .map(|x| x as f32)
723                .collect();
724            #[expect(clippy::cast_precision_loss)]
725            data.push((u64::MAX >> (64 - bits)) as f32);
726
727            let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
728            let decoded = codec.decode(encoded.cow())?;
729
730            let AnyArray::F32(decoded) = decoded else {
731                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
732                    configured: LinearQuantizeDType::F32,
733                    provided: decoded.dtype(),
734                });
735            };
736
737            for (o, d) in data.iter().zip(decoded.iter()) {
738                assert_eq!(o.to_bits(), d.to_bits());
739            }
740        }
741
742        Ok(())
743    }
744
745    #[test]
746    fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> {
747        for bits in 1..=32 {
748            let codec = LinearQuantizeCodec {
749                dtype: LinearQuantizeDType::F64,
750                #[expect(unsafe_code)]
751                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
752                version: StaticCodecVersion,
753            };
754
755            let mut data: Vec<f64> = (0..(u32::MAX >> (32 - bits)))
756                .step_by(1 << (bits.max(8) - 8))
757                .map(f64::from)
758                .collect();
759            data.push(f64::from(u32::MAX >> (32 - bits)));
760
761            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
762            let decoded = codec.decode(encoded.cow())?;
763
764            let AnyArray::F64(decoded) = decoded else {
765                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
766                    configured: LinearQuantizeDType::F64,
767                    provided: decoded.dtype(),
768                });
769            };
770
771            for (o, d) in data.iter().zip(decoded.iter()) {
772                assert_eq!(o.to_bits(), d.to_bits());
773            }
774        }
775
776        Ok(())
777    }
778
779    #[test]
780    fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> {
781        for bits in 1..=64 {
782            let codec = LinearQuantizeCodec {
783                dtype: LinearQuantizeDType::F64,
784                #[expect(unsafe_code)]
785                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
786                version: StaticCodecVersion,
787            };
788
789            #[expect(clippy::cast_precision_loss)]
790            let mut data: Vec<f64> = (0..(u64::MAX >> (64 - bits)))
791                .step_by(1 << (bits.max(8) - 8))
792                .map(|x| x as f64)
793                .collect();
794            #[expect(clippy::cast_precision_loss)]
795            data.push((u64::MAX >> (64 - bits)) as f64);
796
797            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
798            let decoded = codec.decode(encoded.cow())?;
799
800            let AnyArray::F64(decoded) = decoded else {
801                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
802                    configured: LinearQuantizeDType::F64,
803                    provided: decoded.dtype(),
804                });
805            };
806
807            for (o, d) in data.iter().zip(decoded.iter()) {
808                assert_eq!(o.to_bits(), d.to_bits());
809            }
810        }
811
812        Ok(())
813    }
814
815    #[test]
816    fn const_data_roundtrip() -> Result<(), LinearQuantizeCodecError> {
817        for bits in 1..=64 {
818            let data = [42.0, 42.0, 42.0, 42.0];
819
820            let codec = LinearQuantizeCodec {
821                dtype: LinearQuantizeDType::F64,
822                #[expect(unsafe_code)]
823                bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
824                version: StaticCodecVersion,
825            };
826
827            let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
828            let decoded = codec.decode(encoded.cow())?;
829
830            let AnyArray::F64(decoded) = decoded else {
831                return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
832                    configured: LinearQuantizeDType::F64,
833                    provided: decoded.dtype(),
834                });
835            };
836
837            for (o, d) in data.iter().zip(decoded.iter()) {
838                assert_eq!(o.to_bits(), d.to_bits());
839            }
840        }
841
842        Ok(())
843    }
844}