feat: Add retry_if_exception_cause_type (#362)

Add a new retry_base class called `retry_if_exception_cause_type` that
checks that the cause of the raised exception is of a certain type.

Co-authored-by: Guillaume RISBOURG <guillaume.risbourg@mergify.com>
diff --git a/releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml b/releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml
new file mode 100644
index 0000000..8b5a420
--- /dev/null
+++ b/releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml
@@ -0,0 +1,5 @@
+---
+features:
+  - |
+    Add a new `retry_base` class called `retry_if_exception_cause_type` that
+    checks, recursively, if any of the causes of the raised exception is of a certain type.
diff --git a/tenacity/__init__.py b/tenacity/__init__.py
index fd40376..008049a 100644
--- a/tenacity/__init__.py
+++ b/tenacity/__init__.py
@@ -33,6 +33,7 @@
 from .retry import retry_any  # noqa
 from .retry import retry_if_exception  # noqa
 from .retry import retry_if_exception_type  # noqa
+from .retry import retry_if_exception_cause_type  # noqa
 from .retry import retry_if_not_exception_type  # noqa
 from .retry import retry_if_not_result  # noqa
 from .retry import retry_if_result  # noqa
diff --git a/tenacity/retry.py b/tenacity/retry.py
index dd27117..1305d3f 100644
--- a/tenacity/retry.py
+++ b/tenacity/retry.py
@@ -117,6 +117,33 @@
         return self.predicate(retry_state.outcome.exception())
 
 
+class retry_if_exception_cause_type(retry_base):
+    """Retries if any of the causes of the raised exception is of one or more types.
+
+    The check on the type of the cause of the exception is done recursively (until finding
+    an exception in the chain that has no `__cause__`)
+    """
+
+    def __init__(
+        self,
+        exception_types: typing.Union[
+            typing.Type[BaseException],
+            typing.Tuple[typing.Type[BaseException], ...],
+        ] = Exception,
+    ) -> None:
+        self.exception_cause_types = exception_types
+
+    def __call__(self, retry_state: "RetryCallState") -> bool:
+        if retry_state.outcome.failed:
+            exc = retry_state.outcome.exception()
+            while exc is not None:
+                if isinstance(exc.__cause__, self.exception_cause_types):
+                    return True
+                exc = exc.__cause__
+
+        return False
+
+
 class retry_if_result(retry_base):
     """Retries if the result verifies a predicate."""
 
diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py
index b6f6bbb..2e5febd 100644
--- a/tests/test_tenacity.py
+++ b/tests/test_tenacity.py
@@ -676,6 +676,56 @@
         return True
 
 
+class NoNameErrorCauseAfterCount:
+    """Holds counter state for invoking a method several times in a row."""
+
+    def __init__(self, count):
+        self.counter = 0
+        self.count = count
+
+    def go2(self):
+        raise NameError("Hi there, I'm a NameError")
+
+    def go(self):
+        """Raise an IOError with a NameError as cause until after count threshold has been crossed.
+
+        Then return True.
+        """
+        if self.counter < self.count:
+            self.counter += 1
+            try:
+                self.go2()
+            except NameError as e:
+                raise IOError() from e
+
+        return True
+
+
+class NoIOErrorCauseAfterCount:
+    """Holds counter state for invoking a method several times in a row."""
+
+    def __init__(self, count):
+        self.counter = 0
+        self.count = count
+
+    def go2(self):
+        raise IOError("Hi there, I'm an IOError")
+
+    def go(self):
+        """Raise a NameError with an IOError as cause until after count threshold has been crossed.
+
+        Then return True.
+        """
+        if self.counter < self.count:
+            self.counter += 1
+            try:
+                self.go2()
+            except IOError as e:
+                raise NameError() from e
+
+        return True
+
+
 class NameErrorUntilCount:
     """Holds counter state for invoking a method several times in a row."""
 
@@ -783,6 +833,11 @@
     return thing.go()
 
 
+@retry(retry=tenacity.retry_if_exception_cause_type(NameError))
+def _retryable_test_with_exception_cause_type(thing):
+    return thing.go()
+
+
 @retry(retry=tenacity.retry_if_exception_type(IOError))
 def _retryable_test_with_exception_type_io(thing):
     return thing.go()
@@ -987,6 +1042,15 @@
             s = _retryable_test_if_not_exception_message_message.retry.statistics
             self.assertTrue(s["attempt_number"] == 1)
 
+    def test_retry_if_exception_cause_type(self):
+        self.assertTrue(_retryable_test_with_exception_cause_type(NoNameErrorCauseAfterCount(5)))
+
+        try:
+            _retryable_test_with_exception_cause_type(NoIOErrorCauseAfterCount(5))
+            self.fail("Expected exception without NameError as cause")
+        except NameError:
+            pass
+
     def test_defaults(self):
         self.assertTrue(_retryable_default(NoNameErrorAfterCount(5)))
         self.assertTrue(_retryable_default_f(NoNameErrorAfterCount(5)))