1use std::{any::Any, ffi::CString};
2
3use ndarray::{ArrayViewD, ArrayViewMutD, CowArray};
4use numcodecs::{
5 AnyArray, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, DynCodec, DynCodecType,
6};
7use numpy::{
8 IxDyn, PyArray, PyArrayDescrMethods, PyArrayDyn, PyArrayMethods, PyUntypedArrayMethods,
9};
10use pyo3::{
11 exceptions::PyTypeError,
12 intern,
13 prelude::*,
14 types::{IntoPyDict, PyDict, PyString, PyType},
15 PyTypeInfo,
16};
17use pyo3_error::PyErrChain;
18use pythonize::{pythonize, Depythonizer, Pythonizer};
19
20use crate::{
21 schema::{docs_from_schema, signature_from_schema},
22 utils::numpy_asarray,
23 PyCodec, PyCodecClass, PyCodecClassAdapter, PyCodecRegistry,
24};
25
26pub fn export_codec_class<'py, T: DynCodecType>(
34 py: Python<'py>,
35 ty: T,
36 module: Borrowed<'_, 'py, PyModule>,
37) -> Result<Bound<'py, PyCodecClass>, PyErr> {
38 let codec_id = String::from(ty.codec_id());
39 let codec_class_name = convert_case::Casing::to_case(&codec_id, convert_case::Case::Pascal);
40
41 let codec_class: Bound<PyCodecClass> =
42 if let Some(adapter) = (&ty as &dyn Any).downcast_ref::<PyCodecClassAdapter>() {
44 adapter.as_codec_class(py).clone()
45 } else {
46 let codec_config_schema = ty.codec_config_schema();
47
48 let codec_class_bases = (
49 RustCodec::type_object(py),
50 PyCodec::type_object(py),
51 );
52
53 let codec_class_namespace = [
54 (intern!(py, "__module__"), module.name()?.into_any()),
55 (
56 intern!(py, "__doc__"),
57 docs_from_schema(&codec_config_schema).into_pyobject(py)?,
58 ),
59 (
60 intern!(py, RustCodec::TYPE_ATTRIBUTE),
61 Bound::new(py, RustCodecType { ty: Box::new(ty) })?.into_any(),
62 ),
63 (
64 intern!(py, "codec_id"),
65 PyString::new(py, &codec_id).into_any(),
66 ),
67 (
68 intern!(py, RustCodec::SCHEMA_ATTRIBUTE),
69 pythonize(py, &codec_config_schema)?,
70 ),
71 (
72 intern!(py, "__init__"),
73 py.eval(&CString::new(format!(
74 "lambda {}: None",
75 signature_from_schema(&codec_config_schema),
76 ))?, None, None)?,
77 ),
78 ]
79 .into_py_dict(py)?;
80
81 PyType::type_object(py)
82 .call1((&codec_class_name, codec_class_bases, codec_class_namespace))?
83 .extract()?
84 };
85
86 module.add(codec_class_name.as_str(), &codec_class)?;
87
88 PyCodecRegistry::register_codec(codec_class.as_borrowed(), None)?;
89
90 Ok(codec_class)
91}
92
93#[expect(clippy::redundant_pub_crate)]
94#[pyclass(frozen, module = "numcodecs._rust", name = "_RustCodecType")]
95pub(crate) struct RustCodecType {
97 ty: Box<dyn 'static + Send + Sync + AnyCodecType>,
98}
99
100impl RustCodecType {
101 pub fn downcast<T: DynCodecType>(&self) -> Option<&T> {
102 self.ty.as_any().downcast_ref()
103 }
104}
105
106trait AnyCodec {
107 fn encode(&self, py: Python, data: AnyCowArray) -> Result<AnyArray, PyErr>;
108
109 fn decode(&self, py: Python, encoded: AnyCowArray) -> Result<AnyArray, PyErr>;
110
111 fn decode_into(
112 &self,
113 py: Python,
114 encoded: AnyArrayView,
115 decoded: AnyArrayViewMut,
116 ) -> Result<(), PyErr>;
117
118 fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr>;
119
120 fn as_any(&self) -> &dyn Any;
121}
122
123impl<T: DynCodec> AnyCodec for T {
124 fn encode(&self, py: Python, data: AnyCowArray) -> Result<AnyArray, PyErr> {
125 <T as Codec>::encode(self, data).map_err(|err| PyErrChain::pyerr_from_err(py, err))
126 }
127
128 fn decode(&self, py: Python, encoded: AnyCowArray) -> Result<AnyArray, PyErr> {
129 <T as Codec>::decode(self, encoded).map_err(|err| PyErrChain::pyerr_from_err(py, err))
130 }
131
132 fn decode_into(
133 &self,
134 py: Python,
135 encoded: AnyArrayView,
136 decoded: AnyArrayViewMut,
137 ) -> Result<(), PyErr> {
138 <T as Codec>::decode_into(self, encoded, decoded)
139 .map_err(|err| PyErrChain::pyerr_from_err(py, err))
140 }
141
142 fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr> {
143 <T as DynCodec>::get_config(self, Pythonizer::new(py))?.extract()
144 }
145
146 fn as_any(&self) -> &dyn Any {
147 self
148 }
149}
150
151trait AnyCodecType {
152 fn codec_from_config<'py>(
153 &self,
154 config: Bound<'py, PyDict>,
155 ) -> Result<Box<dyn 'static + Send + Sync + AnyCodec>, PyErr>;
156
157 fn as_any(&self) -> &dyn Any;
158}
159
160impl<T: DynCodecType> AnyCodecType for T {
161 fn codec_from_config<'py>(
162 &self,
163 config: Bound<'py, PyDict>,
164 ) -> Result<Box<dyn 'static + Send + Sync + AnyCodec>, PyErr> {
165 match <T as DynCodecType>::codec_from_config(
166 self,
167 &mut Depythonizer::from_object(config.as_any()),
168 ) {
169 Ok(codec) => Ok(Box::new(codec)),
170 Err(err) => Err(err.into()),
171 }
172 }
173
174 fn as_any(&self) -> &dyn Any {
175 self
176 }
177}
178
179#[expect(clippy::redundant_pub_crate)]
180#[pyclass(subclass, frozen, module = "numcodecs._rust")]
181pub(crate) struct RustCodec {
185 cls_module: String,
186 cls_name: String,
187 codec: Box<dyn 'static + Send + Sync + AnyCodec>,
188}
189
190impl RustCodec {
191 pub const SCHEMA_ATTRIBUTE: &'static str = "__schema__";
192 pub const TYPE_ATTRIBUTE: &'static str = "_ty";
193
194 pub fn downcast<T: DynCodec>(&self) -> Option<&T> {
195 self.codec.as_any().downcast_ref()
196 }
197}
198
199#[pymethods]
200impl RustCodec {
201 #[new]
202 #[classmethod]
203 #[pyo3(signature = (**kwargs))]
204 fn new<'py>(
205 cls: &Bound<'py, PyType>,
206 py: Python<'py>,
207 kwargs: Option<Bound<'py, PyDict>>,
208 ) -> Result<Self, PyErr> {
209 let cls: &Bound<PyCodecClass> = cls.downcast()?;
210 let cls_module: String = cls.getattr(intern!(py, "__module__"))?.extract()?;
211 let cls_name: String = cls.getattr(intern!(py, "__name__"))?.extract()?;
212
213 let ty: Bound<RustCodecType> = cls
214 .getattr(intern!(py, RustCodec::TYPE_ATTRIBUTE))
215 .map_err(|_| {
216 PyTypeError::new_err(format!(
217 "{cls_module}.{cls_name} is not linked to a Rust codec type"
218 ))
219 })?
220 .extract()?;
221 let ty: PyRef<RustCodecType> = ty.try_borrow()?;
222
223 let codec = ty
224 .ty
225 .codec_from_config(kwargs.unwrap_or_else(|| PyDict::new(py)))?;
226
227 Ok(Self {
228 cls_module,
229 cls_name,
230 codec,
231 })
232 }
233
234 fn encode<'py>(
248 &self,
249 py: Python<'py>,
250 buf: &Bound<'py, PyAny>,
251 ) -> Result<Bound<'py, PyAny>, PyErr> {
252 self.process(
253 py,
254 buf.as_borrowed(),
255 AnyCodec::encode,
256 &format!("{}.{}::encode", self.cls_module, self.cls_name),
257 )
258 }
259
260 #[pyo3(signature = (buf, out=None))]
261 fn decode<'py>(
278 &self,
279 py: Python<'py>,
280 buf: &Bound<'py, PyAny>,
281 out: Option<Bound<'py, PyAny>>,
282 ) -> Result<Bound<'py, PyAny>, PyErr> {
283 let class_method = &format!("{}.{}::decode", self.cls_module, self.cls_name);
284 if let Some(out) = out {
285 self.process_into(
286 py,
287 buf.as_borrowed(),
288 out.as_borrowed(),
289 AnyCodec::decode_into,
290 class_method,
291 )?;
292 Ok(out)
293 } else {
294 self.process(py, buf.as_borrowed(), AnyCodec::decode, class_method)
295 }
296 }
297
298 fn get_config<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyDict>, PyErr> {
308 self.codec.get_config(py)
309 }
310
311 #[classmethod]
312 fn from_config<'py>(
324 cls: &Bound<'py, PyType>,
325 config: &Bound<'py, PyDict>,
326 ) -> Result<Bound<'py, PyCodec>, PyErr> {
327 let cls: Bound<PyCodecClass> = cls.extract()?;
328
329 cls.call((), Some(config))?.extract()
331 }
332
333 fn __repr__(this: PyRef<Self>, py: Python) -> Result<String, PyErr> {
334 let config = this.get_config(py)?;
335 let Ok(py_this) = this.into_pyobject(py);
336
337 let mut repr = py_this.get_type().name()?.to_cow()?.into_owned();
338 repr.push('(');
339
340 let mut first = true;
341
342 for (name, value) in config.iter() {
343 let name: String = name.extract()?;
344
345 if name == "id" {
346 continue;
348 }
349
350 let value_repr: String = value.repr()?.extract()?;
351
352 if !first {
353 repr.push_str(", ");
354 }
355 first = false;
356
357 repr.push_str(&name);
358 repr.push('=');
359 repr.push_str(&value_repr);
360 }
361
362 repr.push(')');
363
364 Ok(repr)
365 }
366}
367
368impl RustCodec {
369 fn process<'py>(
370 &self,
371 py: Python<'py>,
372 buf: Borrowed<'_, 'py, PyAny>,
373 process: impl FnOnce(
374 &(dyn 'static + Send + Sync + AnyCodec),
375 Python,
376 AnyCowArray,
377 ) -> Result<AnyArray, PyErr>,
378 class_method: &str,
379 ) -> Result<Bound<'py, PyAny>, PyErr> {
380 Self::with_pyarraylike_as_cow(py, buf, class_method, |data| {
381 let processed = process(&*self.codec, py, data)?;
382 Self::any_array_into_pyarray(py, processed, class_method)
383 })
384 }
385
386 fn process_into<'py>(
387 &self,
388 py: Python<'py>,
389 buf: Borrowed<'_, 'py, PyAny>,
390 out: Borrowed<'_, 'py, PyAny>,
391 process: impl FnOnce(
392 &(dyn 'static + Send + Sync + AnyCodec),
393 Python,
394 AnyArrayView,
395 AnyArrayViewMut,
396 ) -> Result<(), PyErr>,
397 class_method: &str,
398 ) -> Result<(), PyErr> {
399 Self::with_pyarraylike_as_view(py, buf, class_method, |data| {
400 Self::with_pyarraylike_as_view_mut(py, out, class_method, |data_out| {
401 process(&*self.codec, py, data, data_out)
402 })
403 })
404 }
405
406 fn with_pyarraylike_as_cow<'py, O>(
407 py: Python<'py>,
408 buf: Borrowed<'_, 'py, PyAny>,
409 class_method: &str,
410 with: impl for<'a> FnOnce(AnyCowArray<'a>) -> Result<O, PyErr>,
411 ) -> Result<O, PyErr> {
412 fn with_pyarraylike_as_cow_inner<T: numpy::Element, O>(
413 data: Borrowed<PyArrayDyn<T>>,
414 with: impl for<'a> FnOnce(CowArray<'a, T, IxDyn>) -> Result<O, PyErr>,
415 ) -> Result<O, PyErr> {
416 let readonly_data = data.try_readonly()?;
417 with(readonly_data.as_array().into())
418 }
419
420 let data = numpy_asarray(py, buf)?;
421 let dtype = data.dtype();
422
423 if dtype.is_equiv_to(&numpy::dtype::<u8>(py)) {
424 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u8>>()?.into(), |a| {
425 with(AnyCowArray::U8(a))
426 })
427 } else if dtype.is_equiv_to(&numpy::dtype::<u16>(py)) {
428 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u16>>()?.into(), |a| {
429 with(AnyCowArray::U16(a))
430 })
431 } else if dtype.is_equiv_to(&numpy::dtype::<u32>(py)) {
432 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u32>>()?.into(), |a| {
433 with(AnyCowArray::U32(a))
434 })
435 } else if dtype.is_equiv_to(&numpy::dtype::<u64>(py)) {
436 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<u64>>()?.into(), |a| {
437 with(AnyCowArray::U64(a))
438 })
439 } else if dtype.is_equiv_to(&numpy::dtype::<i8>(py)) {
440 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i8>>()?.into(), |a| {
441 with(AnyCowArray::I8(a))
442 })
443 } else if dtype.is_equiv_to(&numpy::dtype::<i16>(py)) {
444 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i16>>()?.into(), |a| {
445 with(AnyCowArray::I16(a))
446 })
447 } else if dtype.is_equiv_to(&numpy::dtype::<i32>(py)) {
448 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i32>>()?.into(), |a| {
449 with(AnyCowArray::I32(a))
450 })
451 } else if dtype.is_equiv_to(&numpy::dtype::<i64>(py)) {
452 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<i64>>()?.into(), |a| {
453 with(AnyCowArray::I64(a))
454 })
455 } else if dtype.is_equiv_to(&numpy::dtype::<f32>(py)) {
456 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<f32>>()?.into(), |a| {
457 with(AnyCowArray::F32(a))
458 })
459 } else if dtype.is_equiv_to(&numpy::dtype::<f64>(py)) {
460 with_pyarraylike_as_cow_inner(data.downcast::<PyArrayDyn<f64>>()?.into(), |a| {
461 with(AnyCowArray::F64(a))
462 })
463 } else {
464 Err(PyTypeError::new_err(format!(
465 "{class_method} received buffer of unsupported dtype `{dtype}`",
466 )))
467 }
468 }
469
470 fn with_pyarraylike_as_view<'py, O>(
471 py: Python<'py>,
472 buf: Borrowed<'_, 'py, PyAny>,
473 class_method: &str,
474 with: impl for<'a> FnOnce(AnyArrayView<'a>) -> Result<O, PyErr>,
475 ) -> Result<O, PyErr> {
476 fn with_pyarraylike_as_view_inner<T: numpy::Element, O>(
477 data: Borrowed<PyArrayDyn<T>>,
478 with: impl for<'a> FnOnce(ArrayViewD<'a, T>) -> Result<O, PyErr>,
479 ) -> Result<O, PyErr> {
480 let readonly_data = data.try_readonly()?;
481 with(readonly_data.as_array())
482 }
483
484 let data = numpy_asarray(py, buf)?;
485 let dtype = data.dtype();
486
487 if dtype.is_equiv_to(&numpy::dtype::<u8>(py)) {
488 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u8>>()?.into(), |a| {
489 with(AnyArrayView::U8(a))
490 })
491 } else if dtype.is_equiv_to(&numpy::dtype::<u16>(py)) {
492 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u16>>()?.into(), |a| {
493 with(AnyArrayView::U16(a))
494 })
495 } else if dtype.is_equiv_to(&numpy::dtype::<u32>(py)) {
496 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u32>>()?.into(), |a| {
497 with(AnyArrayView::U32(a))
498 })
499 } else if dtype.is_equiv_to(&numpy::dtype::<u64>(py)) {
500 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<u64>>()?.into(), |a| {
501 with(AnyArrayView::U64(a))
502 })
503 } else if dtype.is_equiv_to(&numpy::dtype::<i8>(py)) {
504 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i8>>()?.into(), |a| {
505 with(AnyArrayView::I8(a))
506 })
507 } else if dtype.is_equiv_to(&numpy::dtype::<i16>(py)) {
508 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i16>>()?.into(), |a| {
509 with(AnyArrayView::I16(a))
510 })
511 } else if dtype.is_equiv_to(&numpy::dtype::<i32>(py)) {
512 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i32>>()?.into(), |a| {
513 with(AnyArrayView::I32(a))
514 })
515 } else if dtype.is_equiv_to(&numpy::dtype::<i64>(py)) {
516 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<i64>>()?.into(), |a| {
517 with(AnyArrayView::I64(a))
518 })
519 } else if dtype.is_equiv_to(&numpy::dtype::<f32>(py)) {
520 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<f32>>()?.into(), |a| {
521 with(AnyArrayView::F32(a))
522 })
523 } else if dtype.is_equiv_to(&numpy::dtype::<f64>(py)) {
524 with_pyarraylike_as_view_inner(data.downcast::<PyArrayDyn<f64>>()?.into(), |a| {
525 with(AnyArrayView::F64(a))
526 })
527 } else {
528 Err(PyTypeError::new_err(format!(
529 "{class_method} received buffer of unsupported dtype `{dtype}`",
530 )))
531 }
532 }
533
534 fn with_pyarraylike_as_view_mut<'py, O>(
535 py: Python<'py>,
536 buf: Borrowed<'_, 'py, PyAny>,
537 class_method: &str,
538 with: impl for<'a> FnOnce(AnyArrayViewMut<'a>) -> Result<O, PyErr>,
539 ) -> Result<O, PyErr> {
540 fn with_pyarraylike_as_view_mut_inner<T: numpy::Element, O>(
541 data: Borrowed<PyArrayDyn<T>>,
542 with: impl for<'a> FnOnce(ArrayViewMutD<'a, T>) -> Result<O, PyErr>,
543 ) -> Result<O, PyErr> {
544 let mut readwrite_data = data.try_readwrite()?;
545 with(readwrite_data.as_array_mut())
546 }
547
548 let data = numpy_asarray(py, buf)?;
549 let dtype = data.dtype();
550
551 if dtype.is_equiv_to(&numpy::dtype::<u8>(py)) {
552 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u8>>()?.into(), |a| {
553 with(AnyArrayViewMut::U8(a))
554 })
555 } else if dtype.is_equiv_to(&numpy::dtype::<u16>(py)) {
556 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u16>>()?.into(), |a| {
557 with(AnyArrayViewMut::U16(a))
558 })
559 } else if dtype.is_equiv_to(&numpy::dtype::<u32>(py)) {
560 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u32>>()?.into(), |a| {
561 with(AnyArrayViewMut::U32(a))
562 })
563 } else if dtype.is_equiv_to(&numpy::dtype::<u64>(py)) {
564 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<u64>>()?.into(), |a| {
565 with(AnyArrayViewMut::U64(a))
566 })
567 } else if dtype.is_equiv_to(&numpy::dtype::<i8>(py)) {
568 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i8>>()?.into(), |a| {
569 with(AnyArrayViewMut::I8(a))
570 })
571 } else if dtype.is_equiv_to(&numpy::dtype::<i16>(py)) {
572 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i16>>()?.into(), |a| {
573 with(AnyArrayViewMut::I16(a))
574 })
575 } else if dtype.is_equiv_to(&numpy::dtype::<i32>(py)) {
576 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i32>>()?.into(), |a| {
577 with(AnyArrayViewMut::I32(a))
578 })
579 } else if dtype.is_equiv_to(&numpy::dtype::<i64>(py)) {
580 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<i64>>()?.into(), |a| {
581 with(AnyArrayViewMut::I64(a))
582 })
583 } else if dtype.is_equiv_to(&numpy::dtype::<f32>(py)) {
584 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<f32>>()?.into(), |a| {
585 with(AnyArrayViewMut::F32(a))
586 })
587 } else if dtype.is_equiv_to(&numpy::dtype::<f64>(py)) {
588 with_pyarraylike_as_view_mut_inner(data.downcast::<PyArrayDyn<f64>>()?.into(), |a| {
589 with(AnyArrayViewMut::F64(a))
590 })
591 } else {
592 Err(PyTypeError::new_err(format!(
593 "{class_method} received buffer of unsupported dtype `{dtype}`",
594 )))
595 }
596 }
597
598 fn any_array_into_pyarray<'py>(
599 py: Python<'py>,
600 array: AnyArray,
601 class_method: &str,
602 ) -> Result<Bound<'py, PyAny>, PyErr> {
603 match array {
604 AnyArray::U8(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
605 AnyArray::U16(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
606 AnyArray::U32(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
607 AnyArray::U64(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
608 AnyArray::I8(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
609 AnyArray::I16(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
610 AnyArray::I32(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
611 AnyArray::I64(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
612 AnyArray::F32(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
613 AnyArray::F64(a) => Ok(PyArray::from_owned_array(py, a).into_any()),
614 array => Err(PyTypeError::new_err(format!(
615 "{class_method} returned unsupported dtype `{}`",
616 array.dtype(),
617 ))),
618 }
619 }
620}