numcodecs_swizzle_reshape/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.85.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-swizzle-reshape
10//! [crates.io]: https://crates.io/crates/numcodecs-swizzle-reshape
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-swizzle-reshape
13//! [docs.rs]: https://docs.rs/numcodecs-swizzle-reshape/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_swizzle_reshape
17//!
18//! Array axis swizzle and reshape codec implementation for the [`numcodecs`]
19//! API.
20
21use std::{
22    borrow::Cow,
23    fmt::{self, Debug},
24};
25
26use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, IxDyn};
27use numcodecs::{
28    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
29    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
30};
31use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
32use serde::{
33    Deserialize, Deserializer, Serialize, Serializer,
34    de::{MapAccess, Visitor},
35    ser::SerializeMap,
36};
37use thiserror::Error;
38
39#[derive(Clone, Serialize, Deserialize, JsonSchema)]
40#[serde(deny_unknown_fields)]
41/// Codec to swizzle/swap the axes of an array and reshape it.
42///
43/// This codec does not store metadata about the original shape of the array.
44/// Since axes that have been combined during encoding cannot be split without
45/// further information, decoding may fail if an output array is not provided.
46///
47/// Swizzling axes is always supported since no additional information about the
48/// array's shape is required to reconstruct it.
49pub struct SwizzleReshapeCodec {
50    /// The permutation of the axes that is applied on encoding.
51    ///
52    /// The permutation is given as a list of axis groups, where each group
53    /// corresponds to one encoded output axis that may consist of several
54    /// decoded input axes. For instance, `[[0], [1, 2]]` flattens a three-
55    /// dimensional array into a two-dimensional one by combining the second and
56    /// third axes.
57    ///
58    /// The permutation also allows specifying a special catch-all remaining
59    /// axes marker:
60    /// - `[[0], {}]` moves the second axis to be the first and appends all
61    ///   other axes afterwards, i.e. the encoded array has the same number
62    ///   of axes as the input array
63    /// - `[[0], [{}]]` in contrast collapses all other axes into one, i.e.
64    ///   the encoded array is two-dimensional
65    pub axes: Vec<AxisGroup>,
66    /// The codec's encoding format version. Do not provide this parameter explicitly.
67    #[serde(default, rename = "_version")]
68    pub version: StaticCodecVersion<1, 0, 0>,
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
72#[serde(untagged)]
73#[serde(deny_unknown_fields)]
74/// An axis group, potentially from a merged combination of multiple input axes
75pub enum AxisGroup {
76    /// A merged combination of zero, one, or multiple input axes
77    Group(Vec<Axis>),
78    /// All remaining axes, each in a separate single-axis group
79    AllRest(Rest),
80}
81
82#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
83#[serde(untagged)]
84#[serde(deny_unknown_fields)]
85/// An axis or all remaining axes
86pub enum Axis {
87    /// A single axis, as determined by its index
88    Index(usize),
89    /// All remaining axes, combined into one
90    MergedRest(Rest),
91}
92
93impl Codec for SwizzleReshapeCodec {
94    type Error = SwizzleReshapeCodecError;
95
96    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
97        match data {
98            AnyCowArray::U8(data) => Ok(AnyArray::U8(swizzle_reshape(data, &self.axes)?)),
99            AnyCowArray::U16(data) => Ok(AnyArray::U16(swizzle_reshape(data, &self.axes)?)),
100            AnyCowArray::U32(data) => Ok(AnyArray::U32(swizzle_reshape(data, &self.axes)?)),
101            AnyCowArray::U64(data) => Ok(AnyArray::U64(swizzle_reshape(data, &self.axes)?)),
102            AnyCowArray::I8(data) => Ok(AnyArray::I8(swizzle_reshape(data, &self.axes)?)),
103            AnyCowArray::I16(data) => Ok(AnyArray::I16(swizzle_reshape(data, &self.axes)?)),
104            AnyCowArray::I32(data) => Ok(AnyArray::I32(swizzle_reshape(data, &self.axes)?)),
105            AnyCowArray::I64(data) => Ok(AnyArray::I64(swizzle_reshape(data, &self.axes)?)),
106            AnyCowArray::F32(data) => Ok(AnyArray::F32(swizzle_reshape(data, &self.axes)?)),
107            AnyCowArray::F64(data) => Ok(AnyArray::F64(swizzle_reshape(data, &self.axes)?)),
108            data => Err(SwizzleReshapeCodecError::UnsupportedDtype(data.dtype())),
109        }
110    }
111
112    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
113        match encoded {
114            AnyCowArray::U8(encoded) => {
115                Ok(AnyArray::U8(undo_swizzle_reshape(encoded, &self.axes)?))
116            }
117            AnyCowArray::U16(encoded) => {
118                Ok(AnyArray::U16(undo_swizzle_reshape(encoded, &self.axes)?))
119            }
120            AnyCowArray::U32(encoded) => {
121                Ok(AnyArray::U32(undo_swizzle_reshape(encoded, &self.axes)?))
122            }
123            AnyCowArray::U64(encoded) => {
124                Ok(AnyArray::U64(undo_swizzle_reshape(encoded, &self.axes)?))
125            }
126            AnyCowArray::I8(encoded) => {
127                Ok(AnyArray::I8(undo_swizzle_reshape(encoded, &self.axes)?))
128            }
129            AnyCowArray::I16(encoded) => {
130                Ok(AnyArray::I16(undo_swizzle_reshape(encoded, &self.axes)?))
131            }
132            AnyCowArray::I32(encoded) => {
133                Ok(AnyArray::I32(undo_swizzle_reshape(encoded, &self.axes)?))
134            }
135            AnyCowArray::I64(encoded) => {
136                Ok(AnyArray::I64(undo_swizzle_reshape(encoded, &self.axes)?))
137            }
138            AnyCowArray::F32(encoded) => {
139                Ok(AnyArray::F32(undo_swizzle_reshape(encoded, &self.axes)?))
140            }
141            AnyCowArray::F64(encoded) => {
142                Ok(AnyArray::F64(undo_swizzle_reshape(encoded, &self.axes)?))
143            }
144            encoded => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
145        }
146    }
147
148    fn decode_into(
149        &self,
150        encoded: AnyArrayView,
151        decoded: AnyArrayViewMut,
152    ) -> Result<(), Self::Error> {
153        match (encoded, decoded) {
154            (AnyArrayView::U8(encoded), AnyArrayViewMut::U8(decoded)) => {
155                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
156            }
157            (AnyArrayView::U16(encoded), AnyArrayViewMut::U16(decoded)) => {
158                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
159            }
160            (AnyArrayView::U32(encoded), AnyArrayViewMut::U32(decoded)) => {
161                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
162            }
163            (AnyArrayView::U64(encoded), AnyArrayViewMut::U64(decoded)) => {
164                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
165            }
166            (AnyArrayView::I8(encoded), AnyArrayViewMut::I8(decoded)) => {
167                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
168            }
169            (AnyArrayView::I16(encoded), AnyArrayViewMut::I16(decoded)) => {
170                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
171            }
172            (AnyArrayView::I32(encoded), AnyArrayViewMut::I32(decoded)) => {
173                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
174            }
175            (AnyArrayView::I64(encoded), AnyArrayViewMut::I64(decoded)) => {
176                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
177            }
178            (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
179                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
180            }
181            (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
182                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
183            }
184            (encoded, decoded) if encoded.dtype() != decoded.dtype() => {
185                Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
186                    source: AnyArrayAssignError::DTypeMismatch {
187                        src: encoded.dtype(),
188                        dst: decoded.dtype(),
189                    },
190                })
191            }
192            (encoded, _decoded) => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
193        }
194    }
195}
196
197impl StaticCodec for SwizzleReshapeCodec {
198    const CODEC_ID: &'static str = "swizzle-reshape.rs";
199
200    type Config<'de> = Self;
201
202    fn from_config(config: Self::Config<'_>) -> Self {
203        config
204    }
205
206    fn get_config(&self) -> StaticCodecConfig<Self> {
207        StaticCodecConfig::from(self)
208    }
209}
210
211#[derive(Debug, Error)]
212/// Errors that may occur when applying the [`SwizzleReshapeCodec`].
213pub enum SwizzleReshapeCodecError {
214    /// [`SwizzleReshapeCodec`] does not support the dtype
215    #[error("SwizzleReshape does not support the dtype {0}")]
216    UnsupportedDtype(AnyArrayDType),
217    /// [`SwizzleReshapeCodec`] cannot decode from an array with merged axes
218    /// without receiving an output array to decode into
219    #[error(
220        "SwizzleReshape cannot decode from an array with merged axes without receiving an output array to decode into"
221    )]
222    CannotDecodeMergedAxes,
223    /// [`SwizzleReshapeCodec`] cannot encode or decode with an invalid axis
224    /// `index` for an array with `ndim` dimensions
225    #[error(
226        "SwizzleReshape cannot encode or decode with an invalid axis {index} for an array with {ndim} dimensions"
227    )]
228    InvalidAxisIndex {
229        /// The out-of-bounds axis index
230        index: usize,
231        /// The number of dimensions of the array
232        ndim: usize,
233    },
234    /// [`SwizzleReshapeCodec`] can only encode or decode with an axis
235    /// permutation `axes` that contains every axis of an array with `ndim`
236    /// dimensions index exactly once
237    #[error(
238        "SwizzleReshape can only encode or decode with an axis permutation {axes:?} that contains every axis of an array with {ndim} dimensions index exactly once"
239    )]
240    InvalidAxisPermutation {
241        /// The invalid permutation of axes
242        axes: Vec<AxisGroup>,
243        /// The number of dimensions of the array
244        ndim: usize,
245    },
246    /// [`SwizzleReshapeCodec`] cannot encode or decode with an axis permutation
247    /// that contains multiple rest-axes markers
248    #[error(
249        "SwizzleReshape cannot encode or decode with an axis permutation that contains multiple rest-axes markers"
250    )]
251    MultipleRestAxes,
252    /// [`SwizzleReshapeCodec`] cannot decode into the provided array
253    #[error("SwizzleReshape cannot decode into the provided array")]
254    MismatchedDecodeIntoArray {
255        /// The source of the error
256        #[from]
257        source: AnyArrayAssignError,
258    },
259}
260
261#[expect(clippy::missing_panics_doc)]
262/// Swizzle and reshape the input `data` array with the new `axes`.
263///
264/// # Errors
265///
266/// Errors with
267/// - [`SwizzleReshapeCodecError::InvalidAxisIndex`] if any axis is out of
268///   bounds
269/// - [`SwizzleReshapeCodecError::InvalidAxisPermutation`] if the `axes`
270///   permutation does not contain every axis index exactly once
271/// - [`SwizzleReshapeCodecError::MultipleRestAxes`] if the `axes` permutation
272///   contains more than one [`Rest`]-axes marker
273pub fn swizzle_reshape<T: Copy, S: Data<Elem = T>>(
274    data: ArrayBase<S, IxDyn>,
275    axes: &[AxisGroup],
276) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
277    let SwizzleReshapeAxes {
278        permutation,
279        swizzled_shape,
280        new_shape,
281    } = validate_into_axes_shape(&data, axes)?;
282
283    let swizzled: ArrayBase<S, ndarray::Dim<ndarray::IxDynImpl>> = data.permuted_axes(permutation);
284    assert_eq!(swizzled.shape(), swizzled_shape, "incorrect swizzled shape");
285
286    #[expect(clippy::expect_used)] // only panics on an implementation bug
287    let reshaped = swizzled
288        .into_owned()
289        .into_shape_clone(new_shape)
290        .expect("new encoding shape should have the correct number of elements");
291
292    Ok(reshaped)
293}
294
295/// Reverts the swizzle and reshape of the `encoded` array with the `axes` and
296/// returns the original array.
297///
298/// Since the shape of the original array is not known, only permutations of
299/// axes are supported.
300///
301/// # Errors
302///
303/// Errors with
304/// - [`SwizzleReshapeCodecError::CannotDecodeMergedAxes`] if any axes were
305///   merged and thus cannot be split without further information
306/// - [`SwizzleReshapeCodecError::InvalidAxisIndex`] if any axis is out of
307///   bounds
308/// - [`SwizzleReshapeCodecError::InvalidAxisPermutation`] if the `axes`
309///   permutation does not contain every axis index exactly once
310/// - [`SwizzleReshapeCodecError::MultipleRestAxes`] if the `axes` permutation
311///   contains more than one [`Rest`]-axes marker
312pub fn undo_swizzle_reshape<T: Copy, S: Data<Elem = T>>(
313    encoded: ArrayBase<S, IxDyn>,
314    axes: &[AxisGroup],
315) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
316    if !axes.iter().all(|axis| match axis {
317        AxisGroup::Group(axes) => matches!(axes.as_slice(), [Axis::Index(_)]),
318        AxisGroup::AllRest(Rest) => true,
319    }) {
320        return Err(SwizzleReshapeCodecError::CannotDecodeMergedAxes);
321    }
322
323    let SwizzleReshapeAxes { permutation, .. } = validate_into_axes_shape(&encoded, axes)?;
324
325    let mut inverse_permutation = vec![0; permutation.len()];
326    #[expect(clippy::indexing_slicing)] // all are guaranteed to be in range
327    for (i, p) in permutation.into_iter().enumerate() {
328        inverse_permutation[p] = i;
329    }
330
331    // since no axes were merged, no reshape is needed
332    let unshaped = encoded;
333    let unswizzled = unshaped.permuted_axes(inverse_permutation);
334
335    Ok(unswizzled.into_owned())
336}
337
338#[expect(clippy::missing_panics_doc)]
339#[expect(clippy::needless_pass_by_value)]
340/// Reverts the swizzle and reshape of the `encoded` array with the `axes` and
341/// outputs it into the `decoded` array.
342///
343/// # Errors
344///
345/// Errors with
346/// - [`SwizzleReshapeCodecError::InvalidAxisIndex`] if any axis is out of
347///   bounds
348/// - [`SwizzleReshapeCodecError::InvalidAxisPermutation`] if the `axes`
349///   permutation does not contain every axis index exactly once
350/// - [`SwizzleReshapeCodecError::MultipleRestAxes`] if the `axes` permutation
351///   contains more than one [`Rest`]-axes marker
352/// - [`SwizzleReshapeCodecError::MismatchedDecodeIntoArray`] if the `encoded`
353///   array's shape does not match the shape that swizzling and reshaping an
354///   array of the `decoded` array's shape would have produced
355pub fn undo_swizzle_reshape_into<T: Copy>(
356    encoded: ArrayView<T, IxDyn>,
357    mut decoded: ArrayViewMut<T, IxDyn>,
358    axes: &[AxisGroup],
359) -> Result<(), SwizzleReshapeCodecError> {
360    let SwizzleReshapeAxes {
361        permutation,
362        swizzled_shape,
363        new_shape,
364    } = validate_into_axes_shape(&decoded, axes)?;
365
366    if encoded.shape() != new_shape {
367        return Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
368            source: AnyArrayAssignError::ShapeMismatch {
369                src: encoded.shape().to_vec(),
370                dst: new_shape,
371            },
372        });
373    }
374
375    let mut inverse_permutation = vec![0; decoded.ndim()];
376    #[expect(clippy::indexing_slicing)] // all are guaranteed to be in range
377    for (i, p) in permutation.into_iter().enumerate() {
378        inverse_permutation[p] = i;
379    }
380
381    #[expect(clippy::expect_used)] // only panics on an implementation bug
382    let unshaped = encoded
383        .to_shape(swizzled_shape)
384        .expect("new decoding shape should have the correct number of elements");
385    let unswizzled = unshaped.permuted_axes(inverse_permutation);
386
387    decoded.assign(&unswizzled);
388
389    Ok(())
390}
391
392struct SwizzleReshapeAxes {
393    permutation: Vec<usize>,
394    swizzled_shape: Vec<usize>,
395    new_shape: Vec<usize>,
396}
397
398fn validate_into_axes_shape<T, S: Data<Elem = T>>(
399    array: &ArrayBase<S, IxDyn>,
400    axes: &[AxisGroup],
401) -> Result<SwizzleReshapeAxes, SwizzleReshapeCodecError> {
402    // counts of each axis index, used to check for missing or duplicate axes,
403    //  and for knowing which axes are caught by the rest catch-all
404    let mut axis_index_counts = vec![0_usize; array.ndim()];
405
406    let mut has_rest = false;
407
408    // validate that all axis indices are in bounds and that there is at most
409    //  one catch-all remaining axes marker
410    for group in axes {
411        match group {
412            AxisGroup::Group(axes) => {
413                for axis in axes {
414                    match axis {
415                        Axis::Index(index) => {
416                            if let Some(c) = axis_index_counts.get_mut(*index) {
417                                *c += 1;
418                            } else {
419                                return Err(SwizzleReshapeCodecError::InvalidAxisIndex {
420                                    index: *index,
421                                    ndim: array.ndim(),
422                                });
423                            }
424                        }
425                        Axis::MergedRest(Rest) => {
426                            if std::mem::replace(&mut has_rest, true) {
427                                return Err(SwizzleReshapeCodecError::MultipleRestAxes);
428                            }
429                        }
430                    }
431                }
432            }
433            AxisGroup::AllRest(Rest) => {
434                if std::mem::replace(&mut has_rest, true) {
435                    return Err(SwizzleReshapeCodecError::MultipleRestAxes);
436                }
437            }
438        }
439    }
440
441    // check that each axis is mentioned
442    // - exactly once if no catch-all is used
443    // - at most once if a catch-all is used
444    if !axis_index_counts
445        .iter()
446        .all(|c| if has_rest { *c <= 1 } else { *c == 1 })
447    {
448        return Err(SwizzleReshapeCodecError::InvalidAxisPermutation {
449            axes: axes.to_vec(),
450            ndim: array.ndim(),
451        });
452    }
453
454    // the permutation to apply to the input axes
455    let mut axis_permutation = Vec::with_capacity(array.len());
456    // the shape of the already permuted intermediary array
457    let mut permuted_shape = Vec::with_capacity(array.len());
458    // the shape of the already permuted and grouped output array
459    let mut grouped_shape = Vec::with_capacity(axes.len());
460
461    for axis in axes {
462        match axis {
463            // a group merged all of its axes
464            // an empty group adds an additional axis of size 1
465            AxisGroup::Group(axes) => {
466                let mut new_len = 1;
467                for axis in axes {
468                    match axis {
469                        Axis::Index(index) => {
470                            axis_permutation.push(*index);
471                            permuted_shape.push(array.len_of(ndarray::Axis(*index)));
472                            new_len *= array.len_of(ndarray::Axis(*index));
473                        }
474                        Axis::MergedRest(Rest) => {
475                            for (index, count) in axis_index_counts.iter().enumerate() {
476                                if *count == 0 {
477                                    axis_permutation.push(index);
478                                    permuted_shape.push(array.len_of(ndarray::Axis(index)));
479                                    new_len *= array.len_of(ndarray::Axis(index));
480                                }
481                            }
482                        }
483                    }
484                }
485                grouped_shape.push(new_len);
486            }
487            AxisGroup::AllRest(Rest) => {
488                for (index, count) in axis_index_counts.iter().enumerate() {
489                    if *count == 0 {
490                        axis_permutation.push(index);
491                        permuted_shape.push(array.len_of(ndarray::Axis(index)));
492                        grouped_shape.push(array.len_of(ndarray::Axis(index)));
493                    }
494                }
495            }
496        }
497    }
498
499    Ok(SwizzleReshapeAxes {
500        permutation: axis_permutation,
501        swizzled_shape: permuted_shape,
502        new_shape: grouped_shape,
503    })
504}
505
506#[derive(Copy, Clone, Debug)]
507/// Marker to signify all remaining (not explicitly named) axes
508pub struct Rest;
509
510impl Serialize for Rest {
511    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
512        serializer.serialize_map(Some(0))?.end()
513    }
514}
515
516impl<'de> Deserialize<'de> for Rest {
517    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
518        struct RestVisitor;
519
520        impl<'de> Visitor<'de> for RestVisitor {
521            type Value = Rest;
522
523            fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
524                fmt.write_str("an empty map")
525            }
526
527            fn visit_map<A: MapAccess<'de>>(self, _map: A) -> Result<Self::Value, A::Error> {
528                Ok(Rest)
529            }
530        }
531
532        deserializer.deserialize_map(RestVisitor)
533    }
534}
535
536impl JsonSchema for Rest {
537    fn schema_name() -> Cow<'static, str> {
538        Cow::Borrowed("Rest")
539    }
540
541    fn schema_id() -> Cow<'static, str> {
542        Cow::Borrowed(concat!(module_path!(), "::", "Rest"))
543    }
544
545    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
546        json_schema!({
547            "type": "object",
548            "properties": {},
549            "additionalProperties": false,
550        })
551    }
552}
553
554#[cfg(test)]
555#[expect(clippy::expect_used)]
556mod tests {
557    use ndarray::array;
558
559    use super::*;
560
561    #[test]
562    fn identity() {
563        roundtrip(
564            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
565            &[AxisGroup::AllRest(Rest)],
566            &[2, 2, 2],
567        );
568
569        roundtrip(
570            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
571            &[
572                AxisGroup::Group(vec![Axis::Index(0)]),
573                AxisGroup::Group(vec![Axis::Index(1)]),
574                AxisGroup::AllRest(Rest),
575            ],
576            &[2, 2, 2],
577        );
578        roundtrip(
579            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
580            &[
581                AxisGroup::Group(vec![Axis::Index(0)]),
582                AxisGroup::AllRest(Rest),
583                AxisGroup::Group(vec![Axis::Index(2)]),
584            ],
585            &[2, 2, 2],
586        );
587        roundtrip(
588            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
589            &[
590                AxisGroup::AllRest(Rest),
591                AxisGroup::Group(vec![Axis::Index(1)]),
592                AxisGroup::Group(vec![Axis::Index(2)]),
593            ],
594            &[2, 2, 2],
595        );
596
597        roundtrip(
598            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
599            &[
600                AxisGroup::Group(vec![Axis::Index(0)]),
601                AxisGroup::Group(vec![Axis::Index(1)]),
602                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
603            ],
604            &[2, 2, 2],
605        );
606        roundtrip(
607            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
608            &[
609                AxisGroup::Group(vec![Axis::Index(0)]),
610                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
611                AxisGroup::Group(vec![Axis::Index(2)]),
612            ],
613            &[2, 2, 2],
614        );
615        roundtrip(
616            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
617            &[
618                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
619                AxisGroup::Group(vec![Axis::Index(1)]),
620                AxisGroup::Group(vec![Axis::Index(2)]),
621            ],
622            &[2, 2, 2],
623        );
624    }
625
626    #[test]
627    fn swizzle() {
628        roundtrip(
629            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
630            &[
631                AxisGroup::Group(vec![Axis::Index(0)]),
632                AxisGroup::Group(vec![Axis::Index(1)]),
633                AxisGroup::AllRest(Rest),
634            ],
635            &[2, 2, 2],
636        );
637        roundtrip(
638            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
639            &[
640                AxisGroup::Group(vec![Axis::Index(2)]),
641                AxisGroup::AllRest(Rest),
642                AxisGroup::Group(vec![Axis::Index(1)]),
643            ],
644            &[2, 2, 2],
645        );
646        roundtrip(
647            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
648            &[
649                AxisGroup::AllRest(Rest),
650                AxisGroup::Group(vec![Axis::Index(0)]),
651                AxisGroup::Group(vec![Axis::Index(1)]),
652            ],
653            &[2, 2, 2],
654        );
655
656        roundtrip(
657            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
658            &[
659                AxisGroup::Group(vec![Axis::Index(0)]),
660                AxisGroup::Group(vec![Axis::Index(1)]),
661                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
662            ],
663            &[2, 2, 2],
664        );
665        roundtrip(
666            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
667            &[
668                AxisGroup::Group(vec![Axis::Index(2)]),
669                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
670                AxisGroup::Group(vec![Axis::Index(1)]),
671            ],
672            &[2, 2, 2],
673        );
674        roundtrip(
675            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
676            &[
677                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
678                AxisGroup::Group(vec![Axis::Index(0)]),
679                AxisGroup::Group(vec![Axis::Index(1)]),
680            ],
681            &[2, 2, 2],
682        );
683
684        let mut i = 0;
685        roundtrip(
686            Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
687                i += 1;
688                i
689            })
690            .into_dyn(),
691            &[
692                AxisGroup::Group(vec![Axis::Index(4)]),
693                AxisGroup::Group(vec![Axis::Index(0)]),
694                AxisGroup::Group(vec![Axis::Index(3)]),
695                AxisGroup::Group(vec![Axis::Index(2)]),
696                AxisGroup::Group(vec![Axis::Index(1)]),
697            ],
698            &[1, 3, 1, 721, 1440],
699        );
700    }
701
702    #[test]
703    fn collapse() {
704        roundtrip(
705            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
706            &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
707            &[8],
708        );
709
710        roundtrip(
711            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
712            &[AxisGroup::Group(vec![
713                Axis::Index(0),
714                Axis::Index(1),
715                Axis::Index(2),
716            ])],
717            &[8],
718        );
719        roundtrip(
720            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
721            &[AxisGroup::Group(vec![
722                Axis::Index(2),
723                Axis::Index(1),
724                Axis::Index(0),
725            ])],
726            &[8],
727        );
728        roundtrip(
729            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
730            &[AxisGroup::Group(vec![
731                Axis::Index(1),
732                Axis::MergedRest(Rest),
733            ])],
734            &[8],
735        );
736
737        roundtrip(
738            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
739            &[
740                AxisGroup::Group(vec![Axis::Index(0), Axis::Index(1)]),
741                AxisGroup::AllRest(Rest),
742            ],
743            &[4, 2],
744        );
745        roundtrip(
746            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
747            &[
748                AxisGroup::Group(vec![Axis::Index(2)]),
749                AxisGroup::Group(vec![Axis::Index(1), Axis::Index(0)]),
750            ],
751            &[2, 4],
752        );
753        roundtrip(
754            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
755            &[
756                AxisGroup::Group(vec![Axis::Index(1), Axis::MergedRest(Rest)]),
757                AxisGroup::Group(vec![Axis::Index(0), Axis::Index(2)]),
758            ],
759            &[2, 4],
760        );
761
762        roundtrip(
763            array![[1, 2], [3, 4], [5, 6], [7, 8]].into_dyn(),
764            &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
765            &[8],
766        );
767
768        let mut i = 0;
769        roundtrip(
770            Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
771                i += 1;
772                i
773            })
774            .into_dyn(),
775            &[AxisGroup::Group(vec![
776                Axis::Index(4),
777                Axis::Index(0),
778                Axis::Index(3),
779                Axis::Index(2),
780                Axis::Index(1),
781            ])],
782            &[1 * 3 * 1 * 721 * 1440],
783        );
784    }
785
786    #[test]
787    fn extend() {
788        roundtrip(
789            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
790            &[
791                AxisGroup::Group(vec![]),
792                AxisGroup::Group(vec![Axis::Index(0)]),
793                AxisGroup::Group(vec![]),
794                AxisGroup::AllRest(Rest),
795                AxisGroup::Group(vec![]),
796                AxisGroup::Group(vec![Axis::Index(2)]),
797                AxisGroup::Group(vec![]),
798            ],
799            &[1, 2, 1, 2, 1, 2, 1],
800        );
801    }
802
803    #[expect(clippy::needless_pass_by_value)]
804    fn roundtrip(data: Array<i32, IxDyn>, axes: &[AxisGroup], swizzle_shape: &[usize]) {
805        let swizzled = swizzle_reshape(data.view(), axes).expect("swizzle should not fail");
806
807        assert_eq!(swizzled.shape(), swizzle_shape);
808
809        let mut unswizzled = Array::zeros(data.shape());
810        undo_swizzle_reshape_into(swizzled.view(), unswizzled.view_mut(), axes)
811            .expect("unswizzle into should not fail");
812
813        assert_eq!(data, unswizzled);
814
815        if axes.iter().any(|a| matches!(a, AxisGroup::Group(a) if a.len() != 1 || a.iter().any(|a| matches!(a, Axis::MergedRest(Rest))))) {
816            undo_swizzle_reshape(swizzled.view(), axes).expect_err("unswizzle should fail");
817        } else {
818            let unswizzled = undo_swizzle_reshape(swizzled.view(), axes).expect("unswizzle should not fail");
819            assert_eq!(data, unswizzled);
820        }
821    }
822}