Add type annotations to cover all code. (#315)
`tenacity` is marked as "typed" package,
which means all public APIs should be annotated
diff --git a/releasenotes/notes/annotate_code-197b93130df14042.yaml b/releasenotes/notes/annotate_code-197b93130df14042.yaml
new file mode 100644
index 0000000..faf4163
--- /dev/null
+++ b/releasenotes/notes/annotate_code-197b93130df14042.yaml
@@ -0,0 +1,3 @@
+---
+other:
+ - Add type annotations to cover all public API.
diff --git a/tenacity/__init__.py b/tenacity/__init__.py
index e258d8a..bc52383 100644
--- a/tenacity/__init__.py
+++ b/tenacity/__init__.py
@@ -89,6 +89,7 @@
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable)
+_RetValT = t.TypeVar("_RetValT")
@t.overload
@@ -388,14 +389,14 @@
break
@abstractmethod
- def __call__(self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any) -> t.Any:
+ def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT:
pass
class Retrying(BaseRetrying):
"""Retrying controller."""
- def __call__(self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any) -> t.Any:
+ def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT:
self.begin()
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py
index 7b2447d..374ef20 100644
--- a/tenacity/_asyncio.py
+++ b/tenacity/_asyncio.py
@@ -17,6 +17,7 @@
import functools
import sys
+import typing
from asyncio import sleep
from tenacity import AttemptManager
@@ -25,13 +26,21 @@
from tenacity import DoSleep
from tenacity import RetryCallState
+WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable)
+_RetValT = typing.TypeVar("_RetValT")
+
class AsyncRetrying(BaseRetrying):
- def __init__(self, sleep=sleep, **kwargs):
+ def __init__(self, sleep: typing.Callable[[float], typing.Awaitable] = sleep, **kwargs: typing.Any) -> None:
super().__init__(**kwargs)
self.sleep = sleep
- async def __call__(self, fn, *args, **kwargs):
+ async def __call__( # type: ignore # Change signature from supertype
+ self,
+ fn: typing.Callable[..., typing.Awaitable[_RetValT]],
+ *args: typing.Any,
+ **kwargs: typing.Any,
+ ) -> _RetValT:
self.begin()
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
@@ -50,12 +59,12 @@
else:
return do
- def __aiter__(self):
+ def __aiter__(self) -> "AsyncRetrying":
self.begin()
self._retry_state = RetryCallState(self, fn=None, args=(), kwargs={})
return self
- async def __anext__(self):
+ async def __anext__(self) -> typing.Union[AttemptManager, typing.Any]:
while True:
do = self.iter(retry_state=self._retry_state)
if do is None:
@@ -68,12 +77,12 @@
else:
return do
- def wraps(self, fn):
+ def wraps(self, fn: WrappedFn) -> WrappedFn:
fn = super().wraps(fn)
# Ensure wrapper is recognized as a coroutine function.
@functools.wraps(fn)
- async def async_wrapped(*args, **kwargs):
+ async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return await fn(*args, **kwargs)
# Preserve attributes
diff --git a/tenacity/tornadoweb.py b/tenacity/tornadoweb.py
index 1175bec..9d7b395 100644
--- a/tenacity/tornadoweb.py
+++ b/tenacity/tornadoweb.py
@@ -13,6 +13,7 @@
# limitations under the License.
import sys
+import typing
from tenacity import BaseRetrying
from tenacity import DoAttempt
@@ -21,14 +22,24 @@
from tornado import gen
+if typing.TYPE_CHECKING:
+ from tornado.concurrent import Future
+
+_RetValT = typing.TypeVar("_RetValT")
+
class TornadoRetrying(BaseRetrying):
- def __init__(self, sleep=gen.sleep, **kwargs):
+ def __init__(self, sleep: "typing.Callable[[float], Future[None]]" = gen.sleep, **kwargs: typing.Any) -> None:
super().__init__(**kwargs)
self.sleep = sleep
@gen.coroutine
- def __call__(self, fn, *args, **kwargs):
+ def __call__( # type: ignore # Change signature from supertype
+ self,
+ fn: "typing.Callable[..., typing.Union[typing.Generator[typing.Any, typing.Any, _RetValT], Future[_RetValT]]]",
+ *args: typing.Any,
+ **kwargs: typing.Any,
+ ) -> "typing.Generator[typing.Any, typing.Any, _RetValT]":
self.begin()
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
diff --git a/tenacity/wait.py b/tenacity/wait.py
index 2c1f4fb..aacb58d 100644
--- a/tenacity/wait.py
+++ b/tenacity/wait.py
@@ -111,7 +111,12 @@
(and restricting the upper limit to some maximum value).
"""
- def __init__(self, start=0, increment=100, max=_utils.MAX_WAIT): # noqa
+ def __init__(
+ self,
+ start: typing.Union[int, float] = 0,
+ increment: typing.Union[int, float] = 100,
+ max: typing.Union[int, float] = _utils.MAX_WAIT, # noqa
+ ) -> None:
self.start = start
self.increment = increment
self.max = max
@@ -134,7 +139,13 @@
wait_random_exponential for the latter case.
"""
- def __init__(self, multiplier=1, max=_utils.MAX_WAIT, exp_base=2, min=0): # noqa
+ def __init__(
+ self,
+ multiplier: typing.Union[int, float] = 1,
+ max: typing.Union[int, float] = _utils.MAX_WAIT, # noqa
+ exp_base: typing.Union[int, float] = 2,
+ min: typing.Union[int, float] = 0, # noqa
+ ) -> None:
self.multiplier = multiplier
self.min = min
self.max = max