feat: accept `datetime.timedelta` instances as argument to `stop_after_delay` (#371)
Rather than just accepting seconds as a float, accept `timedelta` instances, like the `wait` methods.
diff --git a/releasenotes/notes/timedelta-for-stop-ef6bf71b88ce9988.yaml b/releasenotes/notes/timedelta-for-stop-ef6bf71b88ce9988.yaml
new file mode 100644
index 0000000..9f792f0
--- /dev/null
+++ b/releasenotes/notes/timedelta-for-stop-ef6bf71b88ce9988.yaml
@@ -0,0 +1,4 @@
+---
+features:
+ - |
+ - accept ``datetime.timedelta`` instances as argument to ``tenacity.stop.stop_after_delay``
diff --git a/tenacity/_utils.py b/tenacity/_utils.py
index d5c4c9d..f14ff32 100644
--- a/tenacity/_utils.py
+++ b/tenacity/_utils.py
@@ -16,6 +16,7 @@
import sys
import typing
+from datetime import timedelta
# sys.maxsize:
@@ -66,3 +67,10 @@
except AttributeError:
pass
return ".".join(segments)
+
+
+time_unit_type = typing.Union[int, float, timedelta]
+
+
+def to_seconds(time_unit: time_unit_type) -> float:
+ return float(time_unit.total_seconds() if isinstance(time_unit, timedelta) else time_unit)
diff --git a/tenacity/stop.py b/tenacity/stop.py
index 3507224..bb48c81 100644
--- a/tenacity/stop.py
+++ b/tenacity/stop.py
@@ -16,6 +16,8 @@
import abc
import typing
+from tenacity import _utils
+
if typing.TYPE_CHECKING:
import threading
@@ -89,8 +91,8 @@
class stop_after_delay(stop_base):
"""Stop when the time from the first attempt >= limit."""
- def __init__(self, max_delay: float) -> None:
- self.max_delay = max_delay
+ def __init__(self, max_delay: _utils.time_unit_type) -> None:
+ self.max_delay = _utils.to_seconds(max_delay)
def __call__(self, retry_state: "RetryCallState") -> bool:
return retry_state.seconds_since_start >= self.max_delay
diff --git a/tenacity/wait.py b/tenacity/wait.py
index 1d87672..01c94a4 100644
--- a/tenacity/wait.py
+++ b/tenacity/wait.py
@@ -17,19 +17,12 @@
import abc
import random
import typing
-from datetime import timedelta
from tenacity import _utils
if typing.TYPE_CHECKING:
from tenacity import RetryCallState
-wait_unit_type = typing.Union[int, float, timedelta]
-
-
-def to_seconds(wait_unit: wait_unit_type) -> float:
- return float(wait_unit.total_seconds() if isinstance(wait_unit, timedelta) else wait_unit)
-
class wait_base(abc.ABC):
"""Abstract base class for wait strategies."""
@@ -51,8 +44,8 @@
class wait_fixed(wait_base):
"""Wait strategy that waits a fixed amount of time between each retry."""
- def __init__(self, wait: wait_unit_type) -> None:
- self.wait_fixed = to_seconds(wait)
+ def __init__(self, wait: _utils.time_unit_type) -> None:
+ self.wait_fixed = _utils.to_seconds(wait)
def __call__(self, retry_state: "RetryCallState") -> float:
return self.wait_fixed
@@ -68,9 +61,9 @@
class wait_random(wait_base):
"""Wait strategy that waits a random amount of time between min/max."""
- def __init__(self, min: wait_unit_type = 0, max: wait_unit_type = 1) -> None: # noqa
- self.wait_random_min = to_seconds(min)
- self.wait_random_max = to_seconds(max)
+ def __init__(self, min: _utils.time_unit_type = 0, max: _utils.time_unit_type = 1) -> None: # noqa
+ self.wait_random_min = _utils.to_seconds(min)
+ self.wait_random_max = _utils.to_seconds(max)
def __call__(self, retry_state: "RetryCallState") -> float:
return self.wait_random_min + (random.random() * (self.wait_random_max - self.wait_random_min))
@@ -120,13 +113,13 @@
def __init__(
self,
- start: wait_unit_type = 0,
- increment: wait_unit_type = 100,
- max: wait_unit_type = _utils.MAX_WAIT, # noqa
+ start: _utils.time_unit_type = 0,
+ increment: _utils.time_unit_type = 100,
+ max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa
) -> None:
- self.start = to_seconds(start)
- self.increment = to_seconds(increment)
- self.max = to_seconds(max)
+ self.start = _utils.to_seconds(start)
+ self.increment = _utils.to_seconds(increment)
+ self.max = _utils.to_seconds(max)
def __call__(self, retry_state: "RetryCallState") -> float:
result = self.start + (self.increment * (retry_state.attempt_number - 1))
@@ -149,13 +142,13 @@
def __init__(
self,
multiplier: typing.Union[int, float] = 1,
- max: wait_unit_type = _utils.MAX_WAIT, # noqa
+ max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa
exp_base: typing.Union[int, float] = 2,
- min: wait_unit_type = 0, # noqa
+ min: _utils.time_unit_type = 0, # noqa
) -> None:
self.multiplier = multiplier
- self.min = to_seconds(min)
- self.max = to_seconds(max)
+ self.min = _utils.to_seconds(min)
+ self.max = _utils.to_seconds(max)
self.exp_base = exp_base
def __call__(self, retry_state: "RetryCallState") -> float:
diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py
index 2e5febd..82806a6 100644
--- a/tests/test_tenacity.py
+++ b/tests/test_tenacity.py
@@ -155,10 +155,12 @@
self.assertTrue(r.stop(make_retry_state(4, 6546)))
def test_stop_after_delay(self):
- r = Retrying(stop=tenacity.stop_after_delay(1))
- self.assertFalse(r.stop(make_retry_state(2, 0.999)))
- self.assertTrue(r.stop(make_retry_state(2, 1)))
- self.assertTrue(r.stop(make_retry_state(2, 1.001)))
+ for delay in (1, datetime.timedelta(seconds=1)):
+ with self.subTest():
+ r = Retrying(stop=tenacity.stop_after_delay(delay))
+ self.assertFalse(r.stop(make_retry_state(2, 0.999)))
+ self.assertTrue(r.stop(make_retry_state(2, 1)))
+ self.assertTrue(r.stop(make_retry_state(2, 1.001)))
def test_legacy_explicit_stop_type(self):
Retrying(stop="stop_after_attempt")