feat: Add retry_if_exception_cause_type (#362)
Add a new retry_base class called `retry_if_exception_cause_type` that
checks that the cause of the raised exception is of a certain type.
Co-authored-by: Guillaume RISBOURG <guillaume.risbourg@mergify.com>
diff --git a/releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml b/releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml
new file mode 100644
index 0000000..8b5a420
--- /dev/null
+++ b/releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml
@@ -0,0 +1,5 @@
+---
+features:
+ - |
+ Add a new `retry_base` class called `retry_if_exception_cause_type` that
+ checks, recursively, if any of the causes of the raised exception is of a certain type.
diff --git a/tenacity/__init__.py b/tenacity/__init__.py
index fd40376..008049a 100644
--- a/tenacity/__init__.py
+++ b/tenacity/__init__.py
@@ -33,6 +33,7 @@
from .retry import retry_any # noqa
from .retry import retry_if_exception # noqa
from .retry import retry_if_exception_type # noqa
+from .retry import retry_if_exception_cause_type # noqa
from .retry import retry_if_not_exception_type # noqa
from .retry import retry_if_not_result # noqa
from .retry import retry_if_result # noqa
diff --git a/tenacity/retry.py b/tenacity/retry.py
index dd27117..1305d3f 100644
--- a/tenacity/retry.py
+++ b/tenacity/retry.py
@@ -117,6 +117,33 @@
return self.predicate(retry_state.outcome.exception())
+class retry_if_exception_cause_type(retry_base):
+ """Retries if any of the causes of the raised exception is of one or more types.
+
+ The check on the type of the cause of the exception is done recursively (until finding
+ an exception in the chain that has no `__cause__`)
+ """
+
+ def __init__(
+ self,
+ exception_types: typing.Union[
+ typing.Type[BaseException],
+ typing.Tuple[typing.Type[BaseException], ...],
+ ] = Exception,
+ ) -> None:
+ self.exception_cause_types = exception_types
+
+ def __call__(self, retry_state: "RetryCallState") -> bool:
+ if retry_state.outcome.failed:
+ exc = retry_state.outcome.exception()
+ while exc is not None:
+ if isinstance(exc.__cause__, self.exception_cause_types):
+ return True
+ exc = exc.__cause__
+
+ return False
+
+
class retry_if_result(retry_base):
"""Retries if the result verifies a predicate."""
diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py
index b6f6bbb..2e5febd 100644
--- a/tests/test_tenacity.py
+++ b/tests/test_tenacity.py
@@ -676,6 +676,56 @@
return True
+class NoNameErrorCauseAfterCount:
+ """Holds counter state for invoking a method several times in a row."""
+
+ def __init__(self, count):
+ self.counter = 0
+ self.count = count
+
+ def go2(self):
+ raise NameError("Hi there, I'm a NameError")
+
+ def go(self):
+ """Raise an IOError with a NameError as cause until after count threshold has been crossed.
+
+ Then return True.
+ """
+ if self.counter < self.count:
+ self.counter += 1
+ try:
+ self.go2()
+ except NameError as e:
+ raise IOError() from e
+
+ return True
+
+
+class NoIOErrorCauseAfterCount:
+ """Holds counter state for invoking a method several times in a row."""
+
+ def __init__(self, count):
+ self.counter = 0
+ self.count = count
+
+ def go2(self):
+ raise IOError("Hi there, I'm an IOError")
+
+ def go(self):
+ """Raise a NameError with an IOError as cause until after count threshold has been crossed.
+
+ Then return True.
+ """
+ if self.counter < self.count:
+ self.counter += 1
+ try:
+ self.go2()
+ except IOError as e:
+ raise NameError() from e
+
+ return True
+
+
class NameErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""
@@ -783,6 +833,11 @@
return thing.go()
+@retry(retry=tenacity.retry_if_exception_cause_type(NameError))
+def _retryable_test_with_exception_cause_type(thing):
+ return thing.go()
+
+
@retry(retry=tenacity.retry_if_exception_type(IOError))
def _retryable_test_with_exception_type_io(thing):
return thing.go()
@@ -987,6 +1042,15 @@
s = _retryable_test_if_not_exception_message_message.retry.statistics
self.assertTrue(s["attempt_number"] == 1)
+ def test_retry_if_exception_cause_type(self):
+ self.assertTrue(_retryable_test_with_exception_cause_type(NoNameErrorCauseAfterCount(5)))
+
+ try:
+ _retryable_test_with_exception_cause_type(NoIOErrorCauseAfterCount(5))
+ self.fail("Expected exception without NameError as cause")
+ except NameError:
+ pass
+
def test_defaults(self):
self.assertTrue(_retryable_default(NoNameErrorAfterCount(5)))
self.assertTrue(_retryable_default_f(NoNameErrorAfterCount(5)))