numcodecs_swizzle_reshape/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-swizzle-reshape
10//! [crates.io]: https://crates.io/crates/numcodecs-swizzle-reshape
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-swizzle-reshape
13//! [docs.rs]: https://docs.rs/numcodecs-swizzle-reshape/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_swizzle_reshape
17//!
18//! Array axis swizzle and reshape codec implementation for the [`numcodecs`]
19//! API.
20
21use std::{
22    borrow::Cow,
23    fmt::{self, Debug},
24};
25
26use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, IxDyn};
27use numcodecs::{
28    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
29    Codec, StaticCodec, StaticCodecConfig,
30};
31use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
32use serde::{
33    de::{MapAccess, Visitor},
34    ser::SerializeMap,
35    Deserialize, Deserializer, Serialize, Serializer,
36};
37use thiserror::Error;
38
39#[derive(Clone, Serialize, Deserialize, JsonSchema)]
40#[serde(deny_unknown_fields)]
41/// Codec to swizzle/swap the axes of an array and reshape it.
42///
43/// This codec does not store metadata about the original shape of the array.
44/// Since axes that have been combined during encoding cannot be split without
45/// further information, decoding may fail if an output array is not provided.
46///
47/// Swizzling axes is always supported since no additional information about the
48/// array's shape is required to reconstruct it.
49pub struct SwizzleReshapeCodec {
50    /// The permutation of the axes that is applied on encoding.
51    ///
52    /// The permutation is given as a list of axis groups, where each group
53    /// corresponds to one encoded output axis that may consist of several
54    /// decoded input axes. For instance, `[[0], [1, 2]]` flattens a three-
55    /// dimensional array into a two-dimensional one by combining the second and
56    /// third axes.
57    ///
58    /// The permutation also allows specifying a special catch-all remaining
59    /// axes marker:
60    /// - `[[0], {}]` moves the second axis to be the first and appends all
61    ///   other axes afterwards, i.e. the encoded array has the same number
62    ///   of axes as the input array
63    /// - `[[0], [{}]]` in contrast collapses all other axes into one, i.e.
64    ///   the encoded array is two-dimensional
65    pub axes: Vec<AxisGroup>,
66}
67
68#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
69#[serde(untagged)]
70#[serde(deny_unknown_fields)]
71/// An axis group, potentially from a merged combination of multiple input axes
72pub enum AxisGroup {
73    /// A merged combination of zero, one, or multiple input axes
74    Group(Vec<Axis>),
75    /// All remaining axes, each in a separate single-axis group
76    AllRest(Rest),
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
80#[serde(untagged)]
81#[serde(deny_unknown_fields)]
82/// An axis or all remaining axes
83pub enum Axis {
84    /// A single axis, as determined by its index
85    Index(usize),
86    /// All remaining axes, combined into one
87    MergedRest(Rest),
88}
89
90impl Codec for SwizzleReshapeCodec {
91    type Error = SwizzleReshapeCodecError;
92
93    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
94        match data {
95            AnyCowArray::U8(data) => Ok(AnyArray::U8(swizzle_reshape(data, &self.axes)?)),
96            AnyCowArray::U16(data) => Ok(AnyArray::U16(swizzle_reshape(data, &self.axes)?)),
97            AnyCowArray::U32(data) => Ok(AnyArray::U32(swizzle_reshape(data, &self.axes)?)),
98            AnyCowArray::U64(data) => Ok(AnyArray::U64(swizzle_reshape(data, &self.axes)?)),
99            AnyCowArray::I8(data) => Ok(AnyArray::I8(swizzle_reshape(data, &self.axes)?)),
100            AnyCowArray::I16(data) => Ok(AnyArray::I16(swizzle_reshape(data, &self.axes)?)),
101            AnyCowArray::I32(data) => Ok(AnyArray::I32(swizzle_reshape(data, &self.axes)?)),
102            AnyCowArray::I64(data) => Ok(AnyArray::I64(swizzle_reshape(data, &self.axes)?)),
103            AnyCowArray::F32(data) => Ok(AnyArray::F32(swizzle_reshape(data, &self.axes)?)),
104            AnyCowArray::F64(data) => Ok(AnyArray::F64(swizzle_reshape(data, &self.axes)?)),
105            data => Err(SwizzleReshapeCodecError::UnsupportedDtype(data.dtype())),
106        }
107    }
108
109    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
110        match encoded {
111            AnyCowArray::U8(encoded) => {
112                Ok(AnyArray::U8(undo_swizzle_reshape(encoded, &self.axes)?))
113            }
114            AnyCowArray::U16(encoded) => {
115                Ok(AnyArray::U16(undo_swizzle_reshape(encoded, &self.axes)?))
116            }
117            AnyCowArray::U32(encoded) => {
118                Ok(AnyArray::U32(undo_swizzle_reshape(encoded, &self.axes)?))
119            }
120            AnyCowArray::U64(encoded) => {
121                Ok(AnyArray::U64(undo_swizzle_reshape(encoded, &self.axes)?))
122            }
123            AnyCowArray::I8(encoded) => {
124                Ok(AnyArray::I8(undo_swizzle_reshape(encoded, &self.axes)?))
125            }
126            AnyCowArray::I16(encoded) => {
127                Ok(AnyArray::I16(undo_swizzle_reshape(encoded, &self.axes)?))
128            }
129            AnyCowArray::I32(encoded) => {
130                Ok(AnyArray::I32(undo_swizzle_reshape(encoded, &self.axes)?))
131            }
132            AnyCowArray::I64(encoded) => {
133                Ok(AnyArray::I64(undo_swizzle_reshape(encoded, &self.axes)?))
134            }
135            AnyCowArray::F32(encoded) => {
136                Ok(AnyArray::F32(undo_swizzle_reshape(encoded, &self.axes)?))
137            }
138            AnyCowArray::F64(encoded) => {
139                Ok(AnyArray::F64(undo_swizzle_reshape(encoded, &self.axes)?))
140            }
141            encoded => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
142        }
143    }
144
145    fn decode_into(
146        &self,
147        encoded: AnyArrayView,
148        decoded: AnyArrayViewMut,
149    ) -> Result<(), Self::Error> {
150        match (encoded, decoded) {
151            (AnyArrayView::U8(encoded), AnyArrayViewMut::U8(decoded)) => {
152                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
153            }
154            (AnyArrayView::U16(encoded), AnyArrayViewMut::U16(decoded)) => {
155                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
156            }
157            (AnyArrayView::U32(encoded), AnyArrayViewMut::U32(decoded)) => {
158                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
159            }
160            (AnyArrayView::U64(encoded), AnyArrayViewMut::U64(decoded)) => {
161                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
162            }
163            (AnyArrayView::I8(encoded), AnyArrayViewMut::I8(decoded)) => {
164                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
165            }
166            (AnyArrayView::I16(encoded), AnyArrayViewMut::I16(decoded)) => {
167                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
168            }
169            (AnyArrayView::I32(encoded), AnyArrayViewMut::I32(decoded)) => {
170                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
171            }
172            (AnyArrayView::I64(encoded), AnyArrayViewMut::I64(decoded)) => {
173                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
174            }
175            (AnyArrayView::F32(encoded), AnyArrayViewMut::F32(decoded)) => {
176                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
177            }
178            (AnyArrayView::F64(encoded), AnyArrayViewMut::F64(decoded)) => {
179                undo_swizzle_reshape_into(encoded, decoded, &self.axes)
180            }
181            (encoded, decoded) if encoded.dtype() != decoded.dtype() => {
182                Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
183                    source: AnyArrayAssignError::DTypeMismatch {
184                        src: encoded.dtype(),
185                        dst: decoded.dtype(),
186                    },
187                })
188            }
189            (encoded, _decoded) => Err(SwizzleReshapeCodecError::UnsupportedDtype(encoded.dtype())),
190        }
191    }
192}
193
194impl StaticCodec for SwizzleReshapeCodec {
195    const CODEC_ID: &'static str = "swizzle-reshape";
196
197    type Config<'de> = Self;
198
199    fn from_config(config: Self::Config<'_>) -> Self {
200        config
201    }
202
203    fn get_config(&self) -> StaticCodecConfig<Self> {
204        StaticCodecConfig::from(self)
205    }
206}
207
208#[derive(Debug, Error)]
209/// Errors that may occur when applying the [`SwizzleReshapeCodec`].
210pub enum SwizzleReshapeCodecError {
211    /// [`SwizzleReshapeCodec`] does not support the dtype
212    #[error("SwizzleReshape does not support the dtype {0}")]
213    UnsupportedDtype(AnyArrayDType),
214    /// [`SwizzleReshapeCodec`] cannot decode from an array with merged axes
215    /// without receiving an output array to decode into
216    #[error("SwizzleReshape cannot decode from an array with merged axes without receiving an output array to decode into")]
217    CannotDecodeMergedAxes,
218    /// [`SwizzleReshapeCodec`] cannot encode or decode with an invalid axis
219    /// `index` for an array with `ndim` dimensions
220    #[error("SwizzleReshape cannot encode or decode with an invalid axis {index} for an array with {ndim} dimensions")]
221    InvalidAxisIndex {
222        /// The out-of-bounds axis index
223        index: usize,
224        /// The number of dimensions of the array
225        ndim: usize,
226    },
227    /// [`SwizzleReshapeCodec`] can only encode or decode with an axis
228    /// permutation `axes` that contains every axis of an array with `ndim`
229    /// dimensions index exactly once
230    #[error("SwizzleReshape can only encode or decode with an axis permutation {axes:?} that contains every axis of an array with {ndim} dimensions index exactly once")]
231    InvalidAxisPermutation {
232        /// The invalid permutation of axes
233        axes: Vec<AxisGroup>,
234        /// The number of dimensions of the array
235        ndim: usize,
236    },
237    /// [`SwizzleReshapeCodec`] cannot encode or decode with an axis permutation
238    /// that contains multiple rest-axes markers
239    #[error("SwizzleReshape cannot encode or decode with an axis permutation that contains multiple rest-axes markers")]
240    MultipleRestAxes,
241    /// [`SwizzleReshapeCodec`] cannot decode into the provided array
242    #[error("SwizzleReshape cannot decode into the provided array")]
243    MismatchedDecodeIntoArray {
244        /// The source of the error
245        #[from]
246        source: AnyArrayAssignError,
247    },
248}
249
250#[expect(clippy::missing_panics_doc)]
251/// Swizzle and reshape the input `data` array with the new `axes`.
252///
253/// # Errors
254///
255/// Errors with
256/// - [`SwizzleReshapeCodecError::InvalidAxisIndex`] if any axis is out of
257///   bounds
258/// - [`SwizzleReshapeCodecError::InvalidAxisPermutation`] if the `axes`
259///   permutation does not contain every axis index exactly once
260/// - [`SwizzleReshapeCodecError::MultipleRestAxes`] if the `axes` permutation
261///   contains more than one [`Rest`]-axes marker
262pub fn swizzle_reshape<T: Copy, S: Data<Elem = T>>(
263    data: ArrayBase<S, IxDyn>,
264    axes: &[AxisGroup],
265) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
266    let SwizzleReshapeAxes {
267        permutation,
268        swizzled_shape,
269        new_shape,
270    } = validate_into_axes_shape(&data, axes)?;
271
272    let swizzled: ArrayBase<S, ndarray::Dim<ndarray::IxDynImpl>> = data.permuted_axes(permutation);
273    assert_eq!(swizzled.shape(), swizzled_shape, "incorrect swizzled shape");
274
275    #[expect(clippy::expect_used)] // only panics on an implementation bug
276    let reshaped = swizzled
277        .into_owned()
278        .into_shape_clone(new_shape)
279        .expect("new encoding shape should have the correct number of elements");
280
281    Ok(reshaped)
282}
283
284/// Reverts the swizzle and reshape of the `encoded` array with the `axes` and
285/// returns the original array.
286///
287/// Since the shape of the original array is not known, only permutations of
288/// axes are supported.
289///
290/// # Errors
291///
292/// Errors with
293/// - [`SwizzleReshapeCodecError::CannotDecodeMergedAxes`] if any axes were
294///   merged and thus cannot be split without further information
295/// - [`SwizzleReshapeCodecError::InvalidAxisIndex`] if any axis is out of
296///   bounds
297/// - [`SwizzleReshapeCodecError::InvalidAxisPermutation`] if the `axes`
298///   permutation does not contain every axis index exactly once
299/// - [`SwizzleReshapeCodecError::MultipleRestAxes`] if the `axes` permutation
300///   contains more than one [`Rest`]-axes marker
301pub fn undo_swizzle_reshape<T: Copy, S: Data<Elem = T>>(
302    encoded: ArrayBase<S, IxDyn>,
303    axes: &[AxisGroup],
304) -> Result<Array<T, IxDyn>, SwizzleReshapeCodecError> {
305    if !axes.iter().all(|axis| match axis {
306        AxisGroup::Group(axes) => matches!(axes.as_slice(), [Axis::Index(_)]),
307        AxisGroup::AllRest(Rest) => true,
308    }) {
309        return Err(SwizzleReshapeCodecError::CannotDecodeMergedAxes);
310    }
311
312    let SwizzleReshapeAxes { permutation, .. } = validate_into_axes_shape(&encoded, axes)?;
313
314    let mut inverse_permutation = vec![0; permutation.len()];
315    #[expect(clippy::indexing_slicing)] // all are guaranteed to be in range
316    for (i, p) in permutation.into_iter().enumerate() {
317        inverse_permutation[p] = i;
318    }
319
320    // since no axes were merged, no reshape is needed
321    let unshaped = encoded;
322    let unswizzled = unshaped.permuted_axes(inverse_permutation);
323
324    Ok(unswizzled.into_owned())
325}
326
327#[expect(clippy::missing_panics_doc)]
328#[expect(clippy::needless_pass_by_value)]
329/// Reverts the swizzle and reshape of the `encoded` array with the `axes` and
330/// outputs it into the `decoded` array.
331///
332/// # Errors
333///
334/// Errors with
335/// - [`SwizzleReshapeCodecError::InvalidAxisIndex`] if any axis is out of
336///   bounds
337/// - [`SwizzleReshapeCodecError::InvalidAxisPermutation`] if the `axes`
338///   permutation does not contain every axis index exactly once
339/// - [`SwizzleReshapeCodecError::MultipleRestAxes`] if the `axes` permutation
340///   contains more than one [`Rest`]-axes marker
341/// - [`SwizzleReshapeCodecError::MismatchedDecodeIntoArray`] if the `encoded`
342///   array's shape does not match the shape that swizzling and reshaping an
343///   array of the `decoded` array's shape would have produced
344pub fn undo_swizzle_reshape_into<T: Copy>(
345    encoded: ArrayView<T, IxDyn>,
346    mut decoded: ArrayViewMut<T, IxDyn>,
347    axes: &[AxisGroup],
348) -> Result<(), SwizzleReshapeCodecError> {
349    let SwizzleReshapeAxes {
350        permutation,
351        swizzled_shape,
352        new_shape,
353    } = validate_into_axes_shape(&decoded, axes)?;
354
355    if encoded.shape() != new_shape {
356        return Err(SwizzleReshapeCodecError::MismatchedDecodeIntoArray {
357            source: AnyArrayAssignError::ShapeMismatch {
358                src: encoded.shape().to_vec(),
359                dst: new_shape,
360            },
361        });
362    }
363
364    let mut inverse_permutation = vec![0; decoded.ndim()];
365    #[expect(clippy::indexing_slicing)] // all are guaranteed to be in range
366    for (i, p) in permutation.into_iter().enumerate() {
367        inverse_permutation[p] = i;
368    }
369
370    #[expect(clippy::expect_used)] // only panics on an implementation bug
371    let unshaped = encoded
372        .to_shape(swizzled_shape)
373        .expect("new decoding shape should have the correct number of elements");
374    let unswizzled = unshaped.permuted_axes(inverse_permutation);
375
376    decoded.assign(&unswizzled);
377
378    Ok(())
379}
380
381struct SwizzleReshapeAxes {
382    permutation: Vec<usize>,
383    swizzled_shape: Vec<usize>,
384    new_shape: Vec<usize>,
385}
386
387fn validate_into_axes_shape<T, S: Data<Elem = T>>(
388    array: &ArrayBase<S, IxDyn>,
389    axes: &[AxisGroup],
390) -> Result<SwizzleReshapeAxes, SwizzleReshapeCodecError> {
391    // counts of each axis index, used to check for missing or duplicate axes,
392    //  and for knowing which axes are caught by the rest catch-all
393    let mut axis_index_counts = vec![0_usize; array.ndim()];
394
395    let mut has_rest = false;
396
397    // validate that all axis indices are in bounds and that there is at most
398    //  one catch-all remaining axes marker
399    for group in axes {
400        match group {
401            AxisGroup::Group(axes) => {
402                for axis in axes {
403                    match axis {
404                        Axis::Index(index) => {
405                            if let Some(c) = axis_index_counts.get_mut(*index) {
406                                *c += 1;
407                            } else {
408                                return Err(SwizzleReshapeCodecError::InvalidAxisIndex {
409                                    index: *index,
410                                    ndim: array.ndim(),
411                                });
412                            }
413                        }
414                        Axis::MergedRest(Rest) => {
415                            if std::mem::replace(&mut has_rest, true) {
416                                return Err(SwizzleReshapeCodecError::MultipleRestAxes);
417                            }
418                        }
419                    }
420                }
421            }
422            AxisGroup::AllRest(Rest) => {
423                if std::mem::replace(&mut has_rest, true) {
424                    return Err(SwizzleReshapeCodecError::MultipleRestAxes);
425                }
426            }
427        }
428    }
429
430    // check that each axis is mentioned
431    // - exactly once if no catch-all is used
432    // - at most once if a catch-all is used
433    if !axis_index_counts
434        .iter()
435        .all(|c| if has_rest { *c <= 1 } else { *c == 1 })
436    {
437        return Err(SwizzleReshapeCodecError::InvalidAxisPermutation {
438            axes: axes.to_vec(),
439            ndim: array.ndim(),
440        });
441    }
442
443    // the permutation to apply to the input axes
444    let mut axis_permutation = Vec::with_capacity(array.len());
445    // the shape of the already permuted intermediary array
446    let mut permuted_shape = Vec::with_capacity(array.len());
447    // the shape of the already permuted and grouped output array
448    let mut grouped_shape = Vec::with_capacity(axes.len());
449
450    for axis in axes {
451        match axis {
452            // a group merged all of its axes
453            // an empty group adds an additional axis of size 1
454            AxisGroup::Group(axes) => {
455                let mut new_len = 1;
456                for axis in axes {
457                    match axis {
458                        Axis::Index(index) => {
459                            axis_permutation.push(*index);
460                            permuted_shape.push(array.len_of(ndarray::Axis(*index)));
461                            new_len *= array.len_of(ndarray::Axis(*index));
462                        }
463                        Axis::MergedRest(Rest) => {
464                            for (index, count) in axis_index_counts.iter().enumerate() {
465                                if *count == 0 {
466                                    axis_permutation.push(index);
467                                    permuted_shape.push(array.len_of(ndarray::Axis(index)));
468                                    new_len *= array.len_of(ndarray::Axis(index));
469                                }
470                            }
471                        }
472                    }
473                }
474                grouped_shape.push(new_len);
475            }
476            AxisGroup::AllRest(Rest) => {
477                for (index, count) in axis_index_counts.iter().enumerate() {
478                    if *count == 0 {
479                        axis_permutation.push(index);
480                        permuted_shape.push(array.len_of(ndarray::Axis(index)));
481                        grouped_shape.push(array.len_of(ndarray::Axis(index)));
482                    }
483                }
484            }
485        }
486    }
487
488    Ok(SwizzleReshapeAxes {
489        permutation: axis_permutation,
490        swizzled_shape: permuted_shape,
491        new_shape: grouped_shape,
492    })
493}
494
495#[derive(Copy, Clone, Debug)]
496/// Marker to signify all remaining (not explicitly named) axes
497pub struct Rest;
498
499impl Serialize for Rest {
500    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
501        serializer.serialize_map(Some(0))?.end()
502    }
503}
504
505impl<'de> Deserialize<'de> for Rest {
506    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
507        struct RestVisitor;
508
509        impl<'de> Visitor<'de> for RestVisitor {
510            type Value = Rest;
511
512            fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
513                fmt.write_str("an empty map")
514            }
515
516            fn visit_map<A: MapAccess<'de>>(self, _map: A) -> Result<Self::Value, A::Error> {
517                Ok(Rest)
518            }
519        }
520
521        deserializer.deserialize_map(RestVisitor)
522    }
523}
524
525impl JsonSchema for Rest {
526    fn schema_name() -> Cow<'static, str> {
527        Cow::Borrowed("Rest")
528    }
529
530    fn schema_id() -> Cow<'static, str> {
531        Cow::Borrowed(concat!(module_path!(), "::", "Rest"))
532    }
533
534    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
535        json_schema!({
536            "type": "object",
537            "properties": {},
538            "additionalProperties": false,
539        })
540    }
541}
542
543#[cfg(test)]
544#[expect(clippy::expect_used)]
545mod tests {
546    use ndarray::array;
547
548    use super::*;
549
550    #[test]
551    fn identity() {
552        roundtrip(
553            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
554            &[AxisGroup::AllRest(Rest)],
555            &[2, 2, 2],
556        );
557
558        roundtrip(
559            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
560            &[
561                AxisGroup::Group(vec![Axis::Index(0)]),
562                AxisGroup::Group(vec![Axis::Index(1)]),
563                AxisGroup::AllRest(Rest),
564            ],
565            &[2, 2, 2],
566        );
567        roundtrip(
568            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
569            &[
570                AxisGroup::Group(vec![Axis::Index(0)]),
571                AxisGroup::AllRest(Rest),
572                AxisGroup::Group(vec![Axis::Index(2)]),
573            ],
574            &[2, 2, 2],
575        );
576        roundtrip(
577            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
578            &[
579                AxisGroup::AllRest(Rest),
580                AxisGroup::Group(vec![Axis::Index(1)]),
581                AxisGroup::Group(vec![Axis::Index(2)]),
582            ],
583            &[2, 2, 2],
584        );
585
586        roundtrip(
587            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
588            &[
589                AxisGroup::Group(vec![Axis::Index(0)]),
590                AxisGroup::Group(vec![Axis::Index(1)]),
591                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
592            ],
593            &[2, 2, 2],
594        );
595        roundtrip(
596            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
597            &[
598                AxisGroup::Group(vec![Axis::Index(0)]),
599                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
600                AxisGroup::Group(vec![Axis::Index(2)]),
601            ],
602            &[2, 2, 2],
603        );
604        roundtrip(
605            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
606            &[
607                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
608                AxisGroup::Group(vec![Axis::Index(1)]),
609                AxisGroup::Group(vec![Axis::Index(2)]),
610            ],
611            &[2, 2, 2],
612        );
613    }
614
615    #[test]
616    fn swizzle() {
617        roundtrip(
618            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
619            &[
620                AxisGroup::Group(vec![Axis::Index(0)]),
621                AxisGroup::Group(vec![Axis::Index(1)]),
622                AxisGroup::AllRest(Rest),
623            ],
624            &[2, 2, 2],
625        );
626        roundtrip(
627            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
628            &[
629                AxisGroup::Group(vec![Axis::Index(2)]),
630                AxisGroup::AllRest(Rest),
631                AxisGroup::Group(vec![Axis::Index(1)]),
632            ],
633            &[2, 2, 2],
634        );
635        roundtrip(
636            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
637            &[
638                AxisGroup::AllRest(Rest),
639                AxisGroup::Group(vec![Axis::Index(0)]),
640                AxisGroup::Group(vec![Axis::Index(1)]),
641            ],
642            &[2, 2, 2],
643        );
644
645        roundtrip(
646            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
647            &[
648                AxisGroup::Group(vec![Axis::Index(0)]),
649                AxisGroup::Group(vec![Axis::Index(1)]),
650                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
651            ],
652            &[2, 2, 2],
653        );
654        roundtrip(
655            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
656            &[
657                AxisGroup::Group(vec![Axis::Index(2)]),
658                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
659                AxisGroup::Group(vec![Axis::Index(1)]),
660            ],
661            &[2, 2, 2],
662        );
663        roundtrip(
664            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
665            &[
666                AxisGroup::Group(vec![Axis::MergedRest(Rest)]),
667                AxisGroup::Group(vec![Axis::Index(0)]),
668                AxisGroup::Group(vec![Axis::Index(1)]),
669            ],
670            &[2, 2, 2],
671        );
672
673        let mut i = 0;
674        roundtrip(
675            Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
676                i += 1;
677                i
678            })
679            .into_dyn(),
680            &[
681                AxisGroup::Group(vec![Axis::Index(4)]),
682                AxisGroup::Group(vec![Axis::Index(0)]),
683                AxisGroup::Group(vec![Axis::Index(3)]),
684                AxisGroup::Group(vec![Axis::Index(2)]),
685                AxisGroup::Group(vec![Axis::Index(1)]),
686            ],
687            &[1, 3, 1, 721, 1440],
688        );
689    }
690
691    #[test]
692    fn collapse() {
693        roundtrip(
694            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
695            &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
696            &[8],
697        );
698
699        roundtrip(
700            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
701            &[AxisGroup::Group(vec![
702                Axis::Index(0),
703                Axis::Index(1),
704                Axis::Index(2),
705            ])],
706            &[8],
707        );
708        roundtrip(
709            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
710            &[AxisGroup::Group(vec![
711                Axis::Index(2),
712                Axis::Index(1),
713                Axis::Index(0),
714            ])],
715            &[8],
716        );
717        roundtrip(
718            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
719            &[AxisGroup::Group(vec![
720                Axis::Index(1),
721                Axis::MergedRest(Rest),
722            ])],
723            &[8],
724        );
725
726        roundtrip(
727            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
728            &[
729                AxisGroup::Group(vec![Axis::Index(0), Axis::Index(1)]),
730                AxisGroup::AllRest(Rest),
731            ],
732            &[4, 2],
733        );
734        roundtrip(
735            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
736            &[
737                AxisGroup::Group(vec![Axis::Index(2)]),
738                AxisGroup::Group(vec![Axis::Index(1), Axis::Index(0)]),
739            ],
740            &[2, 4],
741        );
742        roundtrip(
743            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
744            &[
745                AxisGroup::Group(vec![Axis::Index(1), Axis::MergedRest(Rest)]),
746                AxisGroup::Group(vec![Axis::Index(0), Axis::Index(2)]),
747            ],
748            &[2, 4],
749        );
750
751        roundtrip(
752            array![[1, 2], [3, 4], [5, 6], [7, 8]].into_dyn(),
753            &[AxisGroup::Group(vec![Axis::MergedRest(Rest)])],
754            &[8],
755        );
756
757        let mut i = 0;
758        roundtrip(
759            Array::from_shape_fn([3, 1440, 721, 1, 1], |_| {
760                i += 1;
761                i
762            })
763            .into_dyn(),
764            &[AxisGroup::Group(vec![
765                Axis::Index(4),
766                Axis::Index(0),
767                Axis::Index(3),
768                Axis::Index(2),
769                Axis::Index(1),
770            ])],
771            &[1 * 3 * 1 * 721 * 1440],
772        );
773    }
774
775    #[test]
776    fn extend() {
777        roundtrip(
778            array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
779            &[
780                AxisGroup::Group(vec![]),
781                AxisGroup::Group(vec![Axis::Index(0)]),
782                AxisGroup::Group(vec![]),
783                AxisGroup::AllRest(Rest),
784                AxisGroup::Group(vec![]),
785                AxisGroup::Group(vec![Axis::Index(2)]),
786                AxisGroup::Group(vec![]),
787            ],
788            &[1, 2, 1, 2, 1, 2, 1],
789        );
790    }
791
792    #[expect(clippy::needless_pass_by_value)]
793    fn roundtrip(data: Array<i32, IxDyn>, axes: &[AxisGroup], swizzle_shape: &[usize]) {
794        let swizzled = swizzle_reshape(data.view(), axes).expect("swizzle should not fail");
795
796        assert_eq!(swizzled.shape(), swizzle_shape);
797
798        let mut unswizzled = Array::zeros(data.shape());
799        undo_swizzle_reshape_into(swizzled.view(), unswizzled.view_mut(), axes)
800            .expect("unswizzle into should not fail");
801
802        assert_eq!(data, unswizzled);
803
804        if axes.iter().any(|a| matches!(a, AxisGroup::Group(a) if a.len() != 1 || a.iter().any(|a| matches!(a, Axis::MergedRest(Rest))))) {
805            undo_swizzle_reshape(swizzled.view(), axes).expect_err("unswizzle should fail");
806        } else {
807            let unswizzled = undo_swizzle_reshape(swizzled.view(), axes).expect("unswizzle should not fail");
808            assert_eq!(data, unswizzled);
809        }
810    }
811}