1use std::{borrow::Cow, error::Error, fmt, io};
21
22use pyo3::{exceptions::PyException, intern, prelude::*, sync::GILOnceCell, types::IntoPyDict};
23
24pub struct PyErrChain {
35 err: PyErr,
36 cause: Option<Box<Self>>,
37}
38
39impl PyErrChain {
40 #[must_use]
56 #[inline]
57 pub fn new<T: Into<Box<dyn Error + 'static>>>(py: Python, err: T) -> Self {
58 Self::from_pyerr(py, Self::pyerr_from_err(py, err))
59 }
60
61 #[must_use]
77 #[inline]
78 pub fn new_with_translator<
79 E: Into<Box<dyn Error + 'static>>,
80 T: AnyErrorToPyErr,
81 M: MapErrorToPyErr,
82 >(
83 py: Python,
84 err: E,
85 ) -> Self {
86 Self::from_pyerr(py, Self::pyerr_from_err_with_translator::<E, T, M>(py, err))
87 }
88
89 #[must_use]
103 #[inline]
104 pub fn pyerr_from_err<T: Into<Box<dyn Error + 'static>>>(py: Python, err: T) -> PyErr {
105 Self::pyerr_from_err_with_translator::<T, ErrorNoPyErr, DowncastToPyErr>(py, err)
106 }
107
108 #[must_use]
121 pub fn pyerr_from_err_with_translator<
122 E: Into<Box<dyn Error + 'static>>,
123 T: AnyErrorToPyErr,
124 M: MapErrorToPyErr,
125 >(
126 py: Python,
127 err: E,
128 ) -> PyErr {
129 let err: Box<dyn Error + 'static> = err.into();
130
131 let err = match M::try_map(py, err, |err: Box<Self>| err.into_pyerr()) {
132 Ok(err) => return err,
133 Err(err) => err,
134 };
135 let err = match M::try_map(py, err, |err: Box<PyErr>| *err) {
136 Ok(err) => return err,
137 Err(err) => err,
138 };
139
140 let mut chain = Vec::new();
141
142 let mut source = err.source();
143 let mut cause = None;
144
145 while let Some(err) = source.take() {
146 if let Some(err) = M::try_map_ref(py, err, |err: &Self| err.as_pyerr().clone_ref(py)) {
147 cause = err.cause(py);
148 chain.push(err);
149 break;
150 }
151 if let Some(err) = M::try_map_ref(py, err, |err: &PyErr| err.clone_ref(py)) {
152 cause = err.cause(py);
153 chain.push(err);
154 break;
155 }
156
157 source = err.source();
158
159 #[allow(clippy::option_if_let_else)]
160 chain.push(match T::try_from_err_ref::<M>(py, err) {
161 Some(err) => err,
162 None => PyException::new_err(format!("{err}")),
163 });
164 }
165
166 while let Some(err) = chain.pop() {
167 err.set_cause(py, cause.take());
168 cause = Some(err);
169 }
170
171 let err = match T::try_from_err::<M>(py, err) {
172 Ok(err) => err,
173 Err(err) => PyException::new_err(format!("{err}")),
174 };
175 err.set_cause(py, cause);
176
177 err
178 }
179
180 #[must_use]
183 pub fn from_pyerr(py: Python, err: PyErr) -> Self {
184 let mut chain = Vec::new();
185
186 let mut cause = err.cause(py);
187
188 while let Some(err) = cause.take() {
189 cause = err.cause(py);
190 chain.push(Self { err, cause: None });
191 }
192
193 let mut cause = None;
194
195 while let Some(mut err) = chain.pop() {
196 err.cause = cause.take();
197 cause = Some(Box::new(err));
198 }
199
200 Self { err, cause }
201 }
202
203 #[must_use]
205 pub fn into_pyerr(self) -> PyErr {
206 self.err
207 }
208
209 #[must_use]
215 pub const fn as_pyerr(&self) -> &PyErr {
216 &self.err
217 }
218
219 #[must_use]
225 pub fn cause(&self) -> Option<&PyErr> {
226 self.cause.as_deref().map(Self::as_pyerr)
227 }
228
229 #[must_use]
237 pub fn clone_ref(&self, py: Python) -> Self {
238 Self {
239 err: self.err.clone_ref(py),
240 cause: self
241 .cause
242 .as_ref()
243 .map(|cause| Box::new(cause.clone_ref(py))),
244 }
245 }
246}
247
248impl fmt::Debug for PyErrChain {
249 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
250 Python::with_gil(|py| {
251 let traceback = self.err.traceback(py).map(|tb| {
252 tb.format()
253 .map_or(Cow::Borrowed("<traceback str() failed>"), |tb| {
254 Cow::Owned(tb)
255 })
256 });
257
258 fmt.debug_struct("PyErrChain")
259 .field("type", &self.err.get_type(py))
260 .field("value", self.err.value(py))
261 .field("traceback", &traceback)
262 .field("cause", &self.cause)
263 .finish()
264 })
265 }
266}
267
268impl fmt::Display for PyErrChain {
269 #[inline]
270 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
271 fmt::Display::fmt(&self.err, fmt)
272 }
273}
274
275impl Error for PyErrChain {
276 fn source(&self) -> Option<&(dyn Error + 'static)> {
277 self.cause.as_deref().map(|cause| cause as &dyn Error)
278 }
279}
280
281impl From<PyErr> for PyErrChain {
282 fn from(err: PyErr) -> Self {
283 Python::with_gil(|py| Self::from_pyerr(py, err))
284 }
285}
286
287impl From<PyErrChain> for PyErr {
288 fn from(err: PyErrChain) -> Self {
289 err.into_pyerr()
290 }
291}
292
293pub trait AnyErrorToPyErr {
299 fn try_from_err<T: MapErrorToPyErr>(
310 py: Python,
311 err: Box<dyn Error + 'static>,
312 ) -> Result<PyErr, Box<dyn Error + 'static>>;
313
314 fn try_from_err_ref<T: MapErrorToPyErr>(
323 py: Python,
324 err: &(dyn Error + 'static),
325 ) -> Option<PyErr>;
326}
327
328pub trait MapErrorToPyErr {
334 fn try_map<T: Error + 'static>(
344 py: Python,
345 err: Box<dyn Error + 'static>,
346 map: impl FnOnce(Box<T>) -> PyErr,
347 ) -> Result<PyErr, Box<dyn Error + 'static>>;
348
349 fn try_map_send_sync<T: Error + 'static>(
359 py: Python,
360 err: Box<dyn Error + Send + Sync + 'static>,
361 map: impl FnOnce(Box<T>) -> PyErr,
362 ) -> Result<PyErr, Box<dyn Error + Send + Sync + 'static>>;
363
364 fn try_map_ref<T: Error + 'static>(
370 py: Python,
371 err: &(dyn Error + 'static),
372 map: impl FnOnce(&T) -> PyErr,
373 ) -> Option<PyErr>;
374}
375
376pub struct ErrorNoPyErr;
379
380impl AnyErrorToPyErr for ErrorNoPyErr {
381 #[inline]
382 fn try_from_err<T: MapErrorToPyErr>(
383 _py: Python,
384 err: Box<dyn Error + 'static>,
385 ) -> Result<PyErr, Box<dyn Error + 'static>> {
386 Err(err)
387 }
388
389 #[inline]
390 fn try_from_err_ref<T: MapErrorToPyErr>(
391 _py: Python,
392 _err: &(dyn Error + 'static),
393 ) -> Option<PyErr> {
394 None
395 }
396}
397
398pub struct IoErrorToPyErr;
400
401impl AnyErrorToPyErr for IoErrorToPyErr {
402 fn try_from_err<T: MapErrorToPyErr>(
403 py: Python,
404 err: Box<dyn Error + 'static>,
405 ) -> Result<PyErr, Box<dyn Error + 'static>> {
406 T::try_map(py, err, |err: Box<io::Error>| {
407 let kind = err.kind();
408
409 if err.get_ref().is_some() {
410 #[allow(clippy::unwrap_used)] let err = err.into_inner().unwrap();
412
413 let err = match T::try_map_send_sync(py, err, |err: Box<PyErr>| *err) {
414 Ok(err) => return err,
415 Err(err) => err,
416 };
417
418 let err =
419 match T::try_map_send_sync(py, err, |err: Box<PyErrChain>| err.into_pyerr()) {
420 Ok(err) => return err,
421 Err(err) => err,
422 };
423
424 return PyErr::from(io::Error::new(kind, err));
425 }
426
427 PyErr::from(*err)
428 })
429 }
430
431 fn try_from_err_ref<T: MapErrorToPyErr>(
432 py: Python,
433 err: &(dyn Error + 'static),
434 ) -> Option<PyErr> {
435 T::try_map_ref(py, err, |err: &io::Error| {
436 if let Some(err) = err.get_ref() {
437 if let Some(err) = T::try_map_ref(py, err, |err: &PyErr| err.clone_ref(py)) {
438 return err;
439 }
440
441 if let Some(err) =
442 T::try_map_ref(py, err, |err: &PyErrChain| err.as_pyerr().clone_ref(py))
443 {
444 return err;
445 }
446 }
447
448 PyErr::from(io::Error::new(err.kind(), format!("{err}")))
449 })
450 }
451}
452
453pub struct DowncastToPyErr;
456
457impl MapErrorToPyErr for DowncastToPyErr {
458 fn try_map<T: Error + 'static>(
459 _py: Python,
460 err: Box<dyn Error + 'static>,
461 map: impl FnOnce(Box<T>) -> PyErr,
462 ) -> Result<PyErr, Box<dyn Error + 'static>> {
463 err.downcast().map(map)
464 }
465
466 fn try_map_send_sync<T: Error + 'static>(
467 _py: Python,
468 err: Box<dyn Error + Send + Sync + 'static>,
469 map: impl FnOnce(Box<T>) -> PyErr,
470 ) -> Result<PyErr, Box<dyn Error + Send + Sync + 'static>> {
471 err.downcast().map(map)
472 }
473
474 fn try_map_ref<T: Error + 'static>(
475 _py: Python,
476 err: &(dyn Error + 'static),
477 map: impl FnOnce(&T) -> PyErr,
478 ) -> Option<PyErr> {
479 err.downcast_ref().map(map)
480 }
481}
482
483#[allow(clippy::missing_panics_doc)]
484#[must_use]
490pub fn err_with_location(py: Python, err: PyErr, file: &str, line: u32, column: u32) -> PyErr {
491 const RAISE: &str = "raise err";
492
493 static COMPILE: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
494 static EXEC: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
495
496 let _ = column;
497
498 #[allow(clippy::expect_used)] let compile = COMPILE
500 .import(py, "builtins", "compile")
501 .expect("Python does not provide a compile() function");
502 #[allow(clippy::expect_used)] let exec = EXEC
504 .import(py, "builtins", "exec")
505 .expect("Python does not provide an exec() function");
506
507 let mut code = String::with_capacity((line as usize) + RAISE.len());
508 for _ in 1..line {
509 code.push('\n');
510 }
511 code.push_str(RAISE);
512
513 #[allow(clippy::expect_used)] let code = compile
515 .call1((code, file, intern!(py, "exec")))
516 .expect("failed to compile PyErr location helper");
517 #[allow(clippy::expect_used)] let globals = [(intern!(py, "err"), err)]
519 .into_py_dict(py)
520 .expect("failed to create a dict(err=...)");
521
522 #[allow(clippy::expect_used)] let err = exec.call1((code, globals)).expect_err("raise must raise");
524 err
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530
531 #[test]
532 fn python_cause() {
533 Python::with_gil(|py| {
534 let err = py
535 .run(
536 &std::ffi::CString::new(
537 r#"
538try:
539 try:
540 raise Exception("source")
541 except Exception as err:
542 raise IndexError("middle") from err
543except Exception as err:
544 raise LookupError("top") from err
545"#,
546 )
547 .unwrap(),
548 None,
549 None,
550 )
551 .expect_err("raise must raise");
552
553 let err = PyErrChain::new(py, err);
554 assert_eq!(format!("{err}"), "LookupError: top");
555
556 let err = err.source().expect("must have source");
557 assert_eq!(format!("{err}"), "IndexError: middle");
558
559 let err = err.source().expect("must have source");
560 assert_eq!(format!("{err}"), "Exception: source");
561
562 assert!(err.source().is_none());
563 })
564 }
565
566 #[test]
567 fn rust_source() {
568 #[derive(Debug)]
569 struct MyErr {
570 msg: &'static str,
571 source: Option<Box<Self>>,
572 }
573
574 impl fmt::Display for MyErr {
575 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
576 fmt.write_str(self.msg)
577 }
578 }
579
580 impl Error for MyErr {
581 fn source(&self) -> Option<&(dyn Error + 'static)> {
582 match &self.source {
583 None => None,
584 Some(source) => Some(&**source as &dyn Error),
585 }
586 }
587 }
588
589 Python::with_gil(|py| {
590 let err = PyErrChain::new(
591 py,
592 MyErr {
593 msg: "top",
594 source: Some(Box::new(MyErr {
595 msg: "middle",
596 source: Some(Box::new(MyErr {
597 msg: "source",
598 source: None,
599 })),
600 })),
601 },
602 );
603
604 let source = err.source().expect("must have source");
605 let source = source.source().expect("must have source");
606 assert!(source.source().is_none());
607
608 let err = PyErr::from(err);
609 assert_eq!(format!("{err}"), "Exception: top");
610
611 let err = err.cause(py).expect("must have cause");
612 assert_eq!(format!("{err}"), "Exception: middle");
613
614 let err = err.cause(py).expect("must have cause");
615 assert_eq!(format!("{err}"), "Exception: source");
616
617 assert!(err.cause(py).is_none());
618 })
619 }
620
621 #[test]
622 fn err_location() {
623 Python::with_gil(|py| {
624 let err = err_with_location(py, PyException::new_err("oh no"), "foo.rs", 27, 15);
625
626 assert_eq!(format!("{err}"), "Exception: oh no");
628 assert_eq!(
629 err.traceback(py)
630 .expect("must have traceback")
631 .format()
632 .expect("traceback must be formattable"),
633 r#"Traceback (most recent call last):
634 File "foo.rs", line 27, in <module>
635"#,
636 );
637 assert!(err.cause(py).is_none());
638
639 let err = err_with_location(py, err, "bar.rs", 24, 18);
641
642 let top = PyException::new_err("oh yes");
644 top.set_cause(py, Some(err));
645 let err = err_with_location(py, top, "baz.rs", 41, 1);
646
647 assert_eq!(format!("{err}"), "Exception: oh yes");
649 assert_eq!(
650 err.traceback(py)
651 .expect("must have traceback")
652 .format()
653 .expect("traceback must be formattable"),
654 r#"Traceback (most recent call last):
655 File "baz.rs", line 41, in <module>
656"#,
657 );
658
659 let cause = err.cause(py).expect("must have a cause");
661
662 assert_eq!(format!("{cause}"), "Exception: oh no");
664 assert_eq!(
665 cause
666 .traceback(py)
667 .expect("must have traceback")
668 .format()
669 .expect("traceback must be formattable"),
670 r#"Traceback (most recent call last):
671 File "bar.rs", line 24, in <module>
672 File "foo.rs", line 27, in <module>
673"#,
674 );
675 assert!(cause.cause(py).is_none());
676 })
677 }
678
679 #[test]
680 fn anyhow() {
681 Python::with_gil(|py| {
682 let err = anyhow::anyhow!("source").context("middle").context("top");
683
684 let err = PyErrChain::new(py, err);
685 assert_eq!(format!("{err}"), "Exception: top");
686
687 let err = err.source().expect("must have source");
688 assert_eq!(format!("{err}"), "Exception: middle");
689
690 let err = err.source().expect("must have source");
691 assert_eq!(format!("{err}"), "Exception: source");
692
693 assert!(err.source().is_none());
694 })
695 }
696}