1use std::{borrow::Cow, fmt, mem::ManuallyDrop, ptr};
2
3use ndarray::{
4 ArrayBase, ArrayD, CowRepr, Data, DataMut, Dimension, IxDyn, OwnedArcRepr, OwnedRepr, RawData,
5 RawDataClone, RawDataSubst, ViewRepr,
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11pub type AnyArcArray = AnyArrayBase<OwnedArcRepr<()>>;
13pub type AnyArray = AnyArrayBase<OwnedRepr<()>>;
15pub type AnyArrayView<'a> = AnyArrayBase<ViewRepr<&'a ()>>;
17pub type AnyArrayViewMut<'a> = AnyArrayBase<ViewRepr<&'a mut ()>>;
19pub type AnyCowArray<'a> = AnyArrayBase<CowRepr<'a, ()>>;
21
22#[non_exhaustive]
24#[expect(missing_docs)]
25pub enum AnyArrayBase<T: AnyRawData> {
26 U8(ArrayBase<T::U8, IxDyn>),
27 U16(ArrayBase<T::U16, IxDyn>),
28 U32(ArrayBase<T::U32, IxDyn>),
29 U64(ArrayBase<T::U64, IxDyn>),
30 I8(ArrayBase<T::I8, IxDyn>),
31 I16(ArrayBase<T::I16, IxDyn>),
32 I32(ArrayBase<T::I32, IxDyn>),
33 I64(ArrayBase<T::I64, IxDyn>),
34 F32(ArrayBase<T::F32, IxDyn>),
35 F64(ArrayBase<T::F64, IxDyn>),
36}
37
38impl<T: AnyRawData> AnyArrayBase<T> {
39 pub fn len(&self) -> usize {
41 match self {
42 Self::U8(a) => a.len(),
43 Self::U16(a) => a.len(),
44 Self::U32(a) => a.len(),
45 Self::U64(a) => a.len(),
46 Self::I8(a) => a.len(),
47 Self::I16(a) => a.len(),
48 Self::I32(a) => a.len(),
49 Self::I64(a) => a.len(),
50 Self::F32(a) => a.len(),
51 Self::F64(a) => a.len(),
52 }
53 }
54
55 pub fn is_empty(&self) -> bool {
57 match self {
58 Self::U8(a) => a.is_empty(),
59 Self::U16(a) => a.is_empty(),
60 Self::U32(a) => a.is_empty(),
61 Self::U64(a) => a.is_empty(),
62 Self::I8(a) => a.is_empty(),
63 Self::I16(a) => a.is_empty(),
64 Self::I32(a) => a.is_empty(),
65 Self::I64(a) => a.is_empty(),
66 Self::F32(a) => a.is_empty(),
67 Self::F64(a) => a.is_empty(),
68 }
69 }
70
71 pub const fn dtype(&self) -> AnyArrayDType {
73 match self {
74 Self::U8(_) => AnyArrayDType::U8,
75 Self::U16(_) => AnyArrayDType::U16,
76 Self::U32(_) => AnyArrayDType::U32,
77 Self::U64(_) => AnyArrayDType::U64,
78 Self::I8(_) => AnyArrayDType::I8,
79 Self::I16(_) => AnyArrayDType::I16,
80 Self::I32(_) => AnyArrayDType::I32,
81 Self::I64(_) => AnyArrayDType::I64,
82 Self::F32(_) => AnyArrayDType::F32,
83 Self::F64(_) => AnyArrayDType::F64,
84 }
85 }
86
87 pub fn shape(&self) -> &[usize] {
89 match self {
90 Self::U8(a) => a.shape(),
91 Self::U16(a) => a.shape(),
92 Self::U32(a) => a.shape(),
93 Self::U64(a) => a.shape(),
94 Self::I8(a) => a.shape(),
95 Self::I16(a) => a.shape(),
96 Self::I32(a) => a.shape(),
97 Self::I64(a) => a.shape(),
98 Self::F32(a) => a.shape(),
99 Self::F64(a) => a.shape(),
100 }
101 }
102
103 pub fn strides(&self) -> &[isize] {
105 match self {
106 Self::U8(a) => a.strides(),
107 Self::U16(a) => a.strides(),
108 Self::U32(a) => a.strides(),
109 Self::U64(a) => a.strides(),
110 Self::I8(a) => a.strides(),
111 Self::I16(a) => a.strides(),
112 Self::I32(a) => a.strides(),
113 Self::I64(a) => a.strides(),
114 Self::F32(a) => a.strides(),
115 Self::F64(a) => a.strides(),
116 }
117 }
118
119 #[must_use]
120 pub const fn as_typed<U: ArrayDType>(&self) -> Option<&ArrayBase<U::RawData<T>, IxDyn>> {
123 #[expect(unsafe_code)]
124 match (self, U::DTYPE) {
127 (Self::U8(a), AnyArrayDType::U8) => Some(unsafe { &*ptr::from_ref(a).cast() }),
128 (Self::U16(a), AnyArrayDType::U16) => Some(unsafe { &*ptr::from_ref(a).cast() }),
129 (Self::U32(a), AnyArrayDType::U32) => Some(unsafe { &*ptr::from_ref(a).cast() }),
130 (Self::U64(a), AnyArrayDType::U64) => Some(unsafe { &*ptr::from_ref(a).cast() }),
131 (Self::I8(a), AnyArrayDType::I8) => Some(unsafe { &*ptr::from_ref(a).cast() }),
132 (Self::I16(a), AnyArrayDType::I16) => Some(unsafe { &*ptr::from_ref(a).cast() }),
133 (Self::I32(a), AnyArrayDType::I32) => Some(unsafe { &*ptr::from_ref(a).cast() }),
134 (Self::I64(a), AnyArrayDType::I64) => Some(unsafe { &*ptr::from_ref(a).cast() }),
135 (Self::F32(a), AnyArrayDType::F32) => Some(unsafe { &*ptr::from_ref(a).cast() }),
136 (Self::F64(a), AnyArrayDType::F64) => Some(unsafe { &*ptr::from_ref(a).cast() }),
137 (_self, _dtype) => None,
138 }
139 }
140
141 #[must_use]
142 pub fn as_typed_mut<U: ArrayDType>(&mut self) -> Option<&mut ArrayBase<U::RawData<T>, IxDyn>> {
145 #[expect(unsafe_code)]
146 match (self, U::DTYPE) {
149 (Self::U8(a), AnyArrayDType::U8) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
150 (Self::U16(a), AnyArrayDType::U16) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
151 (Self::U32(a), AnyArrayDType::U32) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
152 (Self::U64(a), AnyArrayDType::U64) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
153 (Self::I8(a), AnyArrayDType::I8) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
154 (Self::I16(a), AnyArrayDType::I16) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
155 (Self::I32(a), AnyArrayDType::I32) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
156 (Self::I64(a), AnyArrayDType::I64) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
157 (Self::F32(a), AnyArrayDType::F32) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
158 (Self::F64(a), AnyArrayDType::F64) => Some(unsafe { &mut *ptr::from_mut(a).cast() }),
159 (_self, _dtype) => None,
160 }
161 }
162}
163
164impl<T: AnyRawData> AnyArrayBase<T>
165where
166 T::U8: Data,
167 T::U16: Data,
168 T::U32: Data,
169 T::U64: Data,
170 T::I8: Data,
171 T::I16: Data,
172 T::I32: Data,
173 T::I64: Data,
174 T::F32: Data,
175 T::F64: Data,
176{
177 #[must_use]
178 pub fn view(&self) -> AnyArrayView<'_> {
180 match self {
181 Self::U8(a) => AnyArrayView::U8(a.view()),
182 Self::U16(a) => AnyArrayView::U16(a.view()),
183 Self::U32(a) => AnyArrayView::U32(a.view()),
184 Self::U64(a) => AnyArrayView::U64(a.view()),
185 Self::I8(a) => AnyArrayView::I8(a.view()),
186 Self::I16(a) => AnyArrayView::I16(a.view()),
187 Self::I32(a) => AnyArrayView::I32(a.view()),
188 Self::I64(a) => AnyArrayView::I64(a.view()),
189 Self::F32(a) => AnyArrayView::F32(a.view()),
190 Self::F64(a) => AnyArrayView::F64(a.view()),
191 }
192 }
193
194 #[must_use]
195 pub fn cow(&self) -> AnyCowArray<'_> {
197 match self {
198 Self::U8(a) => AnyCowArray::U8(a.into()),
199 Self::U16(a) => AnyCowArray::U16(a.into()),
200 Self::U32(a) => AnyCowArray::U32(a.into()),
201 Self::U64(a) => AnyCowArray::U64(a.into()),
202 Self::I8(a) => AnyCowArray::I8(a.into()),
203 Self::I16(a) => AnyCowArray::I16(a.into()),
204 Self::I32(a) => AnyCowArray::I32(a.into()),
205 Self::I64(a) => AnyCowArray::I64(a.into()),
206 Self::F32(a) => AnyCowArray::F32(a.into()),
207 Self::F64(a) => AnyCowArray::F64(a.into()),
208 }
209 }
210
211 #[must_use]
212 pub fn into_owned(self) -> AnyArray {
215 match self {
216 Self::U8(a) => AnyArray::U8(a.into_owned()),
217 Self::U16(a) => AnyArray::U16(a.into_owned()),
218 Self::U32(a) => AnyArray::U32(a.into_owned()),
219 Self::U64(a) => AnyArray::U64(a.into_owned()),
220 Self::I8(a) => AnyArray::I8(a.into_owned()),
221 Self::I16(a) => AnyArray::I16(a.into_owned()),
222 Self::I32(a) => AnyArray::I32(a.into_owned()),
223 Self::I64(a) => AnyArray::I64(a.into_owned()),
224 Self::F32(a) => AnyArray::F32(a.into_owned()),
225 Self::F64(a) => AnyArray::F64(a.into_owned()),
226 }
227 }
228
229 #[must_use]
230 pub fn as_bytes(&self) -> Cow<'_, [u8]> {
238 fn array_into_bytes<T: Copy, S: Data<Elem = T>>(x: &ArrayBase<S, IxDyn>) -> Cow<'_, [u8]> {
239 #[expect(clippy::option_if_let_else)]
240 if let Some(x) = x.as_slice() {
241 #[expect(unsafe_code)]
242 Cow::Borrowed(unsafe {
246 std::slice::from_raw_parts(x.as_ptr().cast::<u8>(), std::mem::size_of_val(x))
247 })
248 } else {
249 let x = x.into_iter().copied().collect::<Vec<T>>();
250 let mut x = ManuallyDrop::new(x);
251 let (ptr, len, capacity) = (x.as_mut_ptr(), x.len(), x.capacity());
252 #[expect(unsafe_code)]
253 let x = unsafe {
258 Vec::from_raw_parts(
259 ptr.cast::<u8>(),
260 len * std::mem::size_of::<T>(),
261 capacity * std::mem::size_of::<T>(),
262 )
263 };
264 Cow::Owned(x)
265 }
266 }
267
268 match self {
269 Self::U8(a) => array_into_bytes(a),
270 Self::U16(a) => array_into_bytes(a),
271 Self::U32(a) => array_into_bytes(a),
272 Self::U64(a) => array_into_bytes(a),
273 Self::I8(a) => array_into_bytes(a),
274 Self::I16(a) => array_into_bytes(a),
275 Self::I32(a) => array_into_bytes(a),
276 Self::I64(a) => array_into_bytes(a),
277 Self::F32(a) => array_into_bytes(a),
278 Self::F64(a) => array_into_bytes(a),
279 }
280 }
281}
282
283impl<T: AnyRawData> AnyArrayBase<T>
284where
285 T::U8: DataMut,
286 T::U16: DataMut,
287 T::U32: DataMut,
288 T::U64: DataMut,
289 T::I8: DataMut,
290 T::I16: DataMut,
291 T::I32: DataMut,
292 T::I64: DataMut,
293 T::F32: DataMut,
294 T::F64: DataMut,
295{
296 #[must_use]
297 pub fn view_mut(&mut self) -> AnyArrayViewMut<'_> {
299 match self {
300 Self::U8(a) => AnyArrayViewMut::U8(a.view_mut()),
301 Self::U16(a) => AnyArrayViewMut::U16(a.view_mut()),
302 Self::U32(a) => AnyArrayViewMut::U32(a.view_mut()),
303 Self::U64(a) => AnyArrayViewMut::U64(a.view_mut()),
304 Self::I8(a) => AnyArrayViewMut::I8(a.view_mut()),
305 Self::I16(a) => AnyArrayViewMut::I16(a.view_mut()),
306 Self::I32(a) => AnyArrayViewMut::I32(a.view_mut()),
307 Self::I64(a) => AnyArrayViewMut::I64(a.view_mut()),
308 Self::F32(a) => AnyArrayViewMut::F32(a.view_mut()),
309 Self::F64(a) => AnyArrayViewMut::F64(a.view_mut()),
310 }
311 }
312
313 #[must_use]
314 pub fn with_bytes_mut<O>(&mut self, with: impl FnOnce(&mut [u8]) -> O) -> O {
323 fn array_with_bytes_mut<T: ArrayDType, S: DataMut<Elem = T>, O>(
324 x: &mut ArrayBase<S, IxDyn>,
325 with: impl FnOnce(&mut [u8]) -> O,
326 ) -> O {
327 #[expect(unsafe_code)]
328 x.with_slice_mut(|x| {
332 with(unsafe {
333 std::slice::from_raw_parts_mut(
334 x.as_mut_ptr().cast::<u8>(),
335 std::mem::size_of_val(x),
336 )
337 })
338 })
339 }
340
341 match self {
342 Self::U8(a) => array_with_bytes_mut(a, with),
343 Self::U16(a) => array_with_bytes_mut(a, with),
344 Self::U32(a) => array_with_bytes_mut(a, with),
345 Self::U64(a) => array_with_bytes_mut(a, with),
346 Self::I8(a) => array_with_bytes_mut(a, with),
347 Self::I16(a) => array_with_bytes_mut(a, with),
348 Self::I32(a) => array_with_bytes_mut(a, with),
349 Self::I64(a) => array_with_bytes_mut(a, with),
350 Self::F32(a) => array_with_bytes_mut(a, with),
351 Self::F64(a) => array_with_bytes_mut(a, with),
352 }
353 }
354
355 pub fn assign<U: AnyRawData>(
365 &mut self,
366 src: &AnyArrayBase<U>,
367 ) -> Result<(), AnyArrayAssignError>
368 where
369 U::U8: Data,
370 U::U16: Data,
371 U::U32: Data,
372 U::U64: Data,
373 U::I8: Data,
374 U::I16: Data,
375 U::I32: Data,
376 U::I64: Data,
377 U::F32: Data,
378 U::F64: Data,
379 {
380 fn shape_checked_assign<
381 T: Copy,
382 S1: Data<Elem = T>,
383 S2: DataMut<Elem = T>,
384 D1: Dimension,
385 D2: Dimension,
386 >(
387 src: &ArrayBase<S1, D1>,
388 dst: &mut ArrayBase<S2, D2>,
389 ) -> Result<(), AnyArrayAssignError> {
390 #[expect(clippy::unit_arg)]
391 if src.shape() == dst.shape() {
392 Ok(dst.assign(src))
393 } else {
394 Err(AnyArrayAssignError::ShapeMismatch {
395 src: src.shape().to_vec(),
396 dst: dst.shape().to_vec(),
397 })
398 }
399 }
400
401 match (src, self) {
402 (AnyArrayBase::U8(src), Self::U8(dst)) => shape_checked_assign(src, dst),
403 (AnyArrayBase::U16(src), Self::U16(dst)) => shape_checked_assign(src, dst),
404 (AnyArrayBase::U32(src), Self::U32(dst)) => shape_checked_assign(src, dst),
405 (AnyArrayBase::U64(src), Self::U64(dst)) => shape_checked_assign(src, dst),
406 (AnyArrayBase::I8(src), Self::I8(dst)) => shape_checked_assign(src, dst),
407 (AnyArrayBase::I16(src), Self::I16(dst)) => shape_checked_assign(src, dst),
408 (AnyArrayBase::I32(src), Self::I32(dst)) => shape_checked_assign(src, dst),
409 (AnyArrayBase::I64(src), Self::I64(dst)) => shape_checked_assign(src, dst),
410 (AnyArrayBase::F32(src), Self::F32(dst)) => shape_checked_assign(src, dst),
411 (AnyArrayBase::F64(src), Self::F64(dst)) => shape_checked_assign(src, dst),
412 (src, dst) => Err(AnyArrayAssignError::DTypeMismatch {
413 src: src.dtype(),
414 dst: dst.dtype(),
415 }),
416 }
417 }
418}
419
420impl AnyArray {
421 #[must_use]
422 pub fn zeros(dtype: AnyArrayDType, shape: &[usize]) -> Self {
424 match dtype {
425 AnyArrayDType::U8 => Self::U8(ArrayD::zeros(shape)),
426 AnyArrayDType::U16 => Self::U16(ArrayD::zeros(shape)),
427 AnyArrayDType::U32 => Self::U32(ArrayD::zeros(shape)),
428 AnyArrayDType::U64 => Self::U64(ArrayD::zeros(shape)),
429 AnyArrayDType::I8 => Self::I8(ArrayD::zeros(shape)),
430 AnyArrayDType::I16 => Self::I16(ArrayD::zeros(shape)),
431 AnyArrayDType::I32 => Self::I32(ArrayD::zeros(shape)),
432 AnyArrayDType::I64 => Self::I64(ArrayD::zeros(shape)),
433 AnyArrayDType::F32 => Self::F32(ArrayD::zeros(shape)),
434 AnyArrayDType::F64 => Self::F64(ArrayD::zeros(shape)),
435 }
436 }
437
438 pub fn with_zeros_bytes<T>(
445 dtype: AnyArrayDType,
446 shape: &[usize],
447 with: impl FnOnce(&mut [u8]) -> T,
448 ) -> (Self, T) {
449 fn standard_array_as_bytes_mut<T: Copy>(x: &mut ArrayD<T>) -> &mut [u8] {
450 #[expect(unsafe_code)]
451 unsafe {
457 std::slice::from_raw_parts_mut(
458 x.as_mut_ptr().cast::<u8>(),
459 x.len() * std::mem::size_of::<T>(),
460 )
461 }
462 }
463
464 let mut array = Self::zeros(dtype, shape);
465
466 let result = match &mut array {
467 Self::U8(a) => with(standard_array_as_bytes_mut(a)),
468 Self::U16(a) => with(standard_array_as_bytes_mut(a)),
469 Self::U32(a) => with(standard_array_as_bytes_mut(a)),
470 Self::U64(a) => with(standard_array_as_bytes_mut(a)),
471 Self::I8(a) => with(standard_array_as_bytes_mut(a)),
472 Self::I16(a) => with(standard_array_as_bytes_mut(a)),
473 Self::I32(a) => with(standard_array_as_bytes_mut(a)),
474 Self::I64(a) => with(standard_array_as_bytes_mut(a)),
475 Self::F32(a) => with(standard_array_as_bytes_mut(a)),
476 Self::F64(a) => with(standard_array_as_bytes_mut(a)),
477 };
478
479 (array, result)
480 }
481
482 #[must_use]
483 pub fn into_cow(self) -> AnyCowArray<'static> {
485 match self {
486 Self::U8(array) => AnyCowArray::U8(array.into()),
487 Self::U16(array) => AnyCowArray::U16(array.into()),
488 Self::U32(array) => AnyCowArray::U32(array.into()),
489 Self::U64(array) => AnyCowArray::U64(array.into()),
490 Self::I8(array) => AnyCowArray::I8(array.into()),
491 Self::I16(array) => AnyCowArray::I16(array.into()),
492 Self::I32(array) => AnyCowArray::I32(array.into()),
493 Self::I64(array) => AnyCowArray::I64(array.into()),
494 Self::F32(array) => AnyCowArray::F32(array.into()),
495 Self::F64(array) => AnyCowArray::F64(array.into()),
496 }
497 }
498}
499
500impl<T: AnyRawData> Clone for AnyArrayBase<T>
501where
502 T::U8: RawDataClone,
503 T::U16: RawDataClone,
504 T::U32: RawDataClone,
505 T::U64: RawDataClone,
506 T::I8: RawDataClone,
507 T::I16: RawDataClone,
508 T::I32: RawDataClone,
509 T::I64: RawDataClone,
510 T::F32: RawDataClone,
511 T::F64: RawDataClone,
512{
513 fn clone(&self) -> Self {
514 match self {
515 Self::U8(a) => Self::U8(a.clone()),
516 Self::U16(a) => Self::U16(a.clone()),
517 Self::U32(a) => Self::U32(a.clone()),
518 Self::U64(a) => Self::U64(a.clone()),
519 Self::I8(a) => Self::I8(a.clone()),
520 Self::I16(a) => Self::I16(a.clone()),
521 Self::I32(a) => Self::I32(a.clone()),
522 Self::I64(a) => Self::I64(a.clone()),
523 Self::F32(a) => Self::F32(a.clone()),
524 Self::F64(a) => Self::F64(a.clone()),
525 }
526 }
527}
528
529impl<T: AnyRawData> fmt::Debug for AnyArrayBase<T>
530where
531 T::U8: Data,
532 T::U16: Data,
533 T::U32: Data,
534 T::U64: Data,
535 T::I8: Data,
536 T::I16: Data,
537 T::I32: Data,
538 T::I64: Data,
539 T::F32: Data,
540 T::F64: Data,
541{
542 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
543 match self {
544 Self::U8(a) => fmt.debug_tuple("U8").field(a).finish(),
545 Self::U16(a) => fmt.debug_tuple("U16").field(a).finish(),
546 Self::U32(a) => fmt.debug_tuple("U32").field(a).finish(),
547 Self::U64(a) => fmt.debug_tuple("U64").field(a).finish(),
548 Self::I8(a) => fmt.debug_tuple("I8").field(a).finish(),
549 Self::I16(a) => fmt.debug_tuple("I16").field(a).finish(),
550 Self::I32(a) => fmt.debug_tuple("I32").field(a).finish(),
551 Self::I64(a) => fmt.debug_tuple("I64").field(a).finish(),
552 Self::F32(a) => fmt.debug_tuple("F32").field(a).finish(),
553 Self::F64(a) => fmt.debug_tuple("F64").field(a).finish(),
554 }
555 }
556}
557
558impl<T: AnyRawData> PartialEq for AnyArrayBase<T>
559where
560 T::U8: Data,
561 T::U16: Data,
562 T::U32: Data,
563 T::U64: Data,
564 T::I8: Data,
565 T::I16: Data,
566 T::I32: Data,
567 T::I64: Data,
568 T::F32: Data,
569 T::F64: Data,
570{
571 fn eq(&self, other: &Self) -> bool {
572 match (self, other) {
573 (Self::U8(l), Self::U8(r)) => l == r,
574 (Self::U16(l), Self::U16(r)) => l == r,
575 (Self::U32(l), Self::U32(r)) => l == r,
576 (Self::U64(l), Self::U64(r)) => l == r,
577 (Self::I8(l), Self::I8(r)) => l == r,
578 (Self::I16(l), Self::I16(r)) => l == r,
579 (Self::I32(l), Self::I32(r)) => l == r,
580 (Self::I64(l), Self::I64(r)) => l == r,
581 (Self::F32(l), Self::F32(r)) => l == r,
582 (Self::F64(l), Self::F64(r)) => l == r,
583 _ => false,
584 }
585 }
586}
587
588pub trait ArrayDataMutExt<T: ArrayDType>: sealed::SealedArrayDataMutExt {
594 #[must_use]
595 fn with_slice_mut<O>(&mut self, with: impl FnOnce(&mut [T]) -> O) -> O;
604}
605
606impl<T: ArrayDType, S: DataMut<Elem = T>, D: Dimension> ArrayDataMutExt<T> for ArrayBase<S, D> {
607 fn with_slice_mut<O>(&mut self, with: impl FnOnce(&mut [T]) -> O) -> O {
608 if let Some(slice) = self.as_slice_mut() {
609 with(slice)
610 } else {
611 let mut vec: Vec<T> = self.into_iter().map(|x| *x).collect::<Vec<T>>();
612
613 let result = with(vec.as_mut_slice());
614
615 self.iter_mut().zip(vec).for_each(|(x, x_new)| *x = x_new);
616 result
617 }
618 }
619}
620
621impl<T: ArrayDType, S: DataMut<Elem = T>, D: Dimension> sealed::SealedArrayDataMutExt
622 for ArrayBase<S, D>
623{
624}
625
626#[expect(missing_docs)]
628pub trait AnyRawData {
629 type U8: RawData<Elem = u8>;
630 type U16: RawData<Elem = u16>;
631 type U32: RawData<Elem = u32>;
632 type U64: RawData<Elem = u64>;
633 type I8: RawData<Elem = i8>;
634 type I16: RawData<Elem = i16>;
635 type I32: RawData<Elem = i32>;
636 type I64: RawData<Elem = i64>;
637 type F32: RawData<Elem = f32>;
638 type F64: RawData<Elem = f64>;
639}
640
641impl<
642 T: RawDataSubst<u8>
643 + RawDataSubst<u16>
644 + RawDataSubst<u32>
645 + RawDataSubst<u64>
646 + RawDataSubst<i8>
647 + RawDataSubst<i16>
648 + RawDataSubst<i32>
649 + RawDataSubst<i64>
650 + RawDataSubst<f32>
651 + RawDataSubst<f64>,
652> AnyRawData for T
653{
654 type U8 = <T as RawDataSubst<u8>>::Output;
655 type U16 = <T as RawDataSubst<u16>>::Output;
656 type U32 = <T as RawDataSubst<u32>>::Output;
657 type U64 = <T as RawDataSubst<u64>>::Output;
658 type I8 = <T as RawDataSubst<i8>>::Output;
659 type I16 = <T as RawDataSubst<i16>>::Output;
660 type I32 = <T as RawDataSubst<i32>>::Output;
661 type I64 = <T as RawDataSubst<i64>>::Output;
662 type F32 = <T as RawDataSubst<f32>>::Output;
663 type F64 = <T as RawDataSubst<f64>>::Output;
664}
665
666#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
668#[schemars(extend("enum" = [
669 "u8", "uint8",
670 "u16", "uint16",
671 "u32", "uint32",
672 "u64", "uint64",
673 "i8", "int8",
674 "i16", "int16",
675 "i32", "int32",
676 "i64", "int64",
677 "f32", "float32",
678 "f64", "float64"
679]))]
680#[non_exhaustive]
681#[expect(missing_docs)]
682pub enum AnyArrayDType {
683 #[serde(rename = "u8", alias = "uint8")]
684 U8,
685 #[serde(rename = "u16", alias = "uint16")]
686 U16,
687 #[serde(rename = "u32", alias = "uint32")]
688 U32,
689 #[serde(rename = "u64", alias = "uint64")]
690 U64,
691 #[serde(rename = "i8", alias = "int8")]
692 I8,
693 #[serde(rename = "i16", alias = "int16")]
694 I16,
695 #[serde(rename = "i32", alias = "int32")]
696 I32,
697 #[serde(rename = "i64", alias = "int64")]
698 I64,
699 #[serde(rename = "f32", alias = "float32")]
700 F32,
701 #[serde(rename = "f64", alias = "float64")]
702 F64,
703}
704
705impl AnyArrayDType {
706 #[must_use]
707 pub const fn of<T: ArrayDType>() -> Self {
709 T::DTYPE
710 }
711
712 #[must_use]
713 pub const fn to_binary(self) -> Self {
721 match self {
722 Self::U8 | Self::I8 => Self::U8,
723 Self::U16 | Self::I16 => Self::U16,
724 Self::U32 | Self::I32 | Self::F32 => Self::U32,
725 Self::U64 | Self::I64 | Self::F64 => Self::U64,
726 }
727 }
728
729 #[must_use]
730 pub const fn size(self) -> usize {
732 match self {
733 Self::U8 => std::mem::size_of::<u8>(),
734 Self::U16 => std::mem::size_of::<u16>(),
735 Self::U32 => std::mem::size_of::<u32>(),
736 Self::U64 => std::mem::size_of::<u64>(),
737 Self::I8 => std::mem::size_of::<i8>(),
738 Self::I16 => std::mem::size_of::<i16>(),
739 Self::I32 => std::mem::size_of::<i32>(),
740 Self::I64 => std::mem::size_of::<i64>(),
741 Self::F32 => std::mem::size_of::<f32>(),
742 Self::F64 => std::mem::size_of::<f64>(),
743 }
744 }
745}
746
747impl fmt::Display for AnyArrayDType {
748 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
749 fmt.write_str(match self {
750 Self::U8 => "u8",
751 Self::U16 => "u16",
752 Self::U32 => "u32",
753 Self::U64 => "u64",
754 Self::I8 => "i8",
755 Self::I16 => "i16",
756 Self::I32 => "i32",
757 Self::I64 => "i64",
758 Self::F32 => "f32",
759 Self::F64 => "f64",
760 })
761 }
762}
763
764pub trait ArrayDType: sealed::SealedArrayDType {
766 const DTYPE: AnyArrayDType;
768
769 type RawData<T: AnyRawData>: RawData<Elem = Self>;
771}
772
773macro_rules! array_dtype {
774 ($($dtype:ident($ty:ty)),*) => {
775 $(
776 impl sealed::SealedArrayDType for $ty {}
777
778 impl ArrayDType for $ty {
779 const DTYPE: AnyArrayDType = AnyArrayDType::$dtype;
780
781 type RawData<T: AnyRawData> = T::$dtype;
782 }
783 )*
784 };
785}
786
787array_dtype! {
788 U8(u8), U16(u16), U32(u32), U64(u64),
789 I8(i8), I16(i16), I32(i32), I64(i64),
790 F32(f32), F64(f64)
791}
792
793#[derive(Debug, Error)]
794pub enum AnyArrayAssignError {
796 #[error("cannot assign from mismatching {src} array to {dst}")]
798 DTypeMismatch {
799 src: AnyArrayDType,
801 dst: AnyArrayDType,
803 },
804 #[error("cannot assign from array of shape {src:?} to one of shape {dst:?}")]
806 ShapeMismatch {
807 src: Vec<usize>,
809 dst: Vec<usize>,
811 },
812}
813
814mod sealed {
815 pub trait SealedArrayDType: Copy {}
816
817 pub trait SealedArrayDataMutExt {}
818}