numcodecs_wasm_host/
codec.rs

1use std::sync::{Arc, OnceLock};
2
3use ndarray::{ArrayBase, ArrayView, Data, Dimension};
4use numcodecs::{AnyArray, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray};
5use schemars::Schema;
6use serde::Serializer;
7use wasm_component_layer::{
8    AsContextMut, Enum, EnumType, Func, Instance, List, ListType, Record, RecordType, ResourceOwn,
9    Value, ValueType, Variant, VariantCase, VariantType,
10};
11
12use crate::{
13    component::WasmCodecComponent,
14    error::{CodecError, RuntimeError},
15    wit::guest_error_from_wasm,
16};
17
18/// Codec instantiated inside a WebAssembly component.
19///
20/// `WasmCodec` does not implement the [`Codec`][numcodecs::Codec],
21/// [`DynCodec`][numcodecs::DynCodec], [`Clone`], or [`Drop`] traits itself so
22/// that it can expose un-opinionated bindings. However, it provides methods
23/// that can be used to implement these traits on a wrapper.
24pub struct WasmCodec {
25    // codec
26    pub(crate) resource: ResourceOwn,
27    // precomputed properties
28    pub(crate) codec_id: Arc<str>,
29    pub(crate) codec_config_schema: Arc<Schema>,
30    // wit functions
31    // FIXME: make typed instead
32    pub(crate) from_config: Func,
33    pub(crate) encode: Func,
34    pub(crate) decode: Func,
35    pub(crate) decode_into: Func,
36    pub(crate) get_config: Func,
37    // wasm component instance
38    pub(crate) instance: Instance,
39}
40
41/// Methods for implementing the [`Codec`][numcodecs::Codec] trait
42impl WasmCodec {
43    #[expect(clippy::needless_pass_by_value)]
44    /// Encodes the `data` and returns the result.
45    ///
46    /// The `ctx` must refer to the same store in which the component was
47    /// instantiated.
48    ///
49    /// # Errors
50    ///
51    /// Errors with a
52    /// - [`CodecError`] if encoding the buffer fails.
53    /// - [`RuntimeError`] if interacting with the component fails.
54    pub fn encode(
55        &self,
56        ctx: impl AsContextMut,
57        data: AnyCowArray,
58    ) -> Result<Result<AnyArray, CodecError>, RuntimeError> {
59        self.process(
60            ctx,
61            data.view(),
62            None,
63            |ctx, arguments, results| self.encode.call(ctx, arguments, results),
64            |encoded| Ok(encoded.into_owned()),
65        )
66    }
67
68    #[expect(clippy::needless_pass_by_value)]
69    /// Decodes the `encoded` data and returns the result.
70    ///
71    /// The `ctx` must refer to the same store in which the component was
72    /// instantiated.
73    ///
74    /// # Errors
75    ///
76    /// Errors with a
77    /// - [`CodecError`] if decoding the buffer fails.
78    /// - [`RuntimeError`] if interacting with the component fails.
79    pub fn decode(
80        &self,
81        ctx: impl AsContextMut,
82        encoded: AnyCowArray,
83    ) -> Result<Result<AnyArray, CodecError>, RuntimeError> {
84        self.process(
85            ctx,
86            encoded.view(),
87            None,
88            |ctx, arguments, results| self.decode.call(ctx, arguments, results),
89            |decoded| Ok(decoded.into_owned()),
90        )
91    }
92
93    /// Decodes the `encoded` data and writes the result into the provided
94    /// `decoded` output.
95    ///
96    /// The output must have the correct type and shape.
97    ///
98    /// The `ctx` must refer to the same store in which the component was
99    /// instantiated.
100    ///
101    /// # Errors
102    ///
103    /// Errors with a
104    /// - [`CodecError`] if decoding the buffer fails.
105    /// - [`RuntimeError`] if interacting with the component fails.
106    pub fn decode_into(
107        &self,
108        ctx: impl AsContextMut,
109        encoded: AnyArrayView,
110        mut decoded: AnyArrayViewMut,
111    ) -> Result<Result<(), CodecError>, RuntimeError> {
112        self.process(
113            ctx,
114            encoded,
115            #[expect(clippy::unnecessary_to_owned)] // we need the lifetime extension
116            Some((decoded.dtype(), &decoded.shape().to_vec())),
117            |ctx, arguments, results| self.decode_into.call(ctx, arguments, results),
118            |decoded_in| {
119                decoded
120                    .assign(&decoded_in)
121                    .map_err(anyhow::Error::new)
122                    .map_err(RuntimeError::from)
123            },
124        )
125    }
126}
127
128/// Methods for implementing the [`DynCodec`][numcodecs::DynCodec] trait
129impl WasmCodec {
130    /// Returns the component object for this codec.
131    #[must_use]
132    pub fn ty(&self) -> WasmCodecComponent {
133        WasmCodecComponent {
134            codec_id: self.codec_id.clone(),
135            codec_config_schema: self.codec_config_schema.clone(),
136            from_config: self.from_config.clone(),
137            encode: self.encode.clone(),
138            decode: self.decode.clone(),
139            decode_into: self.decode_into.clone(),
140            get_config: self.get_config.clone(),
141            instance: self.instance.clone(),
142        }
143    }
144
145    /// Serializes the configuration parameters for this codec.
146    ///
147    /// The `ctx` must refer to the same store in which the component was
148    /// instantiated.
149    ///
150    /// # Errors
151    ///
152    /// Errors if serializing the codec configuration or interacting with the
153    /// component fails.
154    pub fn get_config<S: Serializer>(
155        &self,
156        mut ctx: impl AsContextMut,
157        serializer: S,
158    ) -> Result<S::Ok, S::Error> {
159        let resource = self
160            .resource
161            .borrow(&mut ctx)
162            .map_err(serde::ser::Error::custom)?;
163
164        let arg = Value::Borrow(resource);
165        let mut result = Value::U8(0);
166
167        self.get_config
168            .call(
169                &mut ctx,
170                std::slice::from_ref(&arg),
171                std::slice::from_mut(&mut result),
172            )
173            .map_err(serde::ser::Error::custom)?;
174
175        let config = match result {
176            Value::Result(result) => match &*result {
177                Ok(Some(Value::String(config))) => config.clone(),
178                Err(err) => match guest_error_from_wasm(err.as_ref()) {
179                    Ok(err) => return Err(serde::ser::Error::custom(err)),
180                    Err(err) => return Err(serde::ser::Error::custom(err)),
181                },
182                result => {
183                    return Err(serde::ser::Error::custom(format!(
184                        "unexpected get-config result value {result:?}"
185                    )))
186                }
187            },
188            value => {
189                return Err(serde::ser::Error::custom(format!(
190                    "unexpected get-config result value {value:?}"
191                )))
192            }
193        };
194
195        serde_transcode::transcode(&mut serde_json::Deserializer::from_str(&config), serializer)
196    }
197}
198
199/// Methods for implementing the [`Clone`] trait
200impl WasmCodec {
201    /// Try cloning the codec by recreating it from its configuration.
202    ///
203    /// The `ctx` must refer to the same store in which the component was
204    /// instantiated.
205    ///
206    /// # Errors
207    ///
208    /// Errors if serializing the codec configuration, constructing the new
209    /// codec, or interacting with the component fails.
210    pub fn try_clone(&self, mut ctx: impl AsContextMut) -> Result<Self, serde_json::Error> {
211        let mut config = self.get_config(&mut ctx, serde_json::value::Serializer)?;
212
213        if let Some(config) = config.as_object_mut() {
214            config.remove("id");
215        }
216
217        let codec: Self = self.ty().codec_from_config(ctx, config)?;
218
219        Ok(codec)
220    }
221
222    /// Try cloning the codec into a different context by recreating it from
223    /// its configuration.
224    ///
225    /// The `ctx_from` must refer to the same store in which the component was
226    /// instantiated.
227    ///
228    /// # Errors
229    ///
230    /// Errors if serializing the codec configuration, constructing the new
231    /// codec, or interacting with the component fails.
232    pub fn try_clone_into(
233        &self,
234        ctx_from: impl AsContextMut,
235        ctx_into: impl AsContextMut,
236    ) -> Result<Self, serde_json::Error> {
237        let mut config = self.get_config(ctx_from, serde_json::value::Serializer)?;
238
239        if let Some(config) = config.as_object_mut() {
240            config.remove("id");
241        }
242
243        let codec: Self = self.ty().codec_from_config(ctx_into, config)?;
244
245        Ok(codec)
246    }
247}
248
249/// Methods for implementing the [`Drop`] trait
250impl WasmCodec {
251    /// Try dropping the codec.
252    ///
253    /// The `ctx` must refer to the same store in which the component was
254    /// instantiated.
255    ///
256    /// # Errors
257    ///
258    /// Errors if the codec's resource is borrowed or has already been dropped.
259    pub fn try_drop(&self, ctx: impl AsContextMut) -> Result<(), RuntimeError> {
260        self.resource.drop(ctx).map_err(RuntimeError::from)
261    }
262}
263
264impl WasmCodec {
265    fn process<O, C: AsContextMut>(
266        &self,
267        mut ctx: C,
268        data: AnyArrayView,
269        output_prototype: Option<(AnyArrayDType, &[usize])>,
270        process: impl FnOnce(&mut C, &[Value], &mut [Value]) -> anyhow::Result<()>,
271        with_result: impl for<'a> FnOnce(AnyArrayView<'a>) -> Result<O, RuntimeError>,
272    ) -> Result<Result<O, CodecError>, RuntimeError> {
273        let resource = self.resource.borrow(&mut ctx)?;
274
275        let array = Self::array_into_wasm(data)?;
276
277        let output_prototype = output_prototype
278            .map(|(dtype, shape)| Self::array_prototype_into_wasm(dtype, shape))
279            .transpose()?;
280
281        let mut result = Value::U8(0);
282
283        process(
284            &mut ctx,
285            &match output_prototype {
286                None => vec![Value::Borrow(resource), Value::Record(array)],
287                Some(output) => vec![
288                    Value::Borrow(resource),
289                    Value::Record(array),
290                    Value::Record(output),
291                ],
292            },
293            std::slice::from_mut(&mut result),
294        )?;
295
296        match result {
297            Value::Result(result) => match &*result {
298                Ok(Some(Value::Record(record))) if &record.ty() == Self::any_array_ty() => {
299                    Self::with_array_view_from_wasm_record(record, |array| {
300                        Ok(Ok(with_result(array)?))
301                    })
302                }
303                Err(err) => guest_error_from_wasm(err.as_ref()).map(Err),
304                result => Err(RuntimeError::from(anyhow::Error::msg(format!(
305                    "unexpected process result value {result:?}"
306                )))),
307            },
308            value => Err(RuntimeError::from(anyhow::Error::msg(format!(
309                "unexpected process result value {value:?}"
310            )))),
311        }
312    }
313
314    fn any_array_data_ty() -> &'static VariantType {
315        static ANY_ARRAY_DATA_TY: OnceLock<VariantType> = OnceLock::new();
316
317        #[expect(clippy::expect_used)]
318        // FIXME: use OnceLock::get_or_try_init,
319        //        blocked on https://github.com/rust-lang/rust/issues/109737
320        ANY_ARRAY_DATA_TY.get_or_init(|| {
321            VariantType::new(
322                None,
323                [
324                    VariantCase::new("u8", Some(ValueType::List(ListType::new(ValueType::U8)))),
325                    VariantCase::new("u16", Some(ValueType::List(ListType::new(ValueType::U16)))),
326                    VariantCase::new("u32", Some(ValueType::List(ListType::new(ValueType::U32)))),
327                    VariantCase::new("u64", Some(ValueType::List(ListType::new(ValueType::U64)))),
328                    VariantCase::new("i8", Some(ValueType::List(ListType::new(ValueType::S8)))),
329                    VariantCase::new("i16", Some(ValueType::List(ListType::new(ValueType::S16)))),
330                    VariantCase::new("i32", Some(ValueType::List(ListType::new(ValueType::S32)))),
331                    VariantCase::new("i64", Some(ValueType::List(ListType::new(ValueType::S64)))),
332                    VariantCase::new("f32", Some(ValueType::List(ListType::new(ValueType::F32)))),
333                    VariantCase::new("f64", Some(ValueType::List(ListType::new(ValueType::F64)))),
334                ],
335            )
336            .expect("constructing the any-array-data variant type must not fail")
337        })
338    }
339
340    fn any_array_ty() -> &'static RecordType {
341        static ANY_ARRAY_TY: OnceLock<RecordType> = OnceLock::new();
342
343        #[expect(clippy::expect_used)]
344        // FIXME: use OnceLock::get_or_try_init,
345        //        blocked on https://github.com/rust-lang/rust/issues/109737
346        ANY_ARRAY_TY.get_or_init(|| {
347            RecordType::new(
348                None,
349                [
350                    (
351                        "data",
352                        ValueType::Variant(Self::any_array_data_ty().clone()),
353                    ),
354                    ("shape", ValueType::List(ListType::new(ValueType::U32))),
355                ],
356            )
357            .expect("constructing the any-array record type must not fail")
358        })
359    }
360
361    #[expect(clippy::needless_pass_by_value)]
362    fn array_into_wasm(array: AnyArrayView) -> Result<Record, RuntimeError> {
363        fn list_from_standard_layout<'a, T: 'static + Copy, S: Data<Elem = T>, D: Dimension>(
364            array: &'a ArrayBase<S, D>,
365        ) -> List
366        where
367            List: From<&'a [T]> + From<Arc<[T]>>,
368        {
369            #[expect(clippy::option_if_let_else)]
370            if let Some(slice) = array.as_slice() {
371                List::from(slice)
372            } else {
373                List::from(Arc::from(array.iter().copied().collect::<Vec<T>>()))
374            }
375        }
376
377        let any_array_data_ty = Self::any_array_data_ty().clone();
378
379        let data = match &array {
380            AnyArrayView::U8(array) => Variant::new(
381                any_array_data_ty,
382                0,
383                Some(Value::List(list_from_standard_layout(array))),
384            ),
385            AnyArrayView::U16(array) => Variant::new(
386                any_array_data_ty,
387                1,
388                Some(Value::List(list_from_standard_layout(array))),
389            ),
390            AnyArrayView::U32(array) => Variant::new(
391                any_array_data_ty,
392                2,
393                Some(Value::List(list_from_standard_layout(array))),
394            ),
395            AnyArrayView::U64(array) => Variant::new(
396                any_array_data_ty,
397                3,
398                Some(Value::List(list_from_standard_layout(array))),
399            ),
400            AnyArrayView::I8(array) => Variant::new(
401                any_array_data_ty,
402                4,
403                Some(Value::List(list_from_standard_layout(array))),
404            ),
405            AnyArrayView::I16(array) => Variant::new(
406                any_array_data_ty,
407                5,
408                Some(Value::List(list_from_standard_layout(array))),
409            ),
410            AnyArrayView::I32(array) => Variant::new(
411                any_array_data_ty,
412                6,
413                Some(Value::List(list_from_standard_layout(array))),
414            ),
415            AnyArrayView::I64(array) => Variant::new(
416                any_array_data_ty,
417                7,
418                Some(Value::List(list_from_standard_layout(array))),
419            ),
420            AnyArrayView::F32(array) => Variant::new(
421                any_array_data_ty,
422                8,
423                Some(Value::List(list_from_standard_layout(array))),
424            ),
425            AnyArrayView::F64(array) => Variant::new(
426                any_array_data_ty,
427                9,
428                Some(Value::List(list_from_standard_layout(array))),
429            ),
430            array => Err(anyhow::Error::msg(format!(
431                "unknown array dtype type {}",
432                array.dtype()
433            ))),
434        }?;
435
436        let shape = array
437            .shape()
438            .iter()
439            .map(|s| u32::try_from(*s))
440            .collect::<Result<Vec<_>, _>>()
441            .map_err(anyhow::Error::new)?;
442        let shape = List::from(Arc::from(shape));
443
444        Record::new(
445            Self::any_array_ty().clone(),
446            [
447                ("data", Value::Variant(data)),
448                ("shape", Value::List(shape)),
449            ],
450        )
451        .map_err(RuntimeError::from)
452    }
453
454    fn any_array_dtype_ty() -> &'static EnumType {
455        static ANY_ARRAY_DTYPE_TY: OnceLock<EnumType> = OnceLock::new();
456
457        #[expect(clippy::expect_used)]
458        // FIXME: use OnceLock::get_or_try_init,
459        //        blocked on https://github.com/rust-lang/rust/issues/109737
460        ANY_ARRAY_DTYPE_TY.get_or_init(|| {
461            EnumType::new(
462                None,
463                [
464                    "u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64", "f32", "f64",
465                ],
466            )
467            .expect("constructing the any-array-dtype enum type must not fail")
468        })
469    }
470
471    fn any_array_prototype_ty() -> &'static RecordType {
472        static ANY_ARRAY_PROTOTYPE_TY: OnceLock<RecordType> = OnceLock::new();
473
474        #[expect(clippy::expect_used)]
475        // FIXME: use OnceLock::get_or_try_init,
476        //        blocked on https://github.com/rust-lang/rust/issues/109737
477        ANY_ARRAY_PROTOTYPE_TY.get_or_init(|| {
478            RecordType::new(
479                None,
480                [
481                    ("dtype", ValueType::Enum(Self::any_array_dtype_ty().clone())),
482                    ("shape", ValueType::List(ListType::new(ValueType::U32))),
483                ],
484            )
485            .expect("constructing the any-array-prototype record type must not fail")
486        })
487    }
488
489    fn array_prototype_into_wasm(
490        dtype: AnyArrayDType,
491        shape: &[usize],
492    ) -> Result<Record, RuntimeError> {
493        let any_array_dtype_ty = Self::any_array_dtype_ty().clone();
494
495        let dtype = match dtype {
496            AnyArrayDType::U8 => Enum::new(any_array_dtype_ty, 0),
497            AnyArrayDType::U16 => Enum::new(any_array_dtype_ty, 1),
498            AnyArrayDType::U32 => Enum::new(any_array_dtype_ty, 2),
499            AnyArrayDType::U64 => Enum::new(any_array_dtype_ty, 3),
500            AnyArrayDType::I8 => Enum::new(any_array_dtype_ty, 4),
501            AnyArrayDType::I16 => Enum::new(any_array_dtype_ty, 5),
502            AnyArrayDType::I32 => Enum::new(any_array_dtype_ty, 6),
503            AnyArrayDType::I64 => Enum::new(any_array_dtype_ty, 7),
504            AnyArrayDType::F32 => Enum::new(any_array_dtype_ty, 8),
505            AnyArrayDType::F64 => Enum::new(any_array_dtype_ty, 9),
506            dtype => Err(anyhow::Error::msg(format!(
507                "unknown array dtype type {dtype}"
508            ))),
509        }?;
510
511        let shape = shape
512            .iter()
513            .map(|s| u32::try_from(*s))
514            .collect::<Result<Vec<_>, _>>()
515            .map_err(anyhow::Error::new)?;
516        let shape = List::from(Arc::from(shape));
517
518        Record::new(
519            Self::any_array_prototype_ty().clone(),
520            [("dtype", Value::Enum(dtype)), ("shape", Value::List(shape))],
521        )
522        .map_err(RuntimeError::from)
523    }
524
525    fn with_array_view_from_wasm_record<O>(
526        record: &Record,
527        with: impl for<'a> FnOnce(AnyArrayView<'a>) -> Result<O, RuntimeError>,
528    ) -> Result<O, RuntimeError> {
529        let Some(Value::List(shape)) = record.field("shape") else {
530            return Err(RuntimeError::from(anyhow::Error::msg(format!(
531                "process result record {record:?} is missing shape field"
532            ))));
533        };
534        let shape = shape
535            .typed::<u32>()?
536            .iter()
537            .copied()
538            .map(usize::try_from)
539            .collect::<Result<Vec<_>, _>>()
540            .map_err(anyhow::Error::new)?;
541
542        let Some(Value::Variant(data)) = record.field("data") else {
543            return Err(RuntimeError::from(anyhow::Error::msg(format!(
544                "process result record {record:?} is missing data field"
545            ))));
546        };
547        let Some(Value::List(values)) = data.value() else {
548            return Err(RuntimeError::from(anyhow::Error::msg(format!(
549                "process result buffer has an invalid variant type {:?}",
550                data.value().map(|v| v.ty())
551            ))));
552        };
553
554        let array = match data.discriminant() {
555            0 => AnyArrayView::U8(
556                ArrayView::from_shape(shape.as_slice(), values.typed()?)
557                    .map_err(anyhow::Error::new)?,
558            ),
559            1 => AnyArrayView::U16(
560                ArrayView::from_shape(shape.as_slice(), values.typed()?)
561                    .map_err(anyhow::Error::new)?,
562            ),
563            2 => AnyArrayView::U32(
564                ArrayView::from_shape(shape.as_slice(), values.typed()?)
565                    .map_err(anyhow::Error::new)?,
566            ),
567            3 => AnyArrayView::U64(
568                ArrayView::from_shape(shape.as_slice(), values.typed()?)
569                    .map_err(anyhow::Error::new)?,
570            ),
571            4 => AnyArrayView::I8(
572                ArrayView::from_shape(shape.as_slice(), values.typed()?)
573                    .map_err(anyhow::Error::new)?,
574            ),
575            5 => AnyArrayView::I16(
576                ArrayView::from_shape(shape.as_slice(), values.typed()?)
577                    .map_err(anyhow::Error::new)?,
578            ),
579            6 => AnyArrayView::I32(
580                ArrayView::from_shape(shape.as_slice(), values.typed()?)
581                    .map_err(anyhow::Error::new)?,
582            ),
583            7 => AnyArrayView::I64(
584                ArrayView::from_shape(shape.as_slice(), values.typed()?)
585                    .map_err(anyhow::Error::new)?,
586            ),
587            8 => AnyArrayView::F32(
588                ArrayView::from_shape(shape.as_slice(), values.typed()?)
589                    .map_err(anyhow::Error::new)?,
590            ),
591            9 => AnyArrayView::F64(
592                ArrayView::from_shape(shape.as_slice(), values.typed()?)
593                    .map_err(anyhow::Error::new)?,
594            ),
595            discriminant => {
596                return Err(RuntimeError::from(anyhow::Error::msg(format!(
597                    "process result buffer has an invalid variant [{discriminant}]:{:?}",
598                    data.value().map(|v| v.ty())
599                ))))
600            }
601        };
602
603        with(array)
604    }
605}