pyodide_webassembly_runtime_layer/
func.rs

1use std::{
2    any::TypeId,
3    marker::PhantomData,
4    sync::{Arc, Weak},
5};
6
7use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyTuple, PyTypeInfo};
8use pyo3_error::PyErrChain;
9use wasm_runtime_layer::{
10    backend::{AsContext, AsContextMut, Value, WasmFunc, WasmStoreContext},
11    FuncType,
12};
13use wobbly::sync::Wobbly;
14
15use crate::{
16    conversion::{py_to_js_proxy, ToPy, ValueExt},
17    store::StoreContextMut,
18    Engine,
19};
20
21/// A bound function, which may be an export from a WASM [`Instance`] or a host
22/// function.
23///
24/// [`Instance`]: crate::instance::Instance
25#[derive(Debug)]
26pub struct Func {
27    /// The inner function
28    func: Py<PyAny>,
29    /// The function signature
30    ty: FuncType,
31    /// The user state type of the context
32    user_state: Option<TypeId>,
33}
34
35impl Clone for Func {
36    fn clone(&self) -> Self {
37        Python::with_gil(|py| Self {
38            func: self.func.clone_ref(py),
39            ty: self.ty.clone(),
40            user_state: self.user_state,
41        })
42    }
43}
44
45impl WasmFunc<Engine> for Func {
46    fn new<T>(
47        mut ctx: impl AsContextMut<Engine, UserState = T>,
48        ty: FuncType,
49        func: impl 'static
50            + Send
51            + Sync
52            + Fn(StoreContextMut<T>, &[Value<Engine>], &mut [Value<Engine>]) -> anyhow::Result<()>,
53    ) -> Self {
54        Python::with_gil(|py| -> Result<Self, PyErr> {
55            #[cfg(feature = "tracing")]
56            tracing::debug!("Func::new");
57
58            let mut store: StoreContextMut<T> = ctx.as_context_mut();
59
60            let weak_store = store.as_weak_proof();
61
62            let user_state = non_static_type_id(store.data());
63            let ty_clone = ty.clone();
64
65            let func = Arc::new(move |args: Bound<PyTuple>| -> Result<Py<PyAny>, PyErr> {
66                let py = args.py();
67
68                let Some(mut strong_store) = Weak::upgrade(&weak_store) else {
69                    return Err(PyRuntimeError::new_err(
70                        "host func called after free of its associated store",
71                    ));
72                };
73
74                // Safety:
75                //
76                // - The proof is constructed from a mutable store context
77                // - Calling a host function (from the host or from WASM) provides that call
78                //   with a mutable reborrow of the store context
79                let store = unsafe { StoreContextMut::from_proof_unchecked(&mut strong_store) };
80
81                let ty = &ty_clone;
82
83                let args = ty
84                    .params()
85                    .iter()
86                    .zip(args.iter())
87                    .map(|(ty, arg)| Value::from_py_typed(arg, *ty))
88                    .collect::<Result<Vec<_>, _>>()?;
89                let mut results = vec![Value::I32(0); ty.results().len()];
90
91                #[cfg(feature = "tracing")]
92                let _span = tracing::debug_span!("call_host", ?args, ?ty).entered();
93
94                match func(store, &args, &mut results) {
95                    Ok(()) => {
96                        #[cfg(feature = "tracing")]
97                        tracing::debug!(?results, "result");
98                    },
99                    Err(err) => {
100                        #[cfg(feature = "tracing")]
101                        tracing::error!("{err:?}");
102                        return Err(PyErrChain::pyerr_from_err(py, err));
103                    },
104                }
105
106                let results = match results.as_slice() {
107                    [] => py.None(),
108                    [res] => res.to_py(py),
109                    results => PyTuple::new(py, results.iter().map(|res| res.to_py(py)))?
110                        .into_any()
111                        .unbind(),
112                };
113
114                Ok(results)
115            });
116
117            let func = Bound::new(
118                py,
119                PyHostFunc {
120                    func: store.register_host_func(func),
121                    #[cfg(feature = "tracing")]
122                    ty: ty.clone(),
123                },
124            )?;
125            let func = py_to_js_proxy(func)?;
126
127            Ok(Self {
128                func: func.unbind(),
129                ty,
130                user_state: Some(user_state),
131            })
132        })
133        .expect("Func::new should not fail")
134    }
135
136    fn ty(&self, _ctx: impl AsContext<Engine>) -> FuncType {
137        self.ty.clone()
138    }
139
140    fn call<T>(
141        &self,
142        mut ctx: impl AsContextMut<Engine>,
143        args: &[Value<Engine>],
144        results: &mut [Value<Engine>],
145    ) -> anyhow::Result<()> {
146        Python::with_gil(|py| {
147            let store: StoreContextMut<_> = ctx.as_context_mut();
148
149            if let Some(user_state) = self.user_state {
150                assert_eq!(user_state, non_static_type_id(store.data()));
151            }
152
153            #[cfg(feature = "tracing")]
154            let _span = tracing::debug_span!("call_guest", ?args, ?self.ty).entered();
155
156            // https://webassembly.github.io/spec/js-api/#exported-function-exotic-objects
157            assert_eq!(self.ty.params().len(), args.len());
158            assert_eq!(self.ty.results().len(), results.len());
159
160            let args = args.iter().map(|arg| arg.to_py(py));
161            let args = PyTuple::new(py, args)?;
162
163            let res = self.func.bind(py).call1(args)?;
164
165            #[cfg(feature = "tracing")]
166            tracing::debug!(%res, ?self.ty);
167
168            match (self.ty.results(), results) {
169                ([], []) => (),
170                ([ty], [result]) => *result = Value::from_py_typed(res, *ty)?,
171                (tys, results) => {
172                    let res: Bound<PyTuple> = PyTuple::type_object(py).call1((res,))?.extract()?;
173
174                    // https://webassembly.github.io/spec/js-api/#exported-function-exotic-objects
175                    assert_eq!(tys.len(), res.len());
176
177                    for ((ty, result), value) in self
178                        .ty
179                        .results()
180                        .iter()
181                        .zip(results.iter_mut())
182                        .zip(res.iter())
183                    {
184                        *result = Value::from_py_typed(value, *ty)?;
185                    }
186                },
187            }
188
189            Ok(())
190        })
191    }
192}
193
194impl ToPy for Func {
195    fn to_py(&self, py: Python) -> Py<PyAny> {
196        self.func.clone_ref(py)
197    }
198}
199
200impl Func {
201    /// Creates a new function from a Python value
202    pub(crate) fn from_exported_function(func: Bound<PyAny>, ty: FuncType) -> anyhow::Result<Self> {
203        if !func.is_callable() {
204            anyhow::bail!("expected WebAssembly.Function but found {func:?} which is not callable");
205        }
206
207        #[cfg(feature = "tracing")]
208        tracing::debug!(%func, ?ty, "Func::from_exported_function");
209
210        Ok(Self {
211            func: func.unbind(),
212            ty,
213            user_state: None,
214        })
215    }
216}
217
218pub type PyHostFuncFn = dyn 'static + Send + Sync + Fn(Bound<PyTuple>) -> Result<Py<PyAny>, PyErr>;
219
220#[pyclass(frozen)]
221struct PyHostFunc {
222    func: Wobbly<PyHostFuncFn>,
223    #[cfg(feature = "tracing")]
224    ty: FuncType,
225}
226
227#[pymethods]
228impl PyHostFunc {
229    #[pyo3(signature = (*args))]
230    fn __call__(&self, args: Bound<PyTuple>) -> Result<Py<PyAny>, PyErr> {
231        #[cfg(feature = "tracing")]
232        let _span = tracing::debug_span!("call_trampoline", ?self.ty, args = %args).entered();
233
234        let Some(func) = self.func.upgrade() else {
235            return Err(PyRuntimeError::new_err(
236                "weak host func called after free of its associated store",
237            ));
238        };
239
240        func(args)
241    }
242}
243
244// Courtesy of David Tolnay:
245// https://github.com/rust-lang/rust/issues/41875#issuecomment-317292888
246fn non_static_type_id<T: ?Sized>(_x: &T) -> TypeId {
247    trait NonStaticAny {
248        fn get_type_id(&self) -> TypeId
249        where
250            Self: 'static;
251    }
252
253    impl<T: ?Sized> NonStaticAny for PhantomData<T> {
254        fn get_type_id(&self) -> TypeId
255        where
256            Self: 'static,
257        {
258            TypeId::of::<T>()
259        }
260    }
261
262    let phantom_data = PhantomData::<T>;
263    NonStaticAny::get_type_id(unsafe {
264        core::mem::transmute::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>(&phantom_data)
265    })
266}