numcodecs_python/
schema.rs

1use std::{
2    borrow::Cow,
3    collections::{hash_map::Entry, HashMap},
4};
5
6use pyo3::{intern, prelude::*, sync::GILOnceCell};
7use pythonize::{depythonize, PythonizeError};
8use schemars::Schema;
9use serde_json::{Map, Value};
10use thiserror::Error;
11
12use crate::{export::RustCodec, PyCodecClass};
13
14macro_rules! once {
15    ($py:ident, $module:literal $(, $path:literal)*) => {{
16        fn once(py: Python) -> Result<&Bound<PyAny>, PyErr> {
17            static ONCE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
18            Ok(ONCE.get_or_try_init(py, || -> Result<Py<PyAny>, PyErr> {
19                Ok(py
20                    .import(intern!(py, $module))?
21                    $(.getattr(intern!(py, $path))?)*
22                    .unbind())
23            })?.bind(py))
24        }
25
26        once($py)
27    }};
28}
29
30pub fn schema_from_codec_class(
31    py: Python,
32    class: &Bound<PyCodecClass>,
33) -> Result<Schema, SchemaError> {
34    if let Ok(schema) = class.getattr(intern!(py, RustCodec::SCHEMA_ATTRIBUTE)) {
35        return depythonize(&schema)
36            .map_err(|err| SchemaError::InvalidCachedJsonSchema { source: err });
37    }
38
39    let mut schema = Schema::default();
40
41    {
42        let schema = schema.ensure_object();
43
44        schema.insert(String::from("type"), Value::String(String::from("object")));
45
46        if let Ok(init) = class.getattr(intern!(py, "__init__")) {
47            let mut properties = Map::new();
48            let mut additional_properties = false;
49            let mut required = Vec::new();
50
51            let object_init = once!(py, "builtins", "object", "__init__")?;
52            let signature = once!(py, "inspect", "signature")?;
53            let empty_parameter = once!(py, "inspect", "Parameter", "empty")?;
54            let args_parameter = once!(py, "inspect", "Parameter", "VAR_POSITIONAL")?;
55            let kwargs_parameter = once!(py, "inspect", "Parameter", "VAR_KEYWORD")?;
56
57            for (i, param) in signature
58                .call1((&init,))?
59                .getattr(intern!(py, "parameters"))?
60                .call_method0(intern!(py, "items"))?
61                .try_iter()?
62                .enumerate()
63            {
64                let (name, param): (String, Bound<PyAny>) = param?.extract()?;
65
66                if i == 0 && name == "self" {
67                    continue;
68                }
69
70                let kind = param.getattr(intern!(py, "kind"))?;
71
72                if kind.eq(args_parameter)? && !init.eq(object_init)? {
73                    return Err(SchemaError::ArgsParameterInSignature);
74                }
75
76                if kind.eq(kwargs_parameter)? {
77                    additional_properties = true;
78                } else {
79                    let default = param.getattr(intern!(py, "default"))?;
80
81                    let mut parameter = Map::new();
82
83                    if default.eq(empty_parameter)? {
84                        required.push(Value::String(name.clone()));
85                    } else {
86                        let default = depythonize(&default).map_err(|err| {
87                            SchemaError::InvalidParameterDefault {
88                                name: name.clone(),
89                                source: err,
90                            }
91                        })?;
92                        parameter.insert(String::from("default"), default);
93                    }
94
95                    properties.insert(name, Value::Object(parameter));
96                }
97            }
98
99            schema.insert(
100                String::from("additionalProperties"),
101                Value::Bool(additional_properties),
102            );
103            schema.insert(String::from("properties"), Value::Object(properties));
104            schema.insert(String::from("required"), Value::Array(required));
105        } else {
106            schema.insert(String::from("additionalProperties"), Value::Bool(true));
107        }
108
109        if let Ok(doc) = class.getattr(intern!(py, "__doc__")) {
110            if !doc.is_none() {
111                let doc: String = doc
112                    .extract()
113                    .map_err(|err| SchemaError::InvalidClassDocs { source: err })?;
114                schema.insert(String::from("description"), Value::String(doc));
115            }
116        }
117
118        let name = class
119            .getattr(intern!(py, "__name__"))
120            .and_then(|name| name.extract())
121            .map_err(|err| SchemaError::InvalidClassName { source: err })?;
122        schema.insert(String::from("title"), Value::String(name));
123
124        schema.insert(
125            String::from("$schema"),
126            Value::String(String::from("https://json-schema.org/draft/2020-12/schema")),
127        );
128    }
129
130    Ok(schema)
131}
132
133pub fn docs_from_schema(schema: &Schema) -> Option<String> {
134    let parameters = parameters_from_schema(schema);
135    let schema = schema.as_object()?;
136
137    let mut docs = String::new();
138
139    if let Some(Value::String(description)) = schema.get("description") {
140        docs.push_str(&derust_doc_comment(description));
141        docs.push_str("\n\n");
142    }
143
144    if !parameters.named.is_empty() || parameters.additional {
145        docs.push_str("Parameters\n----------\n");
146    }
147
148    for parameter in &parameters.named {
149        docs.push_str(parameter.name);
150
151        docs.push_str(" : ...");
152
153        if !parameter.required {
154            docs.push_str(", optional");
155        }
156
157        #[expect(clippy::format_push_string)] // FIXME
158        if let Some(default) = parameter.default {
159            docs.push_str(", default = ");
160            docs.push_str(&format!("{default}"));
161        }
162
163        docs.push('\n');
164
165        if let Some(info) = &parameter.docs {
166            docs.push_str("    ");
167            docs.push_str(&info.replace('\n', "\n    "));
168        }
169
170        docs.push('\n');
171    }
172
173    if parameters.additional {
174        docs.push_str("**kwargs\n");
175        docs.push_str("    ");
176
177        if parameters.named.is_empty() {
178            docs.push_str("This codec takes *any* parameters.");
179        } else {
180            docs.push_str("This codec takes *any* additional parameters.");
181        }
182    } else if parameters.named.is_empty() {
183        docs.push_str("This codec does *not* take any parameters.");
184    }
185
186    docs.truncate(docs.trim_end().len());
187
188    Some(docs)
189}
190
191pub fn signature_from_schema(schema: &Schema) -> String {
192    let parameters = parameters_from_schema(schema);
193
194    let mut signature = String::new();
195    signature.push_str("self");
196
197    for parameter in parameters.named {
198        signature.push_str(", ");
199        signature.push_str(parameter.name);
200
201        #[expect(clippy::format_push_string)] // FIXME
202        if let Some(default) = parameter.default {
203            signature.push('=');
204            signature.push_str(&format!("{default}"));
205        } else if !parameter.required {
206            signature.push_str("=None");
207        }
208    }
209
210    if parameters.additional {
211        signature.push_str(", **kwargs");
212    }
213
214    signature
215}
216
217fn parameters_from_schema(schema: &Schema) -> Parameters {
218    // schema = true means that any parameters are allowed
219    if schema.as_bool() == Some(true) {
220        return Parameters {
221            named: Vec::new(),
222            additional: true,
223        };
224    }
225
226    // schema = false means that no config is valid
227    // we approximate that by saying that no parameters are allowed
228    let Some(schema) = schema.as_object() else {
229        return Parameters {
230            named: Vec::new(),
231            additional: false,
232        };
233    };
234
235    let mut parameters = Vec::new();
236
237    let required = match schema.get("required") {
238        Some(Value::Array(required)) => &**required,
239        _ => &[],
240    };
241
242    // extract the top-level parameters
243    if let Some(Value::Object(properties)) = schema.get("properties") {
244        for (name, parameter) in properties {
245            parameters.push(Parameter::new(name, parameter, required));
246        }
247    }
248
249    let mut additional = false;
250
251    extend_parameters_from_one_of_schema(schema, &mut parameters, &mut additional);
252
253    // iterate over allOf to handle flattened enums
254    if let Some(Value::Array(all)) = schema.get("allOf") {
255        for variant in all {
256            if let Some(variant) = variant.as_object() {
257                extend_parameters_from_one_of_schema(variant, &mut parameters, &mut additional);
258            }
259        }
260    }
261
262    // sort parameters by name and so that required parameters come first
263    parameters.sort_by_key(|p| (!p.required, p.name));
264
265    additional = match (
266        schema.get("additionalProperties"),
267        schema.get("unevaluatedProperties"),
268    ) {
269        (Some(Value::Bool(false)), None) => additional,
270        (None | Some(Value::Bool(false)), Some(Value::Bool(false))) => false,
271        _ => true,
272    };
273
274    Parameters {
275        named: parameters,
276        additional,
277    }
278}
279
280fn extend_parameters_from_one_of_schema<'a>(
281    schema: &'a Map<String, Value>,
282    parameters: &mut Vec<Parameter<'a>>,
283    additional: &mut bool,
284) {
285    // iterate over oneOf to handle top-level or flattened enums
286    if let Some(Value::Array(variants)) = schema.get("oneOf") {
287        let mut variant_parameters = HashMap::new();
288
289        for (generation, schema) in variants.iter().enumerate() {
290            // if any variant allows additional parameters, the top-level also
291            //  allows additional parameters
292            #[expect(clippy::unnested_or_patterns)]
293            if let Some(schema) = schema.as_object() {
294                *additional |= !matches!(
295                    (
296                        schema.get("additionalProperties"),
297                        schema.get("unevaluatedProperties")
298                    ),
299                    (Some(Value::Bool(false)), None)
300                        | (None, Some(Value::Bool(false)))
301                        | (Some(Value::Bool(false)), Some(Value::Bool(false)))
302                );
303            }
304
305            let required = match schema.get("required") {
306                Some(Value::Array(required)) => &**required,
307                _ => &[],
308            };
309            let variant_docs = match schema.get("description") {
310                Some(Value::String(docs)) => Some(derust_doc_comment(docs)),
311                _ => None,
312            };
313
314            // extract the per-variant parameters and check for tag parameters
315            if let Some(Value::Object(properties)) = schema.get("properties") {
316                for (name, parameter) in properties {
317                    match variant_parameters.entry(name) {
318                        Entry::Vacant(entry) => {
319                            entry.insert(VariantParameter::new(
320                                generation,
321                                name,
322                                parameter,
323                                required,
324                                variant_docs.clone(),
325                            ));
326                        }
327                        Entry::Occupied(mut entry) => {
328                            entry.get_mut().merge(
329                                generation,
330                                name,
331                                parameter,
332                                required,
333                                variant_docs.clone(),
334                            );
335                        }
336                    }
337                }
338            }
339
340            // ensure that only parameters in all variants are required or tags
341            for parameter in variant_parameters.values_mut() {
342                parameter.update_generation(generation);
343            }
344        }
345
346        // merge the variant parameters into the top-level parameters
347        parameters.extend(
348            variant_parameters
349                .into_values()
350                .map(VariantParameter::into_parameter),
351        );
352    }
353}
354
355fn derust_doc_comment(docs: &str) -> Cow<str> {
356    if docs.trim() != docs {
357        return Cow::Borrowed(docs);
358    }
359
360    if !docs
361        .split('\n')
362        .skip(1)
363        .all(|l| l.trim().is_empty() || l.starts_with(' '))
364    {
365        return Cow::Borrowed(docs);
366    }
367
368    Cow::Owned(docs.replace("\n ", "\n"))
369}
370
371#[derive(Debug, Error)]
372pub enum SchemaError {
373    #[error("codec class' cached config schema is invalid")]
374    InvalidCachedJsonSchema { source: PythonizeError },
375    #[error("extracting the codec signature failed")]
376    SignatureExtraction {
377        #[from]
378        source: PyErr,
379    },
380    #[error("codec's signature must not contain an `*args` parameter")]
381    ArgsParameterInSignature,
382    #[error("{name} parameter's default value is invalid")]
383    InvalidParameterDefault {
384        name: String,
385        source: PythonizeError,
386    },
387    #[error("codec class's `__doc__` must be a string")]
388    InvalidClassDocs { source: PyErr },
389    #[error("codec class must have a string `__name__`")]
390    InvalidClassName { source: PyErr },
391}
392
393struct Parameters<'a> {
394    named: Vec<Parameter<'a>>,
395    additional: bool,
396}
397
398struct Parameter<'a> {
399    name: &'a str,
400    required: bool,
401    default: Option<&'a Value>,
402    docs: Option<Cow<'a, str>>,
403}
404
405impl<'a> Parameter<'a> {
406    #[must_use]
407    pub fn new(name: &'a str, parameter: &'a Value, required: &[Value]) -> Self {
408        Self {
409            name,
410            required: required
411                .iter()
412                .any(|r| matches!(r, Value::String(n) if n == name)),
413            default: parameter.get("default"),
414            docs: match parameter.get("description") {
415                Some(Value::String(docs)) => Some(derust_doc_comment(docs)),
416                _ => None,
417            },
418        }
419    }
420}
421
422struct VariantParameter<'a> {
423    generation: usize,
424    parameter: Parameter<'a>,
425    #[expect(clippy::type_complexity)]
426    tag_docs: Option<Vec<(&'a Value, Option<Cow<'a, str>>)>>,
427}
428
429impl<'a> VariantParameter<'a> {
430    #[must_use]
431    pub fn new(
432        generation: usize,
433        name: &'a str,
434        parameter: &'a Value,
435        required: &[Value],
436        variant_docs: Option<Cow<'a, str>>,
437    ) -> Self {
438        let r#const = parameter.get("const");
439
440        let mut parameter = Parameter::new(name, parameter, required);
441        parameter.required &= generation == 0;
442
443        let tag_docs = match r#const {
444            // a tag parameter must be introduced in the first generation
445            Some(r#const) if generation == 0 => {
446                let docs = parameter.docs.take().or(variant_docs);
447                Some(vec![(r#const, docs)])
448            }
449            _ => None,
450        };
451
452        Self {
453            generation,
454            parameter,
455            tag_docs,
456        }
457    }
458
459    pub fn merge(
460        &mut self,
461        generation: usize,
462        name: &'a str,
463        parameter: &'a Value,
464        required: &[Value],
465        variant_docs: Option<Cow<'a, str>>,
466    ) {
467        self.generation = generation;
468
469        let r#const = parameter.get("const");
470
471        let parameter = Parameter::new(name, parameter, required);
472
473        self.parameter.required &= parameter.required;
474        if self.parameter.default != parameter.default {
475            self.parameter.default = None;
476        }
477
478        if let Some(tag_docs) = &mut self.tag_docs {
479            // we're building docs for a tag-like parameter
480            if let Some(r#const) = r#const {
481                tag_docs.push((r#const, parameter.docs.or(variant_docs)));
482            } else {
483                // mixing tag and non-tag parameter => no docs
484                self.tag_docs = None;
485                self.parameter.docs = None;
486            }
487        } else {
488            // we're building docs for a normal parameter
489            if r#const.is_none() {
490                // we only accept always matching docs for normal parameters
491                if self.parameter.docs != parameter.docs {
492                    self.parameter.docs = None;
493                }
494            } else {
495                // mixing tag and non-tag parameter => no docs
496                self.tag_docs = None;
497            }
498        }
499    }
500
501    pub fn update_generation(&mut self, generation: usize) {
502        if self.generation < generation {
503            // required and tag parameters must appear in all generations
504            self.parameter.required = false;
505            self.tag_docs = None;
506        }
507    }
508
509    #[must_use]
510    pub fn into_parameter(mut self) -> Parameter<'a> {
511        if let Some(tag_docs) = self.tag_docs {
512            let mut docs = String::new();
513
514            #[expect(clippy::format_push_string)] // FIXME
515            for (tag, tag_docs) in tag_docs {
516                docs.push_str(" - ");
517                docs.push_str(&format!("{tag}"));
518                if let Some(tag_docs) = tag_docs {
519                    docs.push_str(": ");
520                    docs.push_str(&tag_docs.replace('\n', "\n    "));
521                }
522                docs.push_str("\n\n");
523            }
524
525            docs.truncate(docs.trim_end().len());
526
527            self.parameter.docs = Some(Cow::Owned(docs));
528        }
529
530        self.parameter
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use schemars::{schema_for, JsonSchema};
537
538    use super::*;
539
540    #[test]
541    fn schema() {
542        assert_eq!(
543            format!("{}", schema_for!(MyCodec).to_value()),
544            r#"{"type":"object","properties":{"param":{"type":["integer","null"],"format":"int32","description":"An optional integer value."}},"unevaluatedProperties":false,"oneOf":[{"type":"object","description":"Mode a.\n\n It gets another line.","properties":{"value":{"type":"boolean","description":"A boolean value. And some really, really, really, long first\n line that wraps around.\n\n With multiple lines of comments."},"common":{"type":"string","description":"A common string value.\n\n Something else here."},"mode":{"type":"string","const":"A"}},"required":["mode","value","common"]},{"type":"object","description":"Mode b.","properties":{"common":{"type":"string","description":"A common string value.\n\n Something else here."},"mode":{"type":"string","const":"B"}},"required":["mode","common"]}],"description":"A codec that does something on encoding and decoding.\n\n With multiple lines of comments.","title":"MyCodec","$schema":"https://json-schema.org/draft/2020-12/schema"}"#
545        );
546    }
547
548    #[test]
549    fn docs() {
550        assert_eq!(
551            docs_from_schema(&schema_for!(MyCodec)).as_deref(),
552            Some(
553                r#"A codec that does something on encoding and decoding.
554
555With multiple lines of comments.
556
557Parameters
558----------
559common : ...
560    A common string value.
561    
562    Something else here.
563mode : ...
564     - "A": Mode a.
565        
566        It gets another line.
567    
568     - "B": Mode b.
569param : ..., optional
570    An optional integer value.
571value : ..., optional
572    A boolean value. And some really, really, really, long first
573    line that wraps around.
574    
575    With multiple lines of comments."#
576            )
577        );
578    }
579
580    #[test]
581    fn signature() {
582        assert_eq!(
583            signature_from_schema(&schema_for!(MyCodec)),
584            "self, common, mode, param=None, value=None",
585        );
586    }
587
588    #[expect(dead_code)]
589    #[derive(JsonSchema)]
590    #[schemars(deny_unknown_fields)]
591    /// A codec that does something on encoding and decoding.
592    ///
593    /// With multiple lines of comments.
594    struct MyCodec {
595        /// An optional integer value.
596        #[schemars(default, skip_serializing_if = "Option::is_none")]
597        param: Option<i32>,
598        /// The flattened configuration.
599        #[schemars(flatten)]
600        config: Config,
601    }
602
603    #[expect(dead_code)]
604    #[derive(JsonSchema)]
605    #[schemars(tag = "mode")]
606    #[schemars(deny_unknown_fields)]
607    enum Config {
608        /// Mode a.
609        ///
610        /// It gets another line.
611        A {
612            /// A boolean value. And some really, really, really, long first
613            /// line that wraps around.
614            ///
615            /// With multiple lines of comments.
616            value: bool,
617            /// A common string value.
618            ///
619            /// Something else here.
620            common: String,
621        },
622        /// Mode b.
623        B {
624            /// A common string value.
625            ///
626            /// Something else here.
627            common: String,
628        },
629    }
630}