1use std::sync::Arc;
2
3use ndarray::{ArrayBase, DataMut, Dimension};
4use numcodecs::{
5 AnyArray, AnyArrayBase, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec,
6 DynCodecType,
7};
8use numpy::{Element, PyArray, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods};
9use pyo3::{
10 exceptions::{PyTypeError, PyValueError},
11 intern,
12 marker::Ungil,
13 prelude::*,
14 types::{IntoPyDict, PyDict, PyDictMethods},
15};
16use pythonize::{Depythonizer, Pythonizer};
17use schemars::Schema;
18use serde::{Deserializer, Serializer};
19use serde_transcode::transcode;
20
21use crate::{
22 export::{RustCodec, RustCodecType},
23 schema::schema_from_codec_class,
24 utils::numpy_asarray,
25 PyCodec, PyCodecClass, PyCodecClassMethods, PyCodecMethods, PyCodecRegistry,
26};
27
28pub struct PyCodecAdapter {
30 codec: Py<PyCodec>,
31 class: Py<PyCodecClass>,
32 codec_id: Arc<String>,
33 codec_config_schema: Arc<Schema>,
34}
35
36impl PyCodecAdapter {
37 pub fn from_registry_with_config<'de, D: Deserializer<'de>>(
48 config: D,
49 ) -> Result<Self, D::Error> {
50 Python::with_gil(|py| {
51 let config = transcode(config, Pythonizer::new(py))?;
52 let config: Bound<PyDict> = config.extract()?;
53
54 let codec = PyCodecRegistry::get_codec(config.as_borrowed())?;
55
56 Self::from_codec(codec)
57 })
58 .map_err(serde::de::Error::custom)
59 }
60
61 pub fn from_codec(codec: Bound<PyCodec>) -> Result<Self, PyErr> {
67 let class = codec.class();
68 let codec_id = class.codec_id()?;
69 let codec_config_schema = schema_from_codec_class(class.py(), &class).map_err(|err| {
70 PyTypeError::new_err(format!(
71 "failed to extract the {codec_id} codec config schema: {err}"
72 ))
73 })?;
74
75 Ok(Self {
76 codec: codec.unbind(),
77 class: class.unbind(),
78 codec_id: Arc::new(codec_id),
79 codec_config_schema: Arc::new(codec_config_schema),
80 })
81 }
82
83 #[must_use]
85 pub fn as_codec<'py>(&self, py: Python<'py>) -> &Bound<'py, PyCodec> {
86 self.codec.bind(py)
87 }
88
89 #[must_use]
91 pub fn into_codec(self, py: Python) -> Bound<PyCodec> {
92 self.codec.into_bound(py)
93 }
94
95 pub fn try_clone(&self, py: Python) -> Result<Self, PyErr> {
102 let config = self.codec.bind(py).get_config()?;
103
104 let _ = config.del_item(intern!(py, "id"));
106
107 let codec = self
108 .class
109 .bind(py)
110 .codec_from_config(config.as_borrowed())?;
111
112 Ok(Self {
113 codec: codec.unbind(),
114 class: self.class.clone_ref(py),
115 codec_id: self.codec_id.clone(),
116 codec_config_schema: self.codec_config_schema.clone(),
117 })
118 }
119
120 pub fn with_downcast<T: DynCodec, O: Ungil>(
127 py: Python,
128 codec: &Bound<PyCodec>,
129 with: impl Send + Ungil + for<'a> FnOnce(&'a T) -> O,
130 ) -> Option<O> {
131 let Ok(codec) = codec.downcast::<RustCodec>() else {
132 return None;
133 };
134
135 let codec = codec.get().downcast()?;
136
137 Some(py.allow_threads(|| with(codec)))
140 }
141}
142
143impl Codec for PyCodecAdapter {
144 type Error = PyErr;
145
146 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
147 Python::with_gil(|py| {
148 self.with_any_array_view_as_ndarray(py, &data.view(), |data| {
149 let encoded = self.codec.bind(py).encode(data.as_borrowed())?;
150
151 Self::any_array_from_ndarray_like(py, encoded.as_borrowed())
152 })
153 })
154 }
155
156 fn decode(&self, encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
157 Python::with_gil(|py| {
158 self.with_any_array_view_as_ndarray(py, &encoded.view(), |encoded| {
159 let decoded = self.codec.bind(py).decode(encoded.as_borrowed(), None)?;
160
161 Self::any_array_from_ndarray_like(py, decoded.as_borrowed())
162 })
163 })
164 }
165
166 fn decode_into(
167 &self,
168 encoded: AnyArrayView,
169 mut decoded: AnyArrayViewMut,
170 ) -> Result<(), Self::Error> {
171 Python::with_gil(|py| {
172 let decoded_out = self.with_any_array_view_as_ndarray(py, &encoded, |encoded| {
173 self.with_any_array_view_mut_as_ndarray(py, &mut decoded, |decoded_in| {
174 let decoded_out = self
175 .codec
176 .bind(py)
177 .decode(encoded.as_borrowed(), Some(decoded_in.as_borrowed()))?;
178
179 if decoded_out.is(decoded_in) {
181 Ok(Ok(()))
182 } else {
183 Ok(Err(decoded_out.unbind()))
184 }
185 })
186 })?;
187 let decoded_out = match decoded_out {
188 Ok(()) => return Ok(()),
189 Err(decoded_out) => decoded_out.into_bound(py),
190 };
191
192 Self::copy_into_any_array_view_mut_from_ndarray_like(
194 py,
195 &mut decoded,
196 decoded_out.as_borrowed(),
197 )
198 })
199 }
200}
201
202impl PyCodecAdapter {
203 fn with_any_array_view_as_ndarray<T>(
204 &self,
205 py: Python,
206 view: &AnyArrayView,
207 with: impl for<'a> FnOnce(&'a Bound<PyAny>) -> Result<T, PyErr>,
208 ) -> Result<T, PyErr> {
209 let this = self.codec.bind(py).clone().into_any();
210
211 #[expect(unsafe_code)] let ndarray = unsafe {
213 match &view {
214 AnyArrayBase::U8(v) => PyArray::borrow_from_array(v, this).into_any(),
215 AnyArrayBase::U16(v) => PyArray::borrow_from_array(v, this).into_any(),
216 AnyArrayBase::U32(v) => PyArray::borrow_from_array(v, this).into_any(),
217 AnyArrayBase::U64(v) => PyArray::borrow_from_array(v, this).into_any(),
218 AnyArrayBase::I8(v) => PyArray::borrow_from_array(v, this).into_any(),
219 AnyArrayBase::I16(v) => PyArray::borrow_from_array(v, this).into_any(),
220 AnyArrayBase::I32(v) => PyArray::borrow_from_array(v, this).into_any(),
221 AnyArrayBase::I64(v) => PyArray::borrow_from_array(v, this).into_any(),
222 AnyArrayBase::F32(v) => PyArray::borrow_from_array(v, this).into_any(),
223 AnyArrayBase::F64(v) => PyArray::borrow_from_array(v, this).into_any(),
224 _ => {
225 return Err(PyTypeError::new_err(format!(
226 "unsupported type {} of read-only array view",
227 view.dtype()
228 )))
229 }
230 }
231 };
232
233 ndarray.call_method(
235 intern!(py, "setflags"),
236 (),
237 Some(&[(intern!(py, "write"), false)].into_py_dict(py)?),
238 )?;
239 let view = ndarray.call_method0(intern!(py, "view"))?;
240
241 with(&view)
242 }
243
244 fn with_any_array_view_mut_as_ndarray<T>(
245 &self,
246 py: Python,
247 view_mut: &mut AnyArrayViewMut,
248 with: impl for<'a> FnOnce(&'a Bound<PyAny>) -> Result<T, PyErr>,
249 ) -> Result<T, PyErr> {
250 let this = self.codec.bind(py).clone().into_any();
251
252 #[expect(unsafe_code)] let ndarray = unsafe {
254 match &view_mut {
255 AnyArrayBase::U8(v) => PyArray::borrow_from_array(v, this).into_any(),
256 AnyArrayBase::U16(v) => PyArray::borrow_from_array(v, this).into_any(),
257 AnyArrayBase::U32(v) => PyArray::borrow_from_array(v, this).into_any(),
258 AnyArrayBase::U64(v) => PyArray::borrow_from_array(v, this).into_any(),
259 AnyArrayBase::I8(v) => PyArray::borrow_from_array(v, this).into_any(),
260 AnyArrayBase::I16(v) => PyArray::borrow_from_array(v, this).into_any(),
261 AnyArrayBase::I32(v) => PyArray::borrow_from_array(v, this).into_any(),
262 AnyArrayBase::I64(v) => PyArray::borrow_from_array(v, this).into_any(),
263 AnyArrayBase::F32(v) => PyArray::borrow_from_array(v, this).into_any(),
264 AnyArrayBase::F64(v) => PyArray::borrow_from_array(v, this).into_any(),
265 _ => {
266 return Err(PyTypeError::new_err(format!(
267 "unsupported type {} of read-only array view",
268 view_mut.dtype()
269 )))
270 }
271 }
272 };
273
274 with(&ndarray)
275 }
276
277 fn any_array_from_ndarray_like(
278 py: Python,
279 array_like: Borrowed<PyAny>,
280 ) -> Result<AnyArray, PyErr> {
281 let ndarray = numpy_asarray(py, array_like)?;
282
283 let array = if let Ok(e) = ndarray.downcast::<PyArrayDyn<u8>>() {
284 AnyArrayBase::U8(e.try_readonly()?.to_owned_array())
285 } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u16>>() {
286 AnyArrayBase::U16(e.try_readonly()?.to_owned_array())
287 } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u32>>() {
288 AnyArrayBase::U32(e.try_readonly()?.to_owned_array())
289 } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<u64>>() {
290 AnyArrayBase::U64(e.try_readonly()?.to_owned_array())
291 } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i8>>() {
292 AnyArrayBase::I8(e.try_readonly()?.to_owned_array())
293 } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i16>>() {
294 AnyArrayBase::I16(e.try_readonly()?.to_owned_array())
295 } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i32>>() {
296 AnyArrayBase::I32(e.try_readonly()?.to_owned_array())
297 } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<i64>>() {
298 AnyArrayBase::I64(e.try_readonly()?.to_owned_array())
299 } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<f32>>() {
300 AnyArrayBase::F32(e.try_readonly()?.to_owned_array())
301 } else if let Ok(e) = ndarray.downcast::<PyArrayDyn<f64>>() {
302 AnyArrayBase::F64(e.try_readonly()?.to_owned_array())
303 } else {
304 return Err(PyTypeError::new_err(format!(
305 "unsupported dtype {} of array-like",
306 ndarray.dtype()
307 )));
308 };
309
310 Ok(array)
311 }
312
313 fn copy_into_any_array_view_mut_from_ndarray_like(
314 py: Python,
315 view_mut: &mut AnyArrayViewMut,
316 array_like: Borrowed<PyAny>,
317 ) -> Result<(), PyErr> {
318 fn shape_checked_assign<
319 T: Copy + Element,
320 S2: DataMut<Elem = T>,
321 D1: Dimension,
322 D2: Dimension,
323 >(
324 src: &Bound<PyArray<T, D1>>,
325 dst: &mut ArrayBase<S2, D2>,
326 ) -> Result<(), PyErr> {
327 #[expect(clippy::unit_arg)]
328 if src.shape() == dst.shape() {
329 Ok(dst.assign(&src.try_readonly()?.as_array()))
330 } else {
331 Err(PyValueError::new_err(format!(
332 "mismatching shape {:?} of array-like, expected {:?}",
333 src.shape(),
334 dst.shape(),
335 )))
336 }
337 }
338
339 let ndarray = numpy_asarray(py, array_like)?;
340
341 if let Ok(d) = ndarray.downcast::<PyArrayDyn<u8>>() {
342 if let AnyArrayBase::U8(ref mut view_mut) = view_mut {
343 return shape_checked_assign(d, view_mut);
344 }
345 } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u16>>() {
346 if let AnyArrayBase::U16(ref mut view_mut) = view_mut {
347 return shape_checked_assign(d, view_mut);
348 }
349 } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u32>>() {
350 if let AnyArrayBase::U32(ref mut view_mut) = view_mut {
351 return shape_checked_assign(d, view_mut);
352 }
353 } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<u64>>() {
354 if let AnyArrayBase::U64(ref mut view_mut) = view_mut {
355 return shape_checked_assign(d, view_mut);
356 }
357 } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i8>>() {
358 if let AnyArrayBase::I8(ref mut view_mut) = view_mut {
359 return shape_checked_assign(d, view_mut);
360 }
361 } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i16>>() {
362 if let AnyArrayBase::I16(ref mut view_mut) = view_mut {
363 return shape_checked_assign(d, view_mut);
364 }
365 } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i32>>() {
366 if let AnyArrayBase::I32(ref mut view_mut) = view_mut {
367 return shape_checked_assign(d, view_mut);
368 }
369 } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<i64>>() {
370 if let AnyArrayBase::I64(ref mut view_mut) = view_mut {
371 return shape_checked_assign(d, view_mut);
372 }
373 } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<f32>>() {
374 if let AnyArrayBase::F32(ref mut view_mut) = view_mut {
375 return shape_checked_assign(d, view_mut);
376 }
377 } else if let Ok(d) = ndarray.downcast::<PyArrayDyn<f64>>() {
378 if let AnyArrayBase::F64(ref mut view_mut) = view_mut {
379 return shape_checked_assign(d, view_mut);
380 }
381 } else {
382 return Err(PyTypeError::new_err(format!(
383 "unsupported dtype {} of array-like",
384 ndarray.dtype()
385 )));
386 }
387
388 Err(PyTypeError::new_err(format!(
389 "mismatching dtype {} of array-like, expected {}",
390 ndarray.dtype(),
391 view_mut.dtype(),
392 )))
393 }
394}
395
396impl Clone for PyCodecAdapter {
397 fn clone(&self) -> Self {
398 #[expect(clippy::expect_used)] Python::with_gil(|py| {
400 self.try_clone(py)
401 .expect("cloning a PyCodec should not fail")
402 })
403 }
404}
405
406impl DynCodec for PyCodecAdapter {
407 type Type = PyCodecClassAdapter;
408
409 fn ty(&self) -> Self::Type {
410 Python::with_gil(|py| PyCodecClassAdapter {
411 class: self.class.clone_ref(py),
412 codec_id: self.codec_id.clone(),
413 codec_config_schema: self.codec_config_schema.clone(),
414 })
415 }
416
417 fn get_config<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
418 Python::with_gil(|py| {
419 let config = self
420 .codec
421 .bind(py)
422 .get_config()
423 .map_err(serde::ser::Error::custom)?;
424
425 transcode(&mut Depythonizer::from_object(config.as_any()), serializer)
426 })
427 }
428}
429
430pub struct PyCodecClassAdapter {
432 class: Py<PyCodecClass>,
433 codec_id: Arc<String>,
434 codec_config_schema: Arc<Schema>,
435}
436
437impl PyCodecClassAdapter {
438 pub fn from_codec_class(class: Bound<PyCodecClass>) -> Result<Self, PyErr> {
444 let codec_id = class.codec_id()?;
445
446 let codec_config_schema = schema_from_codec_class(class.py(), &class).map_err(|err| {
447 PyTypeError::new_err(format!(
448 "failed to extract the {codec_id} codec config schema: {err}"
449 ))
450 })?;
451
452 Ok(Self {
453 class: class.unbind(),
454 codec_id: Arc::new(codec_id),
455 codec_config_schema: Arc::new(codec_config_schema),
456 })
457 }
458
459 #[must_use]
462 pub fn as_codec_class<'py>(&self, py: Python<'py>) -> &Bound<'py, PyCodecClass> {
463 self.class.bind(py)
464 }
465
466 #[must_use]
468 pub fn into_codec_class(self, py: Python) -> Bound<PyCodecClass> {
469 self.class.into_bound(py)
470 }
471
472 pub fn with_downcast<T: DynCodecType, O: Ungil>(
479 py: Python,
480 class: &Bound<PyCodecClass>,
481 with: impl Send + Ungil + for<'a> FnOnce(&'a T) -> O,
482 ) -> Option<O> {
483 let Ok(ty) = class.getattr(intern!(class.py(), RustCodec::TYPE_ATTRIBUTE)) else {
484 return None;
485 };
486
487 let Ok(ty) = ty.downcast_into_exact::<RustCodecType>() else {
488 return None;
489 };
490
491 let ty: &T = ty.get().downcast()?;
492
493 Some(py.allow_threads(|| with(ty)))
496 }
497}
498
499impl DynCodecType for PyCodecClassAdapter {
500 type Codec = PyCodecAdapter;
501
502 fn codec_id(&self) -> &str {
503 &self.codec_id
504 }
505
506 fn codec_config_schema(&self) -> Schema {
507 (*self.codec_config_schema).clone()
508 }
509
510 fn codec_from_config<'de, D: Deserializer<'de>>(
511 &self,
512 config: D,
513 ) -> Result<Self::Codec, D::Error> {
514 Python::with_gil(|py| {
515 let config =
516 transcode(config, Pythonizer::new(py)).map_err(serde::de::Error::custom)?;
517 let config: Bound<PyDict> = config.extract().map_err(serde::de::Error::custom)?;
518
519 let codec = self
520 .class
521 .bind(py)
522 .codec_from_config(config.as_borrowed())
523 .map_err(serde::de::Error::custom)?;
524
525 Ok(PyCodecAdapter {
526 codec: codec.unbind(),
527 class: self.class.clone_ref(py),
528 codec_id: self.codec_id.clone(),
529 codec_config_schema: self.codec_config_schema.clone(),
530 })
531 })
532 }
533}