1use std::{borrow::Cow, fmt, mem::ManuallyDrop, ptr};
2
3use ndarray::{
4 ArrayBase, ArrayD, CowRepr, Data, DataMut, Dimension, IxDyn, OwnedArcRepr, OwnedRepr, RawData,
5 RawDataClone, RawDataSubst, ViewRepr,
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11pub type AnyArcArray = AnyArrayBase<OwnedArcRepr<()>>;
13pub type AnyArray = AnyArrayBase<OwnedRepr<()>>;
15pub type AnyArrayView<'a> = AnyArrayBase<ViewRepr<&'a ()>>;
17pub type AnyArrayViewMut<'a> = AnyArrayBase<ViewRepr<&'a mut ()>>;
19pub type AnyCowArray<'a> = AnyArrayBase<CowRepr<'a, ()>>;
21
22#[non_exhaustive]
24#[expect(missing_docs)]
25pub enum AnyArrayBase<T: AnyRawData> {
26 U8(ArrayBase<T::U8, IxDyn>),
27 U16(ArrayBase<T::U16, IxDyn>),
28 U32(ArrayBase<T::U32, IxDyn>),
29 U64(ArrayBase<T::U64, IxDyn>),
30 I8(ArrayBase<T::I8, IxDyn>),
31 I16(ArrayBase<T::I16, IxDyn>),
32 I32(ArrayBase<T::I32, IxDyn>),
33 I64(ArrayBase<T::I64, IxDyn>),
34 F32(ArrayBase<T::F32, IxDyn>),
35 F64(ArrayBase<T::F64, IxDyn>),
36}
37
38impl<T: AnyRawData> AnyArrayBase<T> {
39 pub fn len(&self) -> usize {
41 match self {
42 Self::U8(a) => a.len(),
43 Self::U16(a) => a.len(),
44 Self::U32(a) => a.len(),
45 Self::U64(a) => a.len(),
46 Self::I8(a) => a.len(),
47 Self::I16(a) => a.len(),
48 Self::I32(a) => a.len(),
49 Self::I64(a) => a.len(),
50 Self::F32(a) => a.len(),
51 Self::F64(a) => a.len(),
52 }
53 }
54
55 pub fn is_empty(&self) -> bool {
57 match self {
58 Self::U8(a) => a.is_empty(),
59 Self::U16(a) => a.is_empty(),
60 Self::U32(a) => a.is_empty(),
61 Self::U64(a) => a.is_empty(),
62 Self::I8(a) => a.is_empty(),
63 Self::I16(a) => a.is_empty(),
64 Self::I32(a) => a.is_empty(),
65 Self::I64(a) => a.is_empty(),
66 Self::F32(a) => a.is_empty(),
67 Self::F64(a) => a.is_empty(),
68 }
69 }
70
71 pub const fn dtype(&self) -> AnyArrayDType {
73 match self {
74 Self::U8(_) => AnyArrayDType::U8,
75 Self::U16(_) => AnyArrayDType::U16,
76 Self::U32(_) => AnyArrayDType::U32,
77 Self::U64(_) => AnyArrayDType::U64,
78 Self::I8(_) => AnyArrayDType::I8,
79 Self::I16(_) => AnyArrayDType::I16,
80 Self::I32(_) => AnyArrayDType::I32,
81 Self::I64(_) => AnyArrayDType::I64,
82 Self::F32(_) => AnyArrayDType::F32,
83 Self::F64(_) => AnyArrayDType::F64,
84 }
85 }
86
87 pub fn shape(&self) -> &[usize] {
89 match self {
90 Self::U8(a) => a.shape(),
91 Self::U16(a) => a.shape(),
92 Self::U32(a) => a.shape(),
93 Self::U64(a) => a.shape(),
94 Self::I8(a) => a.shape(),
95 Self::I16(a) => a.shape(),
96 Self::I32(a) => a.shape(),
97 Self::I64(a) => a.shape(),
98 Self::F32(a) => a.shape(),
99 Self::F64(a) => a.shape(),
100 }
101 }
102
103 pub fn strides(&self) -> &[isize] {
105 match self {
106 Self::U8(a) => a.strides(),
107 Self::U16(a) => a.strides(),
108 Self::U32(a) => a.strides(),
109 Self::U64(a) => a.strides(),
110 Self::I8(a) => a.strides(),
111 Self::I16(a) => a.strides(),
112 Self::I32(a) => a.strides(),
113 Self::I64(a) => a.strides(),
114 Self::F32(a) => a.strides(),
115 Self::F64(a) => a.strides(),
116 }
117 }
118
119 #[must_use]
120 pub const fn as_typed<U: ArrayDType>(&self) -> Option<&ArrayBase<U::RawData<T>, IxDyn>> {
123 #[expect(unsafe_code)]
124 match (self, U::DTYPE) {
127 (Self::U8(a), AnyArrayDType::U8) => Some(unsafe { &*ptr::from_ref(a).cast() }),
128 (Self::U16(a), AnyArrayDType::U16) => Some(unsafe { &*ptr::from_ref(a).cast() }),
129 (Self::U32(a), AnyArrayDType::U32) => Some(unsafe { &*ptr::from_ref(a).cast() }),
130 (Self::U64(a), AnyArrayDType::U64) => Some(unsafe { &*ptr::from_ref(a).cast() }),
131 (Self::I8(a), AnyArrayDType::I8) => Some(unsafe { &*ptr::from_ref(a).cast() }),
132 (Self::I16(a), AnyArrayDType::I16) => Some(unsafe { &*ptr::from_ref(a).cast() }),
133 (Self::I32(a), AnyArrayDType::I32) => Some(unsafe { &*ptr::from_ref(a).cast() }),
134 (Self::I64(a), AnyArrayDType::I64) => Some(unsafe { &*ptr::from_ref(a).cast() }),
135 (Self::F32(a), AnyArrayDType::F32) => Some(unsafe { &*ptr::from_ref(a).cast() }),
136 (Self::F64(a), AnyArrayDType::F64) => Some(unsafe { &*ptr::from_ref(a).cast() }),
137 (_self, _dtype) => None,
138 }
139 }
140
141 #[must_use]
142 pub fn as_typed_mut<U: ArrayDType>(&mut self) -> Option<&mut ArrayBase<U::RawData<T>, IxDyn>> {
145 #[expect(unsafe_code)]
146 match (self, U::DTYPE) {
149 (Self::U8(a), AnyArrayDType::U8) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
150 (Self::U16(a), AnyArrayDType::U16) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
151 (Self::U32(a), AnyArrayDType::U32) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
152 (Self::U64(a), AnyArrayDType::U64) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
153 (Self::I8(a), AnyArrayDType::I8) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
154 (Self::I16(a), AnyArrayDType::I16) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
155 (Self::I32(a), AnyArrayDType::I32) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
156 (Self::I64(a), AnyArrayDType::I64) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
157 (Self::F32(a), AnyArrayDType::F32) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
158 (Self::F64(a), AnyArrayDType::F64) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
159 (_self, _dtype) => None,
160 }
161 }
162}
163
164impl<T: AnyRawData> AnyArrayBase<T>
165where
166 T::U8: Data,
167 T::U16: Data,
168 T::U32: Data,
169 T::U64: Data,
170 T::I8: Data,
171 T::I16: Data,
172 T::I32: Data,
173 T::I64: Data,
174 T::F32: Data,
175 T::F64: Data,
176{
177 #[must_use]
178 pub fn view(&self) -> AnyArrayView {
180 match self {
181 Self::U8(a) => AnyArrayView::U8(a.view()),
182 Self::U16(a) => AnyArrayView::U16(a.view()),
183 Self::U32(a) => AnyArrayView::U32(a.view()),
184 Self::U64(a) => AnyArrayView::U64(a.view()),
185 Self::I8(a) => AnyArrayView::I8(a.view()),
186 Self::I16(a) => AnyArrayView::I16(a.view()),
187 Self::I32(a) => AnyArrayView::I32(a.view()),
188 Self::I64(a) => AnyArrayView::I64(a.view()),
189 Self::F32(a) => AnyArrayView::F32(a.view()),
190 Self::F64(a) => AnyArrayView::F64(a.view()),
191 }
192 }
193
194 #[must_use]
195 pub fn cow(&self) -> AnyCowArray {
197 match self {
198 Self::U8(a) => AnyCowArray::U8(a.into()),
199 Self::U16(a) => AnyCowArray::U16(a.into()),
200 Self::U32(a) => AnyCowArray::U32(a.into()),
201 Self::U64(a) => AnyCowArray::U64(a.into()),
202 Self::I8(a) => AnyCowArray::I8(a.into()),
203 Self::I16(a) => AnyCowArray::I16(a.into()),
204 Self::I32(a) => AnyCowArray::I32(a.into()),
205 Self::I64(a) => AnyCowArray::I64(a.into()),
206 Self::F32(a) => AnyCowArray::F32(a.into()),
207 Self::F64(a) => AnyCowArray::F64(a.into()),
208 }
209 }
210
211 #[must_use]
212 pub fn into_owned(self) -> AnyArray {
215 match self {
216 Self::U8(a) => AnyArray::U8(a.into_owned()),
217 Self::U16(a) => AnyArray::U16(a.into_owned()),
218 Self::U32(a) => AnyArray::U32(a.into_owned()),
219 Self::U64(a) => AnyArray::U64(a.into_owned()),
220 Self::I8(a) => AnyArray::I8(a.into_owned()),
221 Self::I16(a) => AnyArray::I16(a.into_owned()),
222 Self::I32(a) => AnyArray::I32(a.into_owned()),
223 Self::I64(a) => AnyArray::I64(a.into_owned()),
224 Self::F32(a) => AnyArray::F32(a.into_owned()),
225 Self::F64(a) => AnyArray::F64(a.into_owned()),
226 }
227 }
228
229 #[must_use]
230 pub fn as_bytes(&self) -> Cow<[u8]> {
238 fn array_into_bytes<T: Copy, S: Data<Elem = T>>(x: &ArrayBase<S, IxDyn>) -> Cow<[u8]> {
239 #[expect(clippy::option_if_let_else)]
240 if let Some(x) = x.as_slice() {
241 #[expect(unsafe_code)]
242 Cow::Borrowed(unsafe {
246 std::slice::from_raw_parts(x.as_ptr().cast::<u8>(), std::mem::size_of_val(x))
247 })
248 } else {
249 let x = x.into_iter().copied().collect::<Vec<T>>();
250 let mut x = ManuallyDrop::new(x);
251 let (ptr, len, capacity) = (x.as_mut_ptr(), x.len(), x.capacity());
252 #[expect(unsafe_code)]
253 let x = unsafe {
258 Vec::from_raw_parts(
259 ptr.cast::<u8>(),
260 len * std::mem::size_of::<T>(),
261 capacity * std::mem::size_of::<T>(),
262 )
263 };
264 Cow::Owned(x)
265 }
266 }
267
268 match self {
269 Self::U8(a) => array_into_bytes(a),
270 Self::U16(a) => array_into_bytes(a),
271 Self::U32(a) => array_into_bytes(a),
272 Self::U64(a) => array_into_bytes(a),
273 Self::I8(a) => array_into_bytes(a),
274 Self::I16(a) => array_into_bytes(a),
275 Self::I32(a) => array_into_bytes(a),
276 Self::I64(a) => array_into_bytes(a),
277 Self::F32(a) => array_into_bytes(a),
278 Self::F64(a) => array_into_bytes(a),
279 }
280 }
281}
282
283impl<T: AnyRawData> AnyArrayBase<T>
284where
285 T::U8: DataMut,
286 T::U16: DataMut,
287 T::U32: DataMut,
288 T::U64: DataMut,
289 T::I8: DataMut,
290 T::I16: DataMut,
291 T::I32: DataMut,
292 T::I64: DataMut,
293 T::F32: DataMut,
294 T::F64: DataMut,
295{
296 #[must_use]
297 pub fn view_mut(&mut self) -> AnyArrayViewMut {
299 match self {
300 Self::U8(a) => AnyArrayViewMut::U8(a.view_mut()),
301 Self::U16(a) => AnyArrayViewMut::U16(a.view_mut()),
302 Self::U32(a) => AnyArrayViewMut::U32(a.view_mut()),
303 Self::U64(a) => AnyArrayViewMut::U64(a.view_mut()),
304 Self::I8(a) => AnyArrayViewMut::I8(a.view_mut()),
305 Self::I16(a) => AnyArrayViewMut::I16(a.view_mut()),
306 Self::I32(a) => AnyArrayViewMut::I32(a.view_mut()),
307 Self::I64(a) => AnyArrayViewMut::I64(a.view_mut()),
308 Self::F32(a) => AnyArrayViewMut::F32(a.view_mut()),
309 Self::F64(a) => AnyArrayViewMut::F64(a.view_mut()),
310 }
311 }
312
313 #[must_use]
314 pub fn with_bytes_mut<O>(&mut self, with: impl FnOnce(&mut [u8]) -> O) -> O {
323 fn array_with_bytes_mut<T: Copy, S: DataMut<Elem = T>, O>(
324 x: &mut ArrayBase<S, IxDyn>,
325 with: impl FnOnce(&mut [u8]) -> O,
326 ) -> O {
327 if let Some(x) = x.as_slice_mut() {
328 #[expect(unsafe_code)]
329 with(unsafe {
333 std::slice::from_raw_parts_mut(
334 x.as_mut_ptr().cast::<u8>(),
335 std::mem::size_of_val(x),
336 )
337 })
338 } else {
339 let mut x_vec: Vec<T> = x.into_iter().map(|x| *x).collect::<Vec<T>>();
340
341 #[expect(unsafe_code)]
342 let result = with(unsafe {
346 std::slice::from_raw_parts_mut(
347 x_vec.as_mut_ptr().cast::<u8>(),
348 std::mem::size_of_val(x_vec.as_slice()),
349 )
350 });
351
352 x.iter_mut().zip(x_vec).for_each(|(x, x_new)| *x = x_new);
353 result
354 }
355 }
356
357 match self {
358 Self::U8(a) => array_with_bytes_mut(a, with),
359 Self::U16(a) => array_with_bytes_mut(a, with),
360 Self::U32(a) => array_with_bytes_mut(a, with),
361 Self::U64(a) => array_with_bytes_mut(a, with),
362 Self::I8(a) => array_with_bytes_mut(a, with),
363 Self::I16(a) => array_with_bytes_mut(a, with),
364 Self::I32(a) => array_with_bytes_mut(a, with),
365 Self::I64(a) => array_with_bytes_mut(a, with),
366 Self::F32(a) => array_with_bytes_mut(a, with),
367 Self::F64(a) => array_with_bytes_mut(a, with),
368 }
369 }
370
371 pub fn assign<U: AnyRawData>(
381 &mut self,
382 src: &AnyArrayBase<U>,
383 ) -> Result<(), AnyArrayAssignError>
384 where
385 U::U8: Data,
386 U::U16: Data,
387 U::U32: Data,
388 U::U64: Data,
389 U::I8: Data,
390 U::I16: Data,
391 U::I32: Data,
392 U::I64: Data,
393 U::F32: Data,
394 U::F64: Data,
395 {
396 fn shape_checked_assign<
397 T: Copy,
398 S1: Data<Elem = T>,
399 S2: DataMut<Elem = T>,
400 D1: Dimension,
401 D2: Dimension,
402 >(
403 src: &ArrayBase<S1, D1>,
404 dst: &mut ArrayBase<S2, D2>,
405 ) -> Result<(), AnyArrayAssignError> {
406 #[expect(clippy::unit_arg)]
407 if src.shape() == dst.shape() {
408 Ok(dst.assign(src))
409 } else {
410 Err(AnyArrayAssignError::ShapeMismatch {
411 src: src.shape().to_vec(),
412 dst: dst.shape().to_vec(),
413 })
414 }
415 }
416
417 match (src, self) {
418 (AnyArrayBase::U8(src), Self::U8(dst)) => shape_checked_assign(src, dst),
419 (AnyArrayBase::U16(src), Self::U16(dst)) => shape_checked_assign(src, dst),
420 (AnyArrayBase::U32(src), Self::U32(dst)) => shape_checked_assign(src, dst),
421 (AnyArrayBase::U64(src), Self::U64(dst)) => shape_checked_assign(src, dst),
422 (AnyArrayBase::I8(src), Self::I8(dst)) => shape_checked_assign(src, dst),
423 (AnyArrayBase::I16(src), Self::I16(dst)) => shape_checked_assign(src, dst),
424 (AnyArrayBase::I32(src), Self::I32(dst)) => shape_checked_assign(src, dst),
425 (AnyArrayBase::I64(src), Self::I64(dst)) => shape_checked_assign(src, dst),
426 (AnyArrayBase::F32(src), Self::F32(dst)) => shape_checked_assign(src, dst),
427 (AnyArrayBase::F64(src), Self::F64(dst)) => shape_checked_assign(src, dst),
428 (src, dst) => Err(AnyArrayAssignError::DTypeMismatch {
429 src: src.dtype(),
430 dst: dst.dtype(),
431 }),
432 }
433 }
434}
435
436impl AnyArray {
437 #[must_use]
438 pub fn zeros(dtype: AnyArrayDType, shape: &[usize]) -> Self {
440 match dtype {
441 AnyArrayDType::U8 => Self::U8(ArrayD::zeros(shape)),
442 AnyArrayDType::U16 => Self::U16(ArrayD::zeros(shape)),
443 AnyArrayDType::U32 => Self::U32(ArrayD::zeros(shape)),
444 AnyArrayDType::U64 => Self::U64(ArrayD::zeros(shape)),
445 AnyArrayDType::I8 => Self::I8(ArrayD::zeros(shape)),
446 AnyArrayDType::I16 => Self::I16(ArrayD::zeros(shape)),
447 AnyArrayDType::I32 => Self::I32(ArrayD::zeros(shape)),
448 AnyArrayDType::I64 => Self::I64(ArrayD::zeros(shape)),
449 AnyArrayDType::F32 => Self::F32(ArrayD::zeros(shape)),
450 AnyArrayDType::F64 => Self::F64(ArrayD::zeros(shape)),
451 }
452 }
453
454 pub fn with_zeros_bytes<T>(
461 dtype: AnyArrayDType,
462 shape: &[usize],
463 with: impl FnOnce(&mut [u8]) -> T,
464 ) -> (Self, T) {
465 fn standard_array_as_bytes_mut<T: Copy>(x: &mut ArrayD<T>) -> &mut [u8] {
466 #[expect(unsafe_code)]
467 unsafe {
473 std::slice::from_raw_parts_mut(
474 x.as_mut_ptr().cast::<u8>(),
475 x.len() * std::mem::size_of::<T>(),
476 )
477 }
478 }
479
480 let mut array = Self::zeros(dtype, shape);
481
482 let result = match &mut array {
483 Self::U8(a) => with(standard_array_as_bytes_mut(a)),
484 Self::U16(a) => with(standard_array_as_bytes_mut(a)),
485 Self::U32(a) => with(standard_array_as_bytes_mut(a)),
486 Self::U64(a) => with(standard_array_as_bytes_mut(a)),
487 Self::I8(a) => with(standard_array_as_bytes_mut(a)),
488 Self::I16(a) => with(standard_array_as_bytes_mut(a)),
489 Self::I32(a) => with(standard_array_as_bytes_mut(a)),
490 Self::I64(a) => with(standard_array_as_bytes_mut(a)),
491 Self::F32(a) => with(standard_array_as_bytes_mut(a)),
492 Self::F64(a) => with(standard_array_as_bytes_mut(a)),
493 };
494
495 (array, result)
496 }
497
498 #[must_use]
499 pub fn into_cow(self) -> AnyCowArray<'static> {
501 match self {
502 Self::U8(array) => AnyCowArray::U8(array.into()),
503 Self::U16(array) => AnyCowArray::U16(array.into()),
504 Self::U32(array) => AnyCowArray::U32(array.into()),
505 Self::U64(array) => AnyCowArray::U64(array.into()),
506 Self::I8(array) => AnyCowArray::I8(array.into()),
507 Self::I16(array) => AnyCowArray::I16(array.into()),
508 Self::I32(array) => AnyCowArray::I32(array.into()),
509 Self::I64(array) => AnyCowArray::I64(array.into()),
510 Self::F32(array) => AnyCowArray::F32(array.into()),
511 Self::F64(array) => AnyCowArray::F64(array.into()),
512 }
513 }
514}
515
516impl<T: AnyRawData> Clone for AnyArrayBase<T>
517where
518 T::U8: RawDataClone,
519 T::U16: RawDataClone,
520 T::U32: RawDataClone,
521 T::U64: RawDataClone,
522 T::I8: RawDataClone,
523 T::I16: RawDataClone,
524 T::I32: RawDataClone,
525 T::I64: RawDataClone,
526 T::F32: RawDataClone,
527 T::F64: RawDataClone,
528{
529 fn clone(&self) -> Self {
530 match self {
531 Self::U8(a) => Self::U8(a.clone()),
532 Self::U16(a) => Self::U16(a.clone()),
533 Self::U32(a) => Self::U32(a.clone()),
534 Self::U64(a) => Self::U64(a.clone()),
535 Self::I8(a) => Self::I8(a.clone()),
536 Self::I16(a) => Self::I16(a.clone()),
537 Self::I32(a) => Self::I32(a.clone()),
538 Self::I64(a) => Self::I64(a.clone()),
539 Self::F32(a) => Self::F32(a.clone()),
540 Self::F64(a) => Self::F64(a.clone()),
541 }
542 }
543}
544
545impl<T: AnyRawData> fmt::Debug for AnyArrayBase<T>
546where
547 T::U8: Data,
548 T::U16: Data,
549 T::U32: Data,
550 T::U64: Data,
551 T::I8: Data,
552 T::I16: Data,
553 T::I32: Data,
554 T::I64: Data,
555 T::F32: Data,
556 T::F64: Data,
557{
558 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
559 match self {
560 Self::U8(a) => fmt.debug_tuple("U8").field(a).finish(),
561 Self::U16(a) => fmt.debug_tuple("U16").field(a).finish(),
562 Self::U32(a) => fmt.debug_tuple("U32").field(a).finish(),
563 Self::U64(a) => fmt.debug_tuple("U64").field(a).finish(),
564 Self::I8(a) => fmt.debug_tuple("I8").field(a).finish(),
565 Self::I16(a) => fmt.debug_tuple("I16").field(a).finish(),
566 Self::I32(a) => fmt.debug_tuple("I32").field(a).finish(),
567 Self::I64(a) => fmt.debug_tuple("I64").field(a).finish(),
568 Self::F32(a) => fmt.debug_tuple("F32").field(a).finish(),
569 Self::F64(a) => fmt.debug_tuple("F64").field(a).finish(),
570 }
571 }
572}
573
574impl<T: AnyRawData> PartialEq for AnyArrayBase<T>
575where
576 T::U8: Data,
577 T::U16: Data,
578 T::U32: Data,
579 T::U64: Data,
580 T::I8: Data,
581 T::I16: Data,
582 T::I32: Data,
583 T::I64: Data,
584 T::F32: Data,
585 T::F64: Data,
586{
587 fn eq(&self, other: &Self) -> bool {
588 match (self, other) {
589 (Self::U8(l), Self::U8(r)) => l == r,
590 (Self::U16(l), Self::U16(r)) => l == r,
591 (Self::U32(l), Self::U32(r)) => l == r,
592 (Self::U64(l), Self::U64(r)) => l == r,
593 (Self::I8(l), Self::I8(r)) => l == r,
594 (Self::I16(l), Self::I16(r)) => l == r,
595 (Self::I32(l), Self::I32(r)) => l == r,
596 (Self::I64(l), Self::I64(r)) => l == r,
597 (Self::F32(l), Self::F32(r)) => l == r,
598 (Self::F64(l), Self::F64(r)) => l == r,
599 _ => false,
600 }
601 }
602}
603
604#[expect(missing_docs)]
606pub trait AnyRawData {
607 type U8: RawData<Elem = u8>;
608 type U16: RawData<Elem = u16>;
609 type U32: RawData<Elem = u32>;
610 type U64: RawData<Elem = u64>;
611 type I8: RawData<Elem = i8>;
612 type I16: RawData<Elem = i16>;
613 type I32: RawData<Elem = i32>;
614 type I64: RawData<Elem = i64>;
615 type F32: RawData<Elem = f32>;
616 type F64: RawData<Elem = f64>;
617}
618
619impl<
620 T: RawDataSubst<u8>
621 + RawDataSubst<u16>
622 + RawDataSubst<u32>
623 + RawDataSubst<u64>
624 + RawDataSubst<i8>
625 + RawDataSubst<i16>
626 + RawDataSubst<i32>
627 + RawDataSubst<i64>
628 + RawDataSubst<f32>
629 + RawDataSubst<f64>,
630 > AnyRawData for T
631{
632 type U8 = <T as RawDataSubst<u8>>::Output;
633 type U16 = <T as RawDataSubst<u16>>::Output;
634 type U32 = <T as RawDataSubst<u32>>::Output;
635 type U64 = <T as RawDataSubst<u64>>::Output;
636 type I8 = <T as RawDataSubst<i8>>::Output;
637 type I16 = <T as RawDataSubst<i16>>::Output;
638 type I32 = <T as RawDataSubst<i32>>::Output;
639 type I64 = <T as RawDataSubst<i64>>::Output;
640 type F32 = <T as RawDataSubst<f32>>::Output;
641 type F64 = <T as RawDataSubst<f64>>::Output;
642}
643
644#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
646#[schemars(extend("enum" = [
647 "u8", "uint8",
648 "u16", "uint16",
649 "u32", "uint32",
650 "u64", "uint64",
651 "i8", "int8",
652 "i16", "int16",
653 "i32", "int32",
654 "i64", "int64",
655 "f32", "float32",
656 "f64", "float64"
657]))]
658#[non_exhaustive]
659#[expect(missing_docs)]
660pub enum AnyArrayDType {
661 #[serde(rename = "u8", alias = "uint8")]
662 U8,
663 #[serde(rename = "u16", alias = "uint16")]
664 U16,
665 #[serde(rename = "u32", alias = "uint32")]
666 U32,
667 #[serde(rename = "u64", alias = "uint64")]
668 U64,
669 #[serde(rename = "i8", alias = "int8")]
670 I8,
671 #[serde(rename = "i16", alias = "int16")]
672 I16,
673 #[serde(rename = "i32", alias = "int32")]
674 I32,
675 #[serde(rename = "i64", alias = "int64")]
676 I64,
677 #[serde(rename = "f32", alias = "float32")]
678 F32,
679 #[serde(rename = "f64", alias = "float64")]
680 F64,
681}
682
683impl AnyArrayDType {
684 #[must_use]
685 pub const fn of<T: ArrayDType>() -> Self {
687 T::DTYPE
688 }
689
690 #[must_use]
691 pub const fn to_binary(self) -> Self {
699 match self {
700 Self::U8 | Self::I8 => Self::U8,
701 Self::U16 | Self::I16 => Self::U16,
702 Self::U32 | Self::I32 | Self::F32 => Self::U32,
703 Self::U64 | Self::I64 | Self::F64 => Self::U64,
704 }
705 }
706
707 #[must_use]
708 pub const fn size(self) -> usize {
710 match self {
711 Self::U8 => std::mem::size_of::<u8>(),
712 Self::U16 => std::mem::size_of::<u16>(),
713 Self::U32 => std::mem::size_of::<u32>(),
714 Self::U64 => std::mem::size_of::<u64>(),
715 Self::I8 => std::mem::size_of::<i8>(),
716 Self::I16 => std::mem::size_of::<i16>(),
717 Self::I32 => std::mem::size_of::<i32>(),
718 Self::I64 => std::mem::size_of::<i64>(),
719 Self::F32 => std::mem::size_of::<f32>(),
720 Self::F64 => std::mem::size_of::<f64>(),
721 }
722 }
723}
724
725impl fmt::Display for AnyArrayDType {
726 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
727 fmt.write_str(match self {
728 Self::U8 => "u8",
729 Self::U16 => "u16",
730 Self::U32 => "u32",
731 Self::U64 => "u64",
732 Self::I8 => "i8",
733 Self::I16 => "i16",
734 Self::I32 => "i32",
735 Self::I64 => "i64",
736 Self::F32 => "f32",
737 Self::F64 => "f64",
738 })
739 }
740}
741
742pub trait ArrayDType: crate::sealed::Sealed {
744 const DTYPE: AnyArrayDType;
746
747 type RawData<T: AnyRawData>: RawData<Elem = Self>;
749}
750
751macro_rules! array_dtype {
752 ($($dtype:ident($ty:ty)),*) => {
753 $(
754 impl crate::sealed::Sealed for $ty {}
755
756 impl ArrayDType for $ty {
757 const DTYPE: AnyArrayDType = AnyArrayDType::$dtype;
758
759 type RawData<T: AnyRawData> = T::$dtype;
760 }
761 )*
762 };
763}
764
765array_dtype! {
766 U8(u8), U16(u16), U32(u32), U64(u64),
767 I8(i8), I16(i16), I32(i32), I64(i64),
768 F32(f32), F64(f64)
769}
770
771#[derive(Debug, Error)]
772pub enum AnyArrayAssignError {
774 #[error("cannot assign from mismatching {src} array to {dst}")]
776 DTypeMismatch {
777 src: AnyArrayDType,
779 dst: AnyArrayDType,
781 },
782 #[error("cannot assign from array of shape {src:?} to one of shape {dst:?}")]
784 ShapeMismatch {
785 src: Vec<usize>,
787 dst: Vec<usize>,
789 },
790}