fix: Avoid overwriting local contexts with retry decorator (#479)
* Avoid overwriting local contexts with retry decorator
* Add reno release note
diff --git a/releasenotes/notes/fix-local-context-overwrite-94190ba06a481631.yaml b/releasenotes/notes/fix-local-context-overwrite-94190ba06a481631.yaml
new file mode 100644
index 0000000..ff2ba7e
--- /dev/null
+++ b/releasenotes/notes/fix-local-context-overwrite-94190ba06a481631.yaml
@@ -0,0 +1,4 @@
+---
+fixes:
+ - |
+ Avoid overwriting local contexts when applying the retry decorator.
diff --git a/tenacity/__init__.py b/tenacity/__init__.py
index 7de36d4..06251ed 100644
--- a/tenacity/__init__.py
+++ b/tenacity/__init__.py
@@ -329,13 +329,19 @@
f, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
)
def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any:
- return self(f, *args, **kw)
+ # Always create a copy to prevent overwriting the local contexts when
+ # calling the same wrapped functions multiple times in the same stack
+ copy = self.copy()
+ wrapped_f.statistics = copy.statistics # type: ignore[attr-defined]
+ return copy(f, *args, **kw)
def retry_with(*args: t.Any, **kwargs: t.Any) -> WrappedFn:
return self.copy(*args, **kwargs).wraps(f)
- wrapped_f.retry = self # type: ignore[attr-defined]
+ # Preserve attributes
+ wrapped_f.retry = wrapped_f # type: ignore[attr-defined]
wrapped_f.retry_with = retry_with # type: ignore[attr-defined]
+ wrapped_f.statistics = {} # type: ignore[attr-defined]
return wrapped_f # type: ignore[return-value]
diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py
index 6d63ebc..38b76c7 100644
--- a/tenacity/asyncio/__init__.py
+++ b/tenacity/asyncio/__init__.py
@@ -175,18 +175,23 @@
raise StopAsyncIteration
def wraps(self, fn: WrappedFn) -> WrappedFn:
- fn = super().wraps(fn)
+ wrapped = super().wraps(fn)
# Ensure wrapper is recognized as a coroutine function.
@functools.wraps(
fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
)
async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
- return await fn(*args, **kwargs)
+ # Always create a copy to prevent overwriting the local contexts when
+ # calling the same wrapped functions multiple times in the same stack
+ copy = self.copy()
+ async_wrapped.statistics = copy.statistics # type: ignore[attr-defined]
+ return await copy(fn, *args, **kwargs)
# Preserve attributes
- async_wrapped.retry = fn.retry # type: ignore[attr-defined]
- async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined]
+ async_wrapped.retry = async_wrapped # type: ignore[attr-defined]
+ async_wrapped.retry_with = wrapped.retry_with # type: ignore[attr-defined]
+ async_wrapped.statistics = {} # type: ignore[attr-defined]
return async_wrapped # type: ignore[return-value]
diff --git a/tests/test_issue_478.py b/tests/test_issue_478.py
new file mode 100644
index 0000000..7489ad7
--- /dev/null
+++ b/tests/test_issue_478.py
@@ -0,0 +1,118 @@
+import asyncio
+import typing
+import unittest
+
+from functools import wraps
+
+from tenacity import RetryCallState, retry
+
+
+def asynctest(
+ callable_: typing.Callable[..., typing.Any],
+) -> typing.Callable[..., typing.Any]:
+ @wraps(callable_)
+ def wrapper(*a: typing.Any, **kw: typing.Any) -> typing.Any:
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(callable_(*a, **kw))
+
+ return wrapper
+
+
+MAX_RETRY_FIX_ATTEMPTS = 2
+
+
+class TestIssue478(unittest.TestCase):
+ def test_issue(self) -> None:
+ results = []
+
+ def do_retry(retry_state: RetryCallState) -> bool:
+ outcome = retry_state.outcome
+ assert outcome
+ ex = outcome.exception()
+ _subject_: str = retry_state.args[0]
+
+ if _subject_ == "Fix": # no retry on fix failure
+ return False
+
+ if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS:
+ return False
+
+ if ex:
+ do_fix_work()
+ return True
+
+ return False
+
+ @retry(reraise=True, retry=do_retry)
+ def _do_work(subject: str) -> None:
+ if subject == "Error":
+ results.append(f"{subject} is not working")
+ raise Exception(f"{subject} is not working")
+ results.append(f"{subject} is working")
+
+ def do_any_work(subject: str) -> None:
+ _do_work(subject)
+
+ def do_fix_work() -> None:
+ _do_work("Fix")
+
+ try:
+ do_any_work("Error")
+ except Exception as exc:
+ assert str(exc) == "Error is not working"
+ else:
+ assert False, "No exception caught"
+
+ assert results == [
+ "Error is not working",
+ "Fix is working",
+ "Error is not working",
+ ]
+
+ @asynctest
+ async def test_async(self) -> None:
+ results = []
+
+ async def do_retry(retry_state: RetryCallState) -> bool:
+ outcome = retry_state.outcome
+ assert outcome
+ ex = outcome.exception()
+ _subject_: str = retry_state.args[0]
+
+ if _subject_ == "Fix": # no retry on fix failure
+ return False
+
+ if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS:
+ return False
+
+ if ex:
+ await do_fix_work()
+ return True
+
+ return False
+
+ @retry(reraise=True, retry=do_retry)
+ async def _do_work(subject: str) -> None:
+ if subject == "Error":
+ results.append(f"{subject} is not working")
+ raise Exception(f"{subject} is not working")
+ results.append(f"{subject} is working")
+
+ async def do_any_work(subject: str) -> None:
+ await _do_work(subject)
+
+ async def do_fix_work() -> None:
+ await _do_work("Fix")
+
+ try:
+ await do_any_work("Error")
+ except Exception as exc:
+ assert str(exc) == "Error is not working"
+ else:
+ assert False, "No exception caught"
+
+ assert results == [
+ "Error is not working",
+ "Fix is working",
+ "Error is not working",
+ ]