blob: 6f20494c94f6e0572da320a4a0d22e5bb1549ce7 [file] [log] [blame]
import sys
if sys.version_info[:2] < (3, 8):
import asyncio, functools
from asyncio.coroutines import _is_coroutine
from inspect import ismethod, isfunction, CO_COROUTINE
from unittest import TestCase
def _unwrap_partial(func):
while isinstance(func, functools.partial):
func = func.func
return func
def _has_code_flag(f, flag):
"""Return true if ``f`` is a function (or a method or functools.partial
wrapper wrapping a function) whose code object has the given ``flag``
set in its flags."""
while ismethod(f):
f = f.__func__
f = _unwrap_partial(f)
if not isfunction(f):
return False
return bool(f.__code__.co_flags & flag)
def iscoroutinefunction(obj):
"""Return true if the object is a coroutine function.
Coroutine functions are defined with "async def" syntax.
"""
return (
_has_code_flag(obj, CO_COROUTINE) or
getattr(obj, '_is_coroutine', None) is _is_coroutine
)
class IsolatedAsyncioTestCase(TestCase):
def __init__(self, methodName='runTest'):
super().__init__(methodName)
self._asyncioTestLoop = None
self._asyncioCallsQueue = None
async def _asyncioLoopRunner(self, fut):
self._asyncioCallsQueue = queue = asyncio.Queue()
fut.set_result(None)
while True:
query = await queue.get()
queue.task_done()
assert query is None
def _setupAsyncioLoop(self):
assert self._asyncioTestLoop is None
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.set_debug(True)
self._asyncioTestLoop = loop
fut = loop.create_future()
self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut))
loop.run_until_complete(fut)
def _tearDownAsyncioLoop(self):
assert self._asyncioTestLoop is not None
loop = self._asyncioTestLoop
self._asyncioTestLoop = None
self._asyncioCallsQueue.put_nowait(None)
loop.run_until_complete(self._asyncioCallsQueue.join())
try:
# shutdown asyncgens
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
def run(self, result=None):
self._setupAsyncioLoop()
try:
return super().run(result)
finally:
self._tearDownAsyncioLoop()
else:
from asyncio import iscoroutinefunction
from unittest import IsolatedAsyncioTestCase