numcodecs_python/
schema.rs

1use std::{
2    borrow::Cow,
3    collections::{hash_map::Entry, HashMap},
4};
5
6use pyo3::{intern, prelude::*, sync::GILOnceCell};
7use pythonize::{depythonize, PythonizeError};
8use schemars::Schema;
9use serde_json::{Map, Value};
10use thiserror::Error;
11
12use crate::{export::RustCodec, PyCodecClass};
13
14macro_rules! once {
15    ($py:ident, $module:literal $(, $path:literal)*) => {{
16        fn once(py: Python) -> Result<&Bound<PyAny>, PyErr> {
17            static ONCE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
18            Ok(ONCE.get_or_try_init(py, || -> Result<Py<PyAny>, PyErr> {
19                Ok(py
20                    .import(intern!(py, $module))?
21                    $(.getattr(intern!(py, $path))?)*
22                    .unbind())
23            })?.bind(py))
24        }
25
26        once($py)
27    }};
28}
29
30pub fn schema_from_codec_class(
31    py: Python,
32    class: &Bound<PyCodecClass>,
33) -> Result<Schema, SchemaError> {
34    if let Ok(schema) = class.getattr(intern!(py, RustCodec::SCHEMA_ATTRIBUTE)) {
35        return depythonize(&schema)
36            .map_err(|err| SchemaError::InvalidCachedJsonSchema { source: err });
37    }
38
39    let mut schema = Schema::default();
40
41    {
42        let schema = schema.ensure_object();
43
44        schema.insert(String::from("type"), Value::String(String::from("object")));
45
46        if let Ok(init) = class.getattr(intern!(py, "__init__")) {
47            let mut properties = Map::new();
48            let mut additional_properties = false;
49            let mut required = Vec::new();
50
51            let object_init = once!(py, "builtins", "object", "__init__")?;
52            let signature = once!(py, "inspect", "signature")?;
53            let empty_parameter = once!(py, "inspect", "Parameter", "empty")?;
54            let args_parameter = once!(py, "inspect", "Parameter", "VAR_POSITIONAL")?;
55            let kwargs_parameter = once!(py, "inspect", "Parameter", "VAR_KEYWORD")?;
56
57            for (i, param) in signature
58                .call1((&init,))?
59                .getattr(intern!(py, "parameters"))?
60                .call_method0(intern!(py, "items"))?
61                .try_iter()?
62                .enumerate()
63            {
64                let (name, param): (String, Bound<PyAny>) = param?.extract()?;
65
66                if i == 0 && name == "self" {
67                    continue;
68                }
69
70                let kind = param.getattr(intern!(py, "kind"))?;
71
72                if kind.eq(args_parameter)? && !init.eq(object_init)? {
73                    return Err(SchemaError::ArgsParameterInSignature);
74                }
75
76                if kind.eq(kwargs_parameter)? {
77                    additional_properties = true;
78                } else {
79                    let default = param.getattr(intern!(py, "default"))?;
80
81                    let mut parameter = Map::new();
82
83                    if default.eq(empty_parameter)? {
84                        required.push(Value::String(name.clone()));
85                    } else {
86                        let default = depythonize(&default).map_err(|err| {
87                            SchemaError::InvalidParameterDefault {
88                                name: name.clone(),
89                                source: err,
90                            }
91                        })?;
92                        parameter.insert(String::from("default"), default);
93                    }
94
95                    properties.insert(name, Value::Object(parameter));
96                }
97            }
98
99            schema.insert(
100                String::from("additionalProperties"),
101                Value::Bool(additional_properties),
102            );
103            schema.insert(String::from("properties"), Value::Object(properties));
104            schema.insert(String::from("required"), Value::Array(required));
105        } else {
106            schema.insert(String::from("additionalProperties"), Value::Bool(true));
107        }
108
109        if let Ok(doc) = class.getattr(intern!(py, "__doc__")) {
110            if !doc.is_none() {
111                let doc: String = doc
112                    .extract()
113                    .map_err(|err| SchemaError::InvalidClassDocs { source: err })?;
114                schema.insert(String::from("description"), Value::String(doc));
115            }
116        }
117
118        let name = class
119            .getattr(intern!(py, "__name__"))
120            .and_then(|name| name.extract())
121            .map_err(|err| SchemaError::InvalidClassName { source: err })?;
122        schema.insert(String::from("title"), Value::String(name));
123
124        schema.insert(
125            String::from("$schema"),
126            Value::String(String::from("https://json-schema.org/draft/2020-12/schema")),
127        );
128    }
129
130    Ok(schema)
131}
132
133pub fn docs_from_schema(schema: &Schema) -> Option<String> {
134    let parameters = parameters_from_schema(schema);
135    let schema = schema.as_object()?;
136
137    let mut docs = String::new();
138
139    if let Some(Value::String(description)) = schema.get("description") {
140        docs.push_str(&derust_doc_comment(description));
141        docs.push_str("\n\n");
142    }
143
144    if !parameters.named.is_empty() || parameters.additional {
145        docs.push_str("Parameters\n----------\n");
146    }
147
148    for parameter in &parameters.named {
149        docs.push_str(parameter.name);
150
151        docs.push_str(" : ...");
152
153        if !parameter.required {
154            docs.push_str(", optional");
155        }
156
157        if let Some(default) = parameter.default {
158            docs.push_str(", default = ");
159            docs.push_str(&format!("{default}"));
160        }
161
162        docs.push('\n');
163
164        if let Some(info) = &parameter.docs {
165            docs.push_str("    ");
166            docs.push_str(&info.replace('\n', "\n    "));
167        }
168
169        docs.push('\n');
170    }
171
172    if parameters.additional {
173        docs.push_str("**kwargs\n");
174        docs.push_str("    ");
175
176        if parameters.named.is_empty() {
177            docs.push_str("This codec takes *any* parameters.");
178        } else {
179            docs.push_str("This codec takes *any* additional parameters.");
180        }
181    } else if parameters.named.is_empty() {
182        docs.push_str("This codec does *not* take any parameters.");
183    }
184
185    docs.truncate(docs.trim_end().len());
186
187    Some(docs)
188}
189
190pub fn signature_from_schema(schema: &Schema) -> String {
191    let parameters = parameters_from_schema(schema);
192
193    let mut signature = String::new();
194    signature.push_str("self");
195
196    for parameter in parameters.named {
197        signature.push_str(", ");
198        signature.push_str(parameter.name);
199
200        if let Some(default) = parameter.default {
201            signature.push('=');
202            signature.push_str(&format!("{default}"));
203        } else if !parameter.required {
204            signature.push_str("=None");
205        }
206    }
207
208    if parameters.additional {
209        signature.push_str(", **kwargs");
210    }
211
212    signature
213}
214
215fn parameters_from_schema(schema: &Schema) -> Parameters {
216    // schema = true means that any parameters are allowed
217    if schema.as_bool() == Some(true) {
218        return Parameters {
219            named: Vec::new(),
220            additional: true,
221        };
222    }
223
224    // schema = false means that no config is valid
225    // we approximate that by saying that no parameters are allowed
226    let Some(schema) = schema.as_object() else {
227        return Parameters {
228            named: Vec::new(),
229            additional: false,
230        };
231    };
232
233    let mut parameters = Vec::new();
234
235    let required = match schema.get("required") {
236        Some(Value::Array(required)) => &**required,
237        _ => &[],
238    };
239
240    // extract the top-level parameters
241    if let Some(Value::Object(properties)) = schema.get("properties") {
242        for (name, parameter) in properties {
243            parameters.push(Parameter::new(name, parameter, required));
244        }
245    }
246
247    let mut additional = false;
248
249    extend_parameters_from_one_of_schema(schema, &mut parameters, &mut additional);
250
251    // iterate over allOf to handle flattened enums
252    if let Some(Value::Array(all)) = schema.get("allOf") {
253        for variant in all {
254            if let Some(variant) = variant.as_object() {
255                extend_parameters_from_one_of_schema(variant, &mut parameters, &mut additional);
256            }
257        }
258    }
259
260    // sort parameters by name and so that required parameters come first
261    parameters.sort_by_key(|p| (!p.required, p.name));
262
263    additional = match (
264        schema.get("additionalProperties"),
265        schema.get("unevaluatedProperties"),
266    ) {
267        (Some(Value::Bool(false)), None) => additional,
268        (None | Some(Value::Bool(false)), Some(Value::Bool(false))) => false,
269        _ => true,
270    };
271
272    Parameters {
273        named: parameters,
274        additional,
275    }
276}
277
278fn extend_parameters_from_one_of_schema<'a>(
279    schema: &'a Map<String, Value>,
280    parameters: &mut Vec<Parameter<'a>>,
281    additional: &mut bool,
282) {
283    // iterate over oneOf to handle top-level or flattened enums
284    if let Some(Value::Array(variants)) = schema.get("oneOf") {
285        let mut variant_parameters = HashMap::new();
286
287        for (generation, schema) in variants.iter().enumerate() {
288            // if any variant allows additional parameters, the top-level also
289            //  allows additional parameters
290            #[expect(clippy::unnested_or_patterns)]
291            if let Some(schema) = schema.as_object() {
292                *additional |= !matches!(
293                    (
294                        schema.get("additionalProperties"),
295                        schema.get("unevaluatedProperties")
296                    ),
297                    (Some(Value::Bool(false)), None)
298                        | (None, Some(Value::Bool(false)))
299                        | (Some(Value::Bool(false)), Some(Value::Bool(false)))
300                );
301            }
302
303            let required = match schema.get("required") {
304                Some(Value::Array(required)) => &**required,
305                _ => &[],
306            };
307            let variant_docs = match schema.get("description") {
308                Some(Value::String(docs)) => Some(derust_doc_comment(docs)),
309                _ => None,
310            };
311
312            // extract the per-variant parameters and check for tag parameters
313            if let Some(Value::Object(properties)) = schema.get("properties") {
314                for (name, parameter) in properties {
315                    match variant_parameters.entry(name) {
316                        Entry::Vacant(entry) => {
317                            entry.insert(VariantParameter::new(
318                                generation,
319                                name,
320                                parameter,
321                                required,
322                                variant_docs.clone(),
323                            ));
324                        }
325                        Entry::Occupied(mut entry) => {
326                            entry.get_mut().merge(
327                                generation,
328                                name,
329                                parameter,
330                                required,
331                                variant_docs.clone(),
332                            );
333                        }
334                    }
335                }
336            }
337
338            // ensure that only parameters in all variants are required or tags
339            for parameter in variant_parameters.values_mut() {
340                parameter.update_generation(generation);
341            }
342        }
343
344        // merge the variant parameters into the top-level parameters
345        parameters.extend(
346            variant_parameters
347                .into_values()
348                .map(VariantParameter::into_parameter),
349        );
350    }
351}
352
353fn derust_doc_comment(docs: &str) -> Cow<str> {
354    if docs.trim() != docs {
355        return Cow::Borrowed(docs);
356    }
357
358    if !docs
359        .split('\n')
360        .skip(1)
361        .all(|l| l.trim().is_empty() || l.starts_with(' '))
362    {
363        return Cow::Borrowed(docs);
364    }
365
366    Cow::Owned(docs.replace("\n ", "\n"))
367}
368
369#[derive(Debug, Error)]
370pub enum SchemaError {
371    #[error("codec class' cached config schema is invalid")]
372    InvalidCachedJsonSchema { source: PythonizeError },
373    #[error("extracting the codec signature failed")]
374    SignatureExtraction {
375        #[from]
376        source: PyErr,
377    },
378    #[error("codec's signature must not contain an `*args` parameter")]
379    ArgsParameterInSignature,
380    #[error("{name} parameter's default value is invalid")]
381    InvalidParameterDefault {
382        name: String,
383        source: PythonizeError,
384    },
385    #[error("codec class's `__doc__` must be a string")]
386    InvalidClassDocs { source: PyErr },
387    #[error("codec class must have a string `__name__`")]
388    InvalidClassName { source: PyErr },
389}
390
391struct Parameters<'a> {
392    named: Vec<Parameter<'a>>,
393    additional: bool,
394}
395
396struct Parameter<'a> {
397    name: &'a str,
398    required: bool,
399    default: Option<&'a Value>,
400    docs: Option<Cow<'a, str>>,
401}
402
403impl<'a> Parameter<'a> {
404    #[must_use]
405    pub fn new(name: &'a str, parameter: &'a Value, required: &[Value]) -> Self {
406        Self {
407            name,
408            required: required
409                .iter()
410                .any(|r| matches!(r, Value::String(n) if n == name)),
411            default: parameter.get("default"),
412            docs: match parameter.get("description") {
413                Some(Value::String(docs)) => Some(derust_doc_comment(docs)),
414                _ => None,
415            },
416        }
417    }
418}
419
420struct VariantParameter<'a> {
421    generation: usize,
422    parameter: Parameter<'a>,
423    #[expect(clippy::type_complexity)]
424    tag_docs: Option<Vec<(&'a Value, Option<Cow<'a, str>>)>>,
425}
426
427impl<'a> VariantParameter<'a> {
428    #[must_use]
429    pub fn new(
430        generation: usize,
431        name: &'a str,
432        parameter: &'a Value,
433        required: &[Value],
434        variant_docs: Option<Cow<'a, str>>,
435    ) -> Self {
436        let r#const = parameter.get("const");
437
438        let mut parameter = Parameter::new(name, parameter, required);
439        parameter.required &= generation == 0;
440
441        let tag_docs = match r#const {
442            // a tag parameter must be introduced in the first generation
443            Some(r#const) if generation == 0 => {
444                let docs = parameter.docs.take().or(variant_docs);
445                Some(vec![(r#const, docs)])
446            }
447            _ => None,
448        };
449
450        Self {
451            generation,
452            parameter,
453            tag_docs,
454        }
455    }
456
457    pub fn merge(
458        &mut self,
459        generation: usize,
460        name: &'a str,
461        parameter: &'a Value,
462        required: &[Value],
463        variant_docs: Option<Cow<'a, str>>,
464    ) {
465        self.generation = generation;
466
467        let r#const = parameter.get("const");
468
469        let parameter = Parameter::new(name, parameter, required);
470
471        self.parameter.required &= parameter.required;
472        if self.parameter.default != parameter.default {
473            self.parameter.default = None;
474        }
475
476        if let Some(tag_docs) = &mut self.tag_docs {
477            // we're building docs for a tag-like parameter
478            if let Some(r#const) = r#const {
479                tag_docs.push((r#const, parameter.docs.or(variant_docs)));
480            } else {
481                // mixing tag and non-tag parameter => no docs
482                self.tag_docs = None;
483                self.parameter.docs = None;
484            }
485        } else {
486            // we're building docs for a normal parameter
487            if r#const.is_none() {
488                // we only accept always matching docs for normal parameters
489                if self.parameter.docs != parameter.docs {
490                    self.parameter.docs = None;
491                }
492            } else {
493                // mixing tag and non-tag parameter => no docs
494                self.tag_docs = None;
495            }
496        }
497    }
498
499    pub fn update_generation(&mut self, generation: usize) {
500        if self.generation < generation {
501            // required and tag parameters must appear in all generations
502            self.parameter.required = false;
503            self.tag_docs = None;
504        }
505    }
506
507    #[must_use]
508    pub fn into_parameter(mut self) -> Parameter<'a> {
509        if let Some(tag_docs) = self.tag_docs {
510            let mut docs = String::new();
511
512            for (tag, tag_docs) in tag_docs {
513                docs.push_str(" - ");
514                docs.push_str(&format!("{tag}"));
515                if let Some(tag_docs) = tag_docs {
516                    docs.push_str(": ");
517                    docs.push_str(&tag_docs.replace('\n', "\n    "));
518                }
519                docs.push_str("\n\n");
520            }
521
522            docs.truncate(docs.trim_end().len());
523
524            self.parameter.docs = Some(Cow::Owned(docs));
525        }
526
527        self.parameter
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use schemars::{schema_for, JsonSchema};
534
535    use super::*;
536
537    #[test]
538    fn schema() {
539        assert_eq!(
540            format!("{}", schema_for!(MyCodec).to_value()),
541            r#"{"type":"object","properties":{"param":{"type":["integer","null"],"format":"int32","description":"An optional integer value."}},"unevaluatedProperties":false,"oneOf":[{"type":"object","description":"Mode a.\n\n It gets another line.","properties":{"value":{"type":"boolean","description":"A boolean value. And some really, really, really, long first\n line that wraps around.\n\n With multiple lines of comments."},"common":{"type":"string","description":"A common string value.\n\n Something else here."},"mode":{"type":"string","const":"A"}},"required":["mode","value","common"]},{"type":"object","description":"Mode b.","properties":{"common":{"type":"string","description":"A common string value.\n\n Something else here."},"mode":{"type":"string","const":"B"}},"required":["mode","common"]}],"description":"A codec that does something on encoding and decoding.\n\n With multiple lines of comments.","title":"MyCodec","$schema":"https://json-schema.org/draft/2020-12/schema"}"#
542        );
543    }
544
545    #[test]
546    fn docs() {
547        assert_eq!(
548            docs_from_schema(&schema_for!(MyCodec)).as_deref(),
549            Some(
550                r#"A codec that does something on encoding and decoding.
551
552With multiple lines of comments.
553
554Parameters
555----------
556common : ...
557    A common string value.
558    
559    Something else here.
560mode : ...
561     - "A": Mode a.
562        
563        It gets another line.
564    
565     - "B": Mode b.
566param : ..., optional
567    An optional integer value.
568value : ..., optional
569    A boolean value. And some really, really, really, long first
570    line that wraps around.
571    
572    With multiple lines of comments."#
573            )
574        );
575    }
576
577    #[test]
578    fn signature() {
579        assert_eq!(
580            signature_from_schema(&schema_for!(MyCodec)),
581            "self, common, mode, param=None, value=None",
582        );
583    }
584
585    #[expect(dead_code)]
586    #[derive(JsonSchema)]
587    #[schemars(deny_unknown_fields)]
588    /// A codec that does something on encoding and decoding.
589    ///
590    /// With multiple lines of comments.
591    struct MyCodec {
592        /// An optional integer value.
593        #[schemars(default, skip_serializing_if = "Option::is_none")]
594        param: Option<i32>,
595        /// The flattened configuration.
596        #[schemars(flatten)]
597        config: Config,
598    }
599
600    #[expect(dead_code)]
601    #[derive(JsonSchema)]
602    #[schemars(tag = "mode")]
603    #[schemars(deny_unknown_fields)]
604    enum Config {
605        /// Mode a.
606        ///
607        /// It gets another line.
608        A {
609            /// A boolean value. And some really, really, really, long first
610            /// line that wraps around.
611            ///
612            /// With multiple lines of comments.
613            value: bool,
614            /// A common string value.
615            ///
616            /// Something else here.
617            common: String,
618        },
619        /// Mode b.
620        B {
621            /// A common string value.
622            ///
623            /// Something else here.
624            common: String,
625        },
626    }
627}