1use std::sync::{Arc, OnceLock};
2
3use ndarray::{ArrayBase, ArrayView, Data, Dimension};
4use numcodecs::{AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray};
5use schemars::Schema;
6use serde::Serializer;
7use wasm_component_layer::{
8 AsContextMut, Enum, EnumType, Func, Instance, List, ListType, Record, RecordType, ResourceOwn,
9 Value, ValueType, Variant, VariantCase, VariantType,
10};
11
12use crate::{
13 component::WasmCodecComponent,
14 error::{CodecError, RuntimeError},
15 wit::guest_error_from_wasm,
16};
17
18pub struct WasmCodec {
25 pub(crate) resource: ResourceOwn,
27 pub(crate) codec_id: Arc<str>,
29 pub(crate) codec_config_schema: Arc<Schema>,
30 pub(crate) from_config: Func,
33 pub(crate) encode: Func,
34 pub(crate) decode: Func,
35 pub(crate) decode_into: Func,
36 pub(crate) get_config: Func,
37 pub(crate) instance: Instance,
39}
40
41impl WasmCodec {
43 #[expect(clippy::needless_pass_by_value)]
44 pub fn encode(
55 &self,
56 ctx: impl AsContextMut,
57 data: AnyCowArray,
58 ) -> Result<Result<AnyArray, CodecError>, RuntimeError> {
59 self.process(
60 ctx,
61 data.view(),
62 None,
63 |ctx, arguments, results| self.encode.call(ctx, arguments, results),
64 |encoded| Ok(encoded.into_owned()),
65 )
66 }
67
68 #[expect(clippy::needless_pass_by_value)]
69 pub fn decode(
80 &self,
81 ctx: impl AsContextMut,
82 encoded: AnyCowArray,
83 ) -> Result<Result<AnyArray, CodecError>, RuntimeError> {
84 self.process(
85 ctx,
86 encoded.view(),
87 None,
88 |ctx, arguments, results| self.decode.call(ctx, arguments, results),
89 |decoded| Ok(decoded.into_owned()),
90 )
91 }
92
93 pub fn decode_into(
107 &self,
108 ctx: impl AsContextMut,
109 encoded: AnyArrayView,
110 mut decoded: AnyArrayViewMut,
111 ) -> Result<Result<(), CodecError>, RuntimeError> {
112 self.process(
113 ctx,
114 encoded,
115 #[expect(clippy::unnecessary_to_owned)] Some((decoded.dtype(), &decoded.shape().to_vec())),
117 |ctx, arguments, results| self.decode_into.call(ctx, arguments, results),
118 |decoded_in| {
119 decoded
120 .assign(&decoded_in)
121 .map_err(anyhow::Error::new)
122 .map_err(RuntimeError::from)
123 },
124 )
125 }
126}
127
128impl WasmCodec {
130 #[must_use]
132 pub fn ty(&self) -> WasmCodecComponent {
133 WasmCodecComponent {
134 codec_id: self.codec_id.clone(),
135 codec_config_schema: self.codec_config_schema.clone(),
136 from_config: self.from_config.clone(),
137 encode: self.encode.clone(),
138 decode: self.decode.clone(),
139 decode_into: self.decode_into.clone(),
140 get_config: self.get_config.clone(),
141 instance: self.instance.clone(),
142 }
143 }
144
145 pub fn get_config<S: Serializer>(
155 &self,
156 mut ctx: impl AsContextMut,
157 serializer: S,
158 ) -> Result<S::Ok, S::Error> {
159 let resource = self
160 .resource
161 .borrow(&mut ctx)
162 .map_err(serde::ser::Error::custom)?;
163
164 let arg = Value::Borrow(resource);
165 let mut result = Value::U8(0);
166
167 self.get_config
168 .call(
169 &mut ctx,
170 std::slice::from_ref(&arg),
171 std::slice::from_mut(&mut result),
172 )
173 .map_err(serde::ser::Error::custom)?;
174
175 let config = match result {
176 Value::Result(result) => match &*result {
177 Ok(Some(Value::String(config))) => config.clone(),
178 Err(err) => match guest_error_from_wasm(err.as_ref()) {
179 Ok(err) => return Err(serde::ser::Error::custom(err)),
180 Err(err) => return Err(serde::ser::Error::custom(err)),
181 },
182 result => {
183 return Err(serde::ser::Error::custom(format!(
184 "unexpected get-config result value {result:?}"
185 )))
186 }
187 },
188 value => {
189 return Err(serde::ser::Error::custom(format!(
190 "unexpected get-config result value {value:?}"
191 )))
192 }
193 };
194
195 serde_transcode::transcode(&mut serde_json::Deserializer::from_str(&config), serializer)
196 }
197}
198
199impl WasmCodec {
201 pub fn try_clone(&self, mut ctx: impl AsContextMut) -> Result<Self, serde_json::Error> {
211 let mut config = self.get_config(&mut ctx, serde_json::value::Serializer)?;
212
213 if let Some(config) = config.as_object_mut() {
214 config.remove("id");
215 }
216
217 let codec: Self = self.ty().codec_from_config(ctx, config)?;
218
219 Ok(codec)
220 }
221
222 pub fn try_clone_into(
233 &self,
234 ctx_from: impl AsContextMut,
235 ctx_into: impl AsContextMut,
236 ) -> Result<Self, serde_json::Error> {
237 let mut config = self.get_config(ctx_from, serde_json::value::Serializer)?;
238
239 if let Some(config) = config.as_object_mut() {
240 config.remove("id");
241 }
242
243 let codec: Self = self.ty().codec_from_config(ctx_into, config)?;
244
245 Ok(codec)
246 }
247}
248
249impl WasmCodec {
251 pub fn try_drop(&self, ctx: impl AsContextMut) -> Result<(), RuntimeError> {
260 self.resource.drop(ctx).map_err(RuntimeError::from)
261 }
262}
263
264impl WasmCodec {
265 fn process<O, C: AsContextMut>(
266 &self,
267 mut ctx: C,
268 data: AnyArrayView,
269 output_prototype: Option<(AnyArrayDType, &[usize])>,
270 process: impl FnOnce(&mut C, &[Value], &mut [Value]) -> anyhow::Result<()>,
271 with_result: impl for<'a> FnOnce(AnyArrayView<'a>) -> Result<O, RuntimeError>,
272 ) -> Result<Result<O, CodecError>, RuntimeError> {
273 let resource = self.resource.borrow(&mut ctx)?;
274
275 let array = Self::array_into_wasm(data)?;
276
277 let output_prototype = output_prototype
278 .map(|(dtype, shape)| Self::array_prototype_into_wasm(dtype, shape))
279 .transpose()?;
280
281 let mut result = Value::U8(0);
282
283 process(
284 &mut ctx,
285 &match output_prototype {
286 None => vec![Value::Borrow(resource), Value::Record(array)],
287 Some(output) => vec![
288 Value::Borrow(resource),
289 Value::Record(array),
290 Value::Record(output),
291 ],
292 },
293 std::slice::from_mut(&mut result),
294 )?;
295
296 match result {
297 Value::Result(result) => match &*result {
298 Ok(Some(Value::Record(record))) if &record.ty() == Self::any_array_ty() => {
299 Self::with_array_view_from_wasm_record(record, |array| {
300 Ok(Ok(with_result(array)?))
301 })
302 }
303 Err(err) => guest_error_from_wasm(err.as_ref()).map(Err),
304 result => Err(RuntimeError::from(anyhow::Error::msg(format!(
305 "unexpected process result value {result:?}"
306 )))),
307 },
308 value => Err(RuntimeError::from(anyhow::Error::msg(format!(
309 "unexpected process result value {value:?}"
310 )))),
311 }
312 }
313
314 fn any_array_data_ty() -> &'static VariantType {
315 static ANY_ARRAY_DATA_TY: OnceLock<VariantType> = OnceLock::new();
316
317 #[expect(clippy::expect_used)]
318 ANY_ARRAY_DATA_TY.get_or_init(|| {
321 VariantType::new(
322 None,
323 [
324 VariantCase::new("u8", Some(ValueType::List(ListType::new(ValueType::U8)))),
325 VariantCase::new("u16", Some(ValueType::List(ListType::new(ValueType::U16)))),
326 VariantCase::new("u32", Some(ValueType::List(ListType::new(ValueType::U32)))),
327 VariantCase::new("u64", Some(ValueType::List(ListType::new(ValueType::U64)))),
328 VariantCase::new("i8", Some(ValueType::List(ListType::new(ValueType::S8)))),
329 VariantCase::new("i16", Some(ValueType::List(ListType::new(ValueType::S16)))),
330 VariantCase::new("i32", Some(ValueType::List(ListType::new(ValueType::S32)))),
331 VariantCase::new("i64", Some(ValueType::List(ListType::new(ValueType::S64)))),
332 VariantCase::new("f32", Some(ValueType::List(ListType::new(ValueType::F32)))),
333 VariantCase::new("f64", Some(ValueType::List(ListType::new(ValueType::F64)))),
334 ],
335 )
336 .expect("constructing the any-array-data variant type must not fail")
337 })
338 }
339
340 fn any_array_ty() -> &'static RecordType {
341 static ANY_ARRAY_TY: OnceLock<RecordType> = OnceLock::new();
342
343 #[expect(clippy::expect_used)]
344 ANY_ARRAY_TY.get_or_init(|| {
347 RecordType::new(
348 None,
349 [
350 (
351 "data",
352 ValueType::Variant(Self::any_array_data_ty().clone()),
353 ),
354 ("shape", ValueType::List(ListType::new(ValueType::U32))),
355 ],
356 )
357 .expect("constructing the any-array record type must not fail")
358 })
359 }
360
361 #[expect(clippy::needless_pass_by_value)]
362 fn array_into_wasm(array: AnyArrayView) -> Result<Record, RuntimeError> {
363 fn list_from_standard_layout<'a, T: 'static + Copy, S: Data<Elem = T>, D: Dimension>(
364 array: &'a ArrayBase<S, D>,
365 ) -> List
366 where
367 List: From<&'a [T]> + From<Arc<[T]>>,
368 {
369 #[expect(clippy::option_if_let_else)]
370 if let Some(slice) = array.as_slice() {
371 List::from(slice)
372 } else {
373 List::from(Arc::from(array.iter().copied().collect::<Vec<T>>()))
374 }
375 }
376
377 let any_array_data_ty = Self::any_array_data_ty().clone();
378
379 let data = match &array {
380 AnyArrayView::U8(array) => Variant::new(
381 any_array_data_ty,
382 0,
383 Some(Value::List(list_from_standard_layout(array))),
384 ),
385 AnyArrayView::U16(array) => Variant::new(
386 any_array_data_ty,
387 1,
388 Some(Value::List(list_from_standard_layout(array))),
389 ),
390 AnyArrayView::U32(array) => Variant::new(
391 any_array_data_ty,
392 2,
393 Some(Value::List(list_from_standard_layout(array))),
394 ),
395 AnyArrayView::U64(array) => Variant::new(
396 any_array_data_ty,
397 3,
398 Some(Value::List(list_from_standard_layout(array))),
399 ),
400 AnyArrayView::I8(array) => Variant::new(
401 any_array_data_ty,
402 4,
403 Some(Value::List(list_from_standard_layout(array))),
404 ),
405 AnyArrayView::I16(array) => Variant::new(
406 any_array_data_ty,
407 5,
408 Some(Value::List(list_from_standard_layout(array))),
409 ),
410 AnyArrayView::I32(array) => Variant::new(
411 any_array_data_ty,
412 6,
413 Some(Value::List(list_from_standard_layout(array))),
414 ),
415 AnyArrayView::I64(array) => Variant::new(
416 any_array_data_ty,
417 7,
418 Some(Value::List(list_from_standard_layout(array))),
419 ),
420 AnyArrayView::F32(array) => Variant::new(
421 any_array_data_ty,
422 8,
423 Some(Value::List(list_from_standard_layout(array))),
424 ),
425 AnyArrayView::F64(array) => Variant::new(
426 any_array_data_ty,
427 9,
428 Some(Value::List(list_from_standard_layout(array))),
429 ),
430 array => Err(anyhow::Error::msg(format!(
431 "unknown array dtype type {}",
432 array.dtype()
433 ))),
434 }?;
435
436 let shape = array
437 .shape()
438 .iter()
439 .map(|s| u32::try_from(*s))
440 .collect::<Result<Vec<_>, _>>()
441 .map_err(anyhow::Error::new)?;
442 let shape = List::from(Arc::from(shape));
443
444 Record::new(
445 Self::any_array_ty().clone(),
446 [
447 ("data", Value::Variant(data)),
448 ("shape", Value::List(shape)),
449 ],
450 )
451 .map_err(RuntimeError::from)
452 }
453
454 fn any_array_dtype_ty() -> &'static EnumType {
455 static ANY_ARRAY_DTYPE_TY: OnceLock<EnumType> = OnceLock::new();
456
457 #[expect(clippy::expect_used)]
458 ANY_ARRAY_DTYPE_TY.get_or_init(|| {
461 EnumType::new(
462 None,
463 [
464 "u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64", "f32", "f64",
465 ],
466 )
467 .expect("constructing the any-array-dtype enum type must not fail")
468 })
469 }
470
471 fn any_array_prototype_ty() -> &'static RecordType {
472 static ANY_ARRAY_PROTOTYPE_TY: OnceLock<RecordType> = OnceLock::new();
473
474 #[expect(clippy::expect_used)]
475 ANY_ARRAY_PROTOTYPE_TY.get_or_init(|| {
478 RecordType::new(
479 None,
480 [
481 ("dtype", ValueType::Enum(Self::any_array_dtype_ty().clone())),
482 ("shape", ValueType::List(ListType::new(ValueType::U32))),
483 ],
484 )
485 .expect("constructing the any-array-prototype record type must not fail")
486 })
487 }
488
489 fn array_prototype_into_wasm(
490 dtype: AnyArrayDType,
491 shape: &[usize],
492 ) -> Result<Record, RuntimeError> {
493 let any_array_dtype_ty = Self::any_array_dtype_ty().clone();
494
495 let dtype = match dtype {
496 AnyArrayDType::U8 => Enum::new(any_array_dtype_ty, 0),
497 AnyArrayDType::U16 => Enum::new(any_array_dtype_ty, 1),
498 AnyArrayDType::U32 => Enum::new(any_array_dtype_ty, 2),
499 AnyArrayDType::U64 => Enum::new(any_array_dtype_ty, 3),
500 AnyArrayDType::I8 => Enum::new(any_array_dtype_ty, 4),
501 AnyArrayDType::I16 => Enum::new(any_array_dtype_ty, 5),
502 AnyArrayDType::I32 => Enum::new(any_array_dtype_ty, 6),
503 AnyArrayDType::I64 => Enum::new(any_array_dtype_ty, 7),
504 AnyArrayDType::F32 => Enum::new(any_array_dtype_ty, 8),
505 AnyArrayDType::F64 => Enum::new(any_array_dtype_ty, 9),
506 dtype => Err(anyhow::Error::msg(format!(
507 "unknown array dtype type {dtype}"
508 ))),
509 }?;
510
511 let shape = shape
512 .iter()
513 .map(|s| u32::try_from(*s))
514 .collect::<Result<Vec<_>, _>>()
515 .map_err(anyhow::Error::new)?;
516 let shape = List::from(Arc::from(shape));
517
518 Record::new(
519 Self::any_array_prototype_ty().clone(),
520 [("dtype", Value::Enum(dtype)), ("shape", Value::List(shape))],
521 )
522 .map_err(RuntimeError::from)
523 }
524
525 fn with_array_view_from_wasm_record<O>(
526 record: &Record,
527 with: impl for<'a> FnOnce(AnyArrayView<'a>) -> Result<O, RuntimeError>,
528 ) -> Result<O, RuntimeError> {
529 let Some(Value::List(shape)) = record.field("shape") else {
530 return Err(RuntimeError::from(anyhow::Error::msg(format!(
531 "process result record {record:?} is missing shape field"
532 ))));
533 };
534 let shape = shape
535 .typed::<u32>()?
536 .iter()
537 .copied()
538 .map(usize::try_from)
539 .collect::<Result<Vec<_>, _>>()
540 .map_err(anyhow::Error::new)?;
541
542 let Some(Value::Variant(data)) = record.field("data") else {
543 return Err(RuntimeError::from(anyhow::Error::msg(format!(
544 "process result record {record:?} is missing data field"
545 ))));
546 };
547 let Some(Value::List(values)) = data.value() else {
548 return Err(RuntimeError::from(anyhow::Error::msg(format!(
549 "process result buffer has an invalid variant type {:?}",
550 data.value().map(|v| v.ty())
551 ))));
552 };
553
554 let array = match data.discriminant() {
555 0 => AnyArrayView::U8(
556 ArrayView::from_shape(shape.as_slice(), values.typed()?)
557 .map_err(anyhow::Error::new)?,
558 ),
559 1 => AnyArrayView::U16(
560 ArrayView::from_shape(shape.as_slice(), values.typed()?)
561 .map_err(anyhow::Error::new)?,
562 ),
563 2 => AnyArrayView::U32(
564 ArrayView::from_shape(shape.as_slice(), values.typed()?)
565 .map_err(anyhow::Error::new)?,
566 ),
567 3 => AnyArrayView::U64(
568 ArrayView::from_shape(shape.as_slice(), values.typed()?)
569 .map_err(anyhow::Error::new)?,
570 ),
571 4 => AnyArrayView::I8(
572 ArrayView::from_shape(shape.as_slice(), values.typed()?)
573 .map_err(anyhow::Error::new)?,
574 ),
575 5 => AnyArrayView::I16(
576 ArrayView::from_shape(shape.as_slice(), values.typed()?)
577 .map_err(anyhow::Error::new)?,
578 ),
579 6 => AnyArrayView::I32(
580 ArrayView::from_shape(shape.as_slice(), values.typed()?)
581 .map_err(anyhow::Error::new)?,
582 ),
583 7 => AnyArrayView::I64(
584 ArrayView::from_shape(shape.as_slice(), values.typed()?)
585 .map_err(anyhow::Error::new)?,
586 ),
587 8 => AnyArrayView::F32(
588 ArrayView::from_shape(shape.as_slice(), values.typed()?)
589 .map_err(anyhow::Error::new)?,
590 ),
591 9 => AnyArrayView::F64(
592 ArrayView::from_shape(shape.as_slice(), values.typed()?)
593 .map_err(anyhow::Error::new)?,
594 ),
595 discriminant => {
596 return Err(RuntimeError::from(anyhow::Error::msg(format!(
597 "process result buffer has an invalid variant [{discriminant}]:{:?}",
598 data.value().map(|v| v.ty())
599 ))))
600 }
601 };
602
603 with(array)
604 }
605}