1#![expect(clippy::multiple_crate_versions)] use std::{borrow::Cow, fmt};
23
24use ndarray::{Array, Array1, ArrayBase, ArrayD, ArrayViewMutD, Data, Dimension, ShapeError, Zip};
25use num_traits::{ConstOne, ConstZero, Float};
26use numcodecs::{
27 AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec,
28 StaticCodecConfig, StaticCodecVersion,
29};
30use schemars::{JsonSchema, JsonSchema_repr};
31use serde::{de::DeserializeOwned, Deserialize, Serialize};
32use serde_repr::{Deserialize_repr, Serialize_repr};
33use thiserror::Error;
34use twofloat::TwoFloat;
35
36type LinearQuantizeCodecVersion = StaticCodecVersion<0, 1, 0>;
37
38#[derive(Clone, Serialize, Deserialize, JsonSchema)]
39#[serde(deny_unknown_fields)]
40pub struct LinearQuantizeCodec {
45 pub dtype: LinearQuantizeDType,
47 pub bits: LinearQuantizeBins,
49 #[serde(default, rename = "_version")]
51 pub version: LinearQuantizeCodecVersion,
52}
53
54#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema)]
56#[schemars(extend("enum" = ["f32", "float32", "f64", "float64"]))]
57#[expect(missing_docs)]
58pub enum LinearQuantizeDType {
59 #[serde(rename = "f32", alias = "float32")]
60 F32,
61 #[serde(rename = "f64", alias = "float64")]
62 F64,
63}
64
65impl fmt::Display for LinearQuantizeDType {
66 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
67 fmt.write_str(match self {
68 Self::F32 => "f32",
69 Self::F64 => "f64",
70 })
71 }
72}
73
74#[derive(Copy, Clone, Serialize_repr, Deserialize_repr, JsonSchema_repr)]
80#[repr(u8)]
81#[rustfmt::skip]
82#[expect(missing_docs)]
83pub enum LinearQuantizeBins {
84 _1B1 = 1, _1B2, _1B3, _1B4, _1B5, _1B6, _1B7, _1B8,
85 _1B9, _1B10, _1B11, _1B12, _1B13, _1B14, _1B15, _1B16,
86 _1B17, _1B18, _1B19, _1B20, _1B21, _1B22, _1B23, _1B24,
87 _1B25, _1B26, _1B27, _1B28, _1B29, _1B30, _1B31, _1B32,
88 _1B33, _1B34, _1B35, _1B36, _1B37, _1B38, _1B39, _1B40,
89 _1B41, _1B42, _1B43, _1B44, _1B45, _1B46, _1B47, _1B48,
90 _1B49, _1B50, _1B51, _1B52, _1B53, _1B54, _1B55, _1B56,
91 _1B57, _1B58, _1B59, _1B60, _1B61, _1B62, _1B63, _1B64,
92}
93
94impl Codec for LinearQuantizeCodec {
95 type Error = LinearQuantizeCodecError;
96
97 #[expect(clippy::too_many_lines)]
98 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
99 let encoded = match (&data, self.dtype) {
100 (AnyCowArray::F32(data), LinearQuantizeDType::F32) => match self.bits as u8 {
101 bits @ ..=8 => AnyArray::U8(
102 Array1::from_vec(quantize(data, |x| {
103 let max = f32::from(u8::MAX >> (8 - bits));
104 let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
105 #[expect(unsafe_code)]
106 unsafe {
108 x.to_int_unchecked::<u8>()
109 }
110 })?)
111 .into_dyn(),
112 ),
113 bits @ 9..=16 => AnyArray::U16(
114 Array1::from_vec(quantize(data, |x| {
115 let max = f32::from(u16::MAX >> (16 - bits));
116 let x = x.mul_add(scale_for_bits::<f32>(bits), 0.5).clamp(0.0, max);
117 #[expect(unsafe_code)]
118 unsafe {
120 x.to_int_unchecked::<u16>()
121 }
122 })?)
123 .into_dyn(),
124 ),
125 bits @ 17..=32 => AnyArray::U32(
126 Array1::from_vec(quantize(data, |x| {
127 let max = f64::from(u32::MAX >> (32 - bits));
129 let x = f64::from(x)
130 .mul_add(scale_for_bits::<f64>(bits), 0.5)
131 .clamp(0.0, max);
132 #[expect(unsafe_code)]
133 unsafe {
135 x.to_int_unchecked::<u32>()
136 }
137 })?)
138 .into_dyn(),
139 ),
140 bits @ 33.. => AnyArray::U64(
141 Array1::from_vec(quantize(data, |x| {
142 let max = TwoFloat::from(u64::MAX >> (64 - bits));
144 let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
145 + TwoFloat::from(0.5))
146 .max(TwoFloat::from(0.0))
147 .min(max);
148 #[expect(unsafe_code)]
149 unsafe {
151 u64::try_from(x).unwrap_unchecked()
152 }
153 })?)
154 .into_dyn(),
155 ),
156 },
157 (AnyCowArray::F64(data), LinearQuantizeDType::F64) => match self.bits as u8 {
158 bits @ ..=8 => AnyArray::U8(
159 Array1::from_vec(quantize(data, |x| {
160 let max = f64::from(u8::MAX >> (8 - bits));
161 let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
162 #[expect(unsafe_code)]
163 unsafe {
165 x.to_int_unchecked::<u8>()
166 }
167 })?)
168 .into_dyn(),
169 ),
170 bits @ 9..=16 => AnyArray::U16(
171 Array1::from_vec(quantize(data, |x| {
172 let max = f64::from(u16::MAX >> (16 - bits));
173 let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
174 #[expect(unsafe_code)]
175 unsafe {
177 x.to_int_unchecked::<u16>()
178 }
179 })?)
180 .into_dyn(),
181 ),
182 bits @ 17..=32 => AnyArray::U32(
183 Array1::from_vec(quantize(data, |x| {
184 let max = f64::from(u32::MAX >> (32 - bits));
185 let x = x.mul_add(scale_for_bits::<f64>(bits), 0.5).clamp(0.0, max);
186 #[expect(unsafe_code)]
187 unsafe {
189 x.to_int_unchecked::<u32>()
190 }
191 })?)
192 .into_dyn(),
193 ),
194 bits @ 33.. => AnyArray::U64(
195 Array1::from_vec(quantize(data, |x| {
196 let max = TwoFloat::from(u64::MAX >> (64 - bits));
198 let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
199 + TwoFloat::from(0.5))
200 .max(TwoFloat::from(0.0))
201 .min(max);
202 #[expect(unsafe_code)]
203 unsafe {
205 u64::try_from(x).unwrap_unchecked()
206 }
207 })?)
208 .into_dyn(),
209 ),
210 },
211 (data, dtype) => {
212 return Err(LinearQuantizeCodecError::MismatchedEncodeDType {
213 configured: dtype,
214 provided: data.dtype(),
215 });
216 }
217 };
218
219 Ok(encoded)
220 }
221
222 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
223 #[expect(clippy::option_if_let_else)]
224 fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
225 array: &ArrayBase<S, D>,
226 ) -> Cow<[T]> {
227 if let Some(data) = array.as_slice() {
228 Cow::Borrowed(data)
229 } else {
230 Cow::Owned(array.iter().copied().collect())
231 }
232 }
233
234 if !matches!(encoded.shape(), [_]) {
235 return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
236 shape: encoded.shape().to_vec(),
237 });
238 }
239
240 let decoded = match (&encoded, self.dtype) {
241 (AnyCowArray::U8(encoded), LinearQuantizeDType::F32) => {
242 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
243 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
244 })?)
245 }
246 (AnyCowArray::U16(encoded), LinearQuantizeDType::F32) => {
247 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
248 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
249 })?)
250 }
251 (AnyCowArray::U32(encoded), LinearQuantizeDType::F32) => {
252 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
253 let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
255 #[expect(clippy::cast_possible_truncation)]
256 let x = x as f32;
257 x
258 })?)
259 }
260 (AnyCowArray::U64(encoded), LinearQuantizeDType::F32) => {
261 AnyArray::F32(reconstruct(&as_standard_order(encoded), |x| {
262 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
264 f32::from(x)
265 })?)
266 }
267 (AnyCowArray::U8(encoded), LinearQuantizeDType::F64) => {
268 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
269 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
270 })?)
271 }
272 (AnyCowArray::U16(encoded), LinearQuantizeDType::F64) => {
273 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
274 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
275 })?)
276 }
277 (AnyCowArray::U32(encoded), LinearQuantizeDType::F64) => {
278 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
279 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
280 })?)
281 }
282 (AnyCowArray::U64(encoded), LinearQuantizeDType::F64) => {
283 AnyArray::F64(reconstruct(&as_standard_order(encoded), |x| {
284 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
286 f64::from(x)
287 })?)
288 }
289 (encoded, _dtype) => {
290 return Err(LinearQuantizeCodecError::InvalidEncodedDType {
291 dtype: encoded.dtype(),
292 })
293 }
294 };
295
296 Ok(decoded)
297 }
298
299 fn decode_into(
300 &self,
301 encoded: AnyArrayView,
302 decoded: AnyArrayViewMut,
303 ) -> Result<(), Self::Error> {
304 fn as_standard_order<T: Copy, S: Data<Elem = T>, D: Dimension>(
305 array: &ArrayBase<S, D>,
306 ) -> Cow<[T]> {
307 #[expect(clippy::option_if_let_else)]
308 if let Some(data) = array.as_slice() {
309 Cow::Borrowed(data)
310 } else {
311 Cow::Owned(array.iter().copied().collect())
312 }
313 }
314
315 if !matches!(encoded.shape(), [_]) {
316 return Err(LinearQuantizeCodecError::EncodedDataNotOneDimensional {
317 shape: encoded.shape().to_vec(),
318 });
319 }
320
321 match (decoded, self.dtype) {
322 (AnyArrayViewMut::F32(decoded), LinearQuantizeDType::F32) => {
323 match &encoded {
324 AnyArrayView::U8(encoded) => {
325 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
326 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
327 })
328 }
329 AnyArrayView::U16(encoded) => {
330 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
331 f32::from(x) / scale_for_bits::<f32>(self.bits as u8)
332 })
333 }
334 AnyArrayView::U32(encoded) => {
335 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
336 let x = f64::from(x) / scale_for_bits::<f64>(self.bits as u8);
338 #[expect(clippy::cast_possible_truncation)]
339 let x = x as f32;
340 x
341 })
342 }
343 AnyArrayView::U64(encoded) => {
344 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
345 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
347 f32::from(x)
348 })
349 }
350 encoded => {
351 return Err(LinearQuantizeCodecError::InvalidEncodedDType {
352 dtype: encoded.dtype(),
353 })
354 }
355 }
356 }
357 (AnyArrayViewMut::F64(decoded), LinearQuantizeDType::F64) => {
358 match &encoded {
359 AnyArrayView::U8(encoded) => {
360 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
361 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
362 })
363 }
364 AnyArrayView::U16(encoded) => {
365 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
366 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
367 })
368 }
369 AnyArrayView::U32(encoded) => {
370 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
371 f64::from(x) / scale_for_bits::<f64>(self.bits as u8)
372 })
373 }
374 AnyArrayView::U64(encoded) => {
375 reconstruct_into(&as_standard_order(encoded), decoded, |x| {
376 let x = TwoFloat::from(x) / scale_for_bits::<f64>(self.bits as u8);
378 f64::from(x)
379 })
380 }
381 encoded => {
382 return Err(LinearQuantizeCodecError::InvalidEncodedDType {
383 dtype: encoded.dtype(),
384 })
385 }
386 }
387 }
388 (decoded, dtype) => {
389 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
390 configured: dtype,
391 provided: decoded.dtype(),
392 })
393 }
394 }?;
395
396 Ok(())
397 }
398}
399
400impl StaticCodec for LinearQuantizeCodec {
401 const CODEC_ID: &'static str = "linear-quantize.rs";
402
403 type Config<'de> = Self;
404
405 fn from_config(config: Self::Config<'_>) -> Self {
406 config
407 }
408
409 fn get_config(&self) -> StaticCodecConfig<Self> {
410 StaticCodecConfig::from(self)
411 }
412}
413
414#[derive(Debug, Error)]
415pub enum LinearQuantizeCodecError {
417 #[error("LinearQuantize cannot encode the provided dtype {provided} which differs from the configured dtype {configured}")]
420 MismatchedEncodeDType {
421 configured: LinearQuantizeDType,
423 provided: AnyArrayDType,
425 },
426 #[error("LinearQuantize does not support non-finite (infinite or NaN) floating point data")]
429 NonFiniteData,
430 #[error("LinearQuantize failed to encode the header")]
432 HeaderEncodeFailed {
433 source: LinearQuantizeHeaderError,
435 },
436 #[error("LinearQuantize can only decode one-dimensional arrays but received an array of shape {shape:?}")]
439 EncodedDataNotOneDimensional {
440 shape: Vec<usize>,
442 },
443 #[error("LinearQuantize failed to decode the header")]
445 HeaderDecodeFailed {
446 source: LinearQuantizeHeaderError,
448 },
449 #[error(
452 "LinearQuantize decoded an invalid array shape header which does not fit the decoded data"
453 )]
454 DecodeInvalidShapeHeader {
455 #[from]
457 source: ShapeError,
458 },
459 #[error("LinearQuantize cannot decode the provided dtype {dtype}")]
461 InvalidEncodedDType {
462 dtype: AnyArrayDType,
464 },
465 #[error("LinearQuantize cannot decode the provided dtype {provided} which differs from the configured dtype {configured}")]
468 MismatchedDecodeIntoDtype {
469 configured: LinearQuantizeDType,
471 provided: AnyArrayDType,
473 },
474 #[error("LinearQuantize cannot decode the decoded array of shape {decoded:?} into the provided array of shape {provided:?}")]
477 MismatchedDecodeIntoShape {
478 decoded: Vec<usize>,
480 provided: Vec<usize>,
482 },
483}
484
485#[derive(Debug, Error)]
486#[error(transparent)]
487pub struct LinearQuantizeHeaderError(postcard::Error);
489
490pub fn quantize<
501 T: Float + ConstZero + ConstOne + Serialize,
502 Q: Unsigned,
503 S: Data<Elem = T>,
504 D: Dimension,
505>(
506 data: &ArrayBase<S, D>,
507 quantize: impl Fn(T) -> Q,
508) -> Result<Vec<Q>, LinearQuantizeCodecError> {
509 if !Zip::from(data).all(|x| x.is_finite()) {
510 return Err(LinearQuantizeCodecError::NonFiniteData);
511 }
512
513 let (minimum, maximum) = data.first().map_or((T::ZERO, T::ONE), |first| {
514 (
515 Zip::from(data).fold(*first, |a, b| a.min(*b)),
516 Zip::from(data).fold(*first, |a, b| a.max(*b)),
517 )
518 });
519
520 let header = postcard::to_extend(
521 &CompressionHeader {
522 shape: Cow::Borrowed(data.shape()),
523 minimum,
524 maximum,
525 version: StaticCodecVersion,
526 },
527 Vec::new(),
528 )
529 .map_err(|err| LinearQuantizeCodecError::HeaderEncodeFailed {
530 source: LinearQuantizeHeaderError(err),
531 })?;
532
533 let mut encoded: Vec<Q> = vec![Q::ZERO; header.len().div_ceil(std::mem::size_of::<Q>())];
534 #[expect(unsafe_code)]
535 unsafe {
537 std::ptr::copy_nonoverlapping(header.as_ptr(), encoded.as_mut_ptr().cast(), header.len());
538 }
539 encoded.reserve(data.len());
540
541 if maximum == minimum {
542 encoded.resize(encoded.len() + data.len(), quantize(T::ZERO));
543 } else {
544 encoded.extend(
545 data.iter()
546 .map(|x| quantize((*x - minimum) / (maximum - minimum))),
547 );
548 }
549
550 Ok(encoded)
551}
552
553pub fn reconstruct<T: Float + DeserializeOwned, Q: Unsigned>(
562 encoded: &[Q],
563 floatify: impl Fn(Q) -> T,
564) -> Result<ArrayD<T>, LinearQuantizeCodecError> {
565 #[expect(unsafe_code)]
566 let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
568 std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
569 })
570 .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
571 source: LinearQuantizeHeaderError(err),
572 })?;
573
574 let encoded = encoded
575 .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
576 .unwrap_or(&[]);
577
578 let decoded = encoded
579 .iter()
580 .map(|x| header.minimum + (floatify(*x) * (header.maximum - header.minimum)))
581 .map(|x| x.clamp(header.minimum, header.maximum))
582 .collect();
583
584 let decoded = Array::from_shape_vec(&*header.shape, decoded)?;
585
586 Ok(decoded)
587}
588
589pub fn reconstruct_into<T: Float + DeserializeOwned, Q: Unsigned>(
600 encoded: &[Q],
601 mut decoded: ArrayViewMutD<T>,
602 floatify: impl Fn(Q) -> T,
603) -> Result<(), LinearQuantizeCodecError> {
604 #[expect(unsafe_code)]
605 let (header, remaining) = postcard::take_from_bytes::<CompressionHeader<T>>(unsafe {
607 std::slice::from_raw_parts(encoded.as_ptr().cast(), std::mem::size_of_val(encoded))
608 })
609 .map_err(|err| LinearQuantizeCodecError::HeaderDecodeFailed {
610 source: LinearQuantizeHeaderError(err),
611 })?;
612
613 let encoded = encoded
614 .get(encoded.len() - (remaining.len() / std::mem::size_of::<Q>())..)
615 .unwrap_or(&[]);
616
617 if decoded.shape() != &*header.shape {
618 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoShape {
619 decoded: header.shape.into_owned(),
620 provided: decoded.shape().to_vec(),
621 });
622 }
623
624 for (e, d) in encoded.iter().zip(decoded.iter_mut()) {
626 *d = (header.minimum + (floatify(*e) * (header.maximum - header.minimum)))
627 .clamp(header.minimum, header.maximum);
628 }
629
630 Ok(())
631}
632
633fn scale_for_bits<T: Float + From<u8> + ConstOne>(bits: u8) -> T {
635 <T as From<u8>>::from(bits).exp2() - T::ONE
636}
637
638pub trait Unsigned: Copy {
640 const ZERO: Self;
642}
643
644impl Unsigned for u8 {
645 const ZERO: Self = 0;
646}
647
648impl Unsigned for u16 {
649 const ZERO: Self = 0;
650}
651
652impl Unsigned for u32 {
653 const ZERO: Self = 0;
654}
655
656impl Unsigned for u64 {
657 const ZERO: Self = 0;
658}
659
660#[derive(Serialize, Deserialize)]
661struct CompressionHeader<'a, T> {
662 #[serde(borrow)]
663 shape: Cow<'a, [usize]>,
664 minimum: T,
665 maximum: T,
666 version: LinearQuantizeCodecVersion,
667}
668
669#[cfg(test)]
670mod tests {
671 use ndarray::CowArray;
672
673 use super::*;
674
675 #[test]
676 fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> {
677 for bits in 1..=16 {
678 let codec = LinearQuantizeCodec {
679 dtype: LinearQuantizeDType::F32,
680 #[expect(unsafe_code)]
681 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
682 version: StaticCodecVersion,
683 };
684
685 let mut data: Vec<f32> = (0..(u16::MAX >> (16 - bits)))
686 .step_by(1 << (bits.max(8) - 8))
687 .map(f32::from)
688 .collect();
689 data.push(f32::from(u16::MAX >> (16 - bits)));
690
691 let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
692 let decoded = codec.decode(encoded.cow())?;
693
694 let AnyArray::F32(decoded) = decoded else {
695 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
696 configured: LinearQuantizeDType::F32,
697 provided: decoded.dtype(),
698 });
699 };
700
701 for (o, d) in data.iter().zip(decoded.iter()) {
702 assert_eq!(o.to_bits(), d.to_bits());
703 }
704 }
705
706 Ok(())
707 }
708
709 #[test]
710 fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> {
711 for bits in 1..=64 {
712 let codec = LinearQuantizeCodec {
713 dtype: LinearQuantizeDType::F32,
714 #[expect(unsafe_code)]
715 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
716 version: StaticCodecVersion,
717 };
718
719 #[expect(clippy::cast_precision_loss)]
720 let mut data: Vec<f32> = (0..(u64::MAX >> (64 - bits)))
721 .step_by(1 << (bits.max(8) - 8))
722 .map(|x| x as f32)
723 .collect();
724 #[expect(clippy::cast_precision_loss)]
725 data.push((u64::MAX >> (64 - bits)) as f32);
726
727 let encoded = codec.encode(AnyCowArray::F32(CowArray::from(&data).into_dyn()))?;
728 let decoded = codec.decode(encoded.cow())?;
729
730 let AnyArray::F32(decoded) = decoded else {
731 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
732 configured: LinearQuantizeDType::F32,
733 provided: decoded.dtype(),
734 });
735 };
736
737 for (o, d) in data.iter().zip(decoded.iter()) {
738 assert_eq!(o.to_bits(), d.to_bits());
739 }
740 }
741
742 Ok(())
743 }
744
745 #[test]
746 fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> {
747 for bits in 1..=32 {
748 let codec = LinearQuantizeCodec {
749 dtype: LinearQuantizeDType::F64,
750 #[expect(unsafe_code)]
751 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
752 version: StaticCodecVersion,
753 };
754
755 let mut data: Vec<f64> = (0..(u32::MAX >> (32 - bits)))
756 .step_by(1 << (bits.max(8) - 8))
757 .map(f64::from)
758 .collect();
759 data.push(f64::from(u32::MAX >> (32 - bits)));
760
761 let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
762 let decoded = codec.decode(encoded.cow())?;
763
764 let AnyArray::F64(decoded) = decoded else {
765 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
766 configured: LinearQuantizeDType::F64,
767 provided: decoded.dtype(),
768 });
769 };
770
771 for (o, d) in data.iter().zip(decoded.iter()) {
772 assert_eq!(o.to_bits(), d.to_bits());
773 }
774 }
775
776 Ok(())
777 }
778
779 #[test]
780 fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> {
781 for bits in 1..=64 {
782 let codec = LinearQuantizeCodec {
783 dtype: LinearQuantizeDType::F64,
784 #[expect(unsafe_code)]
785 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
786 version: StaticCodecVersion,
787 };
788
789 #[expect(clippy::cast_precision_loss)]
790 let mut data: Vec<f64> = (0..(u64::MAX >> (64 - bits)))
791 .step_by(1 << (bits.max(8) - 8))
792 .map(|x| x as f64)
793 .collect();
794 #[expect(clippy::cast_precision_loss)]
795 data.push((u64::MAX >> (64 - bits)) as f64);
796
797 let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
798 let decoded = codec.decode(encoded.cow())?;
799
800 let AnyArray::F64(decoded) = decoded else {
801 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
802 configured: LinearQuantizeDType::F64,
803 provided: decoded.dtype(),
804 });
805 };
806
807 for (o, d) in data.iter().zip(decoded.iter()) {
808 assert_eq!(o.to_bits(), d.to_bits());
809 }
810 }
811
812 Ok(())
813 }
814
815 #[test]
816 fn const_data_roundtrip() -> Result<(), LinearQuantizeCodecError> {
817 for bits in 1..=64 {
818 let data = [42.0, 42.0, 42.0, 42.0];
819
820 let codec = LinearQuantizeCodec {
821 dtype: LinearQuantizeDType::F64,
822 #[expect(unsafe_code)]
823 bits: unsafe { std::mem::transmute::<u8, LinearQuantizeBins>(bits) },
824 version: StaticCodecVersion,
825 };
826
827 let encoded = codec.encode(AnyCowArray::F64(CowArray::from(&data).into_dyn()))?;
828 let decoded = codec.decode(encoded.cow())?;
829
830 let AnyArray::F64(decoded) = decoded else {
831 return Err(LinearQuantizeCodecError::MismatchedDecodeIntoDtype {
832 configured: LinearQuantizeDType::F64,
833 provided: decoded.dtype(),
834 });
835 };
836
837 for (o, d) in data.iter().zip(decoded.iter()) {
838 assert_eq!(o.to_bits(), d.to_bits());
839 }
840 }
841
842 Ok(())
843 }
844}