numcodecs_python/
schema.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, hash_map::Entry},
4};
5
6use pyo3::{intern, prelude::*, sync::GILOnceCell};
7use pythonize::{PythonizeError, depythonize};
8use schemars::Schema;
9use serde_json::{Map, Value};
10use thiserror::Error;
11
12use crate::{PyCodecClass, export::RustCodec};
13
14macro_rules! once {
15    ($py:ident, $module:literal $(, $path:literal)*) => {{
16        fn once(py: Python) -> Result<&Bound<PyAny>, PyErr> {
17            static ONCE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
18            Ok(ONCE.get_or_try_init(py, || -> Result<Py<PyAny>, PyErr> {
19                Ok(py
20                    .import(intern!(py, $module))?
21                    $(.getattr(intern!(py, $path))?)*
22                    .unbind())
23            })?.bind(py))
24        }
25
26        once($py)
27    }};
28}
29
30pub fn schema_from_codec_class(
31    py: Python,
32    class: &Bound<PyCodecClass>,
33) -> Result<Schema, SchemaError> {
34    if let Ok(schema) = class.getattr(intern!(py, RustCodec::SCHEMA_ATTRIBUTE)) {
35        return depythonize(&schema)
36            .map_err(|err| SchemaError::InvalidCachedJsonSchema { source: err });
37    }
38
39    let mut schema = Schema::default();
40
41    {
42        let schema = schema.ensure_object();
43
44        schema.insert(String::from("type"), Value::String(String::from("object")));
45
46        if let Ok(init) = class.getattr(intern!(py, "__init__")) {
47            let mut properties = Map::new();
48            let mut additional_properties = false;
49            let mut required = Vec::new();
50
51            let object_init = once!(py, "builtins", "object", "__init__")?;
52            let signature = once!(py, "inspect", "signature")?;
53            let empty_parameter = once!(py, "inspect", "Parameter", "empty")?;
54            let args_parameter = once!(py, "inspect", "Parameter", "VAR_POSITIONAL")?;
55            let kwargs_parameter = once!(py, "inspect", "Parameter", "VAR_KEYWORD")?;
56
57            for (i, param) in signature
58                .call1((&init,))?
59                .getattr(intern!(py, "parameters"))?
60                .call_method0(intern!(py, "items"))?
61                .try_iter()?
62                .enumerate()
63            {
64                let (name, param): (String, Bound<PyAny>) = param?.extract()?;
65
66                if i == 0 && name == "self" {
67                    continue;
68                }
69
70                let kind = param.getattr(intern!(py, "kind"))?;
71
72                if kind.eq(args_parameter)? && !init.eq(object_init)? {
73                    return Err(SchemaError::ArgsParameterInSignature);
74                }
75
76                if kind.eq(kwargs_parameter)? {
77                    additional_properties = true;
78                } else {
79                    let default = param.getattr(intern!(py, "default"))?;
80
81                    let mut parameter = Map::new();
82
83                    if default.eq(empty_parameter)? {
84                        required.push(Value::String(name.clone()));
85                    } else {
86                        let default = depythonize(&default).map_err(|err| {
87                            SchemaError::InvalidParameterDefault {
88                                name: name.clone(),
89                                source: err,
90                            }
91                        })?;
92                        parameter.insert(String::from("default"), default);
93                    }
94
95                    properties.insert(name, Value::Object(parameter));
96                }
97            }
98
99            schema.insert(
100                String::from("additionalProperties"),
101                Value::Bool(additional_properties),
102            );
103            schema.insert(String::from("properties"), Value::Object(properties));
104            schema.insert(String::from("required"), Value::Array(required));
105        } else {
106            schema.insert(String::from("additionalProperties"), Value::Bool(true));
107        }
108
109        if let Ok(doc) = class.getattr(intern!(py, "__doc__")) {
110            if !doc.is_none() {
111                let doc: String = doc
112                    .extract()
113                    .map_err(|err| SchemaError::InvalidClassDocs { source: err })?;
114                schema.insert(String::from("description"), Value::String(doc));
115            }
116        }
117
118        let name = class
119            .getattr(intern!(py, "__name__"))
120            .and_then(|name| name.extract())
121            .map_err(|err| SchemaError::InvalidClassName { source: err })?;
122        schema.insert(String::from("title"), Value::String(name));
123
124        schema.insert(
125            String::from("$schema"),
126            Value::String(String::from("https://json-schema.org/draft/2020-12/schema")),
127        );
128    }
129
130    Ok(schema)
131}
132
133pub fn docs_from_schema(schema: &Schema) -> Option<String> {
134    let parameters = parameters_from_schema(schema);
135    let schema = schema.as_object()?;
136
137    let mut docs = String::new();
138
139    if let Some(Value::String(description)) = schema.get("description") {
140        docs.push_str(description);
141        docs.push_str("\n\n");
142    }
143
144    if !parameters.named.is_empty() || parameters.additional {
145        docs.push_str("Parameters\n----------\n");
146    }
147
148    for parameter in &parameters.named {
149        docs.push_str(parameter.name);
150
151        docs.push_str(" : ...");
152
153        if !parameter.required {
154            docs.push_str(", optional");
155        }
156
157        #[expect(clippy::format_push_string)] // FIXME
158        if let Some(default) = parameter.default {
159            docs.push_str(", default = ");
160            docs.push_str(&format!("{default}"));
161        }
162
163        docs.push('\n');
164
165        if let Some(info) = &parameter.docs {
166            docs.push_str("    ");
167            docs.push_str(&info.replace('\n', "\n    "));
168        }
169
170        docs.push('\n');
171    }
172
173    if parameters.additional {
174        docs.push_str("**kwargs\n");
175        docs.push_str("    ");
176
177        if parameters.named.is_empty() {
178            docs.push_str("This codec takes *any* parameters.");
179        } else {
180            docs.push_str("This codec takes *any* additional parameters.");
181        }
182    } else if parameters.named.is_empty() {
183        docs.push_str("This codec does *not* take any parameters.");
184    }
185
186    docs.truncate(docs.trim_end().len());
187
188    Some(docs)
189}
190
191pub fn signature_from_schema(schema: &Schema) -> String {
192    let parameters = parameters_from_schema(schema);
193
194    let mut signature = String::new();
195    signature.push_str("self");
196
197    for parameter in parameters.named {
198        signature.push_str(", ");
199        signature.push_str(parameter.name);
200
201        #[expect(clippy::format_push_string)] // FIXME
202        if let Some(default) = parameter.default {
203            signature.push('=');
204            signature.push_str(&format!("{default}"));
205        } else if !parameter.required {
206            signature.push_str("=None");
207        }
208    }
209
210    if parameters.additional {
211        signature.push_str(", **kwargs");
212    }
213
214    signature
215}
216
217fn parameters_from_schema(schema: &Schema) -> Parameters {
218    // schema = true means that any parameters are allowed
219    if schema.as_bool() == Some(true) {
220        return Parameters {
221            named: Vec::new(),
222            additional: true,
223        };
224    }
225
226    // schema = false means that no config is valid
227    // we approximate that by saying that no parameters are allowed
228    let Some(schema) = schema.as_object() else {
229        return Parameters {
230            named: Vec::new(),
231            additional: false,
232        };
233    };
234
235    let mut parameters = Vec::new();
236
237    let required = match schema.get("required") {
238        Some(Value::Array(required)) => &**required,
239        _ => &[],
240    };
241
242    // extract the top-level parameters
243    if let Some(Value::Object(properties)) = schema.get("properties") {
244        for (name, parameter) in properties {
245            parameters.push(Parameter::new(name, parameter, required));
246        }
247    }
248
249    let mut additional = false;
250
251    extend_parameters_from_one_of_schema(schema, &mut parameters, &mut additional);
252
253    // iterate over allOf to handle flattened enums
254    if let Some(Value::Array(all)) = schema.get("allOf") {
255        for variant in all {
256            if let Some(variant) = variant.as_object() {
257                extend_parameters_from_one_of_schema(variant, &mut parameters, &mut additional);
258            }
259        }
260    }
261
262    // sort parameters by name and so that required parameters come first
263    parameters.sort_by_key(|p| (!p.required, p.name));
264
265    additional = match (
266        schema.get("additionalProperties"),
267        schema.get("unevaluatedProperties"),
268    ) {
269        (Some(Value::Bool(false)), None) => additional,
270        (None | Some(Value::Bool(false)), Some(Value::Bool(false))) => false,
271        _ => true,
272    };
273
274    Parameters {
275        named: parameters,
276        additional,
277    }
278}
279
280fn extend_parameters_from_one_of_schema<'a>(
281    schema: &'a Map<String, Value>,
282    parameters: &mut Vec<Parameter<'a>>,
283    additional: &mut bool,
284) {
285    // iterate over oneOf to handle top-level or flattened enums
286    if let Some(Value::Array(variants)) = schema.get("oneOf") {
287        let mut variant_parameters = HashMap::new();
288
289        for (generation, schema) in variants.iter().enumerate() {
290            // if any variant allows additional parameters, the top-level also
291            //  allows additional parameters
292            #[expect(clippy::unnested_or_patterns)]
293            if let Some(schema) = schema.as_object() {
294                *additional |= !matches!(
295                    (
296                        schema.get("additionalProperties"),
297                        schema.get("unevaluatedProperties")
298                    ),
299                    (Some(Value::Bool(false)), None)
300                        | (None, Some(Value::Bool(false)))
301                        | (Some(Value::Bool(false)), Some(Value::Bool(false)))
302                );
303            }
304
305            let required = match schema.get("required") {
306                Some(Value::Array(required)) => &**required,
307                _ => &[],
308            };
309            let variant_docs = match schema.get("description") {
310                Some(Value::String(docs)) => Some(docs),
311                _ => None,
312            };
313
314            // extract the per-variant parameters and check for tag parameters
315            if let Some(Value::Object(properties)) = schema.get("properties") {
316                for (name, parameter) in properties {
317                    match variant_parameters.entry(name) {
318                        Entry::Vacant(entry) => {
319                            entry.insert(VariantParameter::new(
320                                generation,
321                                name,
322                                parameter,
323                                required,
324                                variant_docs.map(|x| Cow::Borrowed(x.as_str())),
325                            ));
326                        }
327                        Entry::Occupied(mut entry) => {
328                            entry.get_mut().merge(
329                                generation,
330                                name,
331                                parameter,
332                                required,
333                                variant_docs.map(|x| Cow::Borrowed(x.as_str())),
334                            );
335                        }
336                    }
337                }
338            }
339
340            // ensure that only parameters in all variants are required or tags
341            for parameter in variant_parameters.values_mut() {
342                parameter.update_generation(generation);
343            }
344        }
345
346        // merge the variant parameters into the top-level parameters
347        parameters.extend(
348            variant_parameters
349                .into_values()
350                .map(VariantParameter::into_parameter),
351        );
352    }
353}
354
355#[derive(Debug, Error)]
356pub enum SchemaError {
357    #[error("codec class' cached config schema is invalid")]
358    InvalidCachedJsonSchema { source: PythonizeError },
359    #[error("extracting the codec signature failed")]
360    SignatureExtraction {
361        #[from]
362        source: PyErr,
363    },
364    #[error("codec's signature must not contain an `*args` parameter")]
365    ArgsParameterInSignature,
366    #[error("{name} parameter's default value is invalid")]
367    InvalidParameterDefault {
368        name: String,
369        source: PythonizeError,
370    },
371    #[error("codec class's `__doc__` must be a string")]
372    InvalidClassDocs { source: PyErr },
373    #[error("codec class must have a string `__name__`")]
374    InvalidClassName { source: PyErr },
375}
376
377struct Parameters<'a> {
378    named: Vec<Parameter<'a>>,
379    additional: bool,
380}
381
382struct Parameter<'a> {
383    name: &'a str,
384    required: bool,
385    default: Option<&'a Value>,
386    docs: Option<Cow<'a, str>>,
387}
388
389impl<'a> Parameter<'a> {
390    #[must_use]
391    pub fn new(name: &'a str, parameter: &'a Value, required: &[Value]) -> Self {
392        Self {
393            name,
394            required: required
395                .iter()
396                .any(|r| matches!(r, Value::String(n) if n == name)),
397            default: parameter.get("default"),
398            docs: match parameter.get("description") {
399                Some(Value::String(docs)) => Some(Cow::Borrowed(docs.as_str())),
400                _ => None,
401            },
402        }
403    }
404}
405
406struct VariantParameter<'a> {
407    generation: usize,
408    parameter: Parameter<'a>,
409    #[expect(clippy::type_complexity)]
410    tag_docs: Option<Vec<(&'a Value, Option<Cow<'a, str>>)>>,
411}
412
413impl<'a> VariantParameter<'a> {
414    #[must_use]
415    pub fn new(
416        generation: usize,
417        name: &'a str,
418        parameter: &'a Value,
419        required: &[Value],
420        variant_docs: Option<Cow<'a, str>>,
421    ) -> Self {
422        let r#const = parameter.get("const");
423
424        let mut parameter = Parameter::new(name, parameter, required);
425        parameter.required &= generation == 0;
426
427        let tag_docs = match r#const {
428            // a tag parameter must be introduced in the first generation
429            Some(r#const) if generation == 0 => {
430                let docs = parameter.docs.take().or(variant_docs);
431                Some(vec![(r#const, docs)])
432            }
433            _ => None,
434        };
435
436        Self {
437            generation,
438            parameter,
439            tag_docs,
440        }
441    }
442
443    pub fn merge(
444        &mut self,
445        generation: usize,
446        name: &'a str,
447        parameter: &'a Value,
448        required: &[Value],
449        variant_docs: Option<Cow<'a, str>>,
450    ) {
451        self.generation = generation;
452
453        let r#const = parameter.get("const");
454
455        let parameter = Parameter::new(name, parameter, required);
456
457        self.parameter.required &= parameter.required;
458        if self.parameter.default != parameter.default {
459            self.parameter.default = None;
460        }
461
462        if let Some(tag_docs) = &mut self.tag_docs {
463            // we're building docs for a tag-like parameter
464            if let Some(r#const) = r#const {
465                tag_docs.push((r#const, parameter.docs.or(variant_docs)));
466            } else {
467                // mixing tag and non-tag parameter => no docs
468                self.tag_docs = None;
469                self.parameter.docs = None;
470            }
471        } else {
472            // we're building docs for a normal parameter
473            if r#const.is_none() {
474                // we only accept always matching docs for normal parameters
475                if self.parameter.docs != parameter.docs {
476                    self.parameter.docs = None;
477                }
478            } else {
479                // mixing tag and non-tag parameter => no docs
480                self.tag_docs = None;
481            }
482        }
483    }
484
485    pub fn update_generation(&mut self, generation: usize) {
486        if self.generation < generation {
487            // required and tag parameters must appear in all generations
488            self.parameter.required = false;
489            self.tag_docs = None;
490        }
491    }
492
493    #[must_use]
494    pub fn into_parameter(mut self) -> Parameter<'a> {
495        if let Some(tag_docs) = self.tag_docs {
496            let mut docs = String::new();
497
498            #[expect(clippy::format_push_string)] // FIXME
499            for (tag, tag_docs) in tag_docs {
500                docs.push_str(" - ");
501                docs.push_str(&format!("{tag}"));
502                if let Some(tag_docs) = tag_docs {
503                    docs.push_str(": ");
504                    docs.push_str(&tag_docs.replace('\n', "\n    "));
505                }
506                docs.push_str("\n\n");
507            }
508
509            docs.truncate(docs.trim_end().len());
510
511            self.parameter.docs = Some(Cow::Owned(docs));
512        }
513
514        self.parameter
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use schemars::{JsonSchema, schema_for};
521
522    use super::*;
523
524    #[test]
525    fn schema() {
526        assert_eq!(
527            format!("{}", schema_for!(MyCodec).to_value()),
528            r#"{"type":"object","properties":{"param":{"type":["integer","null"],"format":"int32","description":"An optional integer value."}},"unevaluatedProperties":false,"oneOf":[{"type":"object","description":"Mode a.\n\nIt gets another line.","properties":{"value":{"type":"boolean","description":"A boolean value. And some really, really, really, long first\nline that wraps around.\n\nWith multiple lines of comments."},"common":{"type":"string","description":"A common string value.\n\nSomething else here."},"mode":{"type":"string","const":"A"}},"required":["mode","value","common"]},{"type":"object","description":"Mode b.","properties":{"common":{"type":"string","description":"A common string value.\n\nSomething else here."},"mode":{"type":"string","const":"B"}},"required":["mode","common"]}],"description":"A codec that does something on encoding and decoding.\n\nWith multiple lines of comments.","title":"MyCodec","$schema":"https://json-schema.org/draft/2020-12/schema"}"#
529        );
530    }
531
532    #[test]
533    fn docs() {
534        assert_eq!(
535            docs_from_schema(&schema_for!(MyCodec)).as_deref(),
536            Some(
537                r#"A codec that does something on encoding and decoding.
538
539With multiple lines of comments.
540
541Parameters
542----------
543common : ...
544    A common string value.
545    
546    Something else here.
547mode : ...
548     - "A": Mode a.
549        
550        It gets another line.
551    
552     - "B": Mode b.
553param : ..., optional
554    An optional integer value.
555value : ..., optional
556    A boolean value. And some really, really, really, long first
557    line that wraps around.
558    
559    With multiple lines of comments."#
560            )
561        );
562    }
563
564    #[test]
565    fn signature() {
566        assert_eq!(
567            signature_from_schema(&schema_for!(MyCodec)),
568            "self, common, mode, param=None, value=None",
569        );
570    }
571
572    #[expect(dead_code)]
573    #[derive(JsonSchema)]
574    #[schemars(deny_unknown_fields)]
575    /// A codec that does something on encoding and decoding.
576    ///
577    /// With multiple lines of comments.
578    struct MyCodec {
579        /// An optional integer value.
580        #[schemars(default, skip_serializing_if = "Option::is_none")]
581        param: Option<i32>,
582        /// The flattened configuration.
583        #[schemars(flatten)]
584        config: Config,
585    }
586
587    #[expect(dead_code)]
588    #[derive(JsonSchema)]
589    #[schemars(tag = "mode")]
590    #[schemars(deny_unknown_fields)]
591    enum Config {
592        /// Mode a.
593        ///
594        /// It gets another line.
595        A {
596            /// A boolean value. And some really, really, really, long first
597            /// line that wraps around.
598            ///
599            /// With multiple lines of comments.
600            value: bool,
601            /// A common string value.
602            ///
603            /// Something else here.
604            common: String,
605        },
606        /// Mode b.
607        B {
608            /// A common string value.
609            ///
610            /// Something else here.
611            common: String,
612        },
613    }
614}