numcodecs_wasm/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
use numcodecs_python::{PyCodec, PyCodecAdapter, PyCodecClass};
use numcodecs_wasm_host_reproducible::{ReproducibleWasmCodec, ReproducibleWasmCodecType};
use pyo3::{exceptions::PyTypeError, prelude::*};

mod engine;

use engine::{default_engine, Engine};

#[pymodule]
#[pyo3(name = "_wasm")]
fn wasm<'py>(py: Python<'py>, module: &Bound<'py, PyModule>) -> Result<(), PyErr> {
    let logger = pyo3_log::Logger::new(py, pyo3_log::Caching::Nothing)?;
    logger
        .install()
        .map_err(|err| pyo3_error::PyErrChain::new(py, err))?;

    module.add_function(wrap_pyfunction!(create_codec_class, module)?)?;
    module.add_function(wrap_pyfunction!(read_codec_instruction_counter, module)?)?;

    Ok(())
}

#[pyfunction]
#[pyo3(name = "_create_codec_class")]
fn create_codec_class<'py>(
    py: Python<'py>,
    module: &Bound<'py, PyModule>,
    wasm: Vec<u8>,
) -> Result<Bound<'py, PyCodecClass>, PyErr> {
    let engine = default_engine(py)?;

    let codec_ty = ReproducibleWasmCodecType::new(engine, wasm)
        .map_err(|err| pyo3_error::PyErrChain::new(py, err))?;

    let codec_class = numcodecs_python::export_codec_class(py, codec_ty, module.as_borrowed())?;

    Ok(codec_class)
}

#[pyfunction]
#[pyo3(name = "_read_codec_instruction_counter")]
fn read_codec_instruction_counter<'py>(
    py: Python<'py>,
    codec: &Bound<'py, PyCodec>,
) -> Result<u64, PyErr> {
    let Some(instruction_counter) =
        PyCodecAdapter::with_downcast(codec, |codec: &ReproducibleWasmCodec<Engine>| {
            codec
                .instruction_counter()
                .map_err(|err| pyo3_error::PyErrChain::new(py, err))
        })
        .transpose()?
    else {
        return Err(PyTypeError::new_err(
            "`codec` is not a wasm codec, only wasm codecs have instruction counts",
        ));
    };

    Ok(instruction_counter.0)
}