numcodecs_python/
adapter.rs

1use std::sync::Arc;
2
3use ndarray::{ArrayBase, DataMut, Dimension};
4use numcodecs::{
5    AnyArray, AnyArrayBase, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec,
6    DynCodecType,
7};
8use numpy::{Element, PyArray, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods};
9use pyo3::{
10    exceptions::{PyTypeError, PyValueError},
11    intern,
12    marker::Ungil,
13    prelude::*,
14    types::{IntoPyDict, PyDict, PyDictMethods},
15};
16use pythonize::{Depythonizer, Pythonizer};
17use schemars::Schema;
18use serde::{Deserializer, Serializer};
19use serde_transcode::transcode;
20
21use crate::{
22    export::{RustCodec, RustCodecType},
23    schema::schema_from_codec_class,
24    utils::numpy_asarray,
25    PyCodec, PyCodecClass, PyCodecClassMethods, PyCodecMethods, PyCodecRegistry,
26};
27
28/// Wrapper around [`PyCodec`]s to use the [`Codec`] API.
29pub struct PyCodecAdapter {
30    codec: Py<PyCodec>,
31    class: Py<PyCodecClass>,
32    codec_id: Arc<String>,
33    codec_config_schema: Arc<Schema>,
34}
35
36impl PyCodecAdapter {
37    /// Instantiate a codec from the [`PyCodecRegistry`] with a serialized
38    /// `config`uration.
39    ///
40    /// The config *must* include the `id` field with the
41    /// [`PyCodecClassMethods::codec_id`].
42    ///
43    /// # Errors
44    ///
45    /// Errors if no codec with a matching `id` has been registered, or if
46    /// constructing the codec fails.
47    pub fn from_registry_with_config<'de, D: Deserializer<'de>>(
48        config: D,
49    ) -> Result<Self, D::Error> {
50        Python::with_gil(|py| {
51            let config = transcode(config, Pythonizer::new(py))?;
52            let config: Bound<PyDict> = config.extract()?;
53
54            let codec = PyCodecRegistry::get_codec(config.as_borrowed())?;
55
56            Self::from_codec(codec)
57        })
58        .map_err(serde::de::Error::custom)
59    }
60
61    /// Wraps a [`PyCodec`] to use the [`Codec`] API.
62    ///
63    /// # Errors
64    ///
65    /// Errors if the `codec`'s class does not provide an identifier.
66    pub fn from_codec(codec: Bound<PyCodec>) -> Result<Self, PyErr> {
67        let class = codec.class();
68        let codec_id = class.codec_id()?;
69        let codec_config_schema = schema_from_codec_class(class.py(), &class).map_err(|err| {
70            PyTypeError::new_err(format!(
71                "failed to extract the {codec_id} codec config schema: {err}"
72            ))
73        })?;
74
75        Ok(Self {
76            codec: codec.unbind(),
77            class: class.unbind(),
78            codec_id: Arc::new(codec_id),
79            codec_config_schema: Arc::new(codec_config_schema),
80        })
81    }
82
83    /// Access the wrapped [`PyCodec`] to use its [`PyCodecMethods`] API.
84    #[must_use]
85    pub fn as_codec<'py>(&self, py: Python<'py>) -> &Bound<'py, PyCodec> {
86        self.codec.bind(py)
87    }
88
89    /// Unwrap the [`PyCodec`] to use its [`PyCodecMethods`] API.
90    #[must_use]
91    pub fn into_codec(self, py: Python) -> Bound<PyCodec> {
92        self.codec.into_bound(py)
93    }
94
95    /// Try to [`clone`][`Clone::clone`] this codec.
96    ///
97    /// # Errors
98    ///
99    /// Errors if extracting this codec's config or creating a new codec from
100    /// the config fails.
101    pub fn try_clone(&self, py: Python) -> Result<Self, PyErr> {
102        let config = self.codec.bind(py).get_config()?;
103
104        // removing the `id` field may fail if the config doesn't contain it
105        let _ = config.del_item(intern!(py, "id"));
106
107        let codec = self
108            .class
109            .bind(py)
110            .codec_from_config(config.as_borrowed())?;
111
112        Ok(Self {
113            codec: codec.unbind(),
114            class: self.class.clone_ref(py),
115            codec_id: self.codec_id.clone(),
116            codec_config_schema: self.codec_config_schema.clone(),
117        })
118    }
119
120    /// If `codec` represents an exported [`DynCodec`] `T`, i.e. its class was
121    /// initially created with [`crate::export_codec_class`], the `with` closure
122    /// provides access to the instance of type `T`.
123    ///
124    /// If `codec` is not an instance of `T`, the `with` closure is *not* run
125    /// and `None` is returned.
126    pub fn with_downcast<T: DynCodec, O: Ungil>(
127        py: Python,
128        codec: &Bound<PyCodec>,
129        with: impl Send + Ungil + for<'a> FnOnce(&'a T) -> O,
130    ) -> Option<O> {
131        let Ok(codec) = codec.downcast::<RustCodec>() else {
132            return None;
133        };
134
135        let codec = codec.get().downcast()?;
136
137        // The `with` closure contains arbitrary Rust code and may block,
138        // which we cannot allow while holding the GIL
139        Some(py.allow_threads(|| with(codec)))
140    }
141}
142
143impl Codec for PyCodecAdapter {
144    type Error = PyErr;
145
146    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
147        Python::with_gil(|py| {
148            self.with_any_array_view_as_ndarray(py, &data.view(), |data| {
149                let encoded = self.codec.bind(py).encode(data.as_borrowed())?;
150
151                Self::any_array_from_ndarray_like(py, encoded.as_borrowed())
152            })
153        })
154    }
155
156    fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
157        Python::with_gil(|py| {
158            self.with_any_array_view_as_ndarray(py, &encoded.view(), |encoded| {
159                let decoded = self.codec.bind(py).decode(encoded.as_borrowed(), None)?;
160
161                Self::any_array_from_ndarray_like(py, decoded.as_borrowed())
162            })
163        })
164    }
165
166    fn decode_into(
167        &self,
168        encoded: AnyArrayView,
169        mut decoded: AnyArrayViewMut,
170    ) -> Result<(), Self::Error> {
171        Python::with_gil(|py| {
172            let decoded_out = self.with_any_array_view_as_ndarray(py, &encoded, |encoded| {
173                self.with_any_array_view_mut_as_ndarray(py, &mut decoded, |decoded_in| {
174                    let decoded_out = self
175                        .codec
176                        .bind(py)
177                        .decode(encoded.as_borrowed(), Some(decoded_in.as_borrowed()))?;
178
179                    // Ideally, all codecs should just use the provided out array
180                    if decoded_out.is(decoded_in) {
181                        Ok(Ok(()))
182                    } else {
183                        Ok(Err(decoded_out.unbind()))
184                    }
185                })
186            })?;
187            let decoded_out = match decoded_out {
188                Ok(()) => return Ok(()),
189                Err(decoded_out) => decoded_out.into_bound(py),
190            };
191
192            // Otherwise, we force-copy the output into the decoded array
193            Self::copy_into_any_array_view_mut_from_ndarray_like(
194                py,
195                &mut decoded,
196                decoded_out.as_borrowed(),
197            )
198        })
199    }
200}
201
202impl PyCodecAdapter {
203    fn with_any_array_view_as_ndarray<T>(
204        &self,
205        py: Python,
206        view: &AnyArrayView,
207        with: impl for<'a> FnOnce(&'a Bound<PyAny>) -> Result<T, PyErr>,
208    ) -> Result<T, PyErr> {
209        let this = self.codec.bind(py).clone().into_any();
210
211        #[expect(unsafe_code)] // FIXME: we trust Python code to not store this array
212        let ndarray = unsafe {
213            match &view {
214                AnyArrayBase::U8(v) => PyArray::borrow_from_array(v, this).into_any(),
215                AnyArrayBase::U16(v) => PyArray::borrow_from_array(v, this).into_any(),
216                AnyArrayBase::U32(v) => PyArray::borrow_from_array(v, this).into_any(),
217                AnyArrayBase::U64(v) => PyArray::borrow_from_array(v, this).into_any(),
218                AnyArrayBase::I8(v) => PyArray::borrow_from_array(v, this).into_any(),
219                AnyArrayBase::I16(v) => PyArray::borrow_from_array(v, this).into_any(),
220                AnyArrayBase::I32(v) => PyArray::borrow_from_array(v, this).into_any(),
221                AnyArrayBase::I64(v) => PyArray::borrow_from_array(v, this).into_any(),
222                AnyArrayBase::F32(v) => PyArray::borrow_from_array(v, this).into_any(),
223                AnyArrayBase::F64(v) => PyArray::borrow_from_array(v, this).into_any(),
224                _ => {
225                    return Err(PyTypeError::new_err(format!(
226                        "unsupported type {} of read-only array view",
227                        view.dtype()
228                    )))
229                }
230            }
231        };
232
233        // create a fully-immutable view of the data that is safe to pass to Python
234        ndarray.call_method(
235            intern!(py, "setflags"),
236            (),
237            Some(&[(intern!(py, "write"), false)].into_py_dict(py)?),
238        )?;
239        let view = ndarray.call_method0(intern!(py, "view"))?;
240
241        with(&view)
242    }
243
244    fn with_any_array_view_mut_as_ndarray<T>(
245        &self,
246        py: Python,
247        view_mut: &mut AnyArrayViewMut,
248        with: impl for<'a> FnOnce(&'a Bound<PyAny>) -> Result<T, PyErr>,
249    ) -> Result<T, PyErr> {
250        let this = self.codec.bind(py).clone().into_any();
251
252        #[expect(unsafe_code)] // FIXME: we trust Python code to not store this array
253        let ndarray = unsafe {
254            match &view_mut {
255                AnyArrayBase::U8(v) => PyArray::borrow_from_array(v, this).into_any(),
256                AnyArrayBase::U16(v) => PyArray::borrow_from_array(v, this).into_any(),
257                AnyArrayBase::U32(v) => PyArray::borrow_from_array(v, this).into_any(),
258                AnyArrayBase::U64(v) => PyArray::borrow_from_array(v, this).into_any(),
259                AnyArrayBase::I8(v) => PyArray::borrow_from_array(v, this).into_any(),
260                AnyArrayBase::I16(v) => PyArray::borrow_from_array(v, this).into_any(),
261                AnyArrayBase::I32(v) => PyArray::borrow_from_array(v, this).into_any(),
262                AnyArrayBase::I64(v) => PyArray::borrow_from_array(v, this).into_any(),
263                AnyArrayBase::F32(v) => PyArray::borrow_from_array(v, this).into_any(),
264                AnyArrayBase::F64(v) => PyArray::borrow_from_array(v, this).into_any(),
265                _ => {
266                    return Err(PyTypeError::new_err(format!(
267                        "unsupported type {} of read-only array view",
268                        view_mut.dtype()
269                    )))
270                }
271            }
272        };
273
274        with(&ndarray)
275    }
276
277    fn any_array_from_ndarray_like(
278        py: Python,
279        array_like: Borrowed<PyAny>,
280    ) -> Result<AnyArray, PyErr> {
281        let ndarray = numpy_asarray(py, array_like)?;
282
283        let array = if let Ok(e) = ndarray.downcast::<PyArrayDyn<u8>>() {
284            AnyArrayBase::U8(e.try_readonly()?.to_owned_array())
285        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u16>>() {
286            AnyArrayBase::U16(e.try_readonly()?.to_owned_array())
287        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u32>>() {
288            AnyArrayBase::U32(e.try_readonly()?.to_owned_array())
289        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u64>>() {
290            AnyArrayBase::U64(e.try_readonly()?.to_owned_array())
291        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i8>>() {
292            AnyArrayBase::I8(e.try_readonly()?.to_owned_array())
293        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i16>>() {
294            AnyArrayBase::I16(e.try_readonly()?.to_owned_array())
295        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i32>>() {
296            AnyArrayBase::I32(e.try_readonly()?.to_owned_array())
297        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i64>>() {
298            AnyArrayBase::I64(e.try_readonly()?.to_owned_array())
299        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<f32>>() {
300            AnyArrayBase::F32(e.try_readonly()?.to_owned_array())
301        } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<f64>>() {
302            AnyArrayBase::F64(e.try_readonly()?.to_owned_array())
303        } else {
304            return Err(PyTypeError::new_err(format!(
305                "unsupported dtype {} of array-like",
306                ndarray.dtype()
307            )));
308        };
309
310        Ok(array)
311    }
312
313    fn copy_into_any_array_view_mut_from_ndarray_like(
314        py: Python,
315        view_mut: &mut AnyArrayViewMut,
316        array_like: Borrowed<PyAny>,
317    ) -> Result<(), PyErr> {
318        fn shape_checked_assign<
319            T: Copy + Element,
320            S2: DataMut<Elem = T>,
321            D1: Dimension,
322            D2: Dimension,
323        >(
324            src: &Bound<PyArray<T, D1>>,
325            dst: &mut ArrayBase<S2, D2>,
326        ) -> Result<(), PyErr> {
327            #[expect(clippy::unit_arg)]
328            if src.shape() == dst.shape() {
329                Ok(dst.assign(&src.try_readonly()?.as_array()))
330            } else {
331                Err(PyValueError::new_err(format!(
332                    "mismatching shape {:?} of array-like, expected {:?}",
333                    src.shape(),
334                    dst.shape(),
335                )))
336            }
337        }
338
339        let ndarray = numpy_asarray(py, array_like)?;
340
341        if let Ok(d) = ndarray.downcast::<PyArrayDyn<u8>>() {
342            if let AnyArrayBase::U8(ref mut view_mut) = view_mut {
343                return shape_checked_assign(d, view_mut);
344            }
345        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u16>>() {
346            if let AnyArrayBase::U16(ref mut view_mut) = view_mut {
347                return shape_checked_assign(d, view_mut);
348            }
349        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u32>>() {
350            if let AnyArrayBase::U32(ref mut view_mut) = view_mut {
351                return shape_checked_assign(d, view_mut);
352            }
353        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u64>>() {
354            if let AnyArrayBase::U64(ref mut view_mut) = view_mut {
355                return shape_checked_assign(d, view_mut);
356            }
357        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i8>>() {
358            if let AnyArrayBase::I8(ref mut view_mut) = view_mut {
359                return shape_checked_assign(d, view_mut);
360            }
361        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i16>>() {
362            if let AnyArrayBase::I16(ref mut view_mut) = view_mut {
363                return shape_checked_assign(d, view_mut);
364            }
365        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i32>>() {
366            if let AnyArrayBase::I32(ref mut view_mut) = view_mut {
367                return shape_checked_assign(d, view_mut);
368            }
369        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i64>>() {
370            if let AnyArrayBase::I64(ref mut view_mut) = view_mut {
371                return shape_checked_assign(d, view_mut);
372            }
373        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<f32>>() {
374            if let AnyArrayBase::F32(ref mut view_mut) = view_mut {
375                return shape_checked_assign(d, view_mut);
376            }
377        } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<f64>>() {
378            if let AnyArrayBase::F64(ref mut view_mut) = view_mut {
379                return shape_checked_assign(d, view_mut);
380            }
381        } else {
382            return Err(PyTypeError::new_err(format!(
383                "unsupported dtype {} of array-like",
384                ndarray.dtype()
385            )));
386        }
387
388        Err(PyTypeError::new_err(format!(
389            "mismatching dtype {} of array-like, expected {}",
390            ndarray.dtype(),
391            view_mut.dtype(),
392        )))
393    }
394}
395
396impl Clone for PyCodecAdapter {
397    fn clone(&self) -> Self {
398        #[expect(clippy::expect_used)] // clone is *not* fallible
399        Python::with_gil(|py| {
400            self.try_clone(py)
401                .expect("cloning a PyCodec should not fail")
402        })
403    }
404}
405
406impl DynCodec for PyCodecAdapter {
407    type Type = PyCodecClassAdapter;
408
409    fn ty(&self) -> Self::Type {
410        Python::with_gil(|py| PyCodecClassAdapter {
411            class: self.class.clone_ref(py),
412            codec_id: self.codec_id.clone(),
413            codec_config_schema: self.codec_config_schema.clone(),
414        })
415    }
416
417    fn get_config<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
418        Python::with_gil(|py| {
419            let config = self
420                .codec
421                .bind(py)
422                .get_config()
423                .map_err(serde::ser::Error::custom)?;
424
425            transcode(&mut Depythonizer::from_object(config.as_any()), serializer)
426        })
427    }
428}
429
430/// Wrapper around [`PyCodecClass`]es to use the [`DynCodecType`] API.
431pub struct PyCodecClassAdapter {
432    class: Py<PyCodecClass>,
433    codec_id: Arc<String>,
434    codec_config_schema: Arc<Schema>,
435}
436
437impl PyCodecClassAdapter {
438    /// Wraps a [`PyCodecClass`] to use the [`DynCodecType`] API.
439    ///
440    /// # Errors
441    ///
442    /// Errors if the codec `class` does not provide an identifier.
443    pub fn from_codec_class(class: Bound<PyCodecClass>) -> Result<Self, PyErr> {
444        let codec_id = class.codec_id()?;
445
446        let codec_config_schema = schema_from_codec_class(class.py(), &class).map_err(|err| {
447            PyTypeError::new_err(format!(
448                "failed to extract the {codec_id} codec config schema: {err}"
449            ))
450        })?;
451
452        Ok(Self {
453            class: class.unbind(),
454            codec_id: Arc::new(codec_id),
455            codec_config_schema: Arc::new(codec_config_schema),
456        })
457    }
458
459    /// Access the wrapped [`PyCodecClass`] to use its [`PyCodecClassMethods`]
460    /// API.
461    #[must_use]
462    pub fn as_codec_class<'py>(&self, py: Python<'py>) -> &Bound<'py, PyCodecClass> {
463        self.class.bind(py)
464    }
465
466    /// Unwrap the [`PyCodecClass`] to use its [`PyCodecClassMethods`] API.
467    #[must_use]
468    pub fn into_codec_class(self, py: Python) -> Bound<PyCodecClass> {
469        self.class.into_bound(py)
470    }
471
472    /// If `class` represents an exported [`DynCodecType`] `T`, i.e. it was
473    /// initially created with [`crate::export_codec_class`], the `with` closure
474    /// provides access to the instance of type `T`.
475    ///
476    /// If `class` is not an instance of `T`, the `with` closure is *not* run
477    /// and `None` is returned.
478    pub fn with_downcast<T: DynCodecType, O: Ungil>(
479        py: Python,
480        class: &Bound<PyCodecClass>,
481        with: impl Send + Ungil + for<'a> FnOnce(&'a T) -> O,
482    ) -> Option<O> {
483        let Ok(ty) = class.getattr(intern!(class.py(), RustCodec::TYPE_ATTRIBUTE)) else {
484            return None;
485        };
486
487        let Ok(ty) = ty.downcast_into_exact::<RustCodecType>() else {
488            return None;
489        };
490
491        let ty: &T = ty.get().downcast()?;
492
493        // The `with` closure contains arbitrary Rust code and may block,
494        // which we cannot allow while holding the GIL
495        Some(py.allow_threads(|| with(ty)))
496    }
497}
498
499impl DynCodecType for PyCodecClassAdapter {
500    type Codec = PyCodecAdapter;
501
502    fn codec_id(&self) -> &str {
503        &self.codec_id
504    }
505
506    fn codec_config_schema(&self) -> Schema {
507        (*self.codec_config_schema).clone()
508    }
509
510    fn codec_from_config<'de, D: Deserializer<'de>>(
511        &self,
512        config: D,
513    ) -> Result<Self::Codec, D::Error> {
514        Python::with_gil(|py| {
515            let config =
516                transcode(config, Pythonizer::new(py)).map_err(serde::de::Error::custom)?;
517            let config: Bound<PyDict> = config.extract().map_err(serde::de::Error::custom)?;
518
519            let codec = self
520                .class
521                .bind(py)
522                .codec_from_config(config.as_borrowed())
523                .map_err(serde::de::Error::custom)?;
524
525            Ok(PyCodecAdapter {
526                codec: codec.unbind(),
527                class: self.class.clone_ref(py),
528                codec_id: self.codec_id.clone(),
529                codec_config_schema: self.codec_config_schema.clone(),
530            })
531        })
532    }
533}