pyodide_webassembly_runtime_layer/
func.rs

1use std::{
2    any::TypeId,
3    marker::PhantomData,
4    sync::{Arc, Weak},
5};
6
7use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyTuple, PyTypeInfo};
8use pyo3_error::PyErrChain;
9use wasm_runtime_layer::{
10    backend::{AsContext, AsContextMut, Value, WasmFunc, WasmStoreContext},
11    FuncType,
12};
13use wobbly::sync::Wobbly;
14
15use crate::{
16    conversion::{py_to_js_proxy, ToPy, ValueExt},
17    store::StoreContextMut,
18    Engine,
19};
20
21/// A bound function, which may be an export from a WASM [`Instance`] or a host
22/// function.
23///
24/// [`Instance`]: crate::instance::Instance
25#[derive(Debug)]
26pub struct Func {
27    /// The inner function
28    pyfunc: Py<PyAny>,
29    /// The function signature
30    ty: FuncType,
31    /// The user state type of the context
32    user_state: Option<TypeId>,
33}
34
35impl Clone for Func {
36    fn clone(&self) -> Self {
37        Python::attach(|py| Self {
38            pyfunc: self.pyfunc.clone_ref(py),
39            ty: self.ty.clone(),
40            user_state: self.user_state,
41        })
42    }
43}
44
45impl WasmFunc<Engine> for Func {
46    fn new<T: 'static>(
47        mut ctx: impl AsContextMut<Engine, UserState = T>,
48        ty: FuncType,
49        func: impl 'static
50            + Send
51            + Sync
52            + Fn(StoreContextMut<T>, &[Value<Engine>], &mut [Value<Engine>]) -> anyhow::Result<()>,
53    ) -> Self {
54        Python::attach(|py| -> Result<Self, PyErr> {
55            #[cfg(feature = "tracing")]
56            tracing::debug!("Func::new");
57
58            let mut store: StoreContextMut<T> = ctx.as_context_mut();
59
60            let weak_store = store.as_weak_proof();
61
62            let user_state = non_static_type_id(store.data());
63            let ty_clone = ty.clone();
64
65            let func = Arc::new(move |args: Bound<PyTuple>| -> Result<Py<PyAny>, PyErr> {
66                let py = args.py();
67
68                let Some(mut strong_store) = Weak::upgrade(&weak_store) else {
69                    return Err(PyRuntimeError::new_err(
70                        "host func called after free of its associated store",
71                    ));
72                };
73
74                // Safety:
75                //
76                // - The proof is constructed from a mutable store context
77                // - Calling a host function (from the host or from WASM) provides that call
78                //   with a mutable reborrow of the store context
79                let store = unsafe { StoreContextMut::from_proof_unchecked(&mut strong_store) };
80
81                let ty = &ty_clone;
82
83                let args = ty
84                    .params()
85                    .iter()
86                    .zip(args.iter())
87                    .map(|(ty, arg)| Value::from_py_typed(arg, *ty))
88                    .collect::<Result<Vec<_>, _>>()?;
89                let mut results = vec![Value::I32(0); ty.results().len()];
90
91                #[cfg(feature = "tracing")]
92                let _span = tracing::debug_span!("call_host", ?args, ?ty).entered();
93
94                match func(store, &args, &mut results) {
95                    Ok(()) => {
96                        #[cfg(feature = "tracing")]
97                        tracing::debug!(?results, "result");
98                    },
99                    Err(err) => {
100                        #[cfg(feature = "tracing")]
101                        tracing::error!("{err:?}");
102                        return Err(PyErrChain::pyerr_from_err(py, err));
103                    },
104                }
105
106                let results = match results.as_slice() {
107                    [] => py.None(),
108                    [res] => res.to_py(py),
109                    results => PyTuple::new(py, results.iter().map(|res| res.to_py(py)))?
110                        .into_any()
111                        .unbind(),
112                };
113
114                Ok(results)
115            });
116
117            let func = Bound::new(
118                py,
119                PyHostFunc {
120                    func: store.register_host_func(func),
121                    #[cfg(feature = "tracing")]
122                    ty: ty.clone(),
123                },
124            )?;
125            let func = py_to_js_proxy(func)?;
126
127            Ok(Self {
128                pyfunc: func.unbind(),
129                ty,
130                user_state: Some(user_state),
131            })
132        })
133        .expect("Func::new should not fail")
134    }
135
136    fn ty(&self, _ctx: impl AsContext<Engine>) -> FuncType {
137        self.ty.clone()
138    }
139
140    fn call<T>(
141        &self,
142        mut ctx: impl AsContextMut<Engine>,
143        args: &[Value<Engine>],
144        results: &mut [Value<Engine>],
145    ) -> anyhow::Result<()> {
146        Python::attach(|py| {
147            let store: StoreContextMut<_> = ctx.as_context_mut();
148
149            if let Some(user_state) = self.user_state {
150                assert_eq!(user_state, non_static_type_id(store.data()));
151            }
152
153            #[cfg(feature = "tracing")]
154            let _span = tracing::debug_span!("call_guest", ?args, ?self.ty).entered();
155
156            // https://webassembly.github.io/spec/js-api/#exported-function-exotic-objects
157            assert_eq!(self.ty.params().len(), args.len());
158            assert_eq!(self.ty.results().len(), results.len());
159
160            let args = args.iter().map(|arg| arg.to_py(py));
161            let args = PyTuple::new(py, args)?;
162
163            let res = self.pyfunc.bind(py).call1(args)?;
164
165            #[cfg(feature = "tracing")]
166            tracing::debug!(%res, ?self.ty);
167
168            match (self.ty.results(), results) {
169                ([], []) => (),
170                ([ty], [result]) => *result = Value::from_py_typed(res, *ty)?,
171                (tys, results) => {
172                    let res: Bound<PyTuple> = PyTuple::type_object(py)
173                        .call1((res,))?
174                        .extract()
175                        .map_err(PyErr::from)?;
176
177                    // https://webassembly.github.io/spec/js-api/#exported-function-exotic-objects
178                    assert_eq!(tys.len(), res.len());
179
180                    for ((ty, result), value) in self
181                        .ty
182                        .results()
183                        .iter()
184                        .zip(results.iter_mut())
185                        .zip(res.iter())
186                    {
187                        *result = Value::from_py_typed(value, *ty)?;
188                    }
189                },
190            }
191
192            Ok(())
193        })
194    }
195}
196
197impl ToPy for Func {
198    fn to_py(&self, py: Python) -> Py<PyAny> {
199        self.pyfunc.clone_ref(py)
200    }
201}
202
203impl Func {
204    /// Creates a new function from a Python value
205    pub(crate) fn from_exported_function(func: Bound<PyAny>, ty: FuncType) -> anyhow::Result<Self> {
206        if !func.is_callable() {
207            anyhow::bail!("expected WebAssembly.Function but found {func:?} which is not callable");
208        }
209
210        #[cfg(feature = "tracing")]
211        tracing::debug!(%func, ?ty, "Func::from_exported_function");
212
213        Ok(Self {
214            pyfunc: func.unbind(),
215            ty,
216            user_state: None,
217        })
218    }
219}
220
221pub type PyHostFuncFn = dyn 'static + Send + Sync + Fn(Bound<PyTuple>) -> Result<Py<PyAny>, PyErr>;
222
223#[pyclass(frozen)]
224struct PyHostFunc {
225    func: Wobbly<PyHostFuncFn>,
226    #[cfg(feature = "tracing")]
227    ty: FuncType,
228}
229
230#[pymethods]
231impl PyHostFunc {
232    #[pyo3(signature = (*args))]
233    fn __call__(&self, args: Bound<PyTuple>) -> Result<Py<PyAny>, PyErr> {
234        #[cfg(feature = "tracing")]
235        let _span = tracing::debug_span!("call_trampoline", ?self.ty, args = %args).entered();
236
237        let Some(func) = self.func.upgrade() else {
238            return Err(PyRuntimeError::new_err(
239                "weak host func called after free of its associated store",
240            ));
241        };
242
243        func(args)
244    }
245}
246
247// Courtesy of David Tolnay:
248// https://github.com/rust-lang/rust/issues/41875#issuecomment-317292888
249fn non_static_type_id<T: ?Sized>(_x: &T) -> TypeId {
250    trait NonStaticAny {
251        fn get_type_id(&self) -> TypeId
252        where
253            Self: 'static;
254    }
255
256    impl<T: ?Sized> NonStaticAny for PhantomData<T> {
257        fn get_type_id(&self) -> TypeId
258        where
259            Self: 'static,
260        {
261            TypeId::of::<T>()
262        }
263    }
264
265    let phantom_data = PhantomData::<T>;
266    NonStaticAny::get_type_id(unsafe {
267        core::mem::transmute::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>(&phantom_data)
268    })
269}