Support `datetime.timedelta` as a valid wait unit type (#342)
* Support `datetime.timedelta` as a valid wait unit type
Signed-off-by: Noam Bloom <nbloom5@bloomberg.net>
* Add datetime.timedelta support tests
Signed-off-by: Noam Bloom <nbloom5@bloomberg.net>
Co-authored-by: Noam Bloom <nbloom5@bloomberg.net>
Co-authored-by: Julien Danjou <julien@danjou.info>
diff --git a/releasenotes/notes/support-timedelta-wait-unit-type-5ba1e9fc0fe45523.yaml b/releasenotes/notes/support-timedelta-wait-unit-type-5ba1e9fc0fe45523.yaml
new file mode 100644
index 0000000..bc7e62d
--- /dev/null
+++ b/releasenotes/notes/support-timedelta-wait-unit-type-5ba1e9fc0fe45523.yaml
@@ -0,0 +1,3 @@
+---
+features:
+ - Add ``datetime.timedelta`` as accepted wait unit type.
diff --git a/tenacity/wait.py b/tenacity/wait.py
index 289705c..1d87672 100644
--- a/tenacity/wait.py
+++ b/tenacity/wait.py
@@ -17,12 +17,19 @@
import abc
import random
import typing
+from datetime import timedelta
from tenacity import _utils
if typing.TYPE_CHECKING:
from tenacity import RetryCallState
+wait_unit_type = typing.Union[int, float, timedelta]
+
+
+def to_seconds(wait_unit: wait_unit_type) -> float:
+ return float(wait_unit.total_seconds() if isinstance(wait_unit, timedelta) else wait_unit)
+
class wait_base(abc.ABC):
"""Abstract base class for wait strategies."""
@@ -44,8 +51,8 @@
class wait_fixed(wait_base):
"""Wait strategy that waits a fixed amount of time between each retry."""
- def __init__(self, wait: float) -> None:
- self.wait_fixed = wait
+ def __init__(self, wait: wait_unit_type) -> None:
+ self.wait_fixed = to_seconds(wait)
def __call__(self, retry_state: "RetryCallState") -> float:
return self.wait_fixed
@@ -61,9 +68,9 @@
class wait_random(wait_base):
"""Wait strategy that waits a random amount of time between min/max."""
- def __init__(self, min: typing.Union[int, float] = 0, max: typing.Union[int, float] = 1) -> None: # noqa
- self.wait_random_min = min
- self.wait_random_max = max
+ def __init__(self, min: wait_unit_type = 0, max: wait_unit_type = 1) -> None: # noqa
+ self.wait_random_min = to_seconds(min)
+ self.wait_random_max = to_seconds(max)
def __call__(self, retry_state: "RetryCallState") -> float:
return self.wait_random_min + (random.random() * (self.wait_random_max - self.wait_random_min))
@@ -113,13 +120,13 @@
def __init__(
self,
- start: typing.Union[int, float] = 0,
- increment: typing.Union[int, float] = 100,
- max: typing.Union[int, float] = _utils.MAX_WAIT, # noqa
+ start: wait_unit_type = 0,
+ increment: wait_unit_type = 100,
+ max: wait_unit_type = _utils.MAX_WAIT, # noqa
) -> None:
- self.start = start
- self.increment = increment
- self.max = max
+ self.start = to_seconds(start)
+ self.increment = to_seconds(increment)
+ self.max = to_seconds(max)
def __call__(self, retry_state: "RetryCallState") -> float:
result = self.start + (self.increment * (retry_state.attempt_number - 1))
@@ -142,13 +149,13 @@
def __init__(
self,
multiplier: typing.Union[int, float] = 1,
- max: typing.Union[int, float] = _utils.MAX_WAIT, # noqa
+ max: wait_unit_type = _utils.MAX_WAIT, # noqa
exp_base: typing.Union[int, float] = 2,
- min: typing.Union[int, float] = 0, # noqa
+ min: wait_unit_type = 0, # noqa
) -> None:
self.multiplier = multiplier
- self.min = min
- self.max = max
+ self.min = to_seconds(min)
+ self.max = to_seconds(max)
self.exp_base = exp_base
def __call__(self, retry_state: "RetryCallState") -> float:
diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py
index d9a4858..b6f6bbb 100644
--- a/tests/test_tenacity.py
+++ b/tests/test_tenacity.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import datetime
import logging
import re
import sys
@@ -29,7 +30,6 @@
import tenacity
from tenacity import RetryCallState, RetryError, Retrying, retry
-
_unset = object()
@@ -180,28 +180,34 @@
self.assertEqual(0, r.wait(make_retry_state(18, 9879)))
def test_fixed_sleep(self):
- r = Retrying(wait=tenacity.wait_fixed(1))
- self.assertEqual(1, r.wait(make_retry_state(12, 6546)))
+ for wait in (1, datetime.timedelta(seconds=1)):
+ with self.subTest():
+ r = Retrying(wait=tenacity.wait_fixed(wait))
+ self.assertEqual(1, r.wait(make_retry_state(12, 6546)))
def test_incrementing_sleep(self):
- r = Retrying(wait=tenacity.wait_incrementing(start=500, increment=100))
- self.assertEqual(500, r.wait(make_retry_state(1, 6546)))
- self.assertEqual(600, r.wait(make_retry_state(2, 6546)))
- self.assertEqual(700, r.wait(make_retry_state(3, 6546)))
+ for start, increment in ((500, 100), (datetime.timedelta(seconds=500), datetime.timedelta(seconds=100))):
+ with self.subTest():
+ r = Retrying(wait=tenacity.wait_incrementing(start=start, increment=increment))
+ self.assertEqual(500, r.wait(make_retry_state(1, 6546)))
+ self.assertEqual(600, r.wait(make_retry_state(2, 6546)))
+ self.assertEqual(700, r.wait(make_retry_state(3, 6546)))
def test_random_sleep(self):
- r = Retrying(wait=tenacity.wait_random(min=1, max=20))
- times = set()
- for x in range(1000):
- times.add(r.wait(make_retry_state(1, 6546)))
+ for min_, max_ in ((1, 20), (datetime.timedelta(seconds=1), datetime.timedelta(seconds=20))):
+ with self.subTest():
+ r = Retrying(wait=tenacity.wait_random(min=min_, max=max_))
+ times = set()
+ for _ in range(1000):
+ times.add(r.wait(make_retry_state(1, 6546)))
- # this is kind of non-deterministic...
- self.assertTrue(len(times) > 1)
- for t in times:
- self.assertTrue(t >= 1)
- self.assertTrue(t < 20)
+ # this is kind of non-deterministic...
+ self.assertTrue(len(times) > 1)
+ for t in times:
+ self.assertTrue(t >= 1)
+ self.assertTrue(t < 20)
- def test_random_sleep_without_min(self):
+ def test_random_sleep_withoutmin_(self):
r = Retrying(wait=tenacity.wait_random(max=2))
times = set()
times.add(r.wait(make_retry_state(1, 6546)))
@@ -274,18 +280,20 @@
self.assertEqual(r.wait(make_retry_state(8, 0)), 256)
self.assertEqual(r.wait(make_retry_state(20, 0)), 1048576)
- def test_exponential_with_min_wait_and_max_wait(self):
- r = Retrying(wait=tenacity.wait_exponential(min=10, max=100))
- self.assertEqual(r.wait(make_retry_state(1, 0)), 10)
- self.assertEqual(r.wait(make_retry_state(2, 0)), 10)
- self.assertEqual(r.wait(make_retry_state(3, 0)), 10)
- self.assertEqual(r.wait(make_retry_state(4, 0)), 10)
- self.assertEqual(r.wait(make_retry_state(5, 0)), 16)
- self.assertEqual(r.wait(make_retry_state(6, 0)), 32)
- self.assertEqual(r.wait(make_retry_state(7, 0)), 64)
- self.assertEqual(r.wait(make_retry_state(8, 0)), 100)
- self.assertEqual(r.wait(make_retry_state(9, 0)), 100)
- self.assertEqual(r.wait(make_retry_state(20, 0)), 100)
+ def test_exponential_with_min_wait_andmax__wait(self):
+ for min_, max_ in ((10, 100), (datetime.timedelta(seconds=10), datetime.timedelta(seconds=100))):
+ with self.subTest():
+ r = Retrying(wait=tenacity.wait_exponential(min=min_, max=max_))
+ self.assertEqual(r.wait(make_retry_state(1, 0)), 10)
+ self.assertEqual(r.wait(make_retry_state(2, 0)), 10)
+ self.assertEqual(r.wait(make_retry_state(3, 0)), 10)
+ self.assertEqual(r.wait(make_retry_state(4, 0)), 10)
+ self.assertEqual(r.wait(make_retry_state(5, 0)), 16)
+ self.assertEqual(r.wait(make_retry_state(6, 0)), 32)
+ self.assertEqual(r.wait(make_retry_state(7, 0)), 64)
+ self.assertEqual(r.wait(make_retry_state(8, 0)), 100)
+ self.assertEqual(r.wait(make_retry_state(9, 0)), 100)
+ self.assertEqual(r.wait(make_retry_state(20, 0)), 100)
def test_legacy_explicit_wait_type(self):
Retrying(wait="exponential_sleep")
@@ -335,7 +343,7 @@
)
)
# Test it a few time since it's random
- for i in range(1000):
+ for _ in range(1000):
w = r.wait(make_retry_state(1, 5))
self.assertLess(w, 9)
self.assertGreaterEqual(w, 6)