1#![expect(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayD, ArrayViewMutD, Data, Dimension, ShapeError, Zip};
25use num_traits::{ConstOne, ConstZero, Float};
26use numcodecs::{
27 AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec,
28 StaticCodecConfig,
29};
30use schemars::{JsonSchema, JsonSchema_repr};
31use serde::{de::DeserializeOwned, Deserialize, Serialize};
32use serde_repr::{Deserialize_repr, Serialize_repr};
33use thiserror::Error;
34use twofloat::TwoFloat;
35
36#[derive(Clone, Serialize, Deserialize, JsonSchema)]
37#[serde(deny_unknown_fields)]
38pub struct LinearQuantizeCodec {
43 pub dtype: LinearQuantizeDType,
45 pub bits: LinearQuantizeBins,
47}
48
49#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema)]
51#[schemars(extend("enum" = ["f32", "float32", "f64", "float64"]))]
52#[expect(missing_docs)]
53pub enum LinearQuantizeDType {
54 #[serde(rename = "f32", alias = "float32")]
55 F32,
56 #[serde(rename = "f64", alias = "float64")]
57 F64,
58}
59
60impl fmt::Display for LinearQuantizeDType {
61 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
62 fmt.write_str(match self {
63 Self::F32 => "f32",
64 Self::F64 => "f64",
65 })
66 }
67}
68
69#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
75#[repr(u8)]
76#[rustfmt::skip]
77#[expect(missing_docs)]
78pub enum LinearQuantizeBins {
79 _1B1 = 1, _1B2, _1B3, _1B4, _1B5, _1B6, _1B7, _1B8,
80 _1B9, _1B10, _1B11, _1B12, _1B13, _1B14, _1B15, _1B16,
81 _1B17, _1B18, _1B19, _1B20, _1B21, _1B22, _1B23, _1B24,
82 _1B25, _1B26, _1B27, _1B28, _1B29, _1B30, _1B31, _1B32,
83 _1B33, _1B34, _1B35, _1B36, _1B37, _1B38, _1B39, _1B40,
84 _1B41, _1B42, _1B43, _1B44, _1B45, _1B46, _1B47, _1B48,
85 _1B49, _1B50, _1B51, _1B52, _1B53, _1B54, _1B55, _1B56,
86 _1B57, _1B58, _1B59, _1B60, _1B61, _1B62, _1B63, _1B64,
87}
88
89impl Codec for LinearQuantizeCodec {
90 type Error = LinearQuantizeCodecError;
91
92 #[expect(clippy::too_many_lines)]
93 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
94 let encoded = match (&data, self.dtype) {
95 (AnyCowArray::F32(data), LinearQuantizeDType::F32) => match self.bits as u8 {
96 bits @ ..=8 => AnyArray::U8(
97 Array1::from_vec(quantize(data, |x| {
98 let max = f32::from(u8::MAX >> (8 - bits));
99 let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
100 #[expect(unsafe_code)]
101 unsafe {
103 x.to_int_unchecked::<u8>()
104 }
105 })?)
106 .into_dyn(),
107 ),
108 bits @ 9..=16 => AnyArray::U16(
109 Array1::from_vec(quantize(data, |x| {
110 let max = f32::from(u16::MAX >> (16 - bits));
111 let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
112 #[expect(unsafe_code)]
113 unsafe {
115 x.to_int_unchecked::<u16>()
116 }
117 })?)
118 .into_dyn(),
119 ),
120 bits @ 17..=32 => AnyArray::U32(
121 Array1::from_vec(quantize(data, |x| {
122 let max = f64::from(u32::MAX >> (32 - bits));
124 let x = f64::from(x)
125 .mul_add(scale_for_bits::<f64>(bits), 0.5)
126 .clamp(0.0, max);
127 #[expect(unsafe_code)]
128 unsafe {
130 x.to_int_unchecked::<u32>()
131 }
132 })?)
133 .into_dyn(),
134 ),
135 bits @ 33.. => AnyArray::U64(
136 Array1::from_vec(quantize(data, |x| {
137 let max = TwoFloat::from(u64::MAX >> (64 - bits));
139 let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
140 + TwoFloat::from(0.5))
141 .max(TwoFloat::from(0.0))
142 .min(max);
143 #[expect(unsafe_code)]
144 unsafe {
146 u64::try_from(x).unwrap_unchecked()
147 }
148 })?)
149 .into_dyn(),
150 ),
151 },
152 (AnyCowArray::F64(data), LinearQuantizeDType::F64) => match self.bits as u8 {
153 bits @ ..=8 => AnyArray::U8(
154 Array1::from_vec(quantize(data, |x| {
155 let max = f64::from(u8::MAX >> (8 - bits));
156 let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
157 #[expect(unsafe_code)]
158 unsafe {
160 x.to_int_unchecked::<u8>()
161 }
162 })?)
163 .into_dyn(),
164 ),
165 bits @ 9..=16 => AnyArray::U16(
166 Array1::from_vec(quantize(data, |x| {
167 let max = f64::from(u16::MAX >> (16 - bits));
168 let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
169 #[expect(unsafe_code)]
170 unsafe {
172 x.to_int_unchecked::<u16>()
173 }
174 })?)
175 .into_dyn(),
176 ),
177 bits @ 17..=32 => AnyArray::U32(
178 Array1::from_vec(quantize(data, |x| {
179 let max = f64::from(u32::MAX >> (32 - bits));
180 let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
181 #[expect(unsafe_code)]
182 unsafe {
184 x.to_int_unchecked::<u32>()
185 }
186 })?)
187 .into_dyn(),
188 ),
189 bits @ 33.. => AnyArray::U64(
190 Array1::from_vec(quantize(data, |x| {
191 let max = TwoFloat::from(u64::MAX >> (64 - bits));
193 let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
194 + TwoFloat::from(0.5))
195 .max(TwoFloat::from(0.0))
196 .min(max);
197 #[expect(unsafe_code)]
198 unsafe {
200 u64::try_from(x).unwrap_unchecked()
201 }
202 })?)
203 .into_dyn(),
204 ),
205 },
206 (data, dtype) => {
207 return Err(LinearQuantizeCodecError::MismatchedEncodeDType {
208 configured: dtype,
209 provided: data.dtype(),
210 });
211 }
212 };
213
214 Ok(encoded)
215 }
216
217 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
218 #[expect(clippy::option_if_let_else)]
219 fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
220 array: &ArrayBase<S, D>,
221 ) -> Cow<[T]> {
222 if let Some(data) = array.as_slice() {
223 Cow::Borrowed(data)
224 } else {
225 Cow::Owned(array.iter().copied().collect())
226 }
227 }
228
229 if !matches!(encoded.shape(), [_]) {
230 return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
231 shape: encoded.shape().to_vec(),
232 });
233 }
234
235 let decoded = match (&encoded, self.dtype) {
236 (AnyCowArray::U8(encoded), LinearQuantizeDType::F32) => {
237 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
238 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
239 })?)
240 }
241 (AnyCowArray::U16(encoded), LinearQuantizeDType::F32) => {
242 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
243 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
244 })?)
245 }
246 (AnyCowArray::U32(encoded), LinearQuantizeDType::F32) => {
247 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
248 let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
250 #[expect(clippy::cast_possible_truncation)]
251 let x = x as f32;
252 x
253 })?)
254 }
255 (AnyCowArray::U64(encoded), LinearQuantizeDType::F32) => {
256 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
257 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
259 f32::from(x)
260 })?)
261 }
262 (AnyCowArray::U8(encoded), LinearQuantizeDType::F64) => {
263 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
264 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
265 })?)
266 }
267 (AnyCowArray::U16(encoded), LinearQuantizeDType::F64) => {
268 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
269 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
270 })?)
271 }
272 (AnyCowArray::U32(encoded), LinearQuantizeDType::F64) => {
273 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
274 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
275 })?)
276 }
277 (AnyCowArray::U64(encoded), LinearQuantizeDType::F64) => {
278 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
279 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
281 f64::from(x)
282 })?)
283 }
284 (encoded, _dtype) => {
285 return Err(LinearQuantizeCodecError::InvalidEncodedDType {
286 dtype: encoded.dtype(),
287 })
288 }
289 };
290
291 Ok(decoded)
292 }
293
294 fn decode_into(
295 &self,
296 encoded: AnyArrayView,
297 decoded: AnyArrayViewMut,
298 ) -> Result<(), Self::Error> {
299 fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
300 array: &ArrayBase<S, D>,
301 ) -> Cow<[T]> {
302 #[expect(clippy::option_if_let_else)]
303 if let Some(data) = array.as_slice() {
304 Cow::Borrowed(data)
305 } else {
306 Cow::Owned(array.iter().copied().collect())
307 }
308 }
309
310 if !matches!(encoded.shape(), [_]) {
311 return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
312 shape: encoded.shape().to_vec(),
313 });
314 }
315
316 match (decoded, self.dtype) {
317 (AnyArrayViewMut::F32(decoded), LinearQuantizeDType::F32) => {
318 match &encoded {
319 AnyArrayView::U8(encoded) => {
320 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
321 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
322 })
323 }
324 AnyArrayView::U16(encoded) => {
325 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
326 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
327 })
328 }
329 AnyArrayView::U32(encoded) => {
330 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
331 let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
333 #[expect(clippy::cast_possible_truncation)]
334 let x = x as f32;
335 x
336 })
337 }
338 AnyArrayView::U64(encoded) => {
339 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
340 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
342 f32::from(x)
343 })
344 }
345 encoded => {
346 return Err(LinearQuantizeCodecError::InvalidEncodedDType {
347 dtype: encoded.dtype(),
348 })
349 }
350 }
351 }
352 (AnyArrayViewMut::F64(decoded), LinearQuantizeDType::F64) => {
353 match &encoded {
354 AnyArrayView::U8(encoded) => {
355 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
356 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
357 })
358 }
359 AnyArrayView::U16(encoded) => {
360 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
361 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
362 })
363 }
364 AnyArrayView::U32(encoded) => {
365 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
366 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
367 })
368 }
369 AnyArrayView::U64(encoded) => {
370 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
371 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
373 f64::from(x)
374 })
375 }
376 encoded => {
377 return Err(LinearQuantizeCodecError::InvalidEncodedDType {
378 dtype: encoded.dtype(),
379 })
380 }
381 }
382 }
383 (decoded, dtype) => {
384 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
385 configured: dtype,
386 provided: decoded.dtype(),
387 })
388 }
389 }?;
390
391 Ok(())
392 }
393}
394
395impl StaticCodec for LinearQuantizeCodec {
396 const CODEC_ID: &'static str = "linear-quantize";
397
398 type Config<'de> = Self;
399
400 fn from_config(config: Self::Config<'_>) -> Self {
401 config
402 }
403
404 fn get_config(&self) -> StaticCodecConfig<Self> {
405 StaticCodecConfig::from(self)
406 }
407}
408
409#[derive(Debug, Error)]
410pub enum LinearQuantizeCodecError {
412 #[error("LinearQuantize cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
415 MismatchedEncodeDType {
416 configured: LinearQuantizeDType,
418 provided: AnyArrayDType,
420 },
421 #[error("LinearQuantize does not support non-finite (infinite or NaN) floating point data")]
424 NonFiniteData,
425 #[error("LinearQuantize failed to encode the header")]
427 HeaderEncodeFailed {
428 source: LinearQuantizeHeaderError,
430 },
431 #[error("LinearQuantize can only decode one-dimensional arrays but received an array of shape {shape:?}")]
434 EncodedDataNotOneDimensional {
435 shape: Vec<usize>,
437 },
438 #[error("LinearQuantize failed to decode the header")]
440 HeaderDecodeFailed {
441 source: LinearQuantizeHeaderError,
443 },
444 #[error(
447 "LinearQuantize decoded an invalid array shape header which does not fit the decoded data"
448 )]
449 DecodeInvalidShapeHeader {
450 #[from]
452 source: ShapeError,
453 },
454 #[error("LinearQuantize cannot decode the provided dtype {dtype}")]
456 InvalidEncodedDType {
457 dtype: AnyArrayDType,
459 },
460 #[error("LinearQuantize cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
463 MismatchedDecodeIntoDtype {
464 configured: LinearQuantizeDType,
466 provided: AnyArrayDType,
468 },
469 #[error("LinearQuantize cannot decode the decoded array of shape {decoded:?} into the provided array of shape {provided:?}")]
472 MismatchedDecodeIntoShape {
473 decoded: Vec<usize>,
475 provided: Vec<usize>,
477 },
478}
479
480#[derive(Debug, Error)]
481#[error(transparent)]
482pub struct LinearQuantizeHeaderError(postcard::Error);
484
485pub fn quantize<
496 T: Float + ConstZero + ConstOne + Serialize,
497 Q: Unsigned,
498 S: Data<Elem = T>,
499 D: Dimension,
500>(
501 data: &ArrayBase<S, D>,
502 quantize: impl Fn(T) -> Q,
503) -> Result<Vec<Q>, LinearQuantizeCodecError> {
504 if !Zip::from(data).all(|x| x.is_finite()) {
505 return Err(LinearQuantizeCodecError::NonFiniteData);
506 }
507
508 let (minimum, maximum) = data.first().map_or((T::ZERO, T::ONE), |first| {
509 (
510 Zip::from(data).fold(*first, |a, b| a.min(*b)),
511 Zip::from(data).fold(*first, |a, b| a.max(*b)),
512 )
513 });
514
515 let header = postcard::to_extend(
516 &CompressionHeader {
517 shape: Cow::Borrowed(data.shape()),
518 minimum,
519 maximum,
520 },
521 Vec::new(),
522 )
523 .map_err(|err| LinearQuantizeCodecError::HeaderEncodeFailed {
524 source: LinearQuantizeHeaderError(err),
525 })?;
526
527 let mut encoded: Vec<Q> = vec![Q::ZERO; header.len().div_ceil(std::mem::size_of::<Q>())];
528 #[expect(unsafe_code)]
529 unsafe {
531 std::ptr::copy_nonoverlapping(header.as_ptr(), encoded.as_mut_ptr().cast(), header.len());
532 }
533 encoded.reserve(data.len());
534
535 if maximum == minimum {
536 encoded.resize(encoded.len() + data.len(), quantize(T::ZERO));
537 } else {
538 encoded.extend(
539 data.iter()
540 .map(|x| quantize((*x - minimum) / (maximum - minimum))),
541 );
542 }
543
544 Ok(encoded)
545}
546
547pub fn reconstruct<T: Float + DeserializeOwned, Q: Unsigned>(
556 encoded: &[Q],
557 floatify: impl Fn(Q) -> T,
558) -> Result<ArrayD<T>, LinearQuantizeCodecError> {
559 #[expect(unsafe_code)]
560 let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
562 std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
563 })
564 .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
565 source: LinearQuantizeHeaderError(err),
566 })?;
567
568 let encoded = encoded
569 .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
570 .unwrap_or(&[]);
571
572 let decoded = encoded
573 .iter()
574 .map(|x| header.minimum + (floatify(*x) * (header.maximum - header.minimum)))
575 .map(|x| x.clamp(header.minimum, header.maximum))
576 .collect();
577
578 let decoded = Array::from_shape_vec(&*header.shape, decoded)?;
579
580 Ok(decoded)
581}
582
583pub fn reconstruct_into<T: Float + DeserializeOwned, Q: Unsigned>(
594 encoded: &[Q],
595 mut decoded: ArrayViewMutD<T>,
596 floatify: impl Fn(Q) -> T,
597) -> Result<(), LinearQuantizeCodecError> {
598 #[expect(unsafe_code)]
599 let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
601 std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
602 })
603 .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
604 source: LinearQuantizeHeaderError(err),
605 })?;
606
607 let encoded = encoded
608 .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
609 .unwrap_or(&[]);
610
611 if decoded.shape() != &*header.shape {
612 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoShape {
613 decoded: header.shape.into_owned(),
614 provided: decoded.shape().to_vec(),
615 });
616 }
617
618 for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
620 *d = (header.minimum + (floatify(*e) * (header.maximum - header.minimum)))
621 .clamp(header.minimum, header.maximum);
622 }
623
624 Ok(())
625}
626
627fn scale_for_bits<T: Float + From<u8> + ConstOne>(bits: u8) -> T {
629 <T as From<u8>>::from(bits).exp2() - T::ONE
630}
631
632pub trait Unsigned: Copy {
634 const ZERO: Self;
636}
637
638impl Unsigned for u8 {
639 const ZERO: Self = 0;
640}
641
642impl Unsigned for u16 {
643 const ZERO: Self = 0;
644}
645
646impl Unsigned for u32 {
647 const ZERO: Self = 0;
648}
649
650impl Unsigned for u64 {
651 const ZERO: Self = 0;
652}
653
654#[derive(Serialize, Deserialize)]
655struct CompressionHeader<'a, T> {
656 #[serde(borrow)]
657 shape: Cow<'a, [usize]>,
658 minimum: T,
659 maximum: T,
660}
661
662#[cfg(test)]
663mod tests {
664 use ndarray::CowArray;
665
666 use super::*;
667
668 #[test]
669 fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> {
670 for bits in 1..=16 {
671 let codec = LinearQuantizeCodec {
672 dtype: LinearQuantizeDType::F32,
673 #[expect(unsafe_code)]
674 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
675 };
676
677 let mut data: Vec<f32> = (0..(u16::MAX >> (16 - bits)))
678 .step_by(1 << (bits.max(8) - 8))
679 .map(f32::from)
680 .collect();
681 data.push(f32::from(u16::MAX >> (16 - bits)));
682
683 let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
684 let decoded = codec.decode(encoded.cow())?;
685
686 let AnyArray::F32(decoded) = decoded else {
687 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
688 configured: LinearQuantizeDType::F32,
689 provided: decoded.dtype(),
690 });
691 };
692
693 for (o, d) in data.iter().zip(decoded.iter()) {
694 assert_eq!(o.to_bits(), d.to_bits());
695 }
696 }
697
698 Ok(())
699 }
700
701 #[test]
702 fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> {
703 for bits in 1..=64 {
704 let codec = LinearQuantizeCodec {
705 dtype: LinearQuantizeDType::F32,
706 #[expect(unsafe_code)]
707 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
708 };
709
710 #[expect(clippy::cast_precision_loss)]
711 let mut data: Vec<f32> = (0..(u64::MAX >> (64 - bits)))
712 .step_by(1 << (bits.max(8) - 8))
713 .map(|x| x as f32)
714 .collect();
715 #[expect(clippy::cast_precision_loss)]
716 data.push((u64::MAX >> (64 - bits)) as f32);
717
718 let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
719 let decoded = codec.decode(encoded.cow())?;
720
721 let AnyArray::F32(decoded) = decoded else {
722 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
723 configured: LinearQuantizeDType::F32,
724 provided: decoded.dtype(),
725 });
726 };
727
728 for (o, d) in data.iter().zip(decoded.iter()) {
729 assert_eq!(o.to_bits(), d.to_bits());
730 }
731 }
732
733 Ok(())
734 }
735
736 #[test]
737 fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> {
738 for bits in 1..=32 {
739 let codec = LinearQuantizeCodec {
740 dtype: LinearQuantizeDType::F64,
741 #[expect(unsafe_code)]
742 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
743 };
744
745 let mut data: Vec<f64> = (0..(u32::MAX >> (32 - bits)))
746 .step_by(1 << (bits.max(8) - 8))
747 .map(f64::from)
748 .collect();
749 data.push(f64::from(u32::MAX >> (32 - bits)));
750
751 let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
752 let decoded = codec.decode(encoded.cow())?;
753
754 let AnyArray::F64(decoded) = decoded else {
755 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
756 configured: LinearQuantizeDType::F64,
757 provided: decoded.dtype(),
758 });
759 };
760
761 for (o, d) in data.iter().zip(decoded.iter()) {
762 assert_eq!(o.to_bits(), d.to_bits());
763 }
764 }
765
766 Ok(())
767 }
768
769 #[test]
770 fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> {
771 for bits in 1..=64 {
772 let codec = LinearQuantizeCodec {
773 dtype: LinearQuantizeDType::F64,
774 #[expect(unsafe_code)]
775 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
776 };
777
778 #[expect(clippy::cast_precision_loss)]
779 let mut data: Vec<f64> = (0..(u64::MAX >> (64 - bits)))
780 .step_by(1 << (bits.max(8) - 8))
781 .map(|x| x as f64)
782 .collect();
783 #[expect(clippy::cast_precision_loss)]
784 data.push((u64::MAX >> (64 - bits)) as f64);
785
786 let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
787 let decoded = codec.decode(encoded.cow())?;
788
789 let AnyArray::F64(decoded) = decoded else {
790 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
791 configured: LinearQuantizeDType::F64,
792 provided: decoded.dtype(),
793 });
794 };
795
796 for (o, d) in data.iter().zip(decoded.iter()) {
797 assert_eq!(o.to_bits(), d.to_bits());
798 }
799 }
800
801 Ok(())
802 }
803
804 #[test]
805 fn const_data_roundtrip() -> Result<(), LinearQuantizeCodecError> {
806 for bits in 1..=64 {
807 let data = [42.0, 42.0, 42.0, 42.0];
808
809 let codec = LinearQuantizeCodec {
810 dtype: LinearQuantizeDType::F64,
811 #[expect(unsafe_code)]
812 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
813 };
814
815 let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
816 let decoded = codec.decode(encoded.cow())?;
817
818 let AnyArray::F64(decoded) = decoded else {
819 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
820 configured: LinearQuantizeDType::F64,
821 provided: decoded.dtype(),
822 });
823 };
824
825 for (o, d) in data.iter().zip(decoded.iter()) {
826 assert_eq!(o.to_bits(), d.to_bits());
827 }
828 }
829
830 Ok(())
831 }
832}