numcodecs_python/
schema.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, hash_map::Entry},
4    io,
5    num::FpCategory,
6};
7
8use pyo3::{intern, prelude::*, sync::PyOnceLock};
9use pythonize::{PythonizeError, depythonize};
10use schemars::Schema;
11use serde_json::{Map, Value};
12use thiserror::Error;
13
14use crate::{PyCodecClass, export::RustCodec};
15
16macro_rules! once {
17    ($py:ident, $module:literal $(, $path:literal)*) => {{
18        fn once(py: Python<'_>) -> Result<&Bound<'_, PyAny>, PyErr> {
19            static ONCE: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
20            Ok(ONCE.get_or_try_init(py, || -> Result<Py<PyAny>, PyErr> {
21                Ok(py
22                    .import(intern!(py, $module))?
23                    $(.getattr(intern!(py, $path))?)*
24                    .unbind())
25            })?.bind(py))
26        }
27
28        once($py)
29    }};
30}
31
32pub fn schema_from_codec_class(
33    py: Python,
34    class: &Bound<PyCodecClass>,
35) -> Result<Schema, SchemaError> {
36    if let Ok(schema) = class.getattr(intern!(py, RustCodec::SCHEMA_ATTRIBUTE)) {
37        return depythonize(&schema)
38            .map_err(|err| SchemaError::InvalidCachedJsonSchema { source: err });
39    }
40
41    let mut schema = Schema::default();
42
43    {
44        let schema = schema.ensure_object();
45
46        schema.insert(String::from("type"), Value::String(String::from("object")));
47
48        if let Ok(init) = class.getattr(intern!(py, "__init__")) {
49            let mut properties = Map::new();
50            let mut additional_properties = false;
51            let mut required = Vec::new();
52
53            let object_init = once!(py, "builtins", "object", "__init__")?;
54            let signature = once!(py, "inspect", "signature")?;
55            let empty_parameter = once!(py, "inspect", "Parameter", "empty")?;
56            let args_parameter = once!(py, "inspect", "Parameter", "VAR_POSITIONAL")?;
57            let kwargs_parameter = once!(py, "inspect", "Parameter", "VAR_KEYWORD")?;
58
59            for (i, param) in signature
60                .call1((&init,))?
61                .getattr(intern!(py, "parameters"))?
62                .call_method0(intern!(py, "items"))?
63                .try_iter()?
64                .enumerate()
65            {
66                let (name, param): (String, Bound<PyAny>) = param?.extract()?;
67
68                if i == 0 && name == "self" {
69                    continue;
70                }
71
72                let kind = param.getattr(intern!(py, "kind"))?;
73
74                if kind.eq(args_parameter)? && !init.eq(object_init)? {
75                    return Err(SchemaError::ArgsParameterInSignature);
76                }
77
78                if kind.eq(kwargs_parameter)? {
79                    additional_properties = true;
80                } else {
81                    let default = param.getattr(intern!(py, "default"))?;
82
83                    let mut parameter = Map::new();
84
85                    if default.eq(empty_parameter)? {
86                        required.push(Value::String(name.clone()));
87                    } else {
88                        let default = depythonize(&default).map_err(|err| {
89                            SchemaError::InvalidParameterDefault {
90                                name: name.clone(),
91                                source: err,
92                            }
93                        })?;
94                        parameter.insert(String::from("default"), default);
95                    }
96
97                    properties.insert(name, Value::Object(parameter));
98                }
99            }
100
101            schema.insert(
102                String::from("additionalProperties"),
103                Value::Bool(additional_properties),
104            );
105            schema.insert(String::from("properties"), Value::Object(properties));
106            schema.insert(String::from("required"), Value::Array(required));
107        } else {
108            schema.insert(String::from("additionalProperties"), Value::Bool(true));
109        }
110
111        if let Ok(doc) = class.getattr(intern!(py, "__doc__")) {
112            if !doc.is_none() {
113                let doc: String = doc
114                    .extract()
115                    .map_err(|err| SchemaError::InvalidClassDocs { source: err })?;
116                schema.insert(String::from("description"), Value::String(doc));
117            }
118        }
119
120        let name = class
121            .getattr(intern!(py, "__name__"))
122            .and_then(|name| name.extract())
123            .map_err(|err| SchemaError::InvalidClassName { source: err })?;
124        schema.insert(String::from("title"), Value::String(name));
125
126        schema.insert(
127            String::from("$schema"),
128            Value::String(String::from("https://json-schema.org/draft/2020-12/schema")),
129        );
130    }
131
132    Ok(schema)
133}
134
135pub fn docs_from_schema(schema: &Schema) -> Option<String> {
136    let parameters = parameters_from_schema(schema);
137    let schema = schema.as_object()?;
138
139    let mut docs = String::new();
140
141    if let Some(Value::String(description)) = schema.get("description") {
142        docs.push_str(description);
143        docs.push_str("\n\n");
144    }
145
146    if !parameters.named.is_empty() || parameters.additional {
147        docs.push_str("Parameters\n----------\n");
148    }
149
150    for parameter in &parameters.named {
151        docs.push_str(parameter.name);
152
153        docs.push_str(" : ...");
154
155        if !parameter.required {
156            docs.push_str(", optional");
157        }
158
159        if let Some(default) = parameter.default {
160            docs.push_str(", default = ");
161            JsonToPythonFormatter::push_to_string(&mut docs, default);
162        }
163
164        docs.push('\n');
165
166        if let Some(info) = &parameter.docs {
167            docs.push_str("    ");
168            docs.push_str(&info.replace('\n', "\n    "));
169        }
170
171        docs.push('\n');
172    }
173
174    if parameters.additional {
175        docs.push_str("**kwargs\n");
176        docs.push_str("    ");
177
178        if parameters.named.is_empty() {
179            docs.push_str("This codec takes *any* parameters.");
180        } else {
181            docs.push_str("This codec takes *any* additional parameters.");
182        }
183    } else if parameters.named.is_empty() {
184        docs.push_str("This codec does *not* take any parameters.");
185    }
186
187    docs.truncate(docs.trim_end().len());
188
189    Some(docs)
190}
191
192pub fn signature_from_schema(schema: &Schema) -> String {
193    let parameters = parameters_from_schema(schema);
194
195    let mut signature = String::new();
196    signature.push_str("self");
197
198    for parameter in parameters.named {
199        signature.push_str(", ");
200        signature.push_str(parameter.name);
201
202        if let Some(default) = parameter.default {
203            signature.push('=');
204            JsonToPythonFormatter::push_to_string(&mut signature, default);
205        } else if !parameter.required {
206            signature.push_str("=None");
207        }
208    }
209
210    if parameters.additional {
211        signature.push_str(", **kwargs");
212    }
213
214    signature
215}
216
217fn parameters_from_schema(schema: &Schema) -> Parameters<'_> {
218    // schema = true means that any parameters are allowed
219    if schema.as_bool() == Some(true) {
220        return Parameters {
221            named: Vec::new(),
222            additional: true,
223        };
224    }
225
226    // schema = false means that no config is valid
227    // we approximate that by saying that no parameters are allowed
228    let Some(schema) = schema.as_object() else {
229        return Parameters {
230            named: Vec::new(),
231            additional: false,
232        };
233    };
234
235    let mut parameters = Vec::new();
236
237    let required = match schema.get("required") {
238        Some(Value::Array(required)) => &**required,
239        _ => &[],
240    };
241
242    // extract the top-level parameters
243    if let Some(Value::Object(properties)) = schema.get("properties") {
244        for (name, parameter) in properties {
245            parameters.push(Parameter::new(name, parameter, required));
246        }
247    }
248
249    let mut additional = false;
250
251    extend_parameters_from_one_of_schema(schema, &mut parameters, &mut additional);
252
253    // iterate over allOf to handle flattened enums
254    if let Some(Value::Array(all)) = schema.get("allOf") {
255        for variant in all {
256            if let Some(variant) = variant.as_object() {
257                extend_parameters_from_one_of_schema(variant, &mut parameters, &mut additional);
258            }
259        }
260    }
261
262    // sort parameters by name and so that required parameters come first
263    parameters.sort_by_key(|p| (!p.required, p.name));
264
265    additional = match (
266        schema.get("additionalProperties"),
267        schema.get("unevaluatedProperties"),
268    ) {
269        (Some(Value::Bool(false)), None) => additional,
270        (None | Some(Value::Bool(false)), Some(Value::Bool(false))) => false,
271        _ => true,
272    };
273
274    Parameters {
275        named: parameters,
276        additional,
277    }
278}
279
280fn extend_parameters_from_one_of_schema<'a>(
281    schema: &'a Map<String, Value>,
282    parameters: &mut Vec<Parameter<'a>>,
283    additional: &mut bool,
284) {
285    // iterate over oneOf to handle top-level or flattened enums
286    if let Some(Value::Array(variants)) = schema.get("oneOf") {
287        let mut variant_parameters = HashMap::new();
288
289        for (generation, schema) in variants.iter().enumerate() {
290            // if any variant allows additional parameters, the top-level also
291            //  allows additional parameters
292            #[expect(clippy::unnested_or_patterns)]
293            if let Some(schema) = schema.as_object() {
294                *additional |= !matches!(
295                    (
296                        schema.get("additionalProperties"),
297                        schema.get("unevaluatedProperties")
298                    ),
299                    (Some(Value::Bool(false)), None)
300                        | (None, Some(Value::Bool(false)))
301                        | (Some(Value::Bool(false)), Some(Value::Bool(false)))
302                );
303            }
304
305            let required = match schema.get("required") {
306                Some(Value::Array(required)) => &**required,
307                _ => &[],
308            };
309            let variant_docs = match schema.get("description") {
310                Some(Value::String(docs)) => Some(docs),
311                _ => None,
312            };
313
314            // extract the per-variant parameters and check for tag parameters
315            if let Some(Value::Object(properties)) = schema.get("properties") {
316                for (name, parameter) in properties {
317                    match variant_parameters.entry(name) {
318                        Entry::Vacant(entry) => {
319                            entry.insert(VariantParameter::new(
320                                generation,
321                                name,
322                                parameter,
323                                required,
324                                variant_docs.map(|x| Cow::Borrowed(x.as_str())),
325                            ));
326                        }
327                        Entry::Occupied(mut entry) => {
328                            entry.get_mut().merge(
329                                generation,
330                                name,
331                                parameter,
332                                required,
333                                variant_docs.map(|x| Cow::Borrowed(x.as_str())),
334                            );
335                        }
336                    }
337                }
338            }
339
340            // ensure that only parameters in all variants are required or tags
341            for parameter in variant_parameters.values_mut() {
342                parameter.update_generation(generation);
343            }
344        }
345
346        // merge the variant parameters into the top-level parameters
347        parameters.extend(
348            variant_parameters
349                .into_values()
350                .map(VariantParameter::into_parameter),
351        );
352    }
353}
354
355#[derive(Debug, Error)]
356pub enum SchemaError {
357    #[error("codec class' cached config schema is invalid")]
358    InvalidCachedJsonSchema { source: PythonizeError },
359    #[error("extracting the codec signature failed")]
360    SignatureExtraction {
361        #[from]
362        source: PyErr,
363    },
364    #[error("codec's signature must not contain an `*args` parameter")]
365    ArgsParameterInSignature,
366    #[error("{name} parameter's default value is invalid")]
367    InvalidParameterDefault {
368        name: String,
369        source: PythonizeError,
370    },
371    #[error("codec class's `__doc__` must be a string")]
372    InvalidClassDocs { source: PyErr },
373    #[error("codec class must have a string `__name__`")]
374    InvalidClassName { source: PyErr },
375}
376
377struct Parameters<'a> {
378    named: Vec<Parameter<'a>>,
379    additional: bool,
380}
381
382struct Parameter<'a> {
383    name: &'a str,
384    required: bool,
385    default: Option<&'a Value>,
386    docs: Option<Cow<'a, str>>,
387}
388
389impl<'a> Parameter<'a> {
390    #[must_use]
391    pub fn new(name: &'a str, parameter: &'a Value, required: &[Value]) -> Self {
392        Self {
393            name,
394            required: required
395                .iter()
396                .any(|r| matches!(r, Value::String(n) if n == name)),
397            default: parameter.get("default"),
398            docs: match parameter.get("description") {
399                Some(Value::String(docs)) => Some(Cow::Borrowed(docs.as_str())),
400                _ => None,
401            },
402        }
403    }
404}
405
406struct VariantParameter<'a> {
407    generation: usize,
408    parameter: Parameter<'a>,
409    #[expect(clippy::type_complexity)]
410    tag_docs: Option<Vec<(&'a Value, Option<Cow<'a, str>>)>>,
411}
412
413impl<'a> VariantParameter<'a> {
414    #[must_use]
415    pub fn new(
416        generation: usize,
417        name: &'a str,
418        parameter: &'a Value,
419        required: &[Value],
420        variant_docs: Option<Cow<'a, str>>,
421    ) -> Self {
422        let r#const = parameter.get("const");
423
424        let mut parameter = Parameter::new(name, parameter, required);
425        parameter.required &= generation == 0;
426
427        let tag_docs = match r#const {
428            // a tag parameter must be introduced in the first generation
429            Some(r#const) if generation == 0 => {
430                let docs = parameter.docs.take().or(variant_docs);
431                Some(vec![(r#const, docs)])
432            }
433            _ => None,
434        };
435
436        Self {
437            generation,
438            parameter,
439            tag_docs,
440        }
441    }
442
443    pub fn merge(
444        &mut self,
445        generation: usize,
446        name: &'a str,
447        parameter: &'a Value,
448        required: &[Value],
449        variant_docs: Option<Cow<'a, str>>,
450    ) {
451        self.generation = generation;
452
453        let r#const = parameter.get("const");
454
455        let parameter = Parameter::new(name, parameter, required);
456
457        self.parameter.required &= parameter.required;
458        if self.parameter.default != parameter.default {
459            self.parameter.default = None;
460        }
461
462        if let Some(tag_docs) = &mut self.tag_docs {
463            // we're building docs for a tag-like parameter
464            if let Some(r#const) = r#const {
465                tag_docs.push((r#const, parameter.docs.or(variant_docs)));
466            } else {
467                // mixing tag and non-tag parameter => no docs
468                self.tag_docs = None;
469                self.parameter.docs = None;
470            }
471        } else {
472            // we're building docs for a normal parameter
473            if r#const.is_none() {
474                // we only accept always matching docs for normal parameters
475                if self.parameter.docs != parameter.docs {
476                    self.parameter.docs = None;
477                }
478            } else {
479                // mixing tag and non-tag parameter => no docs
480                self.tag_docs = None;
481            }
482        }
483    }
484
485    pub fn update_generation(&mut self, generation: usize) {
486        if self.generation < generation {
487            // required and tag parameters must appear in all generations
488            self.parameter.required = false;
489            self.tag_docs = None;
490        }
491    }
492
493    #[must_use]
494    pub fn into_parameter(mut self) -> Parameter<'a> {
495        if let Some(tag_docs) = self.tag_docs {
496            let mut docs = String::new();
497
498            for (tag, tag_docs) in tag_docs {
499                docs.push_str(" - ");
500                JsonToPythonFormatter::push_to_string(&mut docs, tag);
501                if let Some(tag_docs) = tag_docs {
502                    docs.push_str(": ");
503                    docs.push_str(&tag_docs.replace('\n', "\n    "));
504                }
505                docs.push_str("\n\n");
506            }
507
508            docs.truncate(docs.trim_end().len());
509
510            self.parameter.docs = Some(Cow::Owned(docs));
511        }
512
513        self.parameter
514    }
515}
516
517struct JsonToPythonFormatter;
518
519impl JsonToPythonFormatter {
520    fn push_to_string(buffer: &mut String, value: &Value) {
521        // Safety: serde_json::Serializer only produces valid UTF8 bytes
522        //  - JsonToPythonFormatter only writes valid UTF8 bytes
523        //  - serde_json::ser::CompactFormatter only writes valid UTF8 bytes
524        #[expect(unsafe_code)]
525        let mut ser = serde_json::Serializer::with_formatter(unsafe { buffer.as_mut_vec() }, Self);
526        #[allow(clippy::expect_used)]
527        serde::Serialize::serialize(value, &mut ser)
528            .expect("JSON value must not fail to serialize");
529    }
530}
531
532impl serde_json::ser::Formatter for JsonToPythonFormatter {
533    #[inline]
534    fn write_null<W: ?Sized + io::Write>(&mut self, writer: &mut W) -> io::Result<()> {
535        writer.write_all(b"None")
536    }
537
538    #[inline]
539    fn write_bool<W: ?Sized + io::Write>(&mut self, writer: &mut W, value: bool) -> io::Result<()> {
540        let s: &[u8] = if value { b"True" } else { b"False" };
541        writer.write_all(s)
542    }
543
544    #[inline]
545    fn write_f32<W: ?Sized + io::Write>(&mut self, writer: &mut W, value: f32) -> io::Result<()> {
546        let s: &[u8] = match (value.classify(), value.is_sign_negative()) {
547            (FpCategory::Nan, false) => b"float('nan')",
548            (FpCategory::Nan, true) => b"float('-nan')",
549            (FpCategory::Infinite, false) => b"float('inf')",
550            (FpCategory::Infinite, true) => b"float('-inf')",
551            (FpCategory::Zero | FpCategory::Subnormal | FpCategory::Normal, false | true) => {
552                return serde_json::ser::CompactFormatter.write_f32(writer, value);
553            }
554        };
555        writer.write_all(s)
556    }
557
558    #[inline]
559    fn write_f64<W: ?Sized + io::Write>(&mut self, writer: &mut W, value: f64) -> io::Result<()> {
560        let s: &[u8] = match (value.classify(), value.is_sign_negative()) {
561            (FpCategory::Nan, false) => b"float('nan')",
562            (FpCategory::Nan, true) => b"float('-nan')",
563            (FpCategory::Infinite, false) => b"float('inf')",
564            (FpCategory::Infinite, true) => b"float('-inf')",
565            (FpCategory::Zero | FpCategory::Subnormal | FpCategory::Normal, false | true) => {
566                return serde_json::ser::CompactFormatter.write_f64(writer, value);
567            }
568        };
569        writer.write_all(s)
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use schemars::{JsonSchema, schema_for};
576
577    use super::*;
578
579    #[test]
580    fn schema() {
581        assert_eq!(
582            format!("{}", schema_for!(MyCodec).to_value()),
583            r#"{"type":"object","properties":{"param":{"type":["integer","null"],"format":"int32","description":"An optional integer value."}},"unevaluatedProperties":false,"oneOf":[{"type":"object","description":"Mode a.\n\nIt gets another line.","properties":{"value":{"type":"boolean","description":"A boolean value. And some really, really, really, long first\nline that wraps around.\n\nWith multiple lines of comments."},"common":{"type":"string","description":"A common string value.\n\nSomething else here."},"mode":{"type":"string","const":"A"}},"required":["mode","value","common"]},{"type":"object","description":"Mode b.","properties":{"common":{"type":"string","description":"A common string value.\n\nSomething else here."},"mode":{"type":"string","const":"B"}},"required":["mode","common"]}],"description":"A codec that does something on encoding and decoding.\n\nWith multiple lines of comments.","title":"MyCodec","$schema":"https://json-schema.org/draft/2020-12/schema"}"#
584        );
585    }
586
587    #[test]
588    fn docs() {
589        assert_eq!(
590            docs_from_schema(&schema_for!(MyCodec)).as_deref(),
591            Some(
592                r#"A codec that does something on encoding and decoding.
593
594With multiple lines of comments.
595
596Parameters
597----------
598common : ...
599    A common string value.
600    
601    Something else here.
602mode : ...
603     - "A": Mode a.
604        
605        It gets another line.
606    
607     - "B": Mode b.
608param : ..., optional
609    An optional integer value.
610value : ..., optional
611    A boolean value. And some really, really, really, long first
612    line that wraps around.
613    
614    With multiple lines of comments."#
615            )
616        );
617    }
618
619    #[test]
620    fn signature() {
621        assert_eq!(
622            signature_from_schema(&schema_for!(MyCodec)),
623            "self, common, mode, param=None, value=None",
624        );
625    }
626
627    #[expect(dead_code)]
628    #[derive(JsonSchema)]
629    #[schemars(deny_unknown_fields)]
630    /// A codec that does something on encoding and decoding.
631    ///
632    /// With multiple lines of comments.
633    struct MyCodec {
634        /// An optional integer value.
635        #[schemars(default, skip_serializing_if = "Option::is_none")]
636        param: Option<i32>,
637        /// The flattened configuration.
638        #[schemars(flatten)]
639        config: Config,
640    }
641
642    #[expect(dead_code)]
643    #[derive(JsonSchema)]
644    #[schemars(tag = "mode")]
645    #[schemars(deny_unknown_fields)]
646    enum Config {
647        /// Mode a.
648        ///
649        /// It gets another line.
650        A {
651            /// A boolean value. And some really, really, really, long first
652            /// line that wraps around.
653            ///
654            /// With multiple lines of comments.
655            value: bool,
656            /// A common string value.
657            ///
658            /// Something else here.
659            common: String,
660        },
661        /// Mode b.
662        B {
663            /// A common string value.
664            ///
665            /// Something else here.
666            common: String,
667        },
668    }
669}