pyodide_webassembly_runtime_layer/
func.rs1use std::{
2 any::TypeId,
3 marker::PhantomData,
4 sync::{Arc, Weak},
5};
6
7use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyTuple, PyTypeInfo};
8use pyo3_error::PyErrChain;
9use wasm_runtime_layer::{
10 backend::{AsContext, AsContextMut, Value, WasmFunc, WasmStoreContext},
11 FuncType,
12};
13use wobbly::sync::Wobbly;
14
15use crate::{
16 conversion::{py_to_js_proxy, ToPy, ValueExt},
17 store::StoreContextMut,
18 Engine,
19};
20
21#[derive(Debug)]
26pub struct Func {
27 func: Py<PyAny>,
29 ty: FuncType,
31 user_state: Option<TypeId>,
33}
34
35impl Clone for Func {
36 fn clone(&self) -> Self {
37 Python::with_gil(|py| Self {
38 func: self.func.clone_ref(py),
39 ty: self.ty.clone(),
40 user_state: self.user_state,
41 })
42 }
43}
44
45impl WasmFunc<Engine> for Func {
46 fn new<T>(
47 mut ctx: impl AsContextMut<Engine, UserState = T>,
48 ty: FuncType,
49 func: impl 'static
50 + Send
51 + Sync
52 + Fn(StoreContextMut<T>, &[Value<Engine>], &mut [Value<Engine>]) -> anyhow::Result<()>,
53 ) -> Self {
54 Python::with_gil(|py| -> Result<Self, PyErr> {
55 #[cfg(feature = "tracing")]
56 tracing::debug!("Func::new");
57
58 let mut store: StoreContextMut<T> = ctx.as_context_mut();
59
60 let weak_store = store.as_weak_proof();
61
62 let user_state = non_static_type_id(store.data());
63 let ty_clone = ty.clone();
64
65 let func = Arc::new(move |args: Bound<PyTuple>| -> Result<Py<PyAny>, PyErr> {
66 let py = args.py();
67
68 let Some(mut strong_store) = Weak::upgrade(&weak_store) else {
69 return Err(PyRuntimeError::new_err(
70 "host func called after free of its associated store",
71 ));
72 };
73
74 let store = unsafe { StoreContextMut::from_proof_unchecked(&mut strong_store) };
80
81 let ty = &ty_clone;
82
83 let args = ty
84 .params()
85 .iter()
86 .zip(args.iter())
87 .map(|(ty, arg)| Value::from_py_typed(arg, *ty))
88 .collect::<Result<Vec<_>, _>>()?;
89 let mut results = vec![Value::I32(0); ty.results().len()];
90
91 #[cfg(feature = "tracing")]
92 let _span = tracing::debug_span!("call_host", ?args, ?ty).entered();
93
94 match func(store, &args, &mut results) {
95 Ok(()) => {
96 #[cfg(feature = "tracing")]
97 tracing::debug!(?results, "result");
98 },
99 Err(err) => {
100 #[cfg(feature = "tracing")]
101 tracing::error!("{err:?}");
102 return Err(PyErrChain::pyerr_from_err(py, err));
103 },
104 }
105
106 let results = match results.as_slice() {
107 [] => py.None(),
108 [res] => res.to_py(py),
109 results => PyTuple::new(py, results.iter().map(|res| res.to_py(py)))?
110 .into_any()
111 .unbind(),
112 };
113
114 Ok(results)
115 });
116
117 let func = Bound::new(
118 py,
119 PyHostFunc {
120 func: store.register_host_func(func),
121 #[cfg(feature = "tracing")]
122 ty: ty.clone(),
123 },
124 )?;
125 let func = py_to_js_proxy(func)?;
126
127 Ok(Self {
128 func: func.unbind(),
129 ty,
130 user_state: Some(user_state),
131 })
132 })
133 .expect("Func::new should not fail")
134 }
135
136 fn ty(&self, _ctx: impl AsContext<Engine>) -> FuncType {
137 self.ty.clone()
138 }
139
140 fn call<T>(
141 &self,
142 mut ctx: impl AsContextMut<Engine>,
143 args: &[Value<Engine>],
144 results: &mut [Value<Engine>],
145 ) -> anyhow::Result<()> {
146 Python::with_gil(|py| {
147 let store: StoreContextMut<_> = ctx.as_context_mut();
148
149 if let Some(user_state) = self.user_state {
150 assert_eq!(user_state, non_static_type_id(store.data()));
151 }
152
153 #[cfg(feature = "tracing")]
154 let _span = tracing::debug_span!("call_guest", ?args, ?self.ty).entered();
155
156 assert_eq!(self.ty.params().len(), args.len());
158 assert_eq!(self.ty.results().len(), results.len());
159
160 let args = args.iter().map(|arg| arg.to_py(py));
161 let args = PyTuple::new(py, args)?;
162
163 let res = self.func.bind(py).call1(args)?;
164
165 #[cfg(feature = "tracing")]
166 tracing::debug!(%res, ?self.ty);
167
168 match (self.ty.results(), results) {
169 ([], []) => (),
170 ([ty], [result]) => *result = Value::from_py_typed(res, *ty)?,
171 (tys, results) => {
172 let res: Bound<PyTuple> = PyTuple::type_object(py).call1((res,))?.extract()?;
173
174 assert_eq!(tys.len(), res.len());
176
177 for ((ty, result), value) in self
178 .ty
179 .results()
180 .iter()
181 .zip(results.iter_mut())
182 .zip(res.iter())
183 {
184 *result = Value::from_py_typed(value, *ty)?;
185 }
186 },
187 }
188
189 Ok(())
190 })
191 }
192}
193
194impl ToPy for Func {
195 fn to_py(&self, py: Python) -> Py<PyAny> {
196 self.func.clone_ref(py)
197 }
198}
199
200impl Func {
201 pub(crate) fn from_exported_function(func: Bound<PyAny>, ty: FuncType) -> anyhow::Result<Self> {
203 if !func.is_callable() {
204 anyhow::bail!("expected WebAssembly.Function but found {func:?} which is not callable");
205 }
206
207 #[cfg(feature = "tracing")]
208 tracing::debug!(%func, ?ty, "Func::from_exported_function");
209
210 Ok(Self {
211 func: func.unbind(),
212 ty,
213 user_state: None,
214 })
215 }
216}
217
218pub type PyHostFuncFn = dyn 'static + Send + Sync + Fn(Bound<PyTuple>) -> Result<Py<PyAny>, PyErr>;
219
220#[pyclass(frozen)]
221struct PyHostFunc {
222 func: Wobbly<PyHostFuncFn>,
223 #[cfg(feature = "tracing")]
224 ty: FuncType,
225}
226
227#[pymethods]
228impl PyHostFunc {
229 #[pyo3(signature = (*args))]
230 fn __call__(&self, args: Bound<PyTuple>) -> Result<Py<PyAny>, PyErr> {
231 #[cfg(feature = "tracing")]
232 let _span = tracing::debug_span!("call_trampoline", ?self.ty, args = %args).entered();
233
234 let Some(func) = self.func.upgrade() else {
235 return Err(PyRuntimeError::new_err(
236 "weak host func called after free of its associated store",
237 ));
238 };
239
240 func(args)
241 }
242}
243
244fn non_static_type_id<T: ?Sized>(_x: &T) -> TypeId {
247 trait NonStaticAny {
248 fn get_type_id(&self) -> TypeId
249 where
250 Self: 'static;
251 }
252
253 impl<T: ?Sized> NonStaticAny for PhantomData<T> {
254 fn get_type_id(&self) -> TypeId
255 where
256 Self: 'static,
257 {
258 TypeId::of::<T>()
259 }
260 }
261
262 let phantom_data = PhantomData::<T>;
263 NonStaticAny::get_type_id(unsafe {
264 core::mem::transmute::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>(&phantom_data)
265 })
266}