1#![expect(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayD, ArrayViewMutD, Data, Dimension, ShapeError, Zip};
25use num_traits::{ConstOne, ConstZero, Float};
26use numcodecs::{
27 AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec,
28 StaticCodecConfig, StaticCodecVersion,
29};
30use schemars::{JsonSchema, JsonSchema_repr};
31use serde::{Deserialize, Serialize, de::DeserializeOwned};
32use serde_repr::{Deserialize_repr, Serialize_repr};
33use thiserror::Error;
34use twofloat::TwoFloat;
35
36type LinearQuantizeCodecVersion = StaticCodecVersion<0, 1, 0>;
37
38#[derive(Clone, Serialize, Deserialize, JsonSchema)]
39#[serde(deny_unknown_fields)]
40pub struct LinearQuantizeCodec {
45 pub dtype: LinearQuantizeDType,
47 pub bits: LinearQuantizeBins,
49 #[serde(default, rename = "_version")]
51 pub version: LinearQuantizeCodecVersion,
52}
53
54#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema)]
56#[schemars(extend("enum" = ["f32", "float32", "f64", "float64"]))]
57#[expect(missing_docs)]
58pub enum LinearQuantizeDType {
59 #[serde(rename = "f32", alias = "float32")]
60 F32,
61 #[serde(rename = "f64", alias = "float64")]
62 F64,
63}
64
65impl fmt::Display for LinearQuantizeDType {
66 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
67 fmt.write_str(match self {
68 Self::F32 => "f32",
69 Self::F64 => "f64",
70 })
71 }
72}
73
74#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
80#[repr(u8)]
81#[rustfmt::skip]
82#[expect(missing_docs)]
83pub enum LinearQuantizeBins {
84 _1B1 = 1, _1B2, _1B3, _1B4, _1B5, _1B6, _1B7, _1B8,
85 _1B9, _1B10, _1B11, _1B12, _1B13, _1B14, _1B15, _1B16,
86 _1B17, _1B18, _1B19, _1B20, _1B21, _1B22, _1B23, _1B24,
87 _1B25, _1B26, _1B27, _1B28, _1B29, _1B30, _1B31, _1B32,
88 _1B33, _1B34, _1B35, _1B36, _1B37, _1B38, _1B39, _1B40,
89 _1B41, _1B42, _1B43, _1B44, _1B45, _1B46, _1B47, _1B48,
90 _1B49, _1B50, _1B51, _1B52, _1B53, _1B54, _1B55, _1B56,
91 _1B57, _1B58, _1B59, _1B60, _1B61, _1B62, _1B63, _1B64,
92}
93
94impl Codec for LinearQuantizeCodec {
95 type Error = LinearQuantizeCodecError;
96
97 #[expect(clippy::too_many_lines)]
98 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
99 let encoded = match (&data, self.dtype) {
100 (AnyCowArray::F32(data), LinearQuantizeDType::F32) => match self.bits as u8 {
101 bits @ ..=8 => AnyArray::U8(
102 Array1::from_vec(quantize(data, |x| {
103 let max = f32::from(u8::MAX >> (8 - bits));
104 let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
105 #[expect(unsafe_code)]
106 unsafe {
108 x.to_int_unchecked::<u8>()
109 }
110 })?)
111 .into_dyn(),
112 ),
113 bits @ 9..=16 => AnyArray::U16(
114 Array1::from_vec(quantize(data, |x| {
115 let max = f32::from(u16::MAX >> (16 - bits));
116 let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
117 #[expect(unsafe_code)]
118 unsafe {
120 x.to_int_unchecked::<u16>()
121 }
122 })?)
123 .into_dyn(),
124 ),
125 bits @ 17..=32 => AnyArray::U32(
126 Array1::from_vec(quantize(data, |x| {
127 let max = f64::from(u32::MAX >> (32 - bits));
129 let x = f64::from(x)
130 .mul_add(scale_for_bits::<f64>(bits), 0.5)
131 .clamp(0.0, max);
132 #[expect(unsafe_code)]
133 unsafe {
135 x.to_int_unchecked::<u32>()
136 }
137 })?)
138 .into_dyn(),
139 ),
140 bits @ 33.. => AnyArray::U64(
141 Array1::from_vec(quantize(data, |x| {
142 let max = TwoFloat::from(u64::MAX >> (64 - bits));
144 let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
145 + TwoFloat::from(0.5))
146 .max(TwoFloat::from(0.0))
147 .min(max);
148 #[expect(unsafe_code)]
149 unsafe {
151 u64::try_from(x).unwrap_unchecked()
152 }
153 })?)
154 .into_dyn(),
155 ),
156 },
157 (AnyCowArray::F64(data), LinearQuantizeDType::F64) => match self.bits as u8 {
158 bits @ ..=8 => AnyArray::U8(
159 Array1::from_vec(quantize(data, |x| {
160 let max = f64::from(u8::MAX >> (8 - bits));
161 let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
162 #[expect(unsafe_code)]
163 unsafe {
165 x.to_int_unchecked::<u8>()
166 }
167 })?)
168 .into_dyn(),
169 ),
170 bits @ 9..=16 => AnyArray::U16(
171 Array1::from_vec(quantize(data, |x| {
172 let max = f64::from(u16::MAX >> (16 - bits));
173 let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
174 #[expect(unsafe_code)]
175 unsafe {
177 x.to_int_unchecked::<u16>()
178 }
179 })?)
180 .into_dyn(),
181 ),
182 bits @ 17..=32 => AnyArray::U32(
183 Array1::from_vec(quantize(data, |x| {
184 let max = f64::from(u32::MAX >> (32 - bits));
185 let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
186 #[expect(unsafe_code)]
187 unsafe {
189 x.to_int_unchecked::<u32>()
190 }
191 })?)
192 .into_dyn(),
193 ),
194 bits @ 33.. => AnyArray::U64(
195 Array1::from_vec(quantize(data, |x| {
196 let max = TwoFloat::from(u64::MAX >> (64 - bits));
198 let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
199 + TwoFloat::from(0.5))
200 .max(TwoFloat::from(0.0))
201 .min(max);
202 #[expect(unsafe_code)]
203 unsafe {
205 u64::try_from(x).unwrap_unchecked()
206 }
207 })?)
208 .into_dyn(),
209 ),
210 },
211 (data, dtype) => {
212 return Err(LinearQuantizeCodecError::MismatchedEncodeDType {
213 configured: dtype,
214 provided: data.dtype(),
215 });
216 }
217 };
218
219 Ok(encoded)
220 }
221
222 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
223 #[expect(clippy::option_if_let_else)]
224 fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
225 array: &ArrayBase<S, D>,
226 ) -> Cow<[T]> {
227 if let Some(data) = array.as_slice() {
228 Cow::Borrowed(data)
229 } else {
230 Cow::Owned(array.iter().copied().collect())
231 }
232 }
233
234 if !matches!(encoded.shape(), [_]) {
235 return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
236 shape: encoded.shape().to_vec(),
237 });
238 }
239
240 let decoded = match (&encoded, self.dtype) {
241 (AnyCowArray::U8(encoded), LinearQuantizeDType::F32) => {
242 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
243 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
244 })?)
245 }
246 (AnyCowArray::U16(encoded), LinearQuantizeDType::F32) => {
247 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
248 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
249 })?)
250 }
251 (AnyCowArray::U32(encoded), LinearQuantizeDType::F32) => {
252 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
253 let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
255 #[expect(clippy::cast_possible_truncation)]
256 let x = x as f32;
257 x
258 })?)
259 }
260 (AnyCowArray::U64(encoded), LinearQuantizeDType::F32) => {
261 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
262 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
264 f32::from(x)
265 })?)
266 }
267 (AnyCowArray::U8(encoded), LinearQuantizeDType::F64) => {
268 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
269 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
270 })?)
271 }
272 (AnyCowArray::U16(encoded), LinearQuantizeDType::F64) => {
273 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
274 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
275 })?)
276 }
277 (AnyCowArray::U32(encoded), LinearQuantizeDType::F64) => {
278 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
279 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
280 })?)
281 }
282 (AnyCowArray::U64(encoded), LinearQuantizeDType::F64) => {
283 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
284 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
286 f64::from(x)
287 })?)
288 }
289 (encoded, _dtype) => {
290 return Err(LinearQuantizeCodecError::InvalidEncodedDType {
291 dtype: encoded.dtype(),
292 });
293 }
294 };
295
296 Ok(decoded)
297 }
298
299 fn decode_into(
300 &self,
301 encoded: AnyArrayView,
302 decoded: AnyArrayViewMut,
303 ) -> Result<(), Self::Error> {
304 fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
305 array: &ArrayBase<S, D>,
306 ) -> Cow<[T]> {
307 #[expect(clippy::option_if_let_else)]
308 if let Some(data) = array.as_slice() {
309 Cow::Borrowed(data)
310 } else {
311 Cow::Owned(array.iter().copied().collect())
312 }
313 }
314
315 if !matches!(encoded.shape(), [_]) {
316 return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
317 shape: encoded.shape().to_vec(),
318 });
319 }
320
321 match (decoded, self.dtype) {
322 (AnyArrayViewMut::F32(decoded), LinearQuantizeDType::F32) => {
323 match &encoded {
324 AnyArrayView::U8(encoded) => {
325 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
326 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
327 })
328 }
329 AnyArrayView::U16(encoded) => {
330 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
331 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
332 })
333 }
334 AnyArrayView::U32(encoded) => {
335 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
336 let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
338 #[expect(clippy::cast_possible_truncation)]
339 let x = x as f32;
340 x
341 })
342 }
343 AnyArrayView::U64(encoded) => {
344 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
345 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
347 f32::from(x)
348 })
349 }
350 encoded => {
351 return Err(LinearQuantizeCodecError::InvalidEncodedDType {
352 dtype: encoded.dtype(),
353 });
354 }
355 }
356 }
357 (AnyArrayViewMut::F64(decoded), LinearQuantizeDType::F64) => {
358 match &encoded {
359 AnyArrayView::U8(encoded) => {
360 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
361 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
362 })
363 }
364 AnyArrayView::U16(encoded) => {
365 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
366 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
367 })
368 }
369 AnyArrayView::U32(encoded) => {
370 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
371 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
372 })
373 }
374 AnyArrayView::U64(encoded) => {
375 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
376 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
378 f64::from(x)
379 })
380 }
381 encoded => {
382 return Err(LinearQuantizeCodecError::InvalidEncodedDType {
383 dtype: encoded.dtype(),
384 });
385 }
386 }
387 }
388 (decoded, dtype) => {
389 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
390 configured: dtype,
391 provided: decoded.dtype(),
392 });
393 }
394 }?;
395
396 Ok(())
397 }
398}
399
400impl StaticCodec for LinearQuantizeCodec {
401 const CODEC_ID: &'static str = "linear-quantize.rs";
402
403 type Config<'de> = Self;
404
405 fn from_config(config: Self::Config<'_>) -> Self {
406 config
407 }
408
409 fn get_config(&self) -> StaticCodecConfig<Self> {
410 StaticCodecConfig::from(self)
411 }
412}
413
414#[derive(Debug, Error)]
415pub enum LinearQuantizeCodecError {
417 #[error(
420 "LinearQuantize cannot encode the provided dtype {provided} which differs from the configured dtype {configured}"
421 )]
422 MismatchedEncodeDType {
423 configured: LinearQuantizeDType,
425 provided: AnyArrayDType,
427 },
428 #[error("LinearQuantize does not support non-finite (infinite or NaN) floating point data")]
431 NonFiniteData,
432 #[error("LinearQuantize failed to encode the header")]
434 HeaderEncodeFailed {
435 source: LinearQuantizeHeaderError,
437 },
438 #[error(
441 "LinearQuantize can only decode one-dimensional arrays but received an array of shape {shape:?}"
442 )]
443 EncodedDataNotOneDimensional {
444 shape: Vec<usize>,
446 },
447 #[error("LinearQuantize failed to decode the header")]
449 HeaderDecodeFailed {
450 source: LinearQuantizeHeaderError,
452 },
453 #[error(
456 "LinearQuantize decoded an invalid array shape header which does not fit the decoded data"
457 )]
458 DecodeInvalidShapeHeader {
459 #[from]
461 source: ShapeError,
462 },
463 #[error("LinearQuantize cannot decode the provided dtype {dtype}")]
465 InvalidEncodedDType {
466 dtype: AnyArrayDType,
468 },
469 #[error(
472 "LinearQuantize cannot decode the provided dtype {provided} which differs from the configured dtype {configured}"
473 )]
474 MismatchedDecodeIntoDtype {
475 configured: LinearQuantizeDType,
477 provided: AnyArrayDType,
479 },
480 #[error(
483 "LinearQuantize cannot decode the decoded array of shape {decoded:?} into the provided array of shape {provided:?}"
484 )]
485 MismatchedDecodeIntoShape {
486 decoded: Vec<usize>,
488 provided: Vec<usize>,
490 },
491}
492
493#[derive(Debug, Error)]
494#[error(transparent)]
495pub struct LinearQuantizeHeaderError(postcard::Error);
497
498pub fn quantize<
509 T: Float + ConstZero + ConstOne + Serialize,
510 Q: Unsigned,
511 S: Data<Elem = T>,
512 D: Dimension,
513>(
514 data: &ArrayBase<S, D>,
515 quantize: impl Fn(T) -> Q,
516) -> Result<Vec<Q>, LinearQuantizeCodecError> {
517 if !Zip::from(data).all(|x| x.is_finite()) {
518 return Err(LinearQuantizeCodecError::NonFiniteData);
519 }
520
521 let (minimum, maximum) = data.first().map_or((T::ZERO, T::ONE), |first| {
522 (
523 Zip::from(data).fold(*first, |a, b| a.min(*b)),
524 Zip::from(data).fold(*first, |a, b| a.max(*b)),
525 )
526 });
527
528 let header = postcard::to_extend(
529 &CompressionHeader {
530 shape: Cow::Borrowed(data.shape()),
531 minimum,
532 maximum,
533 version: StaticCodecVersion,
534 },
535 Vec::new(),
536 )
537 .map_err(|err| LinearQuantizeCodecError::HeaderEncodeFailed {
538 source: LinearQuantizeHeaderError(err),
539 })?;
540
541 let mut encoded: Vec<Q> = vec![Q::ZERO; header.len().div_ceil(std::mem::size_of::<Q>())];
542 #[expect(unsafe_code)]
543 unsafe {
545 std::ptr::copy_nonoverlapping(header.as_ptr(), encoded.as_mut_ptr().cast(), header.len());
546 }
547 encoded.reserve(data.len());
548
549 if maximum == minimum {
550 encoded.resize(encoded.len() + data.len(), quantize(T::ZERO));
551 } else {
552 encoded.extend(
553 data.iter()
554 .map(|x| quantize((*x - minimum) / (maximum - minimum))),
555 );
556 }
557
558 Ok(encoded)
559}
560
561pub fn reconstruct<T: Float + DeserializeOwned, Q: Unsigned>(
570 encoded: &[Q],
571 floatify: impl Fn(Q) -> T,
572) -> Result<ArrayD<T>, LinearQuantizeCodecError> {
573 #[expect(unsafe_code)]
574 let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
576 std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
577 })
578 .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
579 source: LinearQuantizeHeaderError(err),
580 })?;
581
582 let encoded = encoded
583 .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
584 .unwrap_or(&[]);
585
586 let decoded = encoded
587 .iter()
588 .map(|x| header.minimum + (floatify(*x) * (header.maximum - header.minimum)))
589 .map(|x| x.clamp(header.minimum, header.maximum))
590 .collect();
591
592 let decoded = Array::from_shape_vec(&*header.shape, decoded)?;
593
594 Ok(decoded)
595}
596
597pub fn reconstruct_into<T: Float + DeserializeOwned, Q: Unsigned>(
608 encoded: &[Q],
609 mut decoded: ArrayViewMutD<T>,
610 floatify: impl Fn(Q) -> T,
611) -> Result<(), LinearQuantizeCodecError> {
612 #[expect(unsafe_code)]
613 let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
615 std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
616 })
617 .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
618 source: LinearQuantizeHeaderError(err),
619 })?;
620
621 let encoded = encoded
622 .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
623 .unwrap_or(&[]);
624
625 if decoded.shape() != &*header.shape {
626 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoShape {
627 decoded: header.shape.into_owned(),
628 provided: decoded.shape().to_vec(),
629 });
630 }
631
632 for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
634 *d = (header.minimum + (floatify(*e) * (header.maximum - header.minimum)))
635 .clamp(header.minimum, header.maximum);
636 }
637
638 Ok(())
639}
640
641fn scale_for_bits<T: Float + From<u8> + ConstOne>(bits: u8) -> T {
643 <T as From<u8>>::from(bits).exp2() - T::ONE
644}
645
646pub trait Unsigned: Copy {
648 const ZERO: Self;
650}
651
652impl Unsigned for u8 {
653 const ZERO: Self = 0;
654}
655
656impl Unsigned for u16 {
657 const ZERO: Self = 0;
658}
659
660impl Unsigned for u32 {
661 const ZERO: Self = 0;
662}
663
664impl Unsigned for u64 {
665 const ZERO: Self = 0;
666}
667
668#[derive(Serialize, Deserialize)]
669struct CompressionHeader<'a, T> {
670 #[serde(borrow)]
671 shape: Cow<'a, [usize]>,
672 minimum: T,
673 maximum: T,
674 version: LinearQuantizeCodecVersion,
675}
676
677#[cfg(test)]
678mod tests {
679 use ndarray::CowArray;
680
681 use super::*;
682
683 #[test]
684 fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> {
685 for bits in 1..=16 {
686 let codec = LinearQuantizeCodec {
687 dtype: LinearQuantizeDType::F32,
688 #[expect(unsafe_code)]
689 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
690 version: StaticCodecVersion,
691 };
692
693 let mut data: Vec<f32> = (0..(u16::MAX >> (16 - bits)))
694 .step_by(1 << (bits.max(8) - 8))
695 .map(f32::from)
696 .collect();
697 data.push(f32::from(u16::MAX >> (16 - bits)));
698
699 let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
700 let decoded = codec.decode(encoded.cow())?;
701
702 let AnyArray::F32(decoded) = decoded else {
703 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
704 configured: LinearQuantizeDType::F32,
705 provided: decoded.dtype(),
706 });
707 };
708
709 for (o, d) in data.iter().zip(decoded.iter()) {
710 assert_eq!(o.to_bits(), d.to_bits());
711 }
712 }
713
714 Ok(())
715 }
716
717 #[test]
718 fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> {
719 for bits in 1..=64 {
720 let codec = LinearQuantizeCodec {
721 dtype: LinearQuantizeDType::F32,
722 #[expect(unsafe_code)]
723 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
724 version: StaticCodecVersion,
725 };
726
727 #[expect(clippy::cast_precision_loss)]
728 let mut data: Vec<f32> = (0..(u64::MAX >> (64 - bits)))
729 .step_by(1 << (bits.max(8) - 8))
730 .map(|x| x as f32)
731 .collect();
732 #[expect(clippy::cast_precision_loss)]
733 data.push((u64::MAX >> (64 - bits)) as f32);
734
735 let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
736 let decoded = codec.decode(encoded.cow())?;
737
738 let AnyArray::F32(decoded) = decoded else {
739 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
740 configured: LinearQuantizeDType::F32,
741 provided: decoded.dtype(),
742 });
743 };
744
745 for (o, d) in data.iter().zip(decoded.iter()) {
746 assert_eq!(o.to_bits(), d.to_bits());
747 }
748 }
749
750 Ok(())
751 }
752
753 #[test]
754 fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> {
755 for bits in 1..=32 {
756 let codec = LinearQuantizeCodec {
757 dtype: LinearQuantizeDType::F64,
758 #[expect(unsafe_code)]
759 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
760 version: StaticCodecVersion,
761 };
762
763 let mut data: Vec<f64> = (0..(u32::MAX >> (32 - bits)))
764 .step_by(1 << (bits.max(8) - 8))
765 .map(f64::from)
766 .collect();
767 data.push(f64::from(u32::MAX >> (32 - bits)));
768
769 let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
770 let decoded = codec.decode(encoded.cow())?;
771
772 let AnyArray::F64(decoded) = decoded else {
773 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
774 configured: LinearQuantizeDType::F64,
775 provided: decoded.dtype(),
776 });
777 };
778
779 for (o, d) in data.iter().zip(decoded.iter()) {
780 assert_eq!(o.to_bits(), d.to_bits());
781 }
782 }
783
784 Ok(())
785 }
786
787 #[test]
788 fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> {
789 for bits in 1..=64 {
790 let codec = LinearQuantizeCodec {
791 dtype: LinearQuantizeDType::F64,
792 #[expect(unsafe_code)]
793 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
794 version: StaticCodecVersion,
795 };
796
797 #[expect(clippy::cast_precision_loss)]
798 let mut data: Vec<f64> = (0..(u64::MAX >> (64 - bits)))
799 .step_by(1 << (bits.max(8) - 8))
800 .map(|x| x as f64)
801 .collect();
802 #[expect(clippy::cast_precision_loss)]
803 data.push((u64::MAX >> (64 - bits)) as f64);
804
805 let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
806 let decoded = codec.decode(encoded.cow())?;
807
808 let AnyArray::F64(decoded) = decoded else {
809 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
810 configured: LinearQuantizeDType::F64,
811 provided: decoded.dtype(),
812 });
813 };
814
815 for (o, d) in data.iter().zip(decoded.iter()) {
816 assert_eq!(o.to_bits(), d.to_bits());
817 }
818 }
819
820 Ok(())
821 }
822
823 #[test]
824 fn const_data_roundtrip() -> Result<(), LinearQuantizeCodecError> {
825 for bits in 1..=64 {
826 let data = [42.0, 42.0, 42.0, 42.0];
827
828 let codec = LinearQuantizeCodec {
829 dtype: LinearQuantizeDType::F64,
830 #[expect(unsafe_code)]
831 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
832 version: StaticCodecVersion,
833 };
834
835 let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
836 let decoded = codec.decode(encoded.cow())?;
837
838 let AnyArray::F64(decoded) = decoded else {
839 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
840 configured: LinearQuantizeDType::F64,
841 provided: decoded.dtype(),
842 });
843 };
844
845 for (o, d) in data.iter().zip(decoded.iter()) {
846 assert_eq!(o.to_bits(), d.to_bits());
847 }
848 }
849
850 Ok(())
851 }
852}