blob: 7367c454fad5c3ab0e1602d0efdacbae4a211d05 [file]
# Copyright 2019 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests_aio.unit import _common
from tests_aio.unit import _constants
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import _INITIAL_METADATA_KEY
from tests_aio.unit._test_server import _TRAILING_METADATA_KEY
from tests_aio.unit._test_server import start_test_server
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_INITIAL_METADATA_TO_INJECT = aio.Metadata(
(_INITIAL_METADATA_KEY, 'extra info'),
(_TRAILING_METADATA_KEY, b'\x13\x37'),
)
_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED = 1.0
class TestUnaryUnaryClientInterceptor(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
def test_invalid_interceptor(self):
class InvalidInterceptor:
"""Just an invalid Interceptor"""
with self.assertRaises(ValueError):
aio.insecure_channel("", interceptors=[InvalidInterceptor()])
async def test_executed_right_order(self):
interceptors_executed = []
class Interceptor(aio.UnaryUnaryClientInterceptor):
"""Interceptor used for testing if the interceptor is being called"""
async def intercept_unary_unary(self, continuation,
client_call_details, request):
interceptors_executed.append(self)
call = await continuation(client_call_details, request)
return call
interceptors = [Interceptor() for i in range(2)]
async with aio.insecure_channel(self._server_target,
interceptors=interceptors) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
# Check that all interceptors were executed, and were executed
# in the right order.
self.assertSequenceEqual(interceptors_executed, interceptors)
self.assertIsInstance(response, messages_pb2.SimpleResponse)
@unittest.expectedFailure
# TODO(https://github.com/grpc/grpc/issues/20144) Once metadata support is
# implemented in the client-side, this test must be implemented.
def test_modify_metadata(self):
raise NotImplementedError()
@unittest.expectedFailure
# TODO(https://github.com/grpc/grpc/issues/20532) Once credentials support is
# implemented in the client-side, this test must be implemented.
def test_modify_credentials(self):
raise NotImplementedError()
async def test_status_code_Ok(self):
class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor):
"""Interceptor used for observing status code Ok returned by the RPC"""
def __init__(self):
self.status_code_Ok_observed = False
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
code = await call.code()
if code == grpc.StatusCode.OK:
self.status_code_Ok_observed = True
return call
interceptor = StatusCodeOkInterceptor()
async with aio.insecure_channel(self._server_target,
interceptors=[interceptor]) as channel:
# when no error StatusCode.OK must be observed
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
await multicallable(messages_pb2.SimpleRequest())
self.assertTrue(interceptor.status_code_Ok_observed)
async def test_add_timeout(self):
class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor):
"""Interceptor used for adding a timeout to the RPC"""
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready)
return await continuation(new_client_call_details, request)
interceptor = TimeoutInterceptor()
async with aio.insecure_channel(self._server_target,
interceptors=[interceptor]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
call.code())
async def test_retry(self):
class RetryInterceptor(aio.UnaryUnaryClientInterceptor):
"""Simulates a Retry Interceptor which ends up by making
two RPC calls."""
def __init__(self):
self.calls = []
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready)
try:
call = await continuation(new_client_call_details, request)
await call
except grpc.RpcError:
pass
self.calls.append(call)
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=None,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready)
call = await continuation(new_client_call_details, request)
self.calls.append(call)
return call
interceptor = RetryInterceptor()
async with aio.insecure_channel(self._server_target,
interceptors=[interceptor]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
await call
self.assertEqual(grpc.StatusCode.OK, await call.code())
# Check that two calls were made, first one finishing with
# a deadline and second one finishing ok..
self.assertEqual(len(interceptor.calls), 2)
self.assertEqual(await interceptor.calls[0].code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await interceptor.calls[1].code(),
grpc.StatusCode.OK)
async def test_rpcresponse(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
"""Raw responses are seen as reegular calls"""
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
response = await call
return call
class ResponseInterceptor(aio.UnaryUnaryClientInterceptor):
"""Return a raw response"""
response = messages_pb2.SimpleResponse()
async def intercept_unary_unary(self, continuation,
client_call_details, request):
return ResponseInterceptor.response
interceptor, interceptor_response = Interceptor(), ResponseInterceptor()
async with aio.insecure_channel(
self._server_target,
interceptors=[interceptor, interceptor_response]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
# Check that the response returned is the one returned by the
# interceptor
self.assertEqual(id(response), id(ResponseInterceptor.response))
# Check all of the UnaryUnaryCallResponse attributes
self.assertTrue(call.done())
self.assertFalse(call.cancel())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
self.assertEqual(await call.debug_error_string(), None)
class TestInterceptedUnaryUnaryCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_call_ok(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_ok_awaited(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
await call
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_rpc_error(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(
messages_pb2.SimpleRequest(),
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_call_rpc_error_awaited(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
await call
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(
messages_pb2.SimpleRequest(),
timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(await call.trailing_metadata(), aio.Metadata())
async def test_cancel_before_rpc(self):
interceptor_reached = asyncio.Event()
wait_for_ever = self.loop.create_future()
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
interceptor_reached.set()
await wait_for_ever
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
async def test_cancel_after_rpc(self):
interceptor_reached = asyncio.Event()
wait_for_ever = self.loop.create_future()
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
await call
interceptor_reached.set()
await wait_for_ever
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
async def test_cancel_inside_interceptor_after_rpc_awaiting(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
call.cancel()
await call
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
call.cancel()
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), aio.Metadata())
self.assertEqual(
await call.trailing_metadata(), aio.Metadata(),
"When the raw response is None, empty metadata is returned")
async def test_initial_metadata_modification(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_metadata = aio.Metadata(*client_call_details.metadata,
*_INITIAL_METADATA_TO_INJECT)
new_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=new_metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready,
)
return await continuation(new_details, request)
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.UnaryCall(messages_pb2.SimpleRequest())
# Expected to see the echoed initial metadata
self.assertTrue(
_common.seen_metadatum(
expected_key=_INITIAL_METADATA_KEY,
expected_value=_INITIAL_METADATA_TO_INJECT[
_INITIAL_METADATA_KEY],
actual=await call.initial_metadata(),
))
# Expected to see the echoed trailing metadata
self.assertTrue(
_common.seen_metadatum(
expected_key=_TRAILING_METADATA_KEY,
expected_value=_INITIAL_METADATA_TO_INJECT[
_TRAILING_METADATA_KEY],
actual=await call.trailing_metadata(),
))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_add_done_callback_before_finishes(self):
called = asyncio.Event()
interceptor_can_continue = asyncio.Event()
def callback(call):
called.set()
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
await interceptor_can_continue.wait()
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
call.add_done_callback(callback)
interceptor_can_continue.set()
await call
try:
await asyncio.wait_for(
called.wait(),
timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED)
except:
self.fail("Callback was not called")
async def test_add_done_callback_after_finishes(self):
called = asyncio.Event()
def callback(call):
called.set()
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
await call
call.add_done_callback(callback)
try:
await asyncio.wait_for(
called.wait(),
timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED)
except:
self.fail("Callback was not called")
async def test_add_done_callback_after_finishes_before_await(self):
called = asyncio.Event()
def callback(call):
called.set()
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
call.add_done_callback(callback)
await call
try:
await asyncio.wait_for(
called.wait(),
timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED)
except:
self.fail("Callback was not called")
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)