1use std::{
22 borrow::Cow,
23 fmt::{self, Debug},
24};
25
26use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, IxDyn};
27use numcodecs::{
28 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
29 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
30};
31use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
32use serde::{
33 de::{MapAccess, Visitor},
34 ser::SerializeMap,
35 Deserialize, Deserializer, Serialize, Serializer,
36};
37use thiserror::Error;
38
39#[derive(Clone, Serialize, Deserialize, JsonSchema)]
40#[serde(deny_unknown_fields)]
41pub struct SwizzleReshapeCodec {
50 pub axes: Vec<AxisGroup>,
66 #[serde(default, rename = "_version")]
68 pub version: StaticCodecVersion<1, 0, 0>,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
72#[serde(untagged)]
73#[serde(deny_unknown_fields)]
74pub enum AxisGroup {
76 Group(Vec<Axis>),
78 AllRest(Rest),
80}
81
82#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
83#[serde(untagged)]
84#[serde(deny_unknown_fields)]
85pub enum Axis {
87 Index(usize),
89 MergedRest(Rest),
91}
92
93impl Codec for SwizzleReshapeCodec {
94 type Error = SwizzleReshapeCodecError;
95
96 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
97 match data {
98 AnyCowArray::U8(data) => Ok(AnyArray::U8(swizzle_reshape(data, &self.axes)?)),
99 AnyCowArray::U16(data) => Ok(AnyArray::U16(swizzle_reshape(data, &self.axes)?)),
100 AnyCowArray::U32(data) => Ok(AnyArray::U32(swizzle_reshape(data, &self.axes)?)),
101 AnyCowArray::U64(data) => Ok(AnyArray::U64(swizzle_reshape(data, &self.axes)?)),
102 AnyCowArray::I8(data) => Ok(AnyArray::I8(swizzle_reshape(data, &self.axes)?)),
103 AnyCowArray::I16(data) => Ok(AnyArray::I16(swizzle_reshape(data, &self.axes)?)),
104 AnyCowArray::I32(data) => Ok(AnyArray::I32(swizzle_reshape(data, &self.axes)?)),
105 AnyCowArray::I64(data) => Ok(AnyArray::I64(swizzle_reshape(data, &self.axes)?)),
106 AnyCowArray::F32(data) => Ok(AnyArray::F32(swizzle_reshape(data, &self.axes)?)),
107 AnyCowArray::F64(data) => Ok(AnyArray::F64(swizzle_reshape(data, &self.axes)?)),
108 data => Err(SwizzleReshapeCodecError::UnsupportedDtype(data.dtype())),
109 }
110 }
111
112 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
113 match encoded {
114 AnyCowArray::U8(encoded) => {
115 Ok(AnyArray::U8(undo_swizzle_reshape(encoded, &self.axes)?))
116 }
117 AnyCowArray::U16(encoded) => {
118 Ok(AnyArray::U16(undo_swizzle_reshape(encoded, &self.axes)?))
119 }
120 AnyCowArray::U32(encoded) => {
121 Ok(AnyArray::U32(undo_swizzle_reshape(encoded, &self.axes)?))
122 }
123 AnyCowArray::U64(encoded) => {
124 Ok(AnyArray::U64(undo_swizzle_reshape(encoded, &self.axes)?))
125 }
126 AnyCowArray::I8(encoded) => {
127 Ok(AnyArray::I8(undo_swizzle_reshape(encoded, &self.axes)?))
128 }
129 AnyCowArray::I16(encoded) => {
130 Ok(AnyArray::I16(undo_swizzle_reshape(encoded, &self.axes)?))
131 }
132 AnyCowArray::I32(encoded) => {
133 Ok(AnyArray::I32(undo_swizzle_reshape(encoded, &self.axes)?))
134 }
135 AnyCowArray::I64(encoded) => {
136 Ok(AnyArray::I64(undo_swizzle_reshape(encoded, &self.axes)?))
137 }
138 AnyCowArray::F32(encoded) => {
139 Ok(AnyArray::F32(undo_swizzle_reshape(encoded, &self.axes)?))
140 }
141 AnyCowArray::F64(encoded) => {
142 Ok(AnyArray::F64(undo_swizzle_reshape(encoded, &self.axes)?))
143 }
144 encoded => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
145 }
146 }
147
148 fn decode_into(
149 &self,
150 encoded: AnyArrayView,
151 decoded: AnyArrayViewMut,
152 ) -> Result<(), Self::Error> {
153 match (encoded, decoded) {
154 (AnyArrayView::U8(encoded), AnyArrayViewMut::U8(decoded)) => {
155 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
156 }
157 (AnyArrayView::U16(encoded), AnyArrayViewMut::U16(decoded)) => {
158 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
159 }
160 (AnyArrayView::U32(encoded), AnyArrayViewMut::U32(decoded)) => {
161 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
162 }
163 (AnyArrayView::U64(encoded), AnyArrayViewMut::U64(decoded)) => {
164 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
165 }
166 (AnyArrayView::I8(encoded), AnyArrayViewMut::I8(decoded)) => {
167 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
168 }
169 (AnyArrayView::I16(encoded), AnyArrayViewMut::I16(decoded)) => {
170 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
171 }
172 (AnyArrayView::I32(encoded), AnyArrayViewMut::I32(decoded)) => {
173 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
174 }
175 (AnyArrayView::I64(encoded), AnyArrayViewMut::I64(decoded)) => {
176 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
177 }
178 (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
179 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
180 }
181 (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
182 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
183 }
184 (encoded, decoded) if encoded.dtype() != decoded.dtype() => {
185 Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
186 source: AnyArrayAssignError::DTypeMismatch {
187 src: encoded.dtype(),
188 dst: decoded.dtype(),
189 },
190 })
191 }
192 (encoded, _decoded) => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
193 }
194 }
195}
196
197impl StaticCodec for SwizzleReshapeCodec {
198 const CODEC_ID: &'static str = "swizzle-reshape.rs";
199
200 type Config<'de> = Self;
201
202 fn from_config(config: Self::Config<'_>) -> Self {
203 config
204 }
205
206 fn get_config(&self) -> StaticCodecConfig<Self> {
207 StaticCodecConfig::from(self)
208 }
209}
210
211#[derive(Debug, Error)]
212pub enum SwizzleReshapeCodecError {
214 #[error("SwizzleReshape does not support the dtype {0}")]
216 UnsupportedDtype(AnyArrayDType),
217 #[error("SwizzleReshape cannot decode from an array with merged axes without receiving an output array to decode into")]
220 CannotDecodeMergedAxes,
221 #[error("SwizzleReshape cannot encode or decode with an invalid axis {index} for an array with {ndim} dimensions")]
224 InvalidAxisIndex {
225 index: usize,
227 ndim: usize,
229 },
230 #[error("SwizzleReshape can only encode or decode with an axis permutation {axes:?} that contains every axis of an array with {ndim} dimensions index exactly once")]
234 InvalidAxisPermutation {
235 axes: Vec<AxisGroup>,
237 ndim: usize,
239 },
240 #[error("SwizzleReshape cannot encode or decode with an axis permutation that contains multiple rest-axes markers")]
243 MultipleRestAxes,
244 #[error("SwizzleReshape cannot decode into the provided array")]
246 MismatchedDecodeIntoArray {
247 #[from]
249 source: AnyArrayAssignError,
250 },
251}
252
253#[expect(clippy::missing_panics_doc)]
254pub fn swizzle_reshape<T: Copy, S: Data<Elem = T>>(
266 data: ArrayBase<S, IxDyn>,
267 axes: &[AxisGroup],
268) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
269 let SwizzleReshapeAxes {
270 permutation,
271 swizzled_shape,
272 new_shape,
273 } = validate_into_axes_shape(&data, axes)?;
274
275 let swizzled: ArrayBase<S, ndarray::Dim<ndarray::IxDynImpl>> = data.permuted_axes(permutation);
276 assert_eq!(swizzled.shape(), swizzled_shape, "incorrect swizzled shape");
277
278 #[expect(clippy::expect_used)] let reshaped = swizzled
280 .into_owned()
281 .into_shape_clone(new_shape)
282 .expect("new encoding shape should have the correct number of elements");
283
284 Ok(reshaped)
285}
286
287pub fn undo_swizzle_reshape<T: Copy, S: Data<Elem = T>>(
305 encoded: ArrayBase<S, IxDyn>,
306 axes: &[AxisGroup],
307) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
308 if !axes.iter().all(|axis| match axis {
309 AxisGroup::Group(axes) => matches!(axes.as_slice(), [Axis::Index(_)]),
310 AxisGroup::AllRest(Rest) => true,
311 }) {
312 return Err(SwizzleReshapeCodecError::CannotDecodeMergedAxes);
313 }
314
315 let SwizzleReshapeAxes { permutation, .. } = validate_into_axes_shape(&encoded, axes)?;
316
317 let mut inverse_permutation = vec![0; permutation.len()];
318 #[expect(clippy::indexing_slicing)] for (i, p) in permutation.into_iter().enumerate() {
320 inverse_permutation[p] = i;
321 }
322
323 let unshaped = encoded;
325 let unswizzled = unshaped.permuted_axes(inverse_permutation);
326
327 Ok(unswizzled.into_owned())
328}
329
330#[expect(clippy::missing_panics_doc)]
331#[expect(clippy::needless_pass_by_value)]
332pub fn undo_swizzle_reshape_into<T: Copy>(
348 encoded: ArrayView<T, IxDyn>,
349 mut decoded: ArrayViewMut<T, IxDyn>,
350 axes: &[AxisGroup],
351) -> Result<(), SwizzleReshapeCodecError> {
352 let SwizzleReshapeAxes {
353 permutation,
354 swizzled_shape,
355 new_shape,
356 } = validate_into_axes_shape(&decoded, axes)?;
357
358 if encoded.shape() != new_shape {
359 return Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
360 source: AnyArrayAssignError::ShapeMismatch {
361 src: encoded.shape().to_vec(),
362 dst: new_shape,
363 },
364 });
365 }
366
367 let mut inverse_permutation = vec![0; decoded.ndim()];
368 #[expect(clippy::indexing_slicing)] for (i, p) in permutation.into_iter().enumerate() {
370 inverse_permutation[p] = i;
371 }
372
373 #[expect(clippy::expect_used)] let unshaped = encoded
375 .to_shape(swizzled_shape)
376 .expect("new decoding shape should have the correct number of elements");
377 let unswizzled = unshaped.permuted_axes(inverse_permutation);
378
379 decoded.assign(&unswizzled);
380
381 Ok(())
382}
383
384struct SwizzleReshapeAxes {
385 permutation: Vec<usize>,
386 swizzled_shape: Vec<usize>,
387 new_shape: Vec<usize>,
388}
389
390fn validate_into_axes_shape<T, S: Data<Elem = T>>(
391 array: &ArrayBase<S, IxDyn>,
392 axes: &[AxisGroup],
393) -> Result<SwizzleReshapeAxes, SwizzleReshapeCodecError> {
394 let mut axis_index_counts = vec![0_usize; array.ndim()];
397
398 let mut has_rest = false;
399
400 for group in axes {
403 match group {
404 AxisGroup::Group(axes) => {
405 for axis in axes {
406 match axis {
407 Axis::Index(index) => {
408 if let Some(c) = axis_index_counts.get_mut(*index) {
409 *c += 1;
410 } else {
411 return Err(SwizzleReshapeCodecError::InvalidAxisIndex {
412 index: *index,
413 ndim: array.ndim(),
414 });
415 }
416 }
417 Axis::MergedRest(Rest) => {
418 if std::mem::replace(&mut has_rest, true) {
419 return Err(SwizzleReshapeCodecError::MultipleRestAxes);
420 }
421 }
422 }
423 }
424 }
425 AxisGroup::AllRest(Rest) => {
426 if std::mem::replace(&mut has_rest, true) {
427 return Err(SwizzleReshapeCodecError::MultipleRestAxes);
428 }
429 }
430 }
431 }
432
433 if !axis_index_counts
437 .iter()
438 .all(|c| if has_rest { *c <= 1 } else { *c == 1 })
439 {
440 return Err(SwizzleReshapeCodecError::InvalidAxisPermutation {
441 axes: axes.to_vec(),
442 ndim: array.ndim(),
443 });
444 }
445
446 let mut axis_permutation = Vec::with_capacity(array.len());
448 let mut permuted_shape = Vec::with_capacity(array.len());
450 let mut grouped_shape = Vec::with_capacity(axes.len());
452
453 for axis in axes {
454 match axis {
455 AxisGroup::Group(axes) => {
458 let mut new_len = 1;
459 for axis in axes {
460 match axis {
461 Axis::Index(index) => {
462 axis_permutation.push(*index);
463 permuted_shape.push(array.len_of(ndarray::Axis(*index)));
464 new_len *= array.len_of(ndarray::Axis(*index));
465 }
466 Axis::MergedRest(Rest) => {
467 for (index, count) in axis_index_counts.iter().enumerate() {
468 if *count == 0 {
469 axis_permutation.push(index);
470 permuted_shape.push(array.len_of(ndarray::Axis(index)));
471 new_len *= array.len_of(ndarray::Axis(index));
472 }
473 }
474 }
475 }
476 }
477 grouped_shape.push(new_len);
478 }
479 AxisGroup::AllRest(Rest) => {
480 for (index, count) in axis_index_counts.iter().enumerate() {
481 if *count == 0 {
482 axis_permutation.push(index);
483 permuted_shape.push(array.len_of(ndarray::Axis(index)));
484 grouped_shape.push(array.len_of(ndarray::Axis(index)));
485 }
486 }
487 }
488 }
489 }
490
491 Ok(SwizzleReshapeAxes {
492 permutation: axis_permutation,
493 swizzled_shape: permuted_shape,
494 new_shape: grouped_shape,
495 })
496}
497
498#[derive(Copy, Clone, Debug)]
499pub struct Rest;
501
502impl Serialize for Rest {
503 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
504 serializer.serialize_map(Some(0))?.end()
505 }
506}
507
508impl<'de> Deserialize<'de> for Rest {
509 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
510 struct RestVisitor;
511
512 impl<'de> Visitor<'de> for RestVisitor {
513 type Value = Rest;
514
515 fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
516 fmt.write_str("an empty map")
517 }
518
519 fn visit_map<A: MapAccess<'de>>(self, _map: A) -> Result<Self::Value, A::Error> {
520 Ok(Rest)
521 }
522 }
523
524 deserializer.deserialize_map(RestVisitor)
525 }
526}
527
528impl JsonSchema for Rest {
529 fn schema_name() -> Cow<'static, str> {
530 Cow::Borrowed("Rest")
531 }
532
533 fn schema_id() -> Cow<'static, str> {
534 Cow::Borrowed(concat!(module_path!(), "::", "Rest"))
535 }
536
537 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
538 json_schema!({
539 "type": "object",
540 "properties": {},
541 "additionalProperties": false,
542 })
543 }
544}
545
546#[cfg(test)]
547#[expect(clippy::expect_used)]
548mod tests {
549 use ndarray::array;
550
551 use super::*;
552
553 #[test]
554 fn identity() {
555 roundtrip(
556 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
557 &[AxisGroup::AllRest(Rest)],
558 &[2, 2, 2],
559 );
560
561 roundtrip(
562 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
563 &[
564 AxisGroup::Group(vec![Axis::Index(0)]),
565 AxisGroup::Group(vec![Axis::Index(1)]),
566 AxisGroup::AllRest(Rest),
567 ],
568 &[2, 2, 2],
569 );
570 roundtrip(
571 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
572 &[
573 AxisGroup::Group(vec![Axis::Index(0)]),
574 AxisGroup::AllRest(Rest),
575 AxisGroup::Group(vec![Axis::Index(2)]),
576 ],
577 &[2, 2, 2],
578 );
579 roundtrip(
580 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
581 &[
582 AxisGroup::AllRest(Rest),
583 AxisGroup::Group(vec![Axis::Index(1)]),
584 AxisGroup::Group(vec![Axis::Index(2)]),
585 ],
586 &[2, 2, 2],
587 );
588
589 roundtrip(
590 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
591 &[
592 AxisGroup::Group(vec![Axis::Index(0)]),
593 AxisGroup::Group(vec![Axis::Index(1)]),
594 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
595 ],
596 &[2, 2, 2],
597 );
598 roundtrip(
599 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
600 &[
601 AxisGroup::Group(vec![Axis::Index(0)]),
602 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
603 AxisGroup::Group(vec![Axis::Index(2)]),
604 ],
605 &[2, 2, 2],
606 );
607 roundtrip(
608 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
609 &[
610 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
611 AxisGroup::Group(vec![Axis::Index(1)]),
612 AxisGroup::Group(vec![Axis::Index(2)]),
613 ],
614 &[2, 2, 2],
615 );
616 }
617
618 #[test]
619 fn swizzle() {
620 roundtrip(
621 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
622 &[
623 AxisGroup::Group(vec![Axis::Index(0)]),
624 AxisGroup::Group(vec![Axis::Index(1)]),
625 AxisGroup::AllRest(Rest),
626 ],
627 &[2, 2, 2],
628 );
629 roundtrip(
630 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
631 &[
632 AxisGroup::Group(vec![Axis::Index(2)]),
633 AxisGroup::AllRest(Rest),
634 AxisGroup::Group(vec![Axis::Index(1)]),
635 ],
636 &[2, 2, 2],
637 );
638 roundtrip(
639 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
640 &[
641 AxisGroup::AllRest(Rest),
642 AxisGroup::Group(vec![Axis::Index(0)]),
643 AxisGroup::Group(vec![Axis::Index(1)]),
644 ],
645 &[2, 2, 2],
646 );
647
648 roundtrip(
649 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
650 &[
651 AxisGroup::Group(vec![Axis::Index(0)]),
652 AxisGroup::Group(vec![Axis::Index(1)]),
653 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
654 ],
655 &[2, 2, 2],
656 );
657 roundtrip(
658 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
659 &[
660 AxisGroup::Group(vec![Axis::Index(2)]),
661 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
662 AxisGroup::Group(vec![Axis::Index(1)]),
663 ],
664 &[2, 2, 2],
665 );
666 roundtrip(
667 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
668 &[
669 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
670 AxisGroup::Group(vec![Axis::Index(0)]),
671 AxisGroup::Group(vec![Axis::Index(1)]),
672 ],
673 &[2, 2, 2],
674 );
675
676 let mut i = 0;
677 roundtrip(
678 Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
679 i += 1;
680 i
681 })
682 .into_dyn(),
683 &[
684 AxisGroup::Group(vec![Axis::Index(4)]),
685 AxisGroup::Group(vec![Axis::Index(0)]),
686 AxisGroup::Group(vec![Axis::Index(3)]),
687 AxisGroup::Group(vec![Axis::Index(2)]),
688 AxisGroup::Group(vec![Axis::Index(1)]),
689 ],
690 &[1, 3, 1, 721, 1440],
691 );
692 }
693
694 #[test]
695 fn collapse() {
696 roundtrip(
697 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
698 &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
699 &[8],
700 );
701
702 roundtrip(
703 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
704 &[AxisGroup::Group(vec![
705 Axis::Index(0),
706 Axis::Index(1),
707 Axis::Index(2),
708 ])],
709 &[8],
710 );
711 roundtrip(
712 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
713 &[AxisGroup::Group(vec![
714 Axis::Index(2),
715 Axis::Index(1),
716 Axis::Index(0),
717 ])],
718 &[8],
719 );
720 roundtrip(
721 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
722 &[AxisGroup::Group(vec![
723 Axis::Index(1),
724 Axis::MergedRest(Rest),
725 ])],
726 &[8],
727 );
728
729 roundtrip(
730 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
731 &[
732 AxisGroup::Group(vec![Axis::Index(0), Axis::Index(1)]),
733 AxisGroup::AllRest(Rest),
734 ],
735 &[4, 2],
736 );
737 roundtrip(
738 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
739 &[
740 AxisGroup::Group(vec![Axis::Index(2)]),
741 AxisGroup::Group(vec![Axis::Index(1), Axis::Index(0)]),
742 ],
743 &[2, 4],
744 );
745 roundtrip(
746 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
747 &[
748 AxisGroup::Group(vec![Axis::Index(1), Axis::MergedRest(Rest)]),
749 AxisGroup::Group(vec![Axis::Index(0), Axis::Index(2)]),
750 ],
751 &[2, 4],
752 );
753
754 roundtrip(
755 array![[1, 2], [3, 4], [5, 6], [7, 8]].into_dyn(),
756 &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
757 &[8],
758 );
759
760 let mut i = 0;
761 roundtrip(
762 Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
763 i += 1;
764 i
765 })
766 .into_dyn(),
767 &[AxisGroup::Group(vec![
768 Axis::Index(4),
769 Axis::Index(0),
770 Axis::Index(3),
771 Axis::Index(2),
772 Axis::Index(1),
773 ])],
774 &[1 * 3 * 1 * 721 * 1440],
775 );
776 }
777
778 #[test]
779 fn extend() {
780 roundtrip(
781 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
782 &[
783 AxisGroup::Group(vec![]),
784 AxisGroup::Group(vec![Axis::Index(0)]),
785 AxisGroup::Group(vec![]),
786 AxisGroup::AllRest(Rest),
787 AxisGroup::Group(vec![]),
788 AxisGroup::Group(vec![Axis::Index(2)]),
789 AxisGroup::Group(vec![]),
790 ],
791 &[1, 2, 1, 2, 1, 2, 1],
792 );
793 }
794
795 #[expect(clippy::needless_pass_by_value)]
796 fn roundtrip(data: Array<i32, IxDyn>, axes: &[AxisGroup], swizzle_shape: &[usize]) {
797 let swizzled = swizzle_reshape(data.view(), axes).expect("swizzle should not fail");
798
799 assert_eq!(swizzled.shape(), swizzle_shape);
800
801 let mut unswizzled = Array::zeros(data.shape());
802 undo_swizzle_reshape_into(swizzled.view(), unswizzled.view_mut(), axes)
803 .expect("unswizzle into should not fail");
804
805 assert_eq!(data, unswizzled);
806
807 if axes.iter().any(|a| matches!(a, AxisGroup::Group(a) if a.len() != 1 || a.iter().any(|a| matches!(a, Axis::MergedRest(Rest))))) {
808 undo_swizzle_reshape(swizzled.view(), axes).expect_err("unswizzle should fail");
809 } else {
810 let unswizzled = undo_swizzle_reshape(swizzled.view(), axes).expect("unswizzle should not fail");
811 assert_eq!(data, unswizzled);
812 }
813 }
814}