numcodecs_python/
adapter.rs

1use std::sync::Arc;
2
3use ndarray::{ArrayBase, DataMut, Dimension};
4use numcodecs::{
5    AnyArray, AnyArrayBase, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec,
6    DynCodecType,
7};
8use numpy::{Element, PyArray, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods};
9use pyo3::{
10    exceptions::{PyTypeError, PyValueError},
11    intern,
12    prelude::*,
13    types::{IntoPyDict, PyDict, PyDictMethods},
14};
15use pythonize::{Depythonizer, Pythonizer};
16use schemars::Schema;
17use serde::{Deserializer, Serializer};
18use serde_transcode::transcode;
19
20use crate::{
21    export::{RustCodec, RustCodecType},
22    schema::schema_from_codec_class,
23    utils::numpy_asarray,
24    PyCodec, PyCodecClass, PyCodecClassMethods, PyCodecMethods, PyCodecRegistry,
25};
26
27/// Wrapper around [`PyCodec`]s to use the [`Codec`] API.
28pub struct PyCodecAdapter {
29    codec: Py<PyCodec>,
30    class: Py<PyCodecClass>,
31    codec_id: Arc<String>,
32    codec_config_schema: Arc<Schema>,
33}
34
35impl PyCodecAdapter {
36    /// Instantiate a codec from the [`PyCodecRegistry`] with a serialized
37    /// `config`uration.
38    ///
39    /// The config *must* include the `id` field with the
40    /// [`PyCodecClassMethods::codec_id`].
41    ///
42    /// # Errors
43    ///
44    /// Errors if no codec with a matching `id` has been registered, or if
45    /// constructing the codec fails.
46    pub fn from_registry_with_config<'de, D: Deserializer<'de>>(
47        config: D,
48    ) -> Result<Self, D::Error> {
49        Python::with_gil(|py| {
50            let config = transcode(config, Pythonizer::new(py))?;
51            let config: Bound<PyDict> = config.extract()?;
52
53            let codec = PyCodecRegistry::get_codec(config.as_borrowed())?;
54
55            Self::from_codec(codec)
56        })
57        .map_err(serde::de::Error::custom)
58    }
59
60    /// Wraps a [`PyCodec`] to use the [`Codec`] API.
61    ///
62    /// # Errors
63    ///
64    /// Errors if the `codec`'s class does not provide an identifier.
65    pub fn from_codec(codec: Bound<PyCodec>) -> Result<Self, PyErr> {
66        let class = codec.class();
67        let codec_id = class.codec_id()?;
68        let codec_config_schema = schema_from_codec_class(class.py(), &class).map_err(|err| {
69            PyTypeError::new_err(format!(
70                "failed to extract the {codec_id} codec config schema: {err}"
71            ))
72        })?;
73
74        Ok(Self {
75            codec: codec.unbind(),
76            class: class.unbind(),
77            codec_id: Arc::new(codec_id),
78            codec_config_schema: Arc::new(codec_config_schema),
79        })
80    }
81
82    /// Access the wrapped [`PyCodec`] to use its [`PyCodecMethods`] API.
83    #[must_use]
84    pub fn as_codec<'py>(&self, py: Python<'py>) -> &Bound<'py, PyCodec> {
85        self.codec.bind(py)
86    }
87
88    /// Unwrap the [`PyCodec`] to use its [`PyCodecMethods`] API.
89    #[must_use]
90    pub fn into_codec(self, py: Python) -> Bound<PyCodec> {
91        self.codec.into_bound(py)
92    }
93
94    /// Try to [`clone`][`Clone::clone`] this codec.
95    ///
96    /// # Errors
97    ///
98    /// Errors if extracting this codec's config or creating a new codec from
99    /// the config fails.
100    pub fn try_clone(&self, py: Python) -> Result<Self, PyErr> {
101        let config = self.codec.bind(py).get_config()?;
102
103        // removing the `id` field may fail if the config doesn't contain it
104        let _ = config.del_item(intern!(py, "id"));
105
106        let codec = self
107            .class
108            .bind(py)
109            .codec_from_config(config.as_borrowed())?;
110
111        Ok(Self {
112            codec: codec.unbind(),
113            class: self.class.clone_ref(py),
114            codec_id: self.codec_id.clone(),
115            codec_config_schema: self.codec_config_schema.clone(),
116        })
117    }
118
119    /// If `codec` represents an exported [`DynCodec`] `T`, i.e. its class was
120    /// initially created with [`crate::export_codec_class`], the `with` closure
121    /// provides access to the instance of type `T`.
122    ///
123    /// If `codec` is not an instance of `T`, the `with` closure is *not* run
124    /// and `None` is returned.
125    pub fn with_downcast<T: DynCodec, O>(
126        codec: &Bound<PyCodec>,
127        with: impl for<'a> FnOnce(&'a T) -> O,
128    ) -> Option<O> {
129        let Ok(codec) = codec.downcast::<RustCodec>() else {
130            return None;
131        };
132
133        let codec = codec.get().downcast()?;
134
135        Some(with(codec))
136    }
137}
138
139impl Codec for PyCodecAdapter {
140    type Error = PyErr;
141
142    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
143        Python::with_gil(|py| {
144            self.with_any_array_view_as_ndarray(py, &data.view(), |data| {
145                let encoded = self.codec.bind(py).encode(data.as_borrowed())?;
146
147                Self::any_array_from_ndarray_like(py, encoded.as_borrowed())
148            })
149        })
150    }
151
152    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
153        Python::with_gil(|py| {
154            self.with_any_array_view_as_ndarray(py, &encoded.view(), |encoded| {
155                let decoded = self.codec.bind(py).decode(encoded.as_borrowed(), None)?;
156
157                Self::any_array_from_ndarray_like(py, decoded.as_borrowed())
158            })
159        })
160    }
161
162    fn decode_into(
163        &self,
164        encoded: AnyArrayView,
165        mut decoded: AnyArrayViewMut,
166    ) -> Result<(), Self::Error> {
167        Python::with_gil(|py| {
168            let decoded_out = self.with_any_array_view_as_ndarray(py, &encoded, |encoded| {
169                self.with_any_array_view_mut_as_ndarray(py, &mut decoded, |decoded_in| {
170                    let decoded_out = self
171                        .codec
172                        .bind(py)
173                        .decode(encoded.as_borrowed(), Some(decoded_in.as_borrowed()))?;
174
175                    // Ideally, all codecs should just use the provided out array
176                    if decoded_out.is(decoded_in) {
177                        Ok(Ok(()))
178                    } else {
179                        Ok(Err(decoded_out.unbind()))
180                    }
181                })
182            })?;
183            let decoded_out = match decoded_out {
184                Ok(()) => return Ok(()),
185                Err(decoded_out) => decoded_out.into_bound(py),
186            };
187
188            // Otherwise, we force-copy the output into the decoded array
189            Self::copy_into_any_array_view_mut_from_ndarray_like(
190                py,
191                &mut decoded,
192                decoded_out.as_borrowed(),
193            )
194        })
195    }
196}
197
198impl PyCodecAdapter {
199    fn with_any_array_view_as_ndarray<T>(
200        &self,
201        py: Python,
202        view: &AnyArrayView,
203        with: impl for<'a> FnOnce(&'a Bound<PyAny>) -> Result<T, PyErr>,
204    ) -> Result<T, PyErr> {
205        let this = self.codec.bind(py).clone().into_any();
206
207        #[expect(unsafe_code)] // FIXME: we trust Python code to not store this array
208        let ndarray = unsafe {
209            match &view {
210                AnyArrayBase::U8(v) => PyArray::borrow_from_array(v, this).into_any(),
211                AnyArrayBase::U16(v) => PyArray::borrow_from_array(v, this).into_any(),
212                AnyArrayBase::U32(v) => PyArray::borrow_from_array(v, this).into_any(),
213                AnyArrayBase::U64(v) => PyArray::borrow_from_array(v, this).into_any(),
214                AnyArrayBase::I8(v) => PyArray::borrow_from_array(v, this).into_any(),
215                AnyArrayBase::I16(v) => PyArray::borrow_from_array(v, this).into_any(),
216                AnyArrayBase::I32(v) => PyArray::borrow_from_array(v, this).into_any(),
217                AnyArrayBase::I64(v) => PyArray::borrow_from_array(v, this).into_any(),
218                AnyArrayBase::F32(v) => PyArray::borrow_from_array(v, this).into_any(),
219                AnyArrayBase::F64(v) => PyArray::borrow_from_array(v, this).into_any(),
220                _ => {
221                    return Err(PyTypeError::new_err(format!(
222                        "unsupported type {} of read-only array view",
223                        view.dtype()
224                    )))
225                }
226            }
227        };
228
229        // create a fully-immutable view of the data that is safe to pass to Python
230        ndarray.call_method(
231            intern!(py, "setflags"),
232            (),
233            Some(&[(intern!(py, "write"), false)].into_py_dict(py)?),
234        )?;
235        let view = ndarray.call_method0(intern!(py, "view"))?;
236
237        with(&view)
238    }
239
240    fn with_any_array_view_mut_as_ndarray<T>(
241        &self,
242        py: Python,
243        view_mut: &mut AnyArrayViewMut,
244        with: impl for<'a> FnOnce(&'a Bound<PyAny>) -> Result<T, PyErr>,
245    ) -> Result<T, PyErr> {
246        let this = self.codec.bind(py).clone().into_any();
247
248        #[expect(unsafe_code)] // FIXME: we trust Python code to not store this array
249        let ndarray = unsafe {
250            match &view_mut {
251                AnyArrayBase::U8(v) => PyArray::borrow_from_array(v, this).into_any(),
252                AnyArrayBase::U16(v) => PyArray::borrow_from_array(v, this).into_any(),
253                AnyArrayBase::U32(v) => PyArray::borrow_from_array(v, this).into_any(),
254                AnyArrayBase::U64(v) => PyArray::borrow_from_array(v, this).into_any(),
255                AnyArrayBase::I8(v) => PyArray::borrow_from_array(v, this).into_any(),
256                AnyArrayBase::I16(v) => PyArray::borrow_from_array(v, this).into_any(),
257                AnyArrayBase::I32(v) => PyArray::borrow_from_array(v, this).into_any(),
258                AnyArrayBase::I64(v) => PyArray::borrow_from_array(v, this).into_any(),
259                AnyArrayBase::F32(v) => PyArray::borrow_from_array(v, this).into_any(),
260                AnyArrayBase::F64(v) => PyArray::borrow_from_array(v, this).into_any(),
261                _ => {
262                    return Err(PyTypeError::new_err(format!(
263                        "unsupported type {} of read-only array view",
264                        view_mut.dtype()
265                    )))
266                }
267            }
268        };
269
270        with(&ndarray)
271    }
272
273    fn any_array_from_ndarray_like(
274        py: Python,
275        array_like: Borrowed<PyAny>,
276    ) -> Result<AnyArray, PyErr> {
277        let ndarray = numpy_asarray(py, array_like)?;
278
279        let array = if let Ok(e) = ndarray.downcast::<PyArrayDyn<u8>>() {
280            AnyArrayBase::U8(e.try_readonly()?.to_owned_array())
281        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u16>>() {
282            AnyArrayBase::U16(e.try_readonly()?.to_owned_array())
283        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u32>>() {
284            AnyArrayBase::U32(e.try_readonly()?.to_owned_array())
285        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u64>>() {
286            AnyArrayBase::U64(e.try_readonly()?.to_owned_array())
287        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i8>>() {
288            AnyArrayBase::I8(e.try_readonly()?.to_owned_array())
289        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i16>>() {
290            AnyArrayBase::I16(e.try_readonly()?.to_owned_array())
291        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i32>>() {
292            AnyArrayBase::I32(e.try_readonly()?.to_owned_array())
293        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i64>>() {
294            AnyArrayBase::I64(e.try_readonly()?.to_owned_array())
295        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<f32>>() {
296            AnyArrayBase::F32(e.try_readonly()?.to_owned_array())
297        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<f64>>() {
298            AnyArrayBase::F64(e.try_readonly()?.to_owned_array())
299        } else {
300            return Err(PyTypeError::new_err(format!(
301                "unsupported dtype {} of array-like",
302                ndarray.dtype()
303            )));
304        };
305
306        Ok(array)
307    }
308
309    fn copy_into_any_array_view_mut_from_ndarray_like(
310        py: Python,
311        view_mut: &mut AnyArrayViewMut,
312        array_like: Borrowed<PyAny>,
313    ) -> Result<(), PyErr> {
314        fn shape_checked_assign<
315            T: Copy + Element,
316            S2: DataMut<Elem = T>,
317            D1: Dimension,
318            D2: Dimension,
319        >(
320            src: &Bound<PyArray<T, D1>>,
321            dst: &mut ArrayBase<S2, D2>,
322        ) -> Result<(), PyErr> {
323            #[expect(clippy::unit_arg)]
324            if src.shape() == dst.shape() {
325                Ok(dst.assign(&src.try_readonly()?.as_array()))
326            } else {
327                Err(PyValueError::new_err(format!(
328                    "mismatching shape {:?} of array-like, expected {:?}",
329                    src.shape(),
330                    dst.shape(),
331                )))
332            }
333        }
334
335        let ndarray = numpy_asarray(py, array_like)?;
336
337        if let Ok(d) = ndarray.downcast::<PyArrayDyn<u8>>() {
338            if let AnyArrayBase::U8(ref mut view_mut) = view_mut {
339                return shape_checked_assign(d, view_mut);
340            }
341        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u16>>() {
342            if let AnyArrayBase::U16(ref mut view_mut) = view_mut {
343                return shape_checked_assign(d, view_mut);
344            }
345        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u32>>() {
346            if let AnyArrayBase::U32(ref mut view_mut) = view_mut {
347                return shape_checked_assign(d, view_mut);
348            }
349        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u64>>() {
350            if let AnyArrayBase::U64(ref mut view_mut) = view_mut {
351                return shape_checked_assign(d, view_mut);
352            }
353        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i8>>() {
354            if let AnyArrayBase::I8(ref mut view_mut) = view_mut {
355                return shape_checked_assign(d, view_mut);
356            }
357        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i16>>() {
358            if let AnyArrayBase::I16(ref mut view_mut) = view_mut {
359                return shape_checked_assign(d, view_mut);
360            }
361        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i32>>() {
362            if let AnyArrayBase::I32(ref mut view_mut) = view_mut {
363                return shape_checked_assign(d, view_mut);
364            }
365        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i64>>() {
366            if let AnyArrayBase::I64(ref mut view_mut) = view_mut {
367                return shape_checked_assign(d, view_mut);
368            }
369        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<f32>>() {
370            if let AnyArrayBase::F32(ref mut view_mut) = view_mut {
371                return shape_checked_assign(d, view_mut);
372            }
373        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<f64>>() {
374            if let AnyArrayBase::F64(ref mut view_mut) = view_mut {
375                return shape_checked_assign(d, view_mut);
376            }
377        } else {
378            return Err(PyTypeError::new_err(format!(
379                "unsupported dtype {} of array-like",
380                ndarray.dtype()
381            )));
382        };
383
384        Err(PyTypeError::new_err(format!(
385            "mismatching dtype {} of array-like, expected {}",
386            ndarray.dtype(),
387            view_mut.dtype(),
388        )))
389    }
390}
391
392impl Clone for PyCodecAdapter {
393    fn clone(&self) -> Self {
394        #[expect(clippy::expect_used)] // clone is *not* fallible
395        Python::with_gil(|py| {
396            self.try_clone(py)
397                .expect("cloning a PyCodec should not fail")
398        })
399    }
400}
401
402impl DynCodec for PyCodecAdapter {
403    type Type = PyCodecClassAdapter;
404
405    fn ty(&self) -> Self::Type {
406        Python::with_gil(|py| PyCodecClassAdapter {
407            class: self.class.clone_ref(py),
408            codec_id: self.codec_id.clone(),
409            codec_config_schema: self.codec_config_schema.clone(),
410        })
411    }
412
413    fn get_config<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
414        Python::with_gil(|py| {
415            let config = self
416                .codec
417                .bind(py)
418                .get_config()
419                .map_err(serde::ser::Error::custom)?;
420
421            transcode(&mut Depythonizer::from_object(config.as_any()), serializer)
422        })
423    }
424}
425
426/// Wrapper around [`PyCodecClass`]es to use the [`DynCodecType`] API.
427pub struct PyCodecClassAdapter {
428    class: Py<PyCodecClass>,
429    codec_id: Arc<String>,
430    codec_config_schema: Arc<Schema>,
431}
432
433impl PyCodecClassAdapter {
434    /// Wraps a [`PyCodecClass`] to use the [`DynCodecType`] API.
435    ///
436    /// # Errors
437    ///
438    /// Errors if the codec `class` does not provide an identifier.
439    pub fn from_codec_class(class: Bound<PyCodecClass>) -> Result<Self, PyErr> {
440        let codec_id = class.codec_id()?;
441
442        let codec_config_schema = schema_from_codec_class(class.py(), &class).map_err(|err| {
443            PyTypeError::new_err(format!(
444                "failed to extract the {codec_id} codec config schema: {err}"
445            ))
446        })?;
447
448        Ok(Self {
449            class: class.unbind(),
450            codec_id: Arc::new(codec_id),
451            codec_config_schema: Arc::new(codec_config_schema),
452        })
453    }
454
455    /// Access the wrapped [`PyCodecClass`] to use its [`PyCodecClassMethods`]
456    /// API.
457    #[must_use]
458    pub fn as_codec_class<'py>(&self, py: Python<'py>) -> &Bound<'py, PyCodecClass> {
459        self.class.bind(py)
460    }
461
462    /// Unwrap the [`PyCodecClass`] to use its [`PyCodecClassMethods`] API.
463    #[must_use]
464    pub fn into_codec_class(self, py: Python) -> Bound<PyCodecClass> {
465        self.class.into_bound(py)
466    }
467
468    /// If `class` represents an exported [`DynCodecType`] `T`, i.e. it was
469    /// initially created with [`crate::export_codec_class`], the `with` closure
470    /// provides access to the instance of type `T`.
471    ///
472    /// If `class` is not an instance of `T`, the `with` closure is *not* run
473    /// and `None` is returned.
474    pub fn with_downcast<T: DynCodecType, O>(
475        class: &Bound<PyCodecClass>,
476        with: impl for<'a> FnOnce(&'a T) -> O,
477    ) -> Option<O> {
478        let Ok(ty) = class.getattr(intern!(class.py(), RustCodec::TYPE_ATTRIBUTE)) else {
479            return None;
480        };
481
482        let Ok(ty) = ty.downcast_into_exact::<RustCodecType>() else {
483            return None;
484        };
485
486        let ty: &T = ty.get().downcast()?;
487
488        Some(with(ty))
489    }
490}
491
492impl DynCodecType for PyCodecClassAdapter {
493    type Codec = PyCodecAdapter;
494
495    fn codec_id(&self) -> &str {
496        &self.codec_id
497    }
498
499    fn codec_config_schema(&self) -> Schema {
500        (*self.codec_config_schema).clone()
501    }
502
503    fn codec_from_config<'de, D: Deserializer<'de>>(
504        &self,
505        config: D,
506    ) -> Result<Self::Codec, D::Error> {
507        Python::with_gil(|py| {
508            let config =
509                transcode(config, Pythonizer::new(py)).map_err(serde::de::Error::custom)?;
510            let config: Bound<PyDict> = config.extract().map_err(serde::de::Error::custom)?;
511
512            let codec = self
513                .class
514                .bind(py)
515                .codec_from_config(config.as_borrowed())
516                .map_err(serde::de::Error::custom)?;
517
518            Ok(PyCodecAdapter {
519                codec: codec.unbind(),
520                class: self.class.clone_ref(py),
521                codec_id: self.codec_id.clone(),
522                codec_config_schema: self.codec_config_schema.clone(),
523            })
524        })
525    }
526}