1use std::{
2 borrow::Cow,
3 collections::{HashMap, hash_map::Entry},
4 io,
5 num::FpCategory,
6};
7
8use pyo3::{intern, prelude::*, sync::PyOnceLock};
9use pythonize::{PythonizeError, depythonize};
10use schemars::Schema;
11use serde_json::{Map, Value};
12use thiserror::Error;
13
14use crate::{PyCodecClass, export::RustCodec};
15
16macro_rules! once {
17 ($py:ident, $module:literal $(, $path:literal)*) => {{
18 fn once(py: Python<'_>) -> Result<&Bound<'_, PyAny>, PyErr> {
19 static ONCE: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
20 Ok(ONCE.get_or_try_init(py, || -> Result<Py<PyAny>, PyErr> {
21 Ok(py
22 .import(intern!(py, $module))?
23 $(.getattr(intern!(py, $path))?)*
24 .unbind())
25 })?.bind(py))
26 }
27
28 once($py)
29 }};
30}
31
32pub fn schema_from_codec_class(
33 py: Python,
34 class: &Bound<PyCodecClass>,
35) -> Result<Schema, SchemaError> {
36 if let Ok(schema) = class.getattr(intern!(py, RustCodec::SCHEMA_ATTRIBUTE)) {
37 return depythonize(&schema)
38 .map_err(|err| SchemaError::InvalidCachedJsonSchema { source: err });
39 }
40
41 let mut schema = Schema::default();
42
43 {
44 let schema = schema.ensure_object();
45
46 schema.insert(String::from("type"), Value::String(String::from("object")));
47
48 if let Ok(init) = class.getattr(intern!(py, "__init__")) {
49 let mut properties = Map::new();
50 let mut additional_properties = false;
51 let mut required = Vec::new();
52
53 let object_init = once!(py, "builtins", "object", "__init__")?;
54 let signature = once!(py, "inspect", "signature")?;
55 let empty_parameter = once!(py, "inspect", "Parameter", "empty")?;
56 let args_parameter = once!(py, "inspect", "Parameter", "VAR_POSITIONAL")?;
57 let kwargs_parameter = once!(py, "inspect", "Parameter", "VAR_KEYWORD")?;
58
59 for (i, param) in signature
60 .call1((&init,))?
61 .getattr(intern!(py, "parameters"))?
62 .call_method0(intern!(py, "items"))?
63 .try_iter()?
64 .enumerate()
65 {
66 let (name, param): (String, Bound<PyAny>) = param?.extract()?;
67
68 if i == 0 && name == "self" {
69 continue;
70 }
71
72 let kind = param.getattr(intern!(py, "kind"))?;
73
74 if kind.eq(args_parameter)? && !init.eq(object_init)? {
75 return Err(SchemaError::ArgsParameterInSignature);
76 }
77
78 if kind.eq(kwargs_parameter)? {
79 additional_properties = true;
80 } else {
81 let default = param.getattr(intern!(py, "default"))?;
82
83 let mut parameter = Map::new();
84
85 if default.eq(empty_parameter)? {
86 required.push(Value::String(name.clone()));
87 } else {
88 let default = depythonize(&default).map_err(|err| {
89 SchemaError::InvalidParameterDefault {
90 name: name.clone(),
91 source: err,
92 }
93 })?;
94 parameter.insert(String::from("default"), default);
95 }
96
97 properties.insert(name, Value::Object(parameter));
98 }
99 }
100
101 schema.insert(
102 String::from("additionalProperties"),
103 Value::Bool(additional_properties),
104 );
105 schema.insert(String::from("properties"), Value::Object(properties));
106 schema.insert(String::from("required"), Value::Array(required));
107 } else {
108 schema.insert(String::from("additionalProperties"), Value::Bool(true));
109 }
110
111 if let Ok(doc) = class.getattr(intern!(py, "__doc__")) {
112 if !doc.is_none() {
113 let doc: String = doc
114 .extract()
115 .map_err(|err| SchemaError::InvalidClassDocs { source: err })?;
116 schema.insert(String::from("description"), Value::String(doc));
117 }
118 }
119
120 let name = class
121 .getattr(intern!(py, "__name__"))
122 .and_then(|name| name.extract())
123 .map_err(|err| SchemaError::InvalidClassName { source: err })?;
124 schema.insert(String::from("title"), Value::String(name));
125
126 schema.insert(
127 String::from("$schema"),
128 Value::String(String::from("https://json-schema.org/draft/2020-12/schema")),
129 );
130 }
131
132 Ok(schema)
133}
134
135pub fn docs_from_schema(schema: &Schema) -> Option<String> {
136 let parameters = parameters_from_schema(schema);
137 let schema = schema.as_object()?;
138
139 let mut docs = String::new();
140
141 if let Some(Value::String(description)) = schema.get("description") {
142 docs.push_str(description);
143 docs.push_str("\n\n");
144 }
145
146 if !parameters.named.is_empty() || parameters.additional {
147 docs.push_str("Parameters\n----------\n");
148 }
149
150 for parameter in ¶meters.named {
151 docs.push_str(parameter.name);
152
153 docs.push_str(" : ...");
154
155 if !parameter.required {
156 docs.push_str(", optional");
157 }
158
159 if let Some(default) = parameter.default {
160 docs.push_str(", default = ");
161 JsonToPythonFormatter::push_to_string(&mut docs, default);
162 }
163
164 docs.push('\n');
165
166 if let Some(info) = ¶meter.docs {
167 docs.push_str(" ");
168 docs.push_str(&info.replace('\n', "\n "));
169 }
170
171 docs.push('\n');
172 }
173
174 if parameters.additional {
175 docs.push_str("**kwargs\n");
176 docs.push_str(" ");
177
178 if parameters.named.is_empty() {
179 docs.push_str("This codec takes *any* parameters.");
180 } else {
181 docs.push_str("This codec takes *any* additional parameters.");
182 }
183 } else if parameters.named.is_empty() {
184 docs.push_str("This codec does *not* take any parameters.");
185 }
186
187 docs.truncate(docs.trim_end().len());
188
189 Some(docs)
190}
191
192pub fn signature_from_schema(schema: &Schema) -> String {
193 let parameters = parameters_from_schema(schema);
194
195 let mut signature = String::new();
196 signature.push_str("self");
197
198 for parameter in parameters.named {
199 signature.push_str(", ");
200 signature.push_str(parameter.name);
201
202 if let Some(default) = parameter.default {
203 signature.push('=');
204 JsonToPythonFormatter::push_to_string(&mut signature, default);
205 } else if !parameter.required {
206 signature.push_str("=None");
207 }
208 }
209
210 if parameters.additional {
211 signature.push_str(", **kwargs");
212 }
213
214 signature
215}
216
217fn parameters_from_schema(schema: &Schema) -> Parameters<'_> {
218 if schema.as_bool() == Some(true) {
220 return Parameters {
221 named: Vec::new(),
222 additional: true,
223 };
224 }
225
226 let Some(schema) = schema.as_object() else {
229 return Parameters {
230 named: Vec::new(),
231 additional: false,
232 };
233 };
234
235 let mut parameters = Vec::new();
236
237 let required = match schema.get("required") {
238 Some(Value::Array(required)) => &**required,
239 _ => &[],
240 };
241
242 if let Some(Value::Object(properties)) = schema.get("properties") {
244 for (name, parameter) in properties {
245 parameters.push(Parameter::new(name, parameter, required));
246 }
247 }
248
249 let mut additional = false;
250
251 extend_parameters_from_one_of_schema(schema, &mut parameters, &mut additional);
252
253 if let Some(Value::Array(all)) = schema.get("allOf") {
255 for variant in all {
256 if let Some(variant) = variant.as_object() {
257 extend_parameters_from_one_of_schema(variant, &mut parameters, &mut additional);
258 }
259 }
260 }
261
262 parameters.sort_by_key(|p| (!p.required, p.name));
264
265 additional = match (
266 schema.get("additionalProperties"),
267 schema.get("unevaluatedProperties"),
268 ) {
269 (Some(Value::Bool(false)), None) => additional,
270 (None | Some(Value::Bool(false)), Some(Value::Bool(false))) => false,
271 _ => true,
272 };
273
274 Parameters {
275 named: parameters,
276 additional,
277 }
278}
279
280fn extend_parameters_from_one_of_schema<'a>(
281 schema: &'a Map<String, Value>,
282 parameters: &mut Vec<Parameter<'a>>,
283 additional: &mut bool,
284) {
285 if let Some(Value::Array(variants)) = schema.get("oneOf") {
287 let mut variant_parameters = HashMap::new();
288
289 for (generation, schema) in variants.iter().enumerate() {
290 #[expect(clippy::unnested_or_patterns)]
293 if let Some(schema) = schema.as_object() {
294 *additional |= !matches!(
295 (
296 schema.get("additionalProperties"),
297 schema.get("unevaluatedProperties")
298 ),
299 (Some(Value::Bool(false)), None)
300 | (None, Some(Value::Bool(false)))
301 | (Some(Value::Bool(false)), Some(Value::Bool(false)))
302 );
303 }
304
305 let required = match schema.get("required") {
306 Some(Value::Array(required)) => &**required,
307 _ => &[],
308 };
309 let variant_docs = match schema.get("description") {
310 Some(Value::String(docs)) => Some(docs),
311 _ => None,
312 };
313
314 if let Some(Value::Object(properties)) = schema.get("properties") {
316 for (name, parameter) in properties {
317 match variant_parameters.entry(name) {
318 Entry::Vacant(entry) => {
319 entry.insert(VariantParameter::new(
320 generation,
321 name,
322 parameter,
323 required,
324 variant_docs.map(|x| Cow::Borrowed(x.as_str())),
325 ));
326 }
327 Entry::Occupied(mut entry) => {
328 entry.get_mut().merge(
329 generation,
330 name,
331 parameter,
332 required,
333 variant_docs.map(|x| Cow::Borrowed(x.as_str())),
334 );
335 }
336 }
337 }
338 }
339
340 for parameter in variant_parameters.values_mut() {
342 parameter.update_generation(generation);
343 }
344 }
345
346 parameters.extend(
348 variant_parameters
349 .into_values()
350 .map(VariantParameter::into_parameter),
351 );
352 }
353}
354
355#[derive(Debug, Error)]
356pub enum SchemaError {
357 #[error("codec class' cached config schema is invalid")]
358 InvalidCachedJsonSchema { source: PythonizeError },
359 #[error("extracting the codec signature failed")]
360 SignatureExtraction {
361 #[from]
362 source: PyErr,
363 },
364 #[error("codec's signature must not contain an `*args` parameter")]
365 ArgsParameterInSignature,
366 #[error("{name} parameter's default value is invalid")]
367 InvalidParameterDefault {
368 name: String,
369 source: PythonizeError,
370 },
371 #[error("codec class's `__doc__` must be a string")]
372 InvalidClassDocs { source: PyErr },
373 #[error("codec class must have a string `__name__`")]
374 InvalidClassName { source: PyErr },
375}
376
377struct Parameters<'a> {
378 named: Vec<Parameter<'a>>,
379 additional: bool,
380}
381
382struct Parameter<'a> {
383 name: &'a str,
384 required: bool,
385 default: Option<&'a Value>,
386 docs: Option<Cow<'a, str>>,
387}
388
389impl<'a> Parameter<'a> {
390 #[must_use]
391 pub fn new(name: &'a str, parameter: &'a Value, required: &[Value]) -> Self {
392 Self {
393 name,
394 required: required
395 .iter()
396 .any(|r| matches!(r, Value::String(n) if n == name)),
397 default: parameter.get("default"),
398 docs: match parameter.get("description") {
399 Some(Value::String(docs)) => Some(Cow::Borrowed(docs.as_str())),
400 _ => None,
401 },
402 }
403 }
404}
405
406struct VariantParameter<'a> {
407 generation: usize,
408 parameter: Parameter<'a>,
409 #[expect(clippy::type_complexity)]
410 tag_docs: Option<Vec<(&'a Value, Option<Cow<'a, str>>)>>,
411}
412
413impl<'a> VariantParameter<'a> {
414 #[must_use]
415 pub fn new(
416 generation: usize,
417 name: &'a str,
418 parameter: &'a Value,
419 required: &[Value],
420 variant_docs: Option<Cow<'a, str>>,
421 ) -> Self {
422 let r#const = parameter.get("const");
423
424 let mut parameter = Parameter::new(name, parameter, required);
425 parameter.required &= generation == 0;
426
427 let tag_docs = match r#const {
428 Some(r#const) if generation == 0 => {
430 let docs = parameter.docs.take().or(variant_docs);
431 Some(vec![(r#const, docs)])
432 }
433 _ => None,
434 };
435
436 Self {
437 generation,
438 parameter,
439 tag_docs,
440 }
441 }
442
443 pub fn merge(
444 &mut self,
445 generation: usize,
446 name: &'a str,
447 parameter: &'a Value,
448 required: &[Value],
449 variant_docs: Option<Cow<'a, str>>,
450 ) {
451 self.generation = generation;
452
453 let r#const = parameter.get("const");
454
455 let parameter = Parameter::new(name, parameter, required);
456
457 self.parameter.required &= parameter.required;
458 if self.parameter.default != parameter.default {
459 self.parameter.default = None;
460 }
461
462 if let Some(tag_docs) = &mut self.tag_docs {
463 if let Some(r#const) = r#const {
465 tag_docs.push((r#const, parameter.docs.or(variant_docs)));
466 } else {
467 self.tag_docs = None;
469 self.parameter.docs = None;
470 }
471 } else {
472 if r#const.is_none() {
474 if self.parameter.docs != parameter.docs {
476 self.parameter.docs = None;
477 }
478 } else {
479 self.tag_docs = None;
481 }
482 }
483 }
484
485 pub fn update_generation(&mut self, generation: usize) {
486 if self.generation < generation {
487 self.parameter.required = false;
489 self.tag_docs = None;
490 }
491 }
492
493 #[must_use]
494 pub fn into_parameter(mut self) -> Parameter<'a> {
495 if let Some(tag_docs) = self.tag_docs {
496 let mut docs = String::new();
497
498 for (tag, tag_docs) in tag_docs {
499 docs.push_str(" - ");
500 JsonToPythonFormatter::push_to_string(&mut docs, tag);
501 if let Some(tag_docs) = tag_docs {
502 docs.push_str(": ");
503 docs.push_str(&tag_docs.replace('\n', "\n "));
504 }
505 docs.push_str("\n\n");
506 }
507
508 docs.truncate(docs.trim_end().len());
509
510 self.parameter.docs = Some(Cow::Owned(docs));
511 }
512
513 self.parameter
514 }
515}
516
517struct JsonToPythonFormatter;
518
519impl JsonToPythonFormatter {
520 fn push_to_string(buffer: &mut String, value: &Value) {
521 #[expect(unsafe_code)]
525 let mut ser = serde_json::Serializer::with_formatter(unsafe { buffer.as_mut_vec() }, Self);
526 #[allow(clippy::expect_used)]
527 serde::Serialize::serialize(value, &mut ser)
528 .expect("JSON value must not fail to serialize");
529 }
530}
531
532impl serde_json::ser::Formatter for JsonToPythonFormatter {
533 #[inline]
534 fn write_null<W: ?Sized + io::Write>(&mut self, writer: &mut W) -> io::Result<()> {
535 writer.write_all(b"None")
536 }
537
538 #[inline]
539 fn write_bool<W: ?Sized + io::Write>(&mut self, writer: &mut W, value: bool) -> io::Result<()> {
540 let s: &[u8] = if value { b"True" } else { b"False" };
541 writer.write_all(s)
542 }
543
544 #[inline]
545 fn write_f32<W: ?Sized + io::Write>(&mut self, writer: &mut W, value: f32) -> io::Result<()> {
546 let s: &[u8] = match (value.classify(), value.is_sign_negative()) {
547 (FpCategory::Nan, false) => b"float('nan')",
548 (FpCategory::Nan, true) => b"float('-nan')",
549 (FpCategory::Infinite, false) => b"float('inf')",
550 (FpCategory::Infinite, true) => b"float('-inf')",
551 (FpCategory::Zero | FpCategory::Subnormal | FpCategory::Normal, false | true) => {
552 return serde_json::ser::CompactFormatter.write_f32(writer, value);
553 }
554 };
555 writer.write_all(s)
556 }
557
558 #[inline]
559 fn write_f64<W: ?Sized + io::Write>(&mut self, writer: &mut W, value: f64) -> io::Result<()> {
560 let s: &[u8] = match (value.classify(), value.is_sign_negative()) {
561 (FpCategory::Nan, false) => b"float('nan')",
562 (FpCategory::Nan, true) => b"float('-nan')",
563 (FpCategory::Infinite, false) => b"float('inf')",
564 (FpCategory::Infinite, true) => b"float('-inf')",
565 (FpCategory::Zero | FpCategory::Subnormal | FpCategory::Normal, false | true) => {
566 return serde_json::ser::CompactFormatter.write_f64(writer, value);
567 }
568 };
569 writer.write_all(s)
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use schemars::{JsonSchema, schema_for};
576
577 use super::*;
578
579 #[test]
580 fn schema() {
581 assert_eq!(
582 format!("{}", schema_for!(MyCodec).to_value()),
583 r#"{"type":"object","properties":{"param":{"type":["integer","null"],"format":"int32","description":"An optional integer value."}},"unevaluatedProperties":false,"oneOf":[{"type":"object","description":"Mode a.\n\nIt gets another line.","properties":{"value":{"type":"boolean","description":"A boolean value. And some really, really, really, long first\nline that wraps around.\n\nWith multiple lines of comments."},"common":{"type":"string","description":"A common string value.\n\nSomething else here."},"mode":{"type":"string","const":"A"}},"required":["mode","value","common"]},{"type":"object","description":"Mode b.","properties":{"common":{"type":"string","description":"A common string value.\n\nSomething else here."},"mode":{"type":"string","const":"B"}},"required":["mode","common"]}],"description":"A codec that does something on encoding and decoding.\n\nWith multiple lines of comments.","title":"MyCodec","$schema":"https://json-schema.org/draft/2020-12/schema"}"#
584 );
585 }
586
587 #[test]
588 fn docs() {
589 assert_eq!(
590 docs_from_schema(&schema_for!(MyCodec)).as_deref(),
591 Some(
592 r#"A codec that does something on encoding and decoding.
593
594With multiple lines of comments.
595
596Parameters
597----------
598common : ...
599 A common string value.
600
601 Something else here.
602mode : ...
603 - "A": Mode a.
604
605 It gets another line.
606
607 - "B": Mode b.
608param : ..., optional
609 An optional integer value.
610value : ..., optional
611 A boolean value. And some really, really, really, long first
612 line that wraps around.
613
614 With multiple lines of comments."#
615 )
616 );
617 }
618
619 #[test]
620 fn signature() {
621 assert_eq!(
622 signature_from_schema(&schema_for!(MyCodec)),
623 "self, common, mode, param=None, value=None",
624 );
625 }
626
627 #[expect(dead_code)]
628 #[derive(JsonSchema)]
629 #[schemars(deny_unknown_fields)]
630 struct MyCodec {
634 #[schemars(default, skip_serializing_if = "Option::is_none")]
636 param: Option<i32>,
637 #[schemars(flatten)]
639 config: Config,
640 }
641
642 #[expect(dead_code)]
643 #[derive(JsonSchema)]
644 #[schemars(tag = "mode")]
645 #[schemars(deny_unknown_fields)]
646 enum Config {
647 A {
651 value: bool,
656 common: String,
660 },
661 B {
663 common: String,
667 },
668 }
669}