1use std::{borrow::Cow, error::Error, fmt, io};
21
22use pyo3::{exceptions::PyException, intern, prelude::*, sync::PyOnceLock, types::IntoPyDict};
23
24pub struct PyErrChain {
35    err: PyErr,
36    cause: Option<Box<Self>>,
37}
38
39impl PyErrChain {
40    #[must_use]
56    #[inline]
57    pub fn new<T: Into<Box<dyn Error + 'static>>>(py: Python, err: T) -> Self {
58        Self::from_pyerr(py, Self::pyerr_from_err(py, err))
59    }
60
61    #[must_use]
77    #[inline]
78    pub fn new_with_translator<
79        E: Into<Box<dyn Error + 'static>>,
80        T: AnyErrorToPyErr,
81        M: MapErrorToPyErr,
82    >(
83        py: Python,
84        err: E,
85    ) -> Self {
86        Self::from_pyerr(py, Self::pyerr_from_err_with_translator::<E, T, M>(py, err))
87    }
88
89    #[must_use]
103    #[inline]
104    pub fn pyerr_from_err<T: Into<Box<dyn Error + 'static>>>(py: Python, err: T) -> PyErr {
105        Self::pyerr_from_err_with_translator::<T, ErrorNoPyErr, DowncastToPyErr>(py, err)
106    }
107
108    #[must_use]
121    pub fn pyerr_from_err_with_translator<
122        E: Into<Box<dyn Error + 'static>>,
123        T: AnyErrorToPyErr,
124        M: MapErrorToPyErr,
125    >(
126        py: Python,
127        err: E,
128    ) -> PyErr {
129        let err: Box<dyn Error + 'static> = err.into();
130
131        let err = match M::try_map(py, err, |err: Box<Self>| err.into_pyerr()) {
132            Ok(err) => return err,
133            Err(err) => err,
134        };
135        let err = match M::try_map(py, err, |err: Box<PyErr>| *err) {
136            Ok(err) => return err,
137            Err(err) => err,
138        };
139
140        let mut chain = Vec::new();
141
142        let mut source = err.source();
143        let mut cause = None;
144
145        while let Some(err) = source.take() {
146            if let Some(err) = M::try_map_ref(py, err, |err: &Self| err.as_pyerr().clone_ref(py)) {
147                cause = err.cause(py);
148                chain.push(err);
149                break;
150            }
151            if let Some(err) = M::try_map_ref(py, err, |err: &PyErr| err.clone_ref(py)) {
152                cause = err.cause(py);
153                chain.push(err);
154                break;
155            }
156
157            source = err.source();
158
159            #[allow(clippy::option_if_let_else)]
160            chain.push(match T::try_from_err_ref::<M>(py, err) {
161                Some(err) => err,
162                None => PyException::new_err(format!("{err}")),
163            });
164        }
165
166        while let Some(err) = chain.pop() {
167            err.set_cause(py, cause.take());
168            cause = Some(err);
169        }
170
171        let err = match T::try_from_err::<M>(py, err) {
172            Ok(err) => err,
173            Err(err) => PyException::new_err(format!("{err}")),
174        };
175        err.set_cause(py, cause);
176
177        err
178    }
179
180    #[must_use]
183    pub fn from_pyerr(py: Python, err: PyErr) -> Self {
184        let mut chain = Vec::new();
185
186        let mut cause = err.cause(py);
187
188        while let Some(err) = cause.take() {
189            cause = err.cause(py);
190            chain.push(Self { err, cause: None });
191        }
192
193        let mut cause = None;
194
195        while let Some(mut err) = chain.pop() {
196            err.cause = cause.take();
197            cause = Some(Box::new(err));
198        }
199
200        Self { err, cause }
201    }
202
203    #[must_use]
205    pub fn into_pyerr(self) -> PyErr {
206        self.err
207    }
208
209    #[must_use]
215    pub const fn as_pyerr(&self) -> &PyErr {
216        &self.err
217    }
218
219    #[must_use]
225    pub fn cause(&self) -> Option<&PyErr> {
226        self.cause.as_deref().map(Self::as_pyerr)
227    }
228
229    #[must_use]
237    pub fn clone_ref(&self, py: Python) -> Self {
238        Self {
239            err: self.err.clone_ref(py),
240            cause: self
241                .cause
242                .as_ref()
243                .map(|cause| Box::new(cause.clone_ref(py))),
244        }
245    }
246}
247
248impl fmt::Debug for PyErrChain {
249    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
250        Python::attach(|py| {
251            let traceback = self.err.traceback(py).map(|tb| {
252                tb.format()
253                    .map_or(Cow::Borrowed("<traceback str() failed>"), |tb| {
254                        Cow::Owned(tb)
255                    })
256            });
257
258            fmt.debug_struct("PyErrChain")
259                .field("type", &self.err.get_type(py))
260                .field("value", self.err.value(py))
261                .field("traceback", &traceback)
262                .field("cause", &self.cause)
263                .finish()
264        })
265    }
266}
267
268impl fmt::Display for PyErrChain {
269    #[inline]
270    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
271        fmt::Display::fmt(&self.err, fmt)
272    }
273}
274
275impl Error for PyErrChain {
276    fn source(&self) -> Option<&(dyn Error + 'static)> {
277        self.cause.as_deref().map(|cause| cause as &dyn Error)
278    }
279}
280
281impl From<PyErr> for PyErrChain {
282    fn from(err: PyErr) -> Self {
283        Python::attach(|py| Self::from_pyerr(py, err))
284    }
285}
286
287impl From<PyErrChain> for PyErr {
288    fn from(err: PyErrChain) -> Self {
289        err.into_pyerr()
290    }
291}
292
293pub trait AnyErrorToPyErr {
299    fn try_from_err<T: MapErrorToPyErr>(
310        py: Python,
311        err: Box<dyn Error + 'static>,
312    ) -> Result<PyErr, Box<dyn Error + 'static>>;
313
314    fn try_from_err_ref<T: MapErrorToPyErr>(
323        py: Python,
324        err: &(dyn Error + 'static),
325    ) -> Option<PyErr>;
326}
327
328pub trait MapErrorToPyErr {
334    fn try_map<T: Error + 'static>(
344        py: Python,
345        err: Box<dyn Error + 'static>,
346        map: impl FnOnce(Box<T>) -> PyErr,
347    ) -> Result<PyErr, Box<dyn Error + 'static>>;
348
349    fn try_map_send_sync<T: Error + 'static>(
359        py: Python,
360        err: Box<dyn Error + Send + Sync + 'static>,
361        map: impl FnOnce(Box<T>) -> PyErr,
362    ) -> Result<PyErr, Box<dyn Error + Send + Sync + 'static>>;
363
364    fn try_map_ref<T: Error + 'static>(
370        py: Python,
371        err: &(dyn Error + 'static),
372        map: impl FnOnce(&T) -> PyErr,
373    ) -> Option<PyErr>;
374}
375
376pub struct ErrorNoPyErr;
379
380impl AnyErrorToPyErr for ErrorNoPyErr {
381    #[inline]
382    fn try_from_err<T: MapErrorToPyErr>(
383        _py: Python,
384        err: Box<dyn Error + 'static>,
385    ) -> Result<PyErr, Box<dyn Error + 'static>> {
386        Err(err)
387    }
388
389    #[inline]
390    fn try_from_err_ref<T: MapErrorToPyErr>(
391        _py: Python,
392        _err: &(dyn Error + 'static),
393    ) -> Option<PyErr> {
394        None
395    }
396}
397
398pub struct IoErrorToPyErr;
400
401impl AnyErrorToPyErr for IoErrorToPyErr {
402    fn try_from_err<T: MapErrorToPyErr>(
403        py: Python,
404        err: Box<dyn Error + 'static>,
405    ) -> Result<PyErr, Box<dyn Error + 'static>> {
406        T::try_map(py, err, |err: Box<io::Error>| {
407            let kind = err.kind();
408
409            if err.get_ref().is_some() {
410                #[allow(clippy::unwrap_used)] let err = err.into_inner().unwrap();
412
413                let err = match T::try_map_send_sync(py, err, |err: Box<PyErr>| *err) {
414                    Ok(err) => return err,
415                    Err(err) => err,
416                };
417
418                let err =
419                    match T::try_map_send_sync(py, err, |err: Box<PyErrChain>| err.into_pyerr()) {
420                        Ok(err) => return err,
421                        Err(err) => err,
422                    };
423
424                return PyErr::from(io::Error::new(kind, err));
425            }
426
427            PyErr::from(*err)
428        })
429    }
430
431    fn try_from_err_ref<T: MapErrorToPyErr>(
432        py: Python,
433        err: &(dyn Error + 'static),
434    ) -> Option<PyErr> {
435        T::try_map_ref(py, err, |err: &io::Error| {
436            if let Some(err) = err.get_ref() {
437                if let Some(err) = T::try_map_ref(py, err, |err: &PyErr| err.clone_ref(py)) {
438                    return err;
439                }
440
441                if let Some(err) =
442                    T::try_map_ref(py, err, |err: &PyErrChain| err.as_pyerr().clone_ref(py))
443                {
444                    return err;
445                }
446            }
447
448            PyErr::from(io::Error::new(err.kind(), format!("{err}")))
449        })
450    }
451}
452
453pub struct DowncastToPyErr;
456
457impl MapErrorToPyErr for DowncastToPyErr {
458    fn try_map<T: Error + 'static>(
459        _py: Python,
460        err: Box<dyn Error + 'static>,
461        map: impl FnOnce(Box<T>) -> PyErr,
462    ) -> Result<PyErr, Box<dyn Error + 'static>> {
463        err.downcast().map(map)
464    }
465
466    fn try_map_send_sync<T: Error + 'static>(
467        _py: Python,
468        err: Box<dyn Error + Send + Sync + 'static>,
469        map: impl FnOnce(Box<T>) -> PyErr,
470    ) -> Result<PyErr, Box<dyn Error + Send + Sync + 'static>> {
471        err.downcast().map(map)
472    }
473
474    fn try_map_ref<T: Error + 'static>(
475        _py: Python,
476        err: &(dyn Error + 'static),
477        map: impl FnOnce(&T) -> PyErr,
478    ) -> Option<PyErr> {
479        err.downcast_ref().map(map)
480    }
481}
482
483#[allow(clippy::missing_panics_doc)]
484#[must_use]
490pub fn err_with_location(py: Python, err: PyErr, file: &str, line: u32, column: u32) -> PyErr {
491    const RAISE: &str = "raise err";
492
493    static COMPILE: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
494    static EXEC: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
495
496    let _ = column;
497
498    #[allow(clippy::expect_used)] let compile = COMPILE
500        .import(py, "builtins", "compile")
501        .expect("Python does not provide a compile() function");
502    #[allow(clippy::expect_used)] let exec = EXEC
504        .import(py, "builtins", "exec")
505        .expect("Python does not provide an exec() function");
506
507    let mut code = String::with_capacity((line as usize) + RAISE.len());
508    for _ in 1..line {
509        code.push('\n');
510    }
511    code.push_str(RAISE);
512
513    #[allow(clippy::expect_used)] let code = compile
515        .call1((code, file, intern!(py, "exec")))
516        .expect("failed to compile PyErr location helper");
517    #[allow(clippy::expect_used)] let globals = [(intern!(py, "err"), err)]
519        .into_py_dict(py)
520        .expect("failed to create a dict(err=...)");
521
522    #[allow(clippy::expect_used)] let err = exec.call1((code, globals)).expect_err("raise must raise");
524    err
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn python_cause() {
533        Python::attach(|py| {
534            let err = py
535                .run(
536                    &std::ffi::CString::new(
537                        r#"
538try:
539    try:
540        raise Exception("source")
541    except Exception as err:
542        raise IndexError("middle") from err
543except Exception as err:
544    raise LookupError("top") from err
545"#,
546                    )
547                    .unwrap(),
548                    None,
549                    None,
550                )
551                .expect_err("raise must raise");
552
553            let err = PyErrChain::new(py, err);
554            assert_eq!(format!("{err}"), "LookupError: top");
555
556            let err = err.source().expect("must have source");
557            assert_eq!(format!("{err}"), "IndexError: middle");
558
559            let err = err.source().expect("must have source");
560            assert_eq!(format!("{err}"), "Exception: source");
561
562            assert!(err.source().is_none());
563        })
564    }
565
566    #[test]
567    fn rust_source() {
568        #[derive(Debug)]
569        struct MyErr {
570            msg: &'static str,
571            source: Option<Box<Self>>,
572        }
573
574        impl fmt::Display for MyErr {
575            fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
576                fmt.write_str(self.msg)
577            }
578        }
579
580        impl Error for MyErr {
581            fn source(&self) -> Option<&(dyn Error + 'static)> {
582                match &self.source {
583                    None => None,
584                    Some(source) => Some(&**source as &dyn Error),
585                }
586            }
587        }
588
589        Python::attach(|py| {
590            let err = PyErrChain::new(
591                py,
592                MyErr {
593                    msg: "top",
594                    source: Some(Box::new(MyErr {
595                        msg: "middle",
596                        source: Some(Box::new(MyErr {
597                            msg: "source",
598                            source: None,
599                        })),
600                    })),
601                },
602            );
603
604            let source = err.source().expect("must have source");
605            let source = source.source().expect("must have source");
606            assert!(source.source().is_none());
607
608            let err = PyErr::from(err);
609            assert_eq!(format!("{err}"), "Exception: top");
610
611            let err = err.cause(py).expect("must have cause");
612            assert_eq!(format!("{err}"), "Exception: middle");
613
614            let err = err.cause(py).expect("must have cause");
615            assert_eq!(format!("{err}"), "Exception: source");
616
617            assert!(err.cause(py).is_none());
618        })
619    }
620
621    #[test]
622    fn err_location() {
623        Python::attach(|py| {
624            let err = err_with_location(py, PyException::new_err("oh no"), "foo.rs", 27, 15);
625
626            assert_eq!(format!("{err}"), "Exception: oh no");
628            assert_eq!(
629                err.traceback(py)
630                    .expect("must have traceback")
631                    .format()
632                    .expect("traceback must be formattable"),
633                r#"Traceback (most recent call last):
634  File "foo.rs", line 27, in <module>
635"#,
636            );
637            assert!(err.cause(py).is_none());
638
639            let err = err_with_location(py, err, "bar.rs", 24, 18);
641
642            let top = PyException::new_err("oh yes");
644            top.set_cause(py, Some(err));
645            let err = err_with_location(py, top, "baz.rs", 41, 1);
646
647            assert_eq!(format!("{err}"), "Exception: oh yes");
649            assert_eq!(
650                err.traceback(py)
651                    .expect("must have traceback")
652                    .format()
653                    .expect("traceback must be formattable"),
654                r#"Traceback (most recent call last):
655  File "baz.rs", line 41, in <module>
656"#,
657            );
658
659            let cause = err.cause(py).expect("must have a cause");
661
662            assert_eq!(format!("{cause}"), "Exception: oh no");
664            assert_eq!(
665                cause
666                    .traceback(py)
667                    .expect("must have traceback")
668                    .format()
669                    .expect("traceback must be formattable"),
670                r#"Traceback (most recent call last):
671  File "bar.rs", line 24, in <module>
672  File "foo.rs", line 27, in <module>
673"#,
674            );
675            assert!(cause.cause(py).is_none());
676        })
677    }
678
679    #[test]
680    fn anyhow() {
681        Python::attach(|py| {
682            let err = anyhow::anyhow!("source").context("middle").context("top");
683
684            let err = PyErrChain::new(py, err);
685            assert_eq!(format!("{err}"), "Exception: top");
686
687            let err = err.source().expect("must have source");
688            assert_eq!(format!("{err}"), "Exception: middle");
689
690            let err = err.source().expect("must have source");
691            assert_eq!(format!("{err}"), "Exception: source");
692
693            assert!(err.source().is_none());
694        })
695    }
696}