pyodide_webassembly_runtime_layer/
func.rs1use std::{
2 any::TypeId,
3 marker::PhantomData,
4 sync::{Arc, Weak},
5};
6
7use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyTuple, PyTypeInfo};
8use pyo3_error::PyErrChain;
9use wasm_runtime_layer::{
10 backend::{AsContext, AsContextMut, Value, WasmFunc, WasmStoreContext},
11 FuncType,
12};
13use wobbly::sync::Wobbly;
14
15use crate::{
16 conversion::{py_to_js_proxy, ToPy, ValueExt},
17 store::StoreContextMut,
18 Engine,
19};
20
21#[derive(Debug)]
26pub struct Func {
27 pyfunc: Py<PyAny>,
29 ty: FuncType,
31 user_state: Option<TypeId>,
33}
34
35impl Clone for Func {
36 fn clone(&self) -> Self {
37 Python::attach(|py| Self {
38 pyfunc: self.pyfunc.clone_ref(py),
39 ty: self.ty.clone(),
40 user_state: self.user_state,
41 })
42 }
43}
44
45impl WasmFunc<Engine> for Func {
46 fn new<T: 'static>(
47 mut ctx: impl AsContextMut<Engine, UserState = T>,
48 ty: FuncType,
49 func: impl 'static
50 + Send
51 + Sync
52 + Fn(StoreContextMut<T>, &[Value<Engine>], &mut [Value<Engine>]) -> anyhow::Result<()>,
53 ) -> Self {
54 Python::attach(|py| -> Result<Self, PyErr> {
55 #[cfg(feature = "tracing")]
56 tracing::debug!("Func::new");
57
58 let mut store: StoreContextMut<T> = ctx.as_context_mut();
59
60 let weak_store = store.as_weak_proof();
61
62 let user_state = non_static_type_id(store.data());
63 let ty_clone = ty.clone();
64
65 let func = Arc::new(move |args: Bound<PyTuple>| -> Result<Py<PyAny>, PyErr> {
66 let py = args.py();
67
68 let Some(mut strong_store) = Weak::upgrade(&weak_store) else {
69 return Err(PyRuntimeError::new_err(
70 "host func called after free of its associated store",
71 ));
72 };
73
74 let store = unsafe { StoreContextMut::from_proof_unchecked(&mut strong_store) };
80
81 let ty = &ty_clone;
82
83 let args = ty
84 .params()
85 .iter()
86 .zip(args.iter())
87 .map(|(ty, arg)| Value::from_py_typed(arg, *ty))
88 .collect::<Result<Vec<_>, _>>()?;
89 let mut results = vec![Value::I32(0); ty.results().len()];
90
91 #[cfg(feature = "tracing")]
92 let _span = tracing::debug_span!("call_host", ?args, ?ty).entered();
93
94 match func(store, &args, &mut results) {
95 Ok(()) => {
96 #[cfg(feature = "tracing")]
97 tracing::debug!(?results, "result");
98 },
99 Err(err) => {
100 #[cfg(feature = "tracing")]
101 tracing::error!("{err:?}");
102 return Err(PyErrChain::pyerr_from_err(py, err));
103 },
104 }
105
106 let results = match results.as_slice() {
107 [] => py.None(),
108 [res] => res.to_py(py),
109 results => PyTuple::new(py, results.iter().map(|res| res.to_py(py)))?
110 .into_any()
111 .unbind(),
112 };
113
114 Ok(results)
115 });
116
117 let func = Bound::new(
118 py,
119 PyHostFunc {
120 func: store.register_host_func(func),
121 #[cfg(feature = "tracing")]
122 ty: ty.clone(),
123 },
124 )?;
125 let func = py_to_js_proxy(func)?;
126
127 Ok(Self {
128 pyfunc: func.unbind(),
129 ty,
130 user_state: Some(user_state),
131 })
132 })
133 .expect("Func::new should not fail")
134 }
135
136 fn ty(&self, _ctx: impl AsContext<Engine>) -> FuncType {
137 self.ty.clone()
138 }
139
140 fn call<T>(
141 &self,
142 mut ctx: impl AsContextMut<Engine>,
143 args: &[Value<Engine>],
144 results: &mut [Value<Engine>],
145 ) -> anyhow::Result<()> {
146 Python::attach(|py| {
147 let store: StoreContextMut<_> = ctx.as_context_mut();
148
149 if let Some(user_state) = self.user_state {
150 assert_eq!(user_state, non_static_type_id(store.data()));
151 }
152
153 #[cfg(feature = "tracing")]
154 let _span = tracing::debug_span!("call_guest", ?args, ?self.ty).entered();
155
156 assert_eq!(self.ty.params().len(), args.len());
158 assert_eq!(self.ty.results().len(), results.len());
159
160 let args = args.iter().map(|arg| arg.to_py(py));
161 let args = PyTuple::new(py, args)?;
162
163 let res = self.pyfunc.bind(py).call1(args)?;
164
165 #[cfg(feature = "tracing")]
166 tracing::debug!(%res, ?self.ty);
167
168 match (self.ty.results(), results) {
169 ([], []) => (),
170 ([ty], [result]) => *result = Value::from_py_typed(res, *ty)?,
171 (tys, results) => {
172 let res: Bound<PyTuple> = PyTuple::type_object(py)
173 .call1((res,))?
174 .extract()
175 .map_err(PyErr::from)?;
176
177 assert_eq!(tys.len(), res.len());
179
180 for ((ty, result), value) in self
181 .ty
182 .results()
183 .iter()
184 .zip(results.iter_mut())
185 .zip(res.iter())
186 {
187 *result = Value::from_py_typed(value, *ty)?;
188 }
189 },
190 }
191
192 Ok(())
193 })
194 }
195}
196
197impl ToPy for Func {
198 fn to_py(&self, py: Python) -> Py<PyAny> {
199 self.pyfunc.clone_ref(py)
200 }
201}
202
203impl Func {
204 pub(crate) fn from_exported_function(func: Bound<PyAny>, ty: FuncType) -> anyhow::Result<Self> {
206 if !func.is_callable() {
207 anyhow::bail!("expected WebAssembly.Function but found {func:?} which is not callable");
208 }
209
210 #[cfg(feature = "tracing")]
211 tracing::debug!(%func, ?ty, "Func::from_exported_function");
212
213 Ok(Self {
214 pyfunc: func.unbind(),
215 ty,
216 user_state: None,
217 })
218 }
219}
220
221pub type PyHostFuncFn = dyn 'static + Send + Sync + Fn(Bound<PyTuple>) -> Result<Py<PyAny>, PyErr>;
222
223#[pyclass(frozen)]
224struct PyHostFunc {
225 func: Wobbly<PyHostFuncFn>,
226 #[cfg(feature = "tracing")]
227 ty: FuncType,
228}
229
230#[pymethods]
231impl PyHostFunc {
232 #[pyo3(signature = (*args))]
233 fn __call__(&self, args: Bound<PyTuple>) -> Result<Py<PyAny>, PyErr> {
234 #[cfg(feature = "tracing")]
235 let _span = tracing::debug_span!("call_trampoline", ?self.ty, args = %args).entered();
236
237 let Some(func) = self.func.upgrade() else {
238 return Err(PyRuntimeError::new_err(
239 "weak host func called after free of its associated store",
240 ));
241 };
242
243 func(args)
244 }
245}
246
247fn non_static_type_id<T: ?Sized>(_x: &T) -> TypeId {
250 trait NonStaticAny {
251 fn get_type_id(&self) -> TypeId
252 where
253 Self: 'static;
254 }
255
256 impl<T: ?Sized> NonStaticAny for PhantomData<T> {
257 fn get_type_id(&self) -> TypeId
258 where
259 Self: 'static,
260 {
261 TypeId::of::<T>()
262 }
263 }
264
265 let phantom_data = PhantomData::<T>;
266 NonStaticAny::get_type_id(unsafe {
267 core::mem::transmute::<&dyn NonStaticAny, &(dyn NonStaticAny + 'static)>(&phantom_data)
268 })
269}