blob: 982ec90333416ad78a0b9c1d1fab3ae40479a970 [file]
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests metadata flags feature by testing wait-for-ready semantics"""
import logging
import socket
import threading
import time
import unittest
import weakref
import grpc
from six.moves import queue
from tests.unit import test_common
import tests.unit.framework.common
from tests.unit.framework.common import get_socket
from tests.unit.framework.common import test_constants
_UNARY_UNARY = '/test/UnaryUnary'
_UNARY_STREAM = '/test/UnaryStream'
_STREAM_UNARY = '/test/StreamUnary'
_STREAM_STREAM = '/test/StreamStream'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x00\x00\x00'
def handle_unary_unary(test, request, servicer_context):
return _RESPONSE
def handle_unary_stream(test, request, servicer_context):
for _ in range(test_constants.STREAM_LENGTH):
yield _RESPONSE
def handle_stream_unary(test, request_iterator, servicer_context):
for _ in request_iterator:
pass
return _RESPONSE
def handle_stream_stream(test, request_iterator, servicer_context):
for _ in request_iterator:
yield _RESPONSE
class _MethodHandler(grpc.RpcMethodHandler):
def __init__(self, test, request_streaming, response_streaming):
self.request_streaming = request_streaming
self.response_streaming = response_streaming
self.request_deserializer = None
self.response_serializer = None
self.unary_unary = None
self.unary_stream = None
self.stream_unary = None
self.stream_stream = None
if self.request_streaming and self.response_streaming:
self.stream_stream = lambda req, ctx: handle_stream_stream(
test, req, ctx)
elif self.request_streaming:
self.stream_unary = lambda req, ctx: handle_stream_unary(
test, req, ctx)
elif self.response_streaming:
self.unary_stream = lambda req, ctx: handle_unary_stream(
test, req, ctx)
else:
self.unary_unary = lambda req, ctx: handle_unary_unary(
test, req, ctx)
class _GenericHandler(grpc.GenericRpcHandler):
def __init__(self, test):
self._test = test
def service(self, handler_call_details):
if handler_call_details.method == _UNARY_UNARY:
return _MethodHandler(self._test, False, False)
elif handler_call_details.method == _UNARY_STREAM:
return _MethodHandler(self._test, False, True)
elif handler_call_details.method == _STREAM_UNARY:
return _MethodHandler(self._test, True, False)
elif handler_call_details.method == _STREAM_STREAM:
return _MethodHandler(self._test, True, True)
else:
return None
def create_phony_channel():
"""Creating phony channels is a workaround for retries"""
host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,))
sock.close()
return grpc.insecure_channel('{}:{}'.format(host, port))
def perform_unary_unary_call(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).__call__(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
def perform_unary_unary_with_call(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).with_call(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
def perform_unary_unary_future(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).future(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready).result(
timeout=test_constants.LONG_TIMEOUT)
def perform_unary_stream_call(channel, wait_for_ready=None):
response_iterator = channel.unary_stream(_UNARY_STREAM).__call__(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
for _ in response_iterator:
pass
def perform_stream_unary_call(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).__call__(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
def perform_stream_unary_with_call(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).with_call(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
def perform_stream_unary_future(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).future(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready).result(
timeout=test_constants.LONG_TIMEOUT)
def perform_stream_stream_call(channel, wait_for_ready=None):
response_iterator = channel.stream_stream(_STREAM_STREAM).__call__(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready)
for _ in response_iterator:
pass
_ALL_CALL_CASES = [
perform_unary_unary_call, perform_unary_unary_with_call,
perform_unary_unary_future, perform_unary_stream_call,
perform_stream_unary_call, perform_stream_unary_with_call,
perform_stream_unary_future, perform_stream_stream_call
]
class MetadataFlagsTest(unittest.TestCase):
def check_connection_does_failfast(self, fn, channel, wait_for_ready=None):
try:
fn(channel, wait_for_ready)
self.fail("The Call should fail")
except BaseException as e: # pylint: disable=broad-except
self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code())
def test_call_wait_for_ready_default(self):
for perform_call in _ALL_CALL_CASES:
with create_phony_channel() as channel:
self.check_connection_does_failfast(perform_call, channel)
def test_call_wait_for_ready_disabled(self):
for perform_call in _ALL_CALL_CASES:
with create_phony_channel() as channel:
self.check_connection_does_failfast(perform_call,
channel,
wait_for_ready=False)
def test_call_wait_for_ready_enabled(self):
# To test the wait mechanism, Python thread is required to make
# client set up first without handling them case by case.
# Also, Python thread don't pass the unhandled exceptions to
# main thread. So, it need another method to store the
# exceptions and raise them again in main thread.
unhandled_exceptions = queue.Queue()
# We just need an unused TCP port
host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,))
sock.close()
addr = '{}:{}'.format(host, port)
wg = test_common.WaitGroup(len(_ALL_CALL_CASES))
def wait_for_transient_failure(channel_connectivity):
if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE:
wg.done()
def test_call(perform_call):
with grpc.insecure_channel(addr) as channel:
try:
channel.subscribe(wait_for_transient_failure)
perform_call(channel, wait_for_ready=True)
except BaseException as e: # pylint: disable=broad-except
# If the call failed, the thread would be destroyed. The
# channel object can be collected before calling the
# callback, which will result in a deadlock.
wg.done()
unhandled_exceptions.put(e, True)
test_threads = []
for perform_call in _ALL_CALL_CASES:
test_thread = threading.Thread(target=test_call,
args=(perform_call,))
test_thread.daemon = True
test_thread.exception = None
test_thread.start()
test_threads.append(test_thread)
# Start the server after the connections are waiting
wg.wait()
server = test_common.test_server(reuse_port=True)
server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),))
server.add_insecure_port(addr)
server.start()
for test_thread in test_threads:
test_thread.join()
# Stop the server to make test end properly
server.stop(0)
if not unhandled_exceptions.empty():
raise unhandled_exceptions.get(True)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)