1use std::{
22 borrow::Cow,
23 fmt::{self, Debug},
24};
25
26use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, IxDyn};
27use numcodecs::{
28 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
29 Codec, StaticCodec, StaticCodecConfig,
30};
31use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
32use serde::{
33 de::{MapAccess, Visitor},
34 ser::SerializeMap,
35 Deserialize, Deserializer, Serialize, Serializer,
36};
37use thiserror::Error;
38
39#[derive(Clone, Serialize, Deserialize, JsonSchema)]
40#[serde(deny_unknown_fields)]
41pub struct SwizzleReshapeCodec {
50 pub axes: Vec<AxisGroup>,
66}
67
68#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
69#[serde(untagged)]
70#[serde(deny_unknown_fields)]
71pub enum AxisGroup {
73 Group(Vec<Axis>),
75 AllRest(Rest),
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
80#[serde(untagged)]
81#[serde(deny_unknown_fields)]
82pub enum Axis {
84 Index(usize),
86 MergedRest(Rest),
88}
89
90impl Codec for SwizzleReshapeCodec {
91 type Error = SwizzleReshapeCodecError;
92
93 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
94 match data {
95 AnyCowArray::U8(data) => Ok(AnyArray::U8(swizzle_reshape(data, &self.axes)?)),
96 AnyCowArray::U16(data) => Ok(AnyArray::U16(swizzle_reshape(data, &self.axes)?)),
97 AnyCowArray::U32(data) => Ok(AnyArray::U32(swizzle_reshape(data, &self.axes)?)),
98 AnyCowArray::U64(data) => Ok(AnyArray::U64(swizzle_reshape(data, &self.axes)?)),
99 AnyCowArray::I8(data) => Ok(AnyArray::I8(swizzle_reshape(data, &self.axes)?)),
100 AnyCowArray::I16(data) => Ok(AnyArray::I16(swizzle_reshape(data, &self.axes)?)),
101 AnyCowArray::I32(data) => Ok(AnyArray::I32(swizzle_reshape(data, &self.axes)?)),
102 AnyCowArray::I64(data) => Ok(AnyArray::I64(swizzle_reshape(data, &self.axes)?)),
103 AnyCowArray::F32(data) => Ok(AnyArray::F32(swizzle_reshape(data, &self.axes)?)),
104 AnyCowArray::F64(data) => Ok(AnyArray::F64(swizzle_reshape(data, &self.axes)?)),
105 data => Err(SwizzleReshapeCodecError::UnsupportedDtype(data.dtype())),
106 }
107 }
108
109 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
110 match encoded {
111 AnyCowArray::U8(encoded) => {
112 Ok(AnyArray::U8(undo_swizzle_reshape(encoded, &self.axes)?))
113 }
114 AnyCowArray::U16(encoded) => {
115 Ok(AnyArray::U16(undo_swizzle_reshape(encoded, &self.axes)?))
116 }
117 AnyCowArray::U32(encoded) => {
118 Ok(AnyArray::U32(undo_swizzle_reshape(encoded, &self.axes)?))
119 }
120 AnyCowArray::U64(encoded) => {
121 Ok(AnyArray::U64(undo_swizzle_reshape(encoded, &self.axes)?))
122 }
123 AnyCowArray::I8(encoded) => {
124 Ok(AnyArray::I8(undo_swizzle_reshape(encoded, &self.axes)?))
125 }
126 AnyCowArray::I16(encoded) => {
127 Ok(AnyArray::I16(undo_swizzle_reshape(encoded, &self.axes)?))
128 }
129 AnyCowArray::I32(encoded) => {
130 Ok(AnyArray::I32(undo_swizzle_reshape(encoded, &self.axes)?))
131 }
132 AnyCowArray::I64(encoded) => {
133 Ok(AnyArray::I64(undo_swizzle_reshape(encoded, &self.axes)?))
134 }
135 AnyCowArray::F32(encoded) => {
136 Ok(AnyArray::F32(undo_swizzle_reshape(encoded, &self.axes)?))
137 }
138 AnyCowArray::F64(encoded) => {
139 Ok(AnyArray::F64(undo_swizzle_reshape(encoded, &self.axes)?))
140 }
141 encoded => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
142 }
143 }
144
145 fn decode_into(
146 &self,
147 encoded: AnyArrayView,
148 decoded: AnyArrayViewMut,
149 ) -> Result<(), Self::Error> {
150 match (encoded, decoded) {
151 (AnyArrayView::U8(encoded), AnyArrayViewMut::U8(decoded)) => {
152 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
153 }
154 (AnyArrayView::U16(encoded), AnyArrayViewMut::U16(decoded)) => {
155 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
156 }
157 (AnyArrayView::U32(encoded), AnyArrayViewMut::U32(decoded)) => {
158 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
159 }
160 (AnyArrayView::U64(encoded), AnyArrayViewMut::U64(decoded)) => {
161 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
162 }
163 (AnyArrayView::I8(encoded), AnyArrayViewMut::I8(decoded)) => {
164 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
165 }
166 (AnyArrayView::I16(encoded), AnyArrayViewMut::I16(decoded)) => {
167 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
168 }
169 (AnyArrayView::I32(encoded), AnyArrayViewMut::I32(decoded)) => {
170 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
171 }
172 (AnyArrayView::I64(encoded), AnyArrayViewMut::I64(decoded)) => {
173 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
174 }
175 (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
176 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
177 }
178 (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
179 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
180 }
181 (encoded, decoded) if encoded.dtype() != decoded.dtype() => {
182 Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
183 source: AnyArrayAssignError::DTypeMismatch {
184 src: encoded.dtype(),
185 dst: decoded.dtype(),
186 },
187 })
188 }
189 (encoded, _decoded) => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
190 }
191 }
192}
193
194impl StaticCodec for SwizzleReshapeCodec {
195 const CODEC_ID: &'static str = "swizzle-reshape";
196
197 type Config<'de> = Self;
198
199 fn from_config(config: Self::Config<'_>) -> Self {
200 config
201 }
202
203 fn get_config(&self) -> StaticCodecConfig<Self> {
204 StaticCodecConfig::from(self)
205 }
206}
207
208#[derive(Debug, Error)]
209pub enum SwizzleReshapeCodecError {
211 #[error("SwizzleReshape does not support the dtype {0}")]
213 UnsupportedDtype(AnyArrayDType),
214 #[error("SwizzleReshape cannot decode from an array with merged axes without receiving an output array to decode into")]
217 CannotDecodeMergedAxes,
218 #[error("SwizzleReshape cannot encode or decode with an invalid axis {index} for an array with {ndim} dimensions")]
221 InvalidAxisIndex {
222 index: usize,
224 ndim: usize,
226 },
227 #[error("SwizzleReshape can only encode or decode with an axis permutation {axes:?} that contains every axis of an array with {ndim} dimensions index exactly once")]
231 InvalidAxisPermutation {
232 axes: Vec<AxisGroup>,
234 ndim: usize,
236 },
237 #[error("SwizzleReshape cannot encode or decode with an axis permutation that contains multiple rest-axes markers")]
240 MultipleRestAxes,
241 #[error("SwizzleReshape cannot decode into the provided array")]
243 MismatchedDecodeIntoArray {
244 #[from]
246 source: AnyArrayAssignError,
247 },
248}
249
250#[expect(clippy::missing_panics_doc)]
251pub fn swizzle_reshape<T: Copy, S: Data<Elem = T>>(
263 data: ArrayBase<S, IxDyn>,
264 axes: &[AxisGroup],
265) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
266 let SwizzleReshapeAxes {
267 permutation,
268 swizzled_shape,
269 new_shape,
270 } = validate_into_axes_shape(&data, axes)?;
271
272 let swizzled: ArrayBase<S, ndarray::Dim<ndarray::IxDynImpl>> = data.permuted_axes(permutation);
273 assert_eq!(swizzled.shape(), swizzled_shape, "incorrect swizzled shape");
274
275 #[expect(clippy::expect_used)] let reshaped = swizzled
277 .into_owned()
278 .into_shape_clone(new_shape)
279 .expect("new encoding shape should have the correct number of elements");
280
281 Ok(reshaped)
282}
283
284pub fn undo_swizzle_reshape<T: Copy, S: Data<Elem = T>>(
302 encoded: ArrayBase<S, IxDyn>,
303 axes: &[AxisGroup],
304) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
305 if !axes.iter().all(|axis| match axis {
306 AxisGroup::Group(axes) => matches!(axes.as_slice(), [Axis::Index(_)]),
307 AxisGroup::AllRest(Rest) => true,
308 }) {
309 return Err(SwizzleReshapeCodecError::CannotDecodeMergedAxes);
310 }
311
312 let SwizzleReshapeAxes { permutation, .. } = validate_into_axes_shape(&encoded, axes)?;
313
314 let mut inverse_permutation = vec![0; permutation.len()];
315 #[expect(clippy::indexing_slicing)] for (i, p) in permutation.into_iter().enumerate() {
317 inverse_permutation[p] = i;
318 }
319
320 let unshaped = encoded;
322 let unswizzled = unshaped.permuted_axes(inverse_permutation);
323
324 Ok(unswizzled.into_owned())
325}
326
327#[expect(clippy::missing_panics_doc)]
328#[expect(clippy::needless_pass_by_value)]
329pub fn undo_swizzle_reshape_into<T: Copy>(
345 encoded: ArrayView<T, IxDyn>,
346 mut decoded: ArrayViewMut<T, IxDyn>,
347 axes: &[AxisGroup],
348) -> Result<(), SwizzleReshapeCodecError> {
349 let SwizzleReshapeAxes {
350 permutation,
351 swizzled_shape,
352 new_shape,
353 } = validate_into_axes_shape(&decoded, axes)?;
354
355 if encoded.shape() != new_shape {
356 return Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
357 source: AnyArrayAssignError::ShapeMismatch {
358 src: encoded.shape().to_vec(),
359 dst: new_shape,
360 },
361 });
362 }
363
364 let mut inverse_permutation = vec![0; decoded.ndim()];
365 #[expect(clippy::indexing_slicing)] for (i, p) in permutation.into_iter().enumerate() {
367 inverse_permutation[p] = i;
368 }
369
370 #[expect(clippy::expect_used)] let unshaped = encoded
372 .to_shape(swizzled_shape)
373 .expect("new decoding shape should have the correct number of elements");
374 let unswizzled = unshaped.permuted_axes(inverse_permutation);
375
376 decoded.assign(&unswizzled);
377
378 Ok(())
379}
380
381struct SwizzleReshapeAxes {
382 permutation: Vec<usize>,
383 swizzled_shape: Vec<usize>,
384 new_shape: Vec<usize>,
385}
386
387fn validate_into_axes_shape<T, S: Data<Elem = T>>(
388 array: &ArrayBase<S, IxDyn>,
389 axes: &[AxisGroup],
390) -> Result<SwizzleReshapeAxes, SwizzleReshapeCodecError> {
391 let mut axis_index_counts = vec![0_usize; array.ndim()];
394
395 let mut has_rest = false;
396
397 for group in axes {
400 match group {
401 AxisGroup::Group(axes) => {
402 for axis in axes {
403 match axis {
404 Axis::Index(index) => {
405 if let Some(c) = axis_index_counts.get_mut(*index) {
406 *c += 1;
407 } else {
408 return Err(SwizzleReshapeCodecError::InvalidAxisIndex {
409 index: *index,
410 ndim: array.ndim(),
411 });
412 }
413 }
414 Axis::MergedRest(Rest) => {
415 if std::mem::replace(&mut has_rest, true) {
416 return Err(SwizzleReshapeCodecError::MultipleRestAxes);
417 }
418 }
419 }
420 }
421 }
422 AxisGroup::AllRest(Rest) => {
423 if std::mem::replace(&mut has_rest, true) {
424 return Err(SwizzleReshapeCodecError::MultipleRestAxes);
425 }
426 }
427 }
428 }
429
430 if !axis_index_counts
434 .iter()
435 .all(|c| if has_rest { *c <= 1 } else { *c == 1 })
436 {
437 return Err(SwizzleReshapeCodecError::InvalidAxisPermutation {
438 axes: axes.to_vec(),
439 ndim: array.ndim(),
440 });
441 }
442
443 let mut axis_permutation = Vec::with_capacity(array.len());
445 let mut permuted_shape = Vec::with_capacity(array.len());
447 let mut grouped_shape = Vec::with_capacity(axes.len());
449
450 for axis in axes {
451 match axis {
452 AxisGroup::Group(axes) => {
455 let mut new_len = 1;
456 for axis in axes {
457 match axis {
458 Axis::Index(index) => {
459 axis_permutation.push(*index);
460 permuted_shape.push(array.len_of(ndarray::Axis(*index)));
461 new_len *= array.len_of(ndarray::Axis(*index));
462 }
463 Axis::MergedRest(Rest) => {
464 for (index, count) in axis_index_counts.iter().enumerate() {
465 if *count == 0 {
466 axis_permutation.push(index);
467 permuted_shape.push(array.len_of(ndarray::Axis(index)));
468 new_len *= array.len_of(ndarray::Axis(index));
469 }
470 }
471 }
472 }
473 }
474 grouped_shape.push(new_len);
475 }
476 AxisGroup::AllRest(Rest) => {
477 for (index, count) in axis_index_counts.iter().enumerate() {
478 if *count == 0 {
479 axis_permutation.push(index);
480 permuted_shape.push(array.len_of(ndarray::Axis(index)));
481 grouped_shape.push(array.len_of(ndarray::Axis(index)));
482 }
483 }
484 }
485 }
486 }
487
488 Ok(SwizzleReshapeAxes {
489 permutation: axis_permutation,
490 swizzled_shape: permuted_shape,
491 new_shape: grouped_shape,
492 })
493}
494
495#[derive(Copy, Clone, Debug)]
496pub struct Rest;
498
499impl Serialize for Rest {
500 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
501 serializer.serialize_map(Some(0))?.end()
502 }
503}
504
505impl<'de> Deserialize<'de> for Rest {
506 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
507 struct RestVisitor;
508
509 impl<'de> Visitor<'de> for RestVisitor {
510 type Value = Rest;
511
512 fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
513 fmt.write_str("an empty map")
514 }
515
516 fn visit_map<A: MapAccess<'de>>(self, _map: A) -> Result<Self::Value, A::Error> {
517 Ok(Rest)
518 }
519 }
520
521 deserializer.deserialize_map(RestVisitor)
522 }
523}
524
525impl JsonSchema for Rest {
526 fn schema_name() -> Cow<'static, str> {
527 Cow::Borrowed("Rest")
528 }
529
530 fn schema_id() -> Cow<'static, str> {
531 Cow::Borrowed(concat!(module_path!(), "::", "Rest"))
532 }
533
534 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
535 json_schema!({
536 "type": "object",
537 "properties": {},
538 "additionalProperties": false,
539 })
540 }
541}
542
543#[cfg(test)]
544#[expect(clippy::expect_used)]
545mod tests {
546 use ndarray::array;
547
548 use super::*;
549
550 #[test]
551 fn identity() {
552 roundtrip(
553 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
554 &[AxisGroup::AllRest(Rest)],
555 &[2, 2, 2],
556 );
557
558 roundtrip(
559 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
560 &[
561 AxisGroup::Group(vec![Axis::Index(0)]),
562 AxisGroup::Group(vec![Axis::Index(1)]),
563 AxisGroup::AllRest(Rest),
564 ],
565 &[2, 2, 2],
566 );
567 roundtrip(
568 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
569 &[
570 AxisGroup::Group(vec![Axis::Index(0)]),
571 AxisGroup::AllRest(Rest),
572 AxisGroup::Group(vec![Axis::Index(2)]),
573 ],
574 &[2, 2, 2],
575 );
576 roundtrip(
577 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
578 &[
579 AxisGroup::AllRest(Rest),
580 AxisGroup::Group(vec![Axis::Index(1)]),
581 AxisGroup::Group(vec![Axis::Index(2)]),
582 ],
583 &[2, 2, 2],
584 );
585
586 roundtrip(
587 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
588 &[
589 AxisGroup::Group(vec![Axis::Index(0)]),
590 AxisGroup::Group(vec![Axis::Index(1)]),
591 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
592 ],
593 &[2, 2, 2],
594 );
595 roundtrip(
596 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
597 &[
598 AxisGroup::Group(vec![Axis::Index(0)]),
599 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
600 AxisGroup::Group(vec![Axis::Index(2)]),
601 ],
602 &[2, 2, 2],
603 );
604 roundtrip(
605 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
606 &[
607 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
608 AxisGroup::Group(vec![Axis::Index(1)]),
609 AxisGroup::Group(vec![Axis::Index(2)]),
610 ],
611 &[2, 2, 2],
612 );
613 }
614
615 #[test]
616 fn swizzle() {
617 roundtrip(
618 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
619 &[
620 AxisGroup::Group(vec![Axis::Index(0)]),
621 AxisGroup::Group(vec![Axis::Index(1)]),
622 AxisGroup::AllRest(Rest),
623 ],
624 &[2, 2, 2],
625 );
626 roundtrip(
627 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
628 &[
629 AxisGroup::Group(vec![Axis::Index(2)]),
630 AxisGroup::AllRest(Rest),
631 AxisGroup::Group(vec![Axis::Index(1)]),
632 ],
633 &[2, 2, 2],
634 );
635 roundtrip(
636 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
637 &[
638 AxisGroup::AllRest(Rest),
639 AxisGroup::Group(vec![Axis::Index(0)]),
640 AxisGroup::Group(vec![Axis::Index(1)]),
641 ],
642 &[2, 2, 2],
643 );
644
645 roundtrip(
646 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
647 &[
648 AxisGroup::Group(vec![Axis::Index(0)]),
649 AxisGroup::Group(vec![Axis::Index(1)]),
650 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
651 ],
652 &[2, 2, 2],
653 );
654 roundtrip(
655 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
656 &[
657 AxisGroup::Group(vec![Axis::Index(2)]),
658 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
659 AxisGroup::Group(vec![Axis::Index(1)]),
660 ],
661 &[2, 2, 2],
662 );
663 roundtrip(
664 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
665 &[
666 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
667 AxisGroup::Group(vec![Axis::Index(0)]),
668 AxisGroup::Group(vec![Axis::Index(1)]),
669 ],
670 &[2, 2, 2],
671 );
672
673 let mut i = 0;
674 roundtrip(
675 Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
676 i += 1;
677 i
678 })
679 .into_dyn(),
680 &[
681 AxisGroup::Group(vec![Axis::Index(4)]),
682 AxisGroup::Group(vec![Axis::Index(0)]),
683 AxisGroup::Group(vec![Axis::Index(3)]),
684 AxisGroup::Group(vec![Axis::Index(2)]),
685 AxisGroup::Group(vec![Axis::Index(1)]),
686 ],
687 &[1, 3, 1, 721, 1440],
688 );
689 }
690
691 #[test]
692 fn collapse() {
693 roundtrip(
694 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
695 &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
696 &[8],
697 );
698
699 roundtrip(
700 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
701 &[AxisGroup::Group(vec![
702 Axis::Index(0),
703 Axis::Index(1),
704 Axis::Index(2),
705 ])],
706 &[8],
707 );
708 roundtrip(
709 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
710 &[AxisGroup::Group(vec![
711 Axis::Index(2),
712 Axis::Index(1),
713 Axis::Index(0),
714 ])],
715 &[8],
716 );
717 roundtrip(
718 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
719 &[AxisGroup::Group(vec![
720 Axis::Index(1),
721 Axis::MergedRest(Rest),
722 ])],
723 &[8],
724 );
725
726 roundtrip(
727 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
728 &[
729 AxisGroup::Group(vec![Axis::Index(0), Axis::Index(1)]),
730 AxisGroup::AllRest(Rest),
731 ],
732 &[4, 2],
733 );
734 roundtrip(
735 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
736 &[
737 AxisGroup::Group(vec![Axis::Index(2)]),
738 AxisGroup::Group(vec![Axis::Index(1), Axis::Index(0)]),
739 ],
740 &[2, 4],
741 );
742 roundtrip(
743 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
744 &[
745 AxisGroup::Group(vec![Axis::Index(1), Axis::MergedRest(Rest)]),
746 AxisGroup::Group(vec![Axis::Index(0), Axis::Index(2)]),
747 ],
748 &[2, 4],
749 );
750
751 roundtrip(
752 array![[1, 2], [3, 4], [5, 6], [7, 8]].into_dyn(),
753 &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
754 &[8],
755 );
756
757 let mut i = 0;
758 roundtrip(
759 Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
760 i += 1;
761 i
762 })
763 .into_dyn(),
764 &[AxisGroup::Group(vec![
765 Axis::Index(4),
766 Axis::Index(0),
767 Axis::Index(3),
768 Axis::Index(2),
769 Axis::Index(1),
770 ])],
771 &[1 * 3 * 1 * 721 * 1440],
772 );
773 }
774
775 #[test]
776 fn extend() {
777 roundtrip(
778 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
779 &[
780 AxisGroup::Group(vec![]),
781 AxisGroup::Group(vec![Axis::Index(0)]),
782 AxisGroup::Group(vec![]),
783 AxisGroup::AllRest(Rest),
784 AxisGroup::Group(vec![]),
785 AxisGroup::Group(vec![Axis::Index(2)]),
786 AxisGroup::Group(vec![]),
787 ],
788 &[1, 2, 1, 2, 1, 2, 1],
789 );
790 }
791
792 #[expect(clippy::needless_pass_by_value)]
793 fn roundtrip(data: Array<i32, IxDyn>, axes: &[AxisGroup], swizzle_shape: &[usize]) {
794 let swizzled = swizzle_reshape(data.view(), axes).expect("swizzle should not fail");
795
796 assert_eq!(swizzled.shape(), swizzle_shape);
797
798 let mut unswizzled = Array::zeros(data.shape());
799 undo_swizzle_reshape_into(swizzled.view(), unswizzled.view_mut(), axes)
800 .expect("unswizzle into should not fail");
801
802 assert_eq!(data, unswizzled);
803
804 if axes.iter().any(|a| matches!(a, AxisGroup::Group(a) if a.len() != 1 || a.iter().any(|a| matches!(a, Axis::MergedRest(Rest))))) {
805 undo_swizzle_reshape(swizzled.view(), axes).expect_err("unswizzle should fail");
806 } else {
807 let unswizzled = undo_swizzle_reshape(swizzled.view(), axes).expect("unswizzle should not fail");
808 assert_eq!(data, unswizzled);
809 }
810 }
811}