1use std::{
22 borrow::Cow,
23 fmt::{self, Debug},
24};
25
26use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, IxDyn};
27use numcodecs::{
28 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
29 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
30};
31use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
32use serde::{
33 Deserialize, Deserializer, Serialize, Serializer,
34 de::{MapAccess, Visitor},
35 ser::SerializeMap,
36};
37use thiserror::Error;
38
39#[derive(Clone, Serialize, Deserialize, JsonSchema)]
40#[serde(deny_unknown_fields)]
41pub struct SwizzleReshapeCodec {
50 pub axes: Vec<AxisGroup>,
66 #[serde(default, rename = "_version")]
68 pub version: StaticCodecVersion<1, 0, 0>,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
72#[serde(untagged)]
73#[serde(deny_unknown_fields)]
74pub enum AxisGroup {
76 Group(Vec<Axis>),
78 AllRest(Rest),
80}
81
82#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
83#[serde(untagged)]
84#[serde(deny_unknown_fields)]
85pub enum Axis {
87 Index(usize),
89 MergedRest(Rest),
91}
92
93impl Codec for SwizzleReshapeCodec {
94 type Error = SwizzleReshapeCodecError;
95
96 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
97 match data {
98 AnyCowArray::U8(data) => Ok(AnyArray::U8(swizzle_reshape(data, &self.axes)?)),
99 AnyCowArray::U16(data) => Ok(AnyArray::U16(swizzle_reshape(data, &self.axes)?)),
100 AnyCowArray::U32(data) => Ok(AnyArray::U32(swizzle_reshape(data, &self.axes)?)),
101 AnyCowArray::U64(data) => Ok(AnyArray::U64(swizzle_reshape(data, &self.axes)?)),
102 AnyCowArray::I8(data) => Ok(AnyArray::I8(swizzle_reshape(data, &self.axes)?)),
103 AnyCowArray::I16(data) => Ok(AnyArray::I16(swizzle_reshape(data, &self.axes)?)),
104 AnyCowArray::I32(data) => Ok(AnyArray::I32(swizzle_reshape(data, &self.axes)?)),
105 AnyCowArray::I64(data) => Ok(AnyArray::I64(swizzle_reshape(data, &self.axes)?)),
106 AnyCowArray::F32(data) => Ok(AnyArray::F32(swizzle_reshape(data, &self.axes)?)),
107 AnyCowArray::F64(data) => Ok(AnyArray::F64(swizzle_reshape(data, &self.axes)?)),
108 data => Err(SwizzleReshapeCodecError::UnsupportedDtype(data.dtype())),
109 }
110 }
111
112 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
113 match encoded {
114 AnyCowArray::U8(encoded) => {
115 Ok(AnyArray::U8(undo_swizzle_reshape(encoded, &self.axes)?))
116 }
117 AnyCowArray::U16(encoded) => {
118 Ok(AnyArray::U16(undo_swizzle_reshape(encoded, &self.axes)?))
119 }
120 AnyCowArray::U32(encoded) => {
121 Ok(AnyArray::U32(undo_swizzle_reshape(encoded, &self.axes)?))
122 }
123 AnyCowArray::U64(encoded) => {
124 Ok(AnyArray::U64(undo_swizzle_reshape(encoded, &self.axes)?))
125 }
126 AnyCowArray::I8(encoded) => {
127 Ok(AnyArray::I8(undo_swizzle_reshape(encoded, &self.axes)?))
128 }
129 AnyCowArray::I16(encoded) => {
130 Ok(AnyArray::I16(undo_swizzle_reshape(encoded, &self.axes)?))
131 }
132 AnyCowArray::I32(encoded) => {
133 Ok(AnyArray::I32(undo_swizzle_reshape(encoded, &self.axes)?))
134 }
135 AnyCowArray::I64(encoded) => {
136 Ok(AnyArray::I64(undo_swizzle_reshape(encoded, &self.axes)?))
137 }
138 AnyCowArray::F32(encoded) => {
139 Ok(AnyArray::F32(undo_swizzle_reshape(encoded, &self.axes)?))
140 }
141 AnyCowArray::F64(encoded) => {
142 Ok(AnyArray::F64(undo_swizzle_reshape(encoded, &self.axes)?))
143 }
144 encoded => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
145 }
146 }
147
148 fn decode_into(
149 &self,
150 encoded: AnyArrayView,
151 decoded: AnyArrayViewMut,
152 ) -> Result<(), Self::Error> {
153 match (encoded, decoded) {
154 (AnyArrayView::U8(encoded), AnyArrayViewMut::U8(decoded)) => {
155 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
156 }
157 (AnyArrayView::U16(encoded), AnyArrayViewMut::U16(decoded)) => {
158 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
159 }
160 (AnyArrayView::U32(encoded), AnyArrayViewMut::U32(decoded)) => {
161 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
162 }
163 (AnyArrayView::U64(encoded), AnyArrayViewMut::U64(decoded)) => {
164 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
165 }
166 (AnyArrayView::I8(encoded), AnyArrayViewMut::I8(decoded)) => {
167 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
168 }
169 (AnyArrayView::I16(encoded), AnyArrayViewMut::I16(decoded)) => {
170 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
171 }
172 (AnyArrayView::I32(encoded), AnyArrayViewMut::I32(decoded)) => {
173 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
174 }
175 (AnyArrayView::I64(encoded), AnyArrayViewMut::I64(decoded)) => {
176 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
177 }
178 (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
179 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
180 }
181 (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
182 undo_swizzle_reshape_into(encoded, decoded, &self.axes)
183 }
184 (encoded, decoded) if encoded.dtype() != decoded.dtype() => {
185 Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
186 source: AnyArrayAssignError::DTypeMismatch {
187 src: encoded.dtype(),
188 dst: decoded.dtype(),
189 },
190 })
191 }
192 (encoded, _decoded) => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
193 }
194 }
195}
196
197impl StaticCodec for SwizzleReshapeCodec {
198 const CODEC_ID: &'static str = "swizzle-reshape.rs";
199
200 type Config<'de> = Self;
201
202 fn from_config(config: Self::Config<'_>) -> Self {
203 config
204 }
205
206 fn get_config(&self) -> StaticCodecConfig<Self> {
207 StaticCodecConfig::from(self)
208 }
209}
210
211#[derive(Debug, Error)]
212pub enum SwizzleReshapeCodecError {
214 #[error("SwizzleReshape does not support the dtype {0}")]
216 UnsupportedDtype(AnyArrayDType),
217 #[error(
220 "SwizzleReshape cannot decode from an array with merged axes without receiving an output array to decode into"
221 )]
222 CannotDecodeMergedAxes,
223 #[error(
226 "SwizzleReshape cannot encode or decode with an invalid axis {index} for an array with {ndim} dimensions"
227 )]
228 InvalidAxisIndex {
229 index: usize,
231 ndim: usize,
233 },
234 #[error(
238 "SwizzleReshape can only encode or decode with an axis permutation {axes:?} that contains every axis of an array with {ndim} dimensions index exactly once"
239 )]
240 InvalidAxisPermutation {
241 axes: Vec<AxisGroup>,
243 ndim: usize,
245 },
246 #[error(
249 "SwizzleReshape cannot encode or decode with an axis permutation that contains multiple rest-axes markers"
250 )]
251 MultipleRestAxes,
252 #[error("SwizzleReshape cannot decode into the provided array")]
254 MismatchedDecodeIntoArray {
255 #[from]
257 source: AnyArrayAssignError,
258 },
259}
260
261#[expect(clippy::missing_panics_doc)]
262pub fn swizzle_reshape<T: Copy, S: Data<Elem = T>>(
274 data: ArrayBase<S, IxDyn>,
275 axes: &[AxisGroup],
276) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
277 let SwizzleReshapeAxes {
278 permutation,
279 swizzled_shape,
280 new_shape,
281 } = validate_into_axes_shape(&data, axes)?;
282
283 let swizzled: ArrayBase<S, ndarray::Dim<ndarray::IxDynImpl>> = data.permuted_axes(permutation);
284 assert_eq!(swizzled.shape(), swizzled_shape, "incorrect swizzled shape");
285
286 #[expect(clippy::expect_used)] let reshaped = swizzled
288 .into_owned()
289 .into_shape_clone(new_shape)
290 .expect("new encoding shape should have the correct number of elements");
291
292 Ok(reshaped)
293}
294
295pub fn undo_swizzle_reshape<T: Copy, S: Data<Elem = T>>(
313 encoded: ArrayBase<S, IxDyn>,
314 axes: &[AxisGroup],
315) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
316 if !axes.iter().all(|axis| match axis {
317 AxisGroup::Group(axes) => matches!(axes.as_slice(), [Axis::Index(_)]),
318 AxisGroup::AllRest(Rest) => true,
319 }) {
320 return Err(SwizzleReshapeCodecError::CannotDecodeMergedAxes);
321 }
322
323 let SwizzleReshapeAxes { permutation, .. } = validate_into_axes_shape(&encoded, axes)?;
324
325 let mut inverse_permutation = vec![0; permutation.len()];
326 #[expect(clippy::indexing_slicing)] for (i, p) in permutation.into_iter().enumerate() {
328 inverse_permutation[p] = i;
329 }
330
331 let unshaped = encoded;
333 let unswizzled = unshaped.permuted_axes(inverse_permutation);
334
335 Ok(unswizzled.into_owned())
336}
337
338#[expect(clippy::missing_panics_doc)]
339#[expect(clippy::needless_pass_by_value)]
340pub fn undo_swizzle_reshape_into<T: Copy>(
356 encoded: ArrayView<T, IxDyn>,
357 mut decoded: ArrayViewMut<T, IxDyn>,
358 axes: &[AxisGroup],
359) -> Result<(), SwizzleReshapeCodecError> {
360 let SwizzleReshapeAxes {
361 permutation,
362 swizzled_shape,
363 new_shape,
364 } = validate_into_axes_shape(&decoded, axes)?;
365
366 if encoded.shape() != new_shape {
367 return Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
368 source: AnyArrayAssignError::ShapeMismatch {
369 src: encoded.shape().to_vec(),
370 dst: new_shape,
371 },
372 });
373 }
374
375 let mut inverse_permutation = vec![0; decoded.ndim()];
376 #[expect(clippy::indexing_slicing)] for (i, p) in permutation.into_iter().enumerate() {
378 inverse_permutation[p] = i;
379 }
380
381 #[expect(clippy::expect_used)] let unshaped = encoded
383 .to_shape(swizzled_shape)
384 .expect("new decoding shape should have the correct number of elements");
385 let unswizzled = unshaped.permuted_axes(inverse_permutation);
386
387 decoded.assign(&unswizzled);
388
389 Ok(())
390}
391
392struct SwizzleReshapeAxes {
393 permutation: Vec<usize>,
394 swizzled_shape: Vec<usize>,
395 new_shape: Vec<usize>,
396}
397
398fn validate_into_axes_shape<T, S: Data<Elem = T>>(
399 array: &ArrayBase<S, IxDyn>,
400 axes: &[AxisGroup],
401) -> Result<SwizzleReshapeAxes, SwizzleReshapeCodecError> {
402 let mut axis_index_counts = vec![0_usize; array.ndim()];
405
406 let mut has_rest = false;
407
408 for group in axes {
411 match group {
412 AxisGroup::Group(axes) => {
413 for axis in axes {
414 match axis {
415 Axis::Index(index) => {
416 if let Some(c) = axis_index_counts.get_mut(*index) {
417 *c += 1;
418 } else {
419 return Err(SwizzleReshapeCodecError::InvalidAxisIndex {
420 index: *index,
421 ndim: array.ndim(),
422 });
423 }
424 }
425 Axis::MergedRest(Rest) => {
426 if std::mem::replace(&mut has_rest, true) {
427 return Err(SwizzleReshapeCodecError::MultipleRestAxes);
428 }
429 }
430 }
431 }
432 }
433 AxisGroup::AllRest(Rest) => {
434 if std::mem::replace(&mut has_rest, true) {
435 return Err(SwizzleReshapeCodecError::MultipleRestAxes);
436 }
437 }
438 }
439 }
440
441 if !axis_index_counts
445 .iter()
446 .all(|c| if has_rest { *c <= 1 } else { *c == 1 })
447 {
448 return Err(SwizzleReshapeCodecError::InvalidAxisPermutation {
449 axes: axes.to_vec(),
450 ndim: array.ndim(),
451 });
452 }
453
454 let mut axis_permutation = Vec::with_capacity(array.len());
456 let mut permuted_shape = Vec::with_capacity(array.len());
458 let mut grouped_shape = Vec::with_capacity(axes.len());
460
461 for axis in axes {
462 match axis {
463 AxisGroup::Group(axes) => {
466 let mut new_len = 1;
467 for axis in axes {
468 match axis {
469 Axis::Index(index) => {
470 axis_permutation.push(*index);
471 permuted_shape.push(array.len_of(ndarray::Axis(*index)));
472 new_len *= array.len_of(ndarray::Axis(*index));
473 }
474 Axis::MergedRest(Rest) => {
475 for (index, count) in axis_index_counts.iter().enumerate() {
476 if *count == 0 {
477 axis_permutation.push(index);
478 permuted_shape.push(array.len_of(ndarray::Axis(index)));
479 new_len *= array.len_of(ndarray::Axis(index));
480 }
481 }
482 }
483 }
484 }
485 grouped_shape.push(new_len);
486 }
487 AxisGroup::AllRest(Rest) => {
488 for (index, count) in axis_index_counts.iter().enumerate() {
489 if *count == 0 {
490 axis_permutation.push(index);
491 permuted_shape.push(array.len_of(ndarray::Axis(index)));
492 grouped_shape.push(array.len_of(ndarray::Axis(index)));
493 }
494 }
495 }
496 }
497 }
498
499 Ok(SwizzleReshapeAxes {
500 permutation: axis_permutation,
501 swizzled_shape: permuted_shape,
502 new_shape: grouped_shape,
503 })
504}
505
506#[derive(Copy, Clone, Debug)]
507pub struct Rest;
509
510impl Serialize for Rest {
511 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
512 serializer.serialize_map(Some(0))?.end()
513 }
514}
515
516impl<'de> Deserialize<'de> for Rest {
517 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
518 struct RestVisitor;
519
520 impl<'de> Visitor<'de> for RestVisitor {
521 type Value = Rest;
522
523 fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
524 fmt.write_str("an empty map")
525 }
526
527 fn visit_map<A: MapAccess<'de>>(self, _map: A) -> Result<Self::Value, A::Error> {
528 Ok(Rest)
529 }
530 }
531
532 deserializer.deserialize_map(RestVisitor)
533 }
534}
535
536impl JsonSchema for Rest {
537 fn schema_name() -> Cow<'static, str> {
538 Cow::Borrowed("Rest")
539 }
540
541 fn schema_id() -> Cow<'static, str> {
542 Cow::Borrowed(concat!(module_path!(), "::", "Rest"))
543 }
544
545 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
546 json_schema!({
547 "type": "object",
548 "properties": {},
549 "additionalProperties": false,
550 })
551 }
552}
553
554#[cfg(test)]
555#[expect(clippy::expect_used)]
556mod tests {
557 use ndarray::array;
558
559 use super::*;
560
561 #[test]
562 fn identity() {
563 roundtrip(
564 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
565 &[AxisGroup::AllRest(Rest)],
566 &[2, 2, 2],
567 );
568
569 roundtrip(
570 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
571 &[
572 AxisGroup::Group(vec![Axis::Index(0)]),
573 AxisGroup::Group(vec![Axis::Index(1)]),
574 AxisGroup::AllRest(Rest),
575 ],
576 &[2, 2, 2],
577 );
578 roundtrip(
579 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
580 &[
581 AxisGroup::Group(vec![Axis::Index(0)]),
582 AxisGroup::AllRest(Rest),
583 AxisGroup::Group(vec![Axis::Index(2)]),
584 ],
585 &[2, 2, 2],
586 );
587 roundtrip(
588 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
589 &[
590 AxisGroup::AllRest(Rest),
591 AxisGroup::Group(vec![Axis::Index(1)]),
592 AxisGroup::Group(vec![Axis::Index(2)]),
593 ],
594 &[2, 2, 2],
595 );
596
597 roundtrip(
598 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
599 &[
600 AxisGroup::Group(vec![Axis::Index(0)]),
601 AxisGroup::Group(vec![Axis::Index(1)]),
602 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
603 ],
604 &[2, 2, 2],
605 );
606 roundtrip(
607 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
608 &[
609 AxisGroup::Group(vec![Axis::Index(0)]),
610 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
611 AxisGroup::Group(vec![Axis::Index(2)]),
612 ],
613 &[2, 2, 2],
614 );
615 roundtrip(
616 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
617 &[
618 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
619 AxisGroup::Group(vec![Axis::Index(1)]),
620 AxisGroup::Group(vec![Axis::Index(2)]),
621 ],
622 &[2, 2, 2],
623 );
624 }
625
626 #[test]
627 fn swizzle() {
628 roundtrip(
629 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
630 &[
631 AxisGroup::Group(vec![Axis::Index(0)]),
632 AxisGroup::Group(vec![Axis::Index(1)]),
633 AxisGroup::AllRest(Rest),
634 ],
635 &[2, 2, 2],
636 );
637 roundtrip(
638 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
639 &[
640 AxisGroup::Group(vec![Axis::Index(2)]),
641 AxisGroup::AllRest(Rest),
642 AxisGroup::Group(vec![Axis::Index(1)]),
643 ],
644 &[2, 2, 2],
645 );
646 roundtrip(
647 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
648 &[
649 AxisGroup::AllRest(Rest),
650 AxisGroup::Group(vec![Axis::Index(0)]),
651 AxisGroup::Group(vec![Axis::Index(1)]),
652 ],
653 &[2, 2, 2],
654 );
655
656 roundtrip(
657 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
658 &[
659 AxisGroup::Group(vec![Axis::Index(0)]),
660 AxisGroup::Group(vec![Axis::Index(1)]),
661 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
662 ],
663 &[2, 2, 2],
664 );
665 roundtrip(
666 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
667 &[
668 AxisGroup::Group(vec![Axis::Index(2)]),
669 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
670 AxisGroup::Group(vec![Axis::Index(1)]),
671 ],
672 &[2, 2, 2],
673 );
674 roundtrip(
675 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
676 &[
677 AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
678 AxisGroup::Group(vec![Axis::Index(0)]),
679 AxisGroup::Group(vec![Axis::Index(1)]),
680 ],
681 &[2, 2, 2],
682 );
683
684 let mut i = 0;
685 roundtrip(
686 Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
687 i += 1;
688 i
689 })
690 .into_dyn(),
691 &[
692 AxisGroup::Group(vec![Axis::Index(4)]),
693 AxisGroup::Group(vec![Axis::Index(0)]),
694 AxisGroup::Group(vec![Axis::Index(3)]),
695 AxisGroup::Group(vec![Axis::Index(2)]),
696 AxisGroup::Group(vec![Axis::Index(1)]),
697 ],
698 &[1, 3, 1, 721, 1440],
699 );
700 }
701
702 #[test]
703 fn collapse() {
704 roundtrip(
705 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
706 &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
707 &[8],
708 );
709
710 roundtrip(
711 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
712 &[AxisGroup::Group(vec![
713 Axis::Index(0),
714 Axis::Index(1),
715 Axis::Index(2),
716 ])],
717 &[8],
718 );
719 roundtrip(
720 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
721 &[AxisGroup::Group(vec![
722 Axis::Index(2),
723 Axis::Index(1),
724 Axis::Index(0),
725 ])],
726 &[8],
727 );
728 roundtrip(
729 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
730 &[AxisGroup::Group(vec![
731 Axis::Index(1),
732 Axis::MergedRest(Rest),
733 ])],
734 &[8],
735 );
736
737 roundtrip(
738 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
739 &[
740 AxisGroup::Group(vec![Axis::Index(0), Axis::Index(1)]),
741 AxisGroup::AllRest(Rest),
742 ],
743 &[4, 2],
744 );
745 roundtrip(
746 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
747 &[
748 AxisGroup::Group(vec![Axis::Index(2)]),
749 AxisGroup::Group(vec![Axis::Index(1), Axis::Index(0)]),
750 ],
751 &[2, 4],
752 );
753 roundtrip(
754 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
755 &[
756 AxisGroup::Group(vec![Axis::Index(1), Axis::MergedRest(Rest)]),
757 AxisGroup::Group(vec![Axis::Index(0), Axis::Index(2)]),
758 ],
759 &[2, 4],
760 );
761
762 roundtrip(
763 array![[1, 2], [3, 4], [5, 6], [7, 8]].into_dyn(),
764 &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
765 &[8],
766 );
767
768 let mut i = 0;
769 roundtrip(
770 Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
771 i += 1;
772 i
773 })
774 .into_dyn(),
775 &[AxisGroup::Group(vec![
776 Axis::Index(4),
777 Axis::Index(0),
778 Axis::Index(3),
779 Axis::Index(2),
780 Axis::Index(1),
781 ])],
782 &[1 * 3 * 1 * 721 * 1440],
783 );
784 }
785
786 #[test]
787 fn extend() {
788 roundtrip(
789 array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
790 &[
791 AxisGroup::Group(vec![]),
792 AxisGroup::Group(vec![Axis::Index(0)]),
793 AxisGroup::Group(vec![]),
794 AxisGroup::AllRest(Rest),
795 AxisGroup::Group(vec![]),
796 AxisGroup::Group(vec![Axis::Index(2)]),
797 AxisGroup::Group(vec![]),
798 ],
799 &[1, 2, 1, 2, 1, 2, 1],
800 );
801 }
802
803 #[expect(clippy::needless_pass_by_value)]
804 fn roundtrip(data: Array<i32, IxDyn>, axes: &[AxisGroup], swizzle_shape: &[usize]) {
805 let swizzled = swizzle_reshape(data.view(), axes).expect("swizzle should not fail");
806
807 assert_eq!(swizzled.shape(), swizzle_shape);
808
809 let mut unswizzled = Array::zeros(data.shape());
810 undo_swizzle_reshape_into(swizzled.view(), unswizzled.view_mut(), axes)
811 .expect("unswizzle into should not fail");
812
813 assert_eq!(data, unswizzled);
814
815 if axes.iter().any(|a| matches!(a, AxisGroup::Group(a) if a.len() != 1 || a.iter().any(|a| matches!(a, Axis::MergedRest(Rest))))) {
816 undo_swizzle_reshape(swizzled.view(), axes).expect_err("unswizzle should fail");
817 } else {
818 let unswizzled = undo_swizzle_reshape(swizzled.view(), axes).expect("unswizzle should not fail");
819 assert_eq!(data, unswizzled);
820 }
821 }
822}