blob: 871652998a7248da5d0ab0b4416c7d720967d19b [file] [log] [blame]
# mypy: disable-error-code="no-untyped-def,no-untyped-call"
# Copyright 2016 Étienne Bersac
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import inspect
import unittest
from functools import wraps
try:
import trio
except ImportError:
have_trio = False
else:
have_trio = True
import pytest
import tenacity
from tenacity import AsyncRetrying, RetryError
from tenacity import asyncio as tasyncio
from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt
from tenacity.wait import wait_fixed
from .test_tenacity import NoIOErrorAfterCount, current_time_ms
def asynctest(callable_):
@wraps(callable_)
def wrapper(*a, **kw):
loop = asyncio.get_event_loop()
return loop.run_until_complete(callable_(*a, **kw))
return wrapper
async def _async_function(thing):
await asyncio.sleep(0.00001)
return thing.go()
@retry
async def _retryable_coroutine(thing):
await asyncio.sleep(0.00001)
return thing.go()
@retry(stop=stop_after_attempt(2))
async def _retryable_coroutine_with_2_attempts(thing):
await asyncio.sleep(0.00001)
thing.go()
class TestAsyncio(unittest.TestCase):
@asynctest
async def test_retry(self):
thing = NoIOErrorAfterCount(5)
await _retryable_coroutine(thing)
assert thing.counter == thing.count
@asynctest
async def test_iscoroutinefunction(self):
assert asyncio.iscoroutinefunction(_retryable_coroutine)
assert inspect.iscoroutinefunction(_retryable_coroutine)
@asynctest
async def test_retry_using_async_retying(self):
thing = NoIOErrorAfterCount(5)
retrying = AsyncRetrying()
await retrying(_async_function, thing)
assert thing.counter == thing.count
@asynctest
async def test_stop_after_attempt(self):
thing = NoIOErrorAfterCount(2)
try:
await _retryable_coroutine_with_2_attempts(thing)
except RetryError:
assert thing.counter == 2
def test_repr(self):
repr(tasyncio.AsyncRetrying())
def test_retry_attributes(self):
assert hasattr(_retryable_coroutine, "retry")
assert hasattr(_retryable_coroutine, "retry_with")
def test_retry_preserves_argument_defaults(self):
async def function_with_defaults(a=1):
return a
async def function_with_kwdefaults(*, a=1):
return a
retrying = AsyncRetrying(
wait=tenacity.wait_fixed(0.01), stop=tenacity.stop_after_attempt(3)
)
wrapped_defaults_function = retrying.wraps(function_with_defaults)
wrapped_kwdefaults_function = retrying.wraps(function_with_kwdefaults)
self.assertEqual(
function_with_defaults.__defaults__, wrapped_defaults_function.__defaults__
)
self.assertEqual(
function_with_kwdefaults.__kwdefaults__,
wrapped_kwdefaults_function.__kwdefaults__,
)
@asynctest
async def test_attempt_number_is_correct_for_interleaved_coroutines(self):
attempts = []
def after(retry_state):
attempts.append((retry_state.args[0], retry_state.attempt_number))
thing1 = NoIOErrorAfterCount(3)
thing2 = NoIOErrorAfterCount(3)
await asyncio.gather(
_retryable_coroutine.retry_with(after=after)(thing1), # type: ignore[attr-defined]
_retryable_coroutine.retry_with(after=after)(thing2), # type: ignore[attr-defined]
)
# There's no waiting on retry, only a wait in the coroutine, so the
# executions should be interleaved.
even_thing_attempts = attempts[::2]
things, attempt_nos1 = zip(*even_thing_attempts)
assert len(set(things)) == 1
assert list(attempt_nos1) == [1, 2, 3]
odd_thing_attempts = attempts[1::2]
things, attempt_nos2 = zip(*odd_thing_attempts)
assert len(set(things)) == 1
assert list(attempt_nos2) == [1, 2, 3]
@unittest.skipIf(not have_trio, "trio not installed")
class TestTrio(unittest.TestCase):
def test_trio_basic(self):
thing = NoIOErrorAfterCount(5)
@retry
async def trio_function():
await trio.sleep(0.00001)
return thing.go()
trio.run(trio_function)
assert thing.counter == thing.count
class TestContextManager(unittest.TestCase):
@asynctest
async def test_do_max_attempts(self):
attempts = 0
retrying = tasyncio.AsyncRetrying(stop=stop_after_attempt(3))
try:
async for attempt in retrying:
with attempt:
attempts += 1
raise Exception
except RetryError:
pass
assert attempts == 3
@asynctest
async def test_reraise(self):
class CustomError(Exception):
pass
try:
async for attempt in tasyncio.AsyncRetrying(
stop=stop_after_attempt(1), reraise=True
):
with attempt:
raise CustomError()
except CustomError:
pass
else:
raise Exception
@asynctest
async def test_sleeps(self):
start = current_time_ms()
try:
async for attempt in tasyncio.AsyncRetrying(
stop=stop_after_attempt(1), wait=wait_fixed(1)
):
with attempt:
raise Exception()
except RetryError:
pass
t = current_time_ms() - start
self.assertLess(t, 1.1)
@asynctest
async def test_retry_with_result(self):
async def test():
attempts = 0
# mypy doesn't have great lambda support
def lt_3(x: float) -> bool:
return x < 3
async for attempt in tasyncio.AsyncRetrying(retry=retry_if_result(lt_3)):
with attempt:
attempts += 1
attempt.retry_state.set_result(attempts)
return attempts
result = await test()
self.assertEqual(3, result)
@asynctest
async def test_retry_with_async_result(self):
async def test():
attempts = 0
async def lt_3(x: float) -> bool:
return x < 3
async for attempt in tasyncio.AsyncRetrying(
retry=tasyncio.retry_if_result(lt_3)
):
with attempt:
attempts += 1
assert attempt.retry_state.outcome # help mypy
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(attempts)
return attempts
result = await test()
self.assertEqual(3, result)
@asynctest
async def test_retry_with_async_exc(self):
async def test():
attempts = 0
class CustomException(Exception):
pass
async def is_exc(e: BaseException) -> bool:
return isinstance(e, CustomException)
async for attempt in tasyncio.AsyncRetrying(
retry=tasyncio.retry_if_exception(is_exc)
):
with attempt:
attempts += 1
if attempts < 3:
raise CustomException()
assert attempt.retry_state.outcome # help mypy
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(attempts)
return attempts
result = await test()
self.assertEqual(3, result)
@asynctest
async def test_retry_with_async_result_or(self):
async def test():
attempts = 0
async def lt_3(x: float) -> bool:
return x < 3
class CustomException(Exception):
pass
def is_exc(e: BaseException) -> bool:
return isinstance(e, CustomException)
retry_strategy = tasyncio.retry_if_result(lt_3) | retry_if_exception(is_exc)
async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy):
with attempt:
attempts += 1
if 2 < attempts < 4:
raise CustomException()
assert attempt.retry_state.outcome # help mypy
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(attempts)
return attempts
result = await test()
self.assertEqual(4, result)
@asynctest
async def test_retry_with_async_result_ror(self):
async def test():
attempts = 0
def lt_3(x: float) -> bool:
return x < 3
class CustomException(Exception):
pass
async def is_exc(e: BaseException) -> bool:
return isinstance(e, CustomException)
retry_strategy = retry_if_result(lt_3) | tasyncio.retry_if_exception(is_exc)
async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy):
with attempt:
attempts += 1
if 2 < attempts < 4:
raise CustomException()
assert attempt.retry_state.outcome # help mypy
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(attempts)
return attempts
result = await test()
self.assertEqual(4, result)
@asynctest
async def test_retry_with_async_result_and(self):
async def test():
attempts = 0
async def lt_3(x: float) -> bool:
return x < 3
def gt_0(x: float) -> bool:
return x > 0
retry_strategy = tasyncio.retry_if_result(lt_3) & retry_if_result(gt_0)
async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy):
with attempt:
attempts += 1
attempt.retry_state.set_result(attempts)
return attempts
result = await test()
self.assertEqual(3, result)
@asynctest
async def test_retry_with_async_result_rand(self):
async def test():
attempts = 0
async def lt_3(x: float) -> bool:
return x < 3
def gt_0(x: float) -> bool:
return x > 0
retry_strategy = retry_if_result(gt_0) & tasyncio.retry_if_result(lt_3)
async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy):
with attempt:
attempts += 1
attempt.retry_state.set_result(attempts)
return attempts
result = await test()
self.assertEqual(3, result)
@asynctest
async def test_async_retying_iterator(self):
thing = NoIOErrorAfterCount(5)
with pytest.raises(TypeError):
for attempts in AsyncRetrying():
with attempts:
await _async_function(thing)
# make sure mypy accepts passing an async sleep function
# https://github.com/jd/tenacity/issues/399
async def my_async_sleep(x: float) -> None:
await asyncio.sleep(x)
@retry(sleep=my_async_sleep)
async def foo():
pass
if __name__ == "__main__":
unittest.main()