numcodecs_swizzle_reshape/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-swizzle-reshape
10//! [crates.io]: https://crates.io/crates/numcodecs-swizzle-reshape
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-swizzle-reshape
13//! [docs.rs]: https://docs.rs/numcodecs-swizzle-reshape/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_swizzle_reshape
17//!
18//! Array axis swizzle and reshape codec implementation for the [`numcodecs`]
19//! API.
20
21use std::{
22    borrow::Cow,
23    fmt::{self, Debug},
24};
25
26use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, IxDyn};
27use numcodecs::{
28    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
29    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
30};
31use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
32use serde::{
33    de::{MapAccess, Visitor},
34    ser::SerializeMap,
35    Deserialize, Deserializer, Serialize, Serializer,
36};
37use thiserror::Error;
38
39#[derive(Clone, Serialize, Deserialize, JsonSchema)]
40#[serde(deny_unknown_fields)]
41/// Codec to swizzle/swap the axes of an array and reshape it.
42///
43/// This codec does not store metadata about the original shape of the array.
44/// Since axes that have been combined during encoding cannot be split without
45/// further information, decoding may fail if an output array is not provided.
46///
47/// Swizzling axes is always supported since no additional information about the
48/// array's shape is required to reconstruct it.
49pub struct SwizzleReshapeCodec {
50    /// The permutation of the axes that is applied on encoding.
51    ///
52    /// The permutation is given as a list of axis groups, where each group
53    /// corresponds to one encoded output axis that may consist of several
54    /// decoded input axes. For instance, `[[0], [1, 2]]` flattens a three-
55    /// dimensional array into a two-dimensional one by combining the second and
56    /// third axes.
57    ///
58    /// The permutation also allows specifying a special catch-all remaining
59    /// axes marker:
60    /// - `[[0], {}]` moves the second axis to be the first and appends all
61    ///   other axes afterwards, i.e. the encoded array has the same number
62    ///   of axes as the input array
63    /// - `[[0], [{}]]` in contrast collapses all other axes into one, i.e.
64    ///   the encoded array is two-dimensional
65    pub axes: Vec<AxisGroup>,
66    /// The codec's encoding format version. Do not provide this parameter explicitly.
67    #[serde(default, rename = "_version")]
68    pub version: StaticCodecVersion<1, 0, 0>,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
72#[serde(untagged)]
73#[serde(deny_unknown_fields)]
74/// An axis group, potentially from a merged combination of multiple input axes
75pub enum AxisGroup {
76    /// A merged combination of zero, one, or multiple input axes
77    Group(Vec<Axis>),
78    /// All remaining axes, each in a separate single-axis group
79    AllRest(Rest),
80}
81
82#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
83#[serde(untagged)]
84#[serde(deny_unknown_fields)]
85/// An axis or all remaining axes
86pub enum Axis {
87    /// A single axis, as determined by its index
88    Index(usize),
89    /// All remaining axes, combined into one
90    MergedRest(Rest),
91}
92
93impl Codec for SwizzleReshapeCodec {
94    type Error = SwizzleReshapeCodecError;
95
96    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
97        match data {
98            AnyCowArray::U8(data) => Ok(AnyArray::U8(swizzle_reshape(data, &self.axes)?)),
99            AnyCowArray::U16(data) => Ok(AnyArray::U16(swizzle_reshape(data, &self.axes)?)),
100            AnyCowArray::U32(data) => Ok(AnyArray::U32(swizzle_reshape(data, &self.axes)?)),
101            AnyCowArray::U64(data) => Ok(AnyArray::U64(swizzle_reshape(data, &self.axes)?)),
102            AnyCowArray::I8(data) => Ok(AnyArray::I8(swizzle_reshape(data, &self.axes)?)),
103            AnyCowArray::I16(data) => Ok(AnyArray::I16(swizzle_reshape(data, &self.axes)?)),
104            AnyCowArray::I32(data) => Ok(AnyArray::I32(swizzle_reshape(data, &self.axes)?)),
105            AnyCowArray::I64(data) => Ok(AnyArray::I64(swizzle_reshape(data, &self.axes)?)),
106            AnyCowArray::F32(data) => Ok(AnyArray::F32(swizzle_reshape(data, &self.axes)?)),
107            AnyCowArray::F64(data) => Ok(AnyArray::F64(swizzle_reshape(data, &self.axes)?)),
108            data => Err(SwizzleReshapeCodecError::UnsupportedDtype(data.dtype())),
109        }
110    }
111
112    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
113        match encoded {
114            AnyCowArray::U8(encoded) => {
115                Ok(AnyArray::U8(undo_swizzle_reshape(encoded, &self.axes)?))
116            }
117            AnyCowArray::U16(encoded) => {
118                Ok(AnyArray::U16(undo_swizzle_reshape(encoded, &self.axes)?))
119            }
120            AnyCowArray::U32(encoded) => {
121                Ok(AnyArray::U32(undo_swizzle_reshape(encoded, &self.axes)?))
122            }
123            AnyCowArray::U64(encoded) => {
124                Ok(AnyArray::U64(undo_swizzle_reshape(encoded, &self.axes)?))
125            }
126            AnyCowArray::I8(encoded) => {
127                Ok(AnyArray::I8(undo_swizzle_reshape(encoded, &self.axes)?))
128            }
129            AnyCowArray::I16(encoded) => {
130                Ok(AnyArray::I16(undo_swizzle_reshape(encoded, &self.axes)?))
131            }
132            AnyCowArray::I32(encoded) => {
133                Ok(AnyArray::I32(undo_swizzle_reshape(encoded, &self.axes)?))
134            }
135            AnyCowArray::I64(encoded) => {
136                Ok(AnyArray::I64(undo_swizzle_reshape(encoded, &self.axes)?))
137            }
138            AnyCowArray::F32(encoded) => {
139                Ok(AnyArray::F32(undo_swizzle_reshape(encoded, &self.axes)?))
140            }
141            AnyCowArray::F64(encoded) => {
142                Ok(AnyArray::F64(undo_swizzle_reshape(encoded, &self.axes)?))
143            }
144            encoded => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
145        }
146    }
147
148    fn decode_into(
149        &self,
150        encoded: AnyArrayView,
151        decoded: AnyArrayViewMut,
152    ) -> Result<(), Self::Error> {
153        match (encoded, decoded) {
154            (AnyArrayView::U8(encoded), AnyArrayViewMut::U8(decoded)) => {
155                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
156            }
157            (AnyArrayView::U16(encoded), AnyArrayViewMut::U16(decoded)) => {
158                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
159            }
160            (AnyArrayView::U32(encoded), AnyArrayViewMut::U32(decoded)) => {
161                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
162            }
163            (AnyArrayView::U64(encoded), AnyArrayViewMut::U64(decoded)) => {
164                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
165            }
166            (AnyArrayView::I8(encoded), AnyArrayViewMut::I8(decoded)) => {
167                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
168            }
169            (AnyArrayView::I16(encoded), AnyArrayViewMut::I16(decoded)) => {
170                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
171            }
172            (AnyArrayView::I32(encoded), AnyArrayViewMut::I32(decoded)) => {
173                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
174            }
175            (AnyArrayView::I64(encoded), AnyArrayViewMut::I64(decoded)) => {
176                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
177            }
178            (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
179                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
180            }
181            (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
182                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
183            }
184            (encoded, decoded) if encoded.dtype() != decoded.dtype() => {
185                Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
186                    source: AnyArrayAssignError::DTypeMismatch {
187                        src: encoded.dtype(),
188                        dst: decoded.dtype(),
189                    },
190                })
191            }
192            (encoded, _decoded) => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
193        }
194    }
195}
196
197impl StaticCodec for SwizzleReshapeCodec {
198    const CODEC_ID: &'static str = "swizzle-reshape.rs";
199
200    type Config<'de> = Self;
201
202    fn from_config(config: Self::Config<'_>) -> Self {
203        config
204    }
205
206    fn get_config(&self) -> StaticCodecConfig<Self> {
207        StaticCodecConfig::from(self)
208    }
209}
210
211#[derive(Debug, Error)]
212/// Errors that may occur when applying the [`SwizzleReshapeCodec`].
213pub enum SwizzleReshapeCodecError {
214    /// [`SwizzleReshapeCodec`] does not support the dtype
215    #[error("SwizzleReshape does not support the dtype {0}")]
216    UnsupportedDtype(AnyArrayDType),
217    /// [`SwizzleReshapeCodec`] cannot decode from an array with merged axes
218    /// without receiving an output array to decode into
219    #[error("SwizzleReshape cannot decode from an array with merged axes without receiving an output array to decode into")]
220    CannotDecodeMergedAxes,
221    /// [`SwizzleReshapeCodec`] cannot encode or decode with an invalid axis
222    /// `index` for an array with `ndim` dimensions
223    #[error("SwizzleReshape cannot encode or decode with an invalid axis {index} for an array with {ndim} dimensions")]
224    InvalidAxisIndex {
225        /// The out-of-bounds axis index
226        index: usize,
227        /// The number of dimensions of the array
228        ndim: usize,
229    },
230    /// [`SwizzleReshapeCodec`] can only encode or decode with an axis
231    /// permutation `axes` that contains every axis of an array with `ndim`
232    /// dimensions index exactly once
233    #[error("SwizzleReshape can only encode or decode with an axis permutation {axes:?} that contains every axis of an array with {ndim} dimensions index exactly once")]
234    InvalidAxisPermutation {
235        /// The invalid permutation of axes
236        axes: Vec<AxisGroup>,
237        /// The number of dimensions of the array
238        ndim: usize,
239    },
240    /// [`SwizzleReshapeCodec`] cannot encode or decode with an axis permutation
241    /// that contains multiple rest-axes markers
242    #[error("SwizzleReshape cannot encode or decode with an axis permutation that contains multiple rest-axes markers")]
243    MultipleRestAxes,
244    /// [`SwizzleReshapeCodec`] cannot decode into the provided array
245    #[error("SwizzleReshape cannot decode into the provided array")]
246    MismatchedDecodeIntoArray {
247        /// The source of the error
248        #[from]
249        source: AnyArrayAssignError,
250    },
251}
252
253#[expect(clippy::missing_panics_doc)]
254/// Swizzle and reshape the input `data` array with the new `axes`.
255///
256/// # Errors
257///
258/// Errors with
259/// - [`SwizzleReshapeCodecError::InvalidAxisIndex`] if any axis is out of
260///   bounds
261/// - [`SwizzleReshapeCodecError::InvalidAxisPermutation`] if the `axes`
262///   permutation does not contain every axis index exactly once
263/// - [`SwizzleReshapeCodecError::MultipleRestAxes`] if the `axes` permutation
264///   contains more than one [`Rest`]-axes marker
265pub fn swizzle_reshape<T: Copy, S: Data<Elem = T>>(
266    data: ArrayBase<S, IxDyn>,
267    axes: &[AxisGroup],
268) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
269    let SwizzleReshapeAxes {
270        permutation,
271        swizzled_shape,
272        new_shape,
273    } = validate_into_axes_shape(&data, axes)?;
274
275    let swizzled: ArrayBase<S, ndarray::Dim<ndarray::IxDynImpl>> = data.permuted_axes(permutation);
276    assert_eq!(swizzled.shape(), swizzled_shape, "incorrect swizzled shape");
277
278    #[expect(clippy::expect_used)] // only panics on an implementation bug
279    let reshaped = swizzled
280        .into_owned()
281        .into_shape_clone(new_shape)
282        .expect("new encoding shape should have the correct number of elements");
283
284    Ok(reshaped)
285}
286
287/// Reverts the swizzle and reshape of the `encoded` array with the `axes` and
288/// returns the original array.
289///
290/// Since the shape of the original array is not known, only permutations of
291/// axes are supported.
292///
293/// # Errors
294///
295/// Errors with
296/// - [`SwizzleReshapeCodecError::CannotDecodeMergedAxes`] if any axes were
297///   merged and thus cannot be split without further information
298/// - [`SwizzleReshapeCodecError::InvalidAxisIndex`] if any axis is out of
299///   bounds
300/// - [`SwizzleReshapeCodecError::InvalidAxisPermutation`] if the `axes`
301///   permutation does not contain every axis index exactly once
302/// - [`SwizzleReshapeCodecError::MultipleRestAxes`] if the `axes` permutation
303///   contains more than one [`Rest`]-axes marker
304pub fn undo_swizzle_reshape<T: Copy, S: Data<Elem = T>>(
305    encoded: ArrayBase<S, IxDyn>,
306    axes: &[AxisGroup],
307) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
308    if !axes.iter().all(|axis| match axis {
309        AxisGroup::Group(axes) => matches!(axes.as_slice(), [Axis::Index(_)]),
310        AxisGroup::AllRest(Rest) => true,
311    }) {
312        return Err(SwizzleReshapeCodecError::CannotDecodeMergedAxes);
313    }
314
315    let SwizzleReshapeAxes { permutation, .. } = validate_into_axes_shape(&encoded, axes)?;
316
317    let mut inverse_permutation = vec![0; permutation.len()];
318    #[expect(clippy::indexing_slicing)] // all are guaranteed to be in range
319    for (i, p) in permutation.into_iter().enumerate() {
320        inverse_permutation[p] = i;
321    }
322
323    // since no axes were merged, no reshape is needed
324    let unshaped = encoded;
325    let unswizzled = unshaped.permuted_axes(inverse_permutation);
326
327    Ok(unswizzled.into_owned())
328}
329
330#[expect(clippy::missing_panics_doc)]
331#[expect(clippy::needless_pass_by_value)]
332/// Reverts the swizzle and reshape of the `encoded` array with the `axes` and
333/// outputs it into the `decoded` array.
334///
335/// # Errors
336///
337/// Errors with
338/// - [`SwizzleReshapeCodecError::InvalidAxisIndex`] if any axis is out of
339///   bounds
340/// - [`SwizzleReshapeCodecError::InvalidAxisPermutation`] if the `axes`
341///   permutation does not contain every axis index exactly once
342/// - [`SwizzleReshapeCodecError::MultipleRestAxes`] if the `axes` permutation
343///   contains more than one [`Rest`]-axes marker
344/// - [`SwizzleReshapeCodecError::MismatchedDecodeIntoArray`] if the `encoded`
345///   array's shape does not match the shape that swizzling and reshaping an
346///   array of the `decoded` array's shape would have produced
347pub fn undo_swizzle_reshape_into<T: Copy>(
348    encoded: ArrayView<T, IxDyn>,
349    mut decoded: ArrayViewMut<T, IxDyn>,
350    axes: &[AxisGroup],
351) -> Result<(), SwizzleReshapeCodecError> {
352    let SwizzleReshapeAxes {
353        permutation,
354        swizzled_shape,
355        new_shape,
356    } = validate_into_axes_shape(&decoded, axes)?;
357
358    if encoded.shape() != new_shape {
359        return Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
360            source: AnyArrayAssignError::ShapeMismatch {
361                src: encoded.shape().to_vec(),
362                dst: new_shape,
363            },
364        });
365    }
366
367    let mut inverse_permutation = vec![0; decoded.ndim()];
368    #[expect(clippy::indexing_slicing)] // all are guaranteed to be in range
369    for (i, p) in permutation.into_iter().enumerate() {
370        inverse_permutation[p] = i;
371    }
372
373    #[expect(clippy::expect_used)] // only panics on an implementation bug
374    let unshaped = encoded
375        .to_shape(swizzled_shape)
376        .expect("new decoding shape should have the correct number of elements");
377    let unswizzled = unshaped.permuted_axes(inverse_permutation);
378
379    decoded.assign(&unswizzled);
380
381    Ok(())
382}
383
384struct SwizzleReshapeAxes {
385    permutation: Vec<usize>,
386    swizzled_shape: Vec<usize>,
387    new_shape: Vec<usize>,
388}
389
390fn validate_into_axes_shape<T, S: Data<Elem = T>>(
391    array: &ArrayBase<S, IxDyn>,
392    axes: &[AxisGroup],
393) -> Result<SwizzleReshapeAxes, SwizzleReshapeCodecError> {
394    // counts of each axis index, used to check for missing or duplicate axes,
395    //  and for knowing which axes are caught by the rest catch-all
396    let mut axis_index_counts = vec![0_usize; array.ndim()];
397
398    let mut has_rest = false;
399
400    // validate that all axis indices are in bounds and that there is at most
401    //  one catch-all remaining axes marker
402    for group in axes {
403        match group {
404            AxisGroup::Group(axes) => {
405                for axis in axes {
406                    match axis {
407                        Axis::Index(index) => {
408                            if let Some(c) = axis_index_counts.get_mut(*index) {
409                                *c += 1;
410                            } else {
411                                return Err(SwizzleReshapeCodecError::InvalidAxisIndex {
412                                    index: *index,
413                                    ndim: array.ndim(),
414                                });
415                            }
416                        }
417                        Axis::MergedRest(Rest) => {
418                            if std::mem::replace(&mut has_rest, true) {
419                                return Err(SwizzleReshapeCodecError::MultipleRestAxes);
420                            }
421                        }
422                    }
423                }
424            }
425            AxisGroup::AllRest(Rest) => {
426                if std::mem::replace(&mut has_rest, true) {
427                    return Err(SwizzleReshapeCodecError::MultipleRestAxes);
428                }
429            }
430        }
431    }
432
433    // check that each axis is mentioned
434    // - exactly once if no catch-all is used
435    // - at most once if a catch-all is used
436    if !axis_index_counts
437        .iter()
438        .all(|c| if has_rest { *c <= 1 } else { *c == 1 })
439    {
440        return Err(SwizzleReshapeCodecError::InvalidAxisPermutation {
441            axes: axes.to_vec(),
442            ndim: array.ndim(),
443        });
444    }
445
446    // the permutation to apply to the input axes
447    let mut axis_permutation = Vec::with_capacity(array.len());
448    // the shape of the already permuted intermediary array
449    let mut permuted_shape = Vec::with_capacity(array.len());
450    // the shape of the already permuted and grouped output array
451    let mut grouped_shape = Vec::with_capacity(axes.len());
452
453    for axis in axes {
454        match axis {
455            // a group merged all of its axes
456            // an empty group adds an additional axis of size 1
457            AxisGroup::Group(axes) => {
458                let mut new_len = 1;
459                for axis in axes {
460                    match axis {
461                        Axis::Index(index) => {
462                            axis_permutation.push(*index);
463                            permuted_shape.push(array.len_of(ndarray::Axis(*index)));
464                            new_len *= array.len_of(ndarray::Axis(*index));
465                        }
466                        Axis::MergedRest(Rest) => {
467                            for (index, count) in axis_index_counts.iter().enumerate() {
468                                if *count == 0 {
469                                    axis_permutation.push(index);
470                                    permuted_shape.push(array.len_of(ndarray::Axis(index)));
471                                    new_len *= array.len_of(ndarray::Axis(index));
472                                }
473                            }
474                        }
475                    }
476                }
477                grouped_shape.push(new_len);
478            }
479            AxisGroup::AllRest(Rest) => {
480                for (index, count) in axis_index_counts.iter().enumerate() {
481                    if *count == 0 {
482                        axis_permutation.push(index);
483                        permuted_shape.push(array.len_of(ndarray::Axis(index)));
484                        grouped_shape.push(array.len_of(ndarray::Axis(index)));
485                    }
486                }
487            }
488        }
489    }
490
491    Ok(SwizzleReshapeAxes {
492        permutation: axis_permutation,
493        swizzled_shape: permuted_shape,
494        new_shape: grouped_shape,
495    })
496}
497
498#[derive(Copy, Clone, Debug)]
499/// Marker to signify all remaining (not explicitly named) axes
500pub struct Rest;
501
502impl Serialize for Rest {
503    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
504        serializer.serialize_map(Some(0))?.end()
505    }
506}
507
508impl<'de> Deserialize<'de> for Rest {
509    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
510        struct RestVisitor;
511
512        impl<'de> Visitor<'de> for RestVisitor {
513            type Value = Rest;
514
515            fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
516                fmt.write_str("an empty map")
517            }
518
519            fn visit_map<A: MapAccess<'de>>(self, _map: A) -> Result<Self::Value, A::Error> {
520                Ok(Rest)
521            }
522        }
523
524        deserializer.deserialize_map(RestVisitor)
525    }
526}
527
528impl JsonSchema for Rest {
529    fn schema_name() -> Cow<'static, str> {
530        Cow::Borrowed("Rest")
531    }
532
533    fn schema_id() -> Cow<'static, str> {
534        Cow::Borrowed(concat!(module_path!(), "::", "Rest"))
535    }
536
537    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
538        json_schema!({
539            "type": "object",
540            "properties": {},
541            "additionalProperties": false,
542        })
543    }
544}
545
546#[cfg(test)]
547#[expect(clippy::expect_used)]
548mod tests {
549    use ndarray::array;
550
551    use super::*;
552
553    #[test]
554    fn identity() {
555        roundtrip(
556            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
557            &[AxisGroup::AllRest(Rest)],
558            &[2, 2, 2],
559        );
560
561        roundtrip(
562            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
563            &[
564                AxisGroup::Group(vec![Axis::Index(0)]),
565                AxisGroup::Group(vec![Axis::Index(1)]),
566                AxisGroup::AllRest(Rest),
567            ],
568            &[2, 2, 2],
569        );
570        roundtrip(
571            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
572            &[
573                AxisGroup::Group(vec![Axis::Index(0)]),
574                AxisGroup::AllRest(Rest),
575                AxisGroup::Group(vec![Axis::Index(2)]),
576            ],
577            &[2, 2, 2],
578        );
579        roundtrip(
580            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
581            &[
582                AxisGroup::AllRest(Rest),
583                AxisGroup::Group(vec![Axis::Index(1)]),
584                AxisGroup::Group(vec![Axis::Index(2)]),
585            ],
586            &[2, 2, 2],
587        );
588
589        roundtrip(
590            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
591            &[
592                AxisGroup::Group(vec![Axis::Index(0)]),
593                AxisGroup::Group(vec![Axis::Index(1)]),
594                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
595            ],
596            &[2, 2, 2],
597        );
598        roundtrip(
599            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
600            &[
601                AxisGroup::Group(vec![Axis::Index(0)]),
602                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
603                AxisGroup::Group(vec![Axis::Index(2)]),
604            ],
605            &[2, 2, 2],
606        );
607        roundtrip(
608            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
609            &[
610                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
611                AxisGroup::Group(vec![Axis::Index(1)]),
612                AxisGroup::Group(vec![Axis::Index(2)]),
613            ],
614            &[2, 2, 2],
615        );
616    }
617
618    #[test]
619    fn swizzle() {
620        roundtrip(
621            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
622            &[
623                AxisGroup::Group(vec![Axis::Index(0)]),
624                AxisGroup::Group(vec![Axis::Index(1)]),
625                AxisGroup::AllRest(Rest),
626            ],
627            &[2, 2, 2],
628        );
629        roundtrip(
630            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
631            &[
632                AxisGroup::Group(vec![Axis::Index(2)]),
633                AxisGroup::AllRest(Rest),
634                AxisGroup::Group(vec![Axis::Index(1)]),
635            ],
636            &[2, 2, 2],
637        );
638        roundtrip(
639            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
640            &[
641                AxisGroup::AllRest(Rest),
642                AxisGroup::Group(vec![Axis::Index(0)]),
643                AxisGroup::Group(vec![Axis::Index(1)]),
644            ],
645            &[2, 2, 2],
646        );
647
648        roundtrip(
649            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
650            &[
651                AxisGroup::Group(vec![Axis::Index(0)]),
652                AxisGroup::Group(vec![Axis::Index(1)]),
653                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
654            ],
655            &[2, 2, 2],
656        );
657        roundtrip(
658            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
659            &[
660                AxisGroup::Group(vec![Axis::Index(2)]),
661                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
662                AxisGroup::Group(vec![Axis::Index(1)]),
663            ],
664            &[2, 2, 2],
665        );
666        roundtrip(
667            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
668            &[
669                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
670                AxisGroup::Group(vec![Axis::Index(0)]),
671                AxisGroup::Group(vec![Axis::Index(1)]),
672            ],
673            &[2, 2, 2],
674        );
675
676        let mut i = 0;
677        roundtrip(
678            Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
679                i += 1;
680                i
681            })
682            .into_dyn(),
683            &[
684                AxisGroup::Group(vec![Axis::Index(4)]),
685                AxisGroup::Group(vec![Axis::Index(0)]),
686                AxisGroup::Group(vec![Axis::Index(3)]),
687                AxisGroup::Group(vec![Axis::Index(2)]),
688                AxisGroup::Group(vec![Axis::Index(1)]),
689            ],
690            &[1, 3, 1, 721, 1440],
691        );
692    }
693
694    #[test]
695    fn collapse() {
696        roundtrip(
697            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
698            &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
699            &[8],
700        );
701
702        roundtrip(
703            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
704            &[AxisGroup::Group(vec![
705                Axis::Index(0),
706                Axis::Index(1),
707                Axis::Index(2),
708            ])],
709            &[8],
710        );
711        roundtrip(
712            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
713            &[AxisGroup::Group(vec![
714                Axis::Index(2),
715                Axis::Index(1),
716                Axis::Index(0),
717            ])],
718            &[8],
719        );
720        roundtrip(
721            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
722            &[AxisGroup::Group(vec![
723                Axis::Index(1),
724                Axis::MergedRest(Rest),
725            ])],
726            &[8],
727        );
728
729        roundtrip(
730            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
731            &[
732                AxisGroup::Group(vec![Axis::Index(0), Axis::Index(1)]),
733                AxisGroup::AllRest(Rest),
734            ],
735            &[4, 2],
736        );
737        roundtrip(
738            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
739            &[
740                AxisGroup::Group(vec![Axis::Index(2)]),
741                AxisGroup::Group(vec![Axis::Index(1), Axis::Index(0)]),
742            ],
743            &[2, 4],
744        );
745        roundtrip(
746            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
747            &[
748                AxisGroup::Group(vec![Axis::Index(1), Axis::MergedRest(Rest)]),
749                AxisGroup::Group(vec![Axis::Index(0), Axis::Index(2)]),
750            ],
751            &[2, 4],
752        );
753
754        roundtrip(
755            array![[1, 2], [3, 4], [5, 6], [7, 8]].into_dyn(),
756            &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
757            &[8],
758        );
759
760        let mut i = 0;
761        roundtrip(
762            Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
763                i += 1;
764                i
765            })
766            .into_dyn(),
767            &[AxisGroup::Group(vec![
768                Axis::Index(4),
769                Axis::Index(0),
770                Axis::Index(3),
771                Axis::Index(2),
772                Axis::Index(1),
773            ])],
774            &[1 * 3 * 1 * 721 * 1440],
775        );
776    }
777
778    #[test]
779    fn extend() {
780        roundtrip(
781            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
782            &[
783                AxisGroup::Group(vec![]),
784                AxisGroup::Group(vec![Axis::Index(0)]),
785                AxisGroup::Group(vec![]),
786                AxisGroup::AllRest(Rest),
787                AxisGroup::Group(vec![]),
788                AxisGroup::Group(vec![Axis::Index(2)]),
789                AxisGroup::Group(vec![]),
790            ],
791            &[1, 2, 1, 2, 1, 2, 1],
792        );
793    }
794
795    #[expect(clippy::needless_pass_by_value)]
796    fn roundtrip(data: Array<i32, IxDyn>, axes: &[AxisGroup], swizzle_shape: &[usize]) {
797        let swizzled = swizzle_reshape(data.view(), axes).expect("swizzle should not fail");
798
799        assert_eq!(swizzled.shape(), swizzle_shape);
800
801        let mut unswizzled = Array::zeros(data.shape());
802        undo_swizzle_reshape_into(swizzled.view(), unswizzled.view_mut(), axes)
803            .expect("unswizzle into should not fail");
804
805        assert_eq!(data, unswizzled);
806
807        if axes.iter().any(|a| matches!(a, AxisGroup::Group(a) if a.len() != 1 || a.iter().any(|a| matches!(a, Axis::MergedRest(Rest))))) {
808            undo_swizzle_reshape(swizzled.view(), axes).expect_err("unswizzle should fail");
809        } else {
810            let unswizzled = undo_swizzle_reshape(swizzled.view(), axes).expect("unswizzle should not fail");
811            assert_eq!(data, unswizzled);
812        }
813    }
814}