Automated change: Fix sanity tests
diff --git a/src/core/ext/transport/chaotic_good/server_transport.h b/src/core/ext/transport/chaotic_good/server_transport.h index b5d0cf9..9ce9292 100644 --- a/src/core/ext/transport/chaotic_good/server_transport.h +++ b/src/core/ext/transport/chaotic_good/server_transport.h
@@ -15,11 +15,8 @@ #ifndef GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H #define GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H -#include <grpc/event_engine/event_engine.h> -#include <grpc/event_engine/memory_allocator.h> -#include <grpc/slice.h> -#include <grpc/support/log.h> #include <grpc/support/port_platform.h> + #include <stdint.h> #include <stdio.h> @@ -42,6 +39,12 @@ #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/variant.h" + +#include <grpc/event_engine/event_engine.h> +#include <grpc/event_engine/memory_allocator.h> +#include <grpc/slice.h> +#include <grpc/support/log.h> + #include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h" #include "src/core/ext/transport/chaotic_good/frame.h" #include "src/core/ext/transport/chaotic_good/frame_header.h"
diff --git a/src/core/lib/event_engine/trace.cc b/src/core/lib/event_engine/trace.cc index 33b2fa7..20ae3ae 100644 --- a/src/core/lib/event_engine/trace.cc +++ b/src/core/lib/event_engine/trace.cc
@@ -11,10 +11,10 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "src/core/lib/debug/trace.h" - #include <grpc/support/port_platform.h> +#include "src/core/lib/debug/trace.h" + grpc_core::TraceFlag grpc_event_engine_trace(false, "event_engine"); grpc_core::TraceFlag grpc_event_engine_dns_trace(false, "event_engine_dns"); grpc_core::TraceFlag grpc_event_engine_endpoint_trace(false,
diff --git a/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc b/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc index ddf1a12..75498eb 100644 --- a/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc +++ b/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc
@@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <grpc/support/port_platform.h> + #include <grpc/grpc.h> #include <grpc/grpc_security.h> -#include <grpc/support/port_platform.h> #include "src/core/lib/gpr/useful.h" #include "src/core/lib/gprpp/ref_counted_ptr.h"
diff --git a/src/core/lib/security/certificate_provider/certificate_provider_registry.h b/src/core/lib/security/certificate_provider/certificate_provider_registry.h index 6368845..5cc05ae 100644 --- a/src/core/lib/security/certificate_provider/certificate_provider_registry.h +++ b/src/core/lib/security/certificate_provider/certificate_provider_registry.h
@@ -26,6 +26,7 @@ #include <utility> #include "absl/strings/string_view.h" + #include "src/core/lib/security/certificate_provider/certificate_provider_factory.h" namespace grpc_core {
diff --git a/src/core/lib/security/credentials/insecure/insecure_credentials.cc b/src/core/lib/security/credentials/insecure/insecure_credentials.cc index e996072..cce66f6 100644 --- a/src/core/lib/security/credentials/insecure/insecure_credentials.cc +++ b/src/core/lib/security/credentials/insecure/insecure_credentials.cc
@@ -16,10 +16,10 @@ // // -#include "src/core/lib/security/credentials/insecure/insecure_credentials.h" - #include <grpc/support/port_platform.h> +#include "src/core/lib/security/credentials/insecure/insecure_credentials.h" + #include <utility> #include "src/core/lib/channel/channel_args.h"
diff --git a/src/core/lib/transport/promise_endpoint.h b/src/core/lib/transport/promise_endpoint.h index 36c3713..fbdc467 100644 --- a/src/core/lib/transport/promise_endpoint.h +++ b/src/core/lib/transport/promise_endpoint.h
@@ -15,12 +15,8 @@ #ifndef GRPC_SRC_CORE_LIB_TRANSPORT_PROMISE_ENDPOINT_H #define GRPC_SRC_CORE_LIB_TRANSPORT_PROMISE_ENDPOINT_H -#include <grpc/event_engine/event_engine.h> -#include <grpc/event_engine/slice.h> -#include <grpc/event_engine/slice_buffer.h> -#include <grpc/slice_buffer.h> -#include <grpc/support/log.h> #include <grpc/support/port_platform.h> + #include <stddef.h> #include <stdint.h> @@ -33,6 +29,13 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" + +#include <grpc/event_engine/event_engine.h> +#include <grpc/event_engine/slice.h> +#include <grpc/event_engine/slice_buffer.h> +#include <grpc/slice_buffer.h> +#include <grpc/support/log.h> + #include "src/core/lib/gprpp/sync.h" #include "src/core/lib/promise/activity.h" #include "src/core/lib/promise/if.h"
diff --git a/src/core/tsi/transport_security_grpc.cc b/src/core/tsi/transport_security_grpc.cc index 4d120d9..3f04ce9 100644 --- a/src/core/tsi/transport_security_grpc.cc +++ b/src/core/tsi/transport_security_grpc.cc
@@ -16,10 +16,10 @@ // // -#include "src/core/tsi/transport_security_grpc.h" - #include <grpc/support/port_platform.h> +#include "src/core/tsi/transport_security_grpc.h" + // This method creates a tsi_zero_copy_grpc_protector object. tsi_result tsi_handshaker_result_create_zero_copy_grpc_protector( const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py index 1a44fc7..9310f8a 100644 --- a/src/python/grpcio_tests/tests/qps/benchmark_client.py +++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py
@@ -20,6 +20,7 @@ import time import grpc + from src.proto.grpc.testing import benchmark_service_pb2_grpc from src.proto.grpc.testing import messages_pb2 from tests.unit import resources @@ -29,229 +30,229 @@ class GenericStub(object): - - def __init__(self, channel): - self.UnaryCall = channel.unary_unary( - "/grpc.testing.BenchmarkService/UnaryCall", - _registered_method=True, - ) - self.StreamingFromServer = channel.unary_stream( - "/grpc.testing.BenchmarkService/StreamingFromServer", - _registered_method=True, - ) - self.StreamingCall = channel.stream_stream( - "/grpc.testing.BenchmarkService/StreamingCall", - _registered_method=True, - ) + def __init__(self, channel): + self.UnaryCall = channel.unary_unary( + "/grpc.testing.BenchmarkService/UnaryCall", + _registered_method=True, + ) + self.StreamingFromServer = channel.unary_stream( + "/grpc.testing.BenchmarkService/StreamingFromServer", + _registered_method=True, + ) + self.StreamingCall = channel.stream_stream( + "/grpc.testing.BenchmarkService/StreamingCall", + _registered_method=True, + ) class BenchmarkClient: - """Benchmark client interface that exposes a non-blocking send_request().""" + """Benchmark client interface that exposes a non-blocking send_request().""" - __metaclass__ = abc.ABCMeta + __metaclass__ = abc.ABCMeta - def __init__(self, server, config, hist): - # Create the stub - if config.HasField("security_params"): - creds = grpc.ssl_channel_credentials(resources.test_root_certificates()) - channel = test_common.test_secure_channel( - server, creds, config.security_params.server_host_override - ) - else: - channel = grpc.insecure_channel(server) + def __init__(self, server, config, hist): + # Create the stub + if config.HasField("security_params"): + creds = grpc.ssl_channel_credentials( + resources.test_root_certificates() + ) + channel = test_common.test_secure_channel( + server, creds, config.security_params.server_host_override + ) + else: + channel = grpc.insecure_channel(server) - # waits for the channel to be ready before we start sending messages - grpc.channel_ready_future(channel).result() + # waits for the channel to be ready before we start sending messages + grpc.channel_ready_future(channel).result() - if config.payload_config.WhichOneof("payload") == "simple_params": - self._generic = False - self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(channel) - payload = messages_pb2.Payload( - body=bytes(b"\0" * config.payload_config.simple_params.req_size) - ) - self._request = messages_pb2.SimpleRequest( - payload=payload, - response_size=config.payload_config.simple_params.resp_size, - ) - else: - self._generic = True - self._stub = GenericStub(channel) - self._request = bytes( - b"\0" * config.payload_config.bytebuf_params.req_size - ) + if config.payload_config.WhichOneof("payload") == "simple_params": + self._generic = False + self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( + channel + ) + payload = messages_pb2.Payload( + body=bytes(b"\0" * config.payload_config.simple_params.req_size) + ) + self._request = messages_pb2.SimpleRequest( + payload=payload, + response_size=config.payload_config.simple_params.resp_size, + ) + else: + self._generic = True + self._stub = GenericStub(channel) + self._request = bytes( + b"\0" * config.payload_config.bytebuf_params.req_size + ) - self._hist = hist - self._response_callbacks = [] + self._hist = hist + self._response_callbacks = [] - def add_response_callback(self, callback): - """callback will be invoked as callback(client, query_time)""" - self._response_callbacks.append(callback) + def add_response_callback(self, callback): + """callback will be invoked as callback(client, query_time)""" + self._response_callbacks.append(callback) - @abc.abstractmethod - def send_request(self): - """Non-blocking wrapper for a client's request operation.""" - raise NotImplementedError() + @abc.abstractmethod + def send_request(self): + """Non-blocking wrapper for a client's request operation.""" + raise NotImplementedError() - def start(self): - pass + def start(self): + pass - def stop(self): - pass + def stop(self): + pass - def _handle_response(self, client, query_time): - self._hist.add(query_time * 1e9) # Report times in nanoseconds - for callback in self._response_callbacks: - callback(client, query_time) + def _handle_response(self, client, query_time): + self._hist.add(query_time * 1e9) # Report times in nanoseconds + for callback in self._response_callbacks: + callback(client, query_time) class UnarySyncBenchmarkClient(BenchmarkClient): + def __init__(self, server, config, hist): + super(UnarySyncBenchmarkClient, self).__init__(server, config, hist) + self._pool = futures.ThreadPoolExecutor( + max_workers=config.outstanding_rpcs_per_channel + ) - def __init__(self, server, config, hist): - super(UnarySyncBenchmarkClient, self).__init__(server, config, hist) - self._pool = futures.ThreadPoolExecutor( - max_workers=config.outstanding_rpcs_per_channel - ) + def send_request(self): + # Send requests in separate threads to support multiple outstanding rpcs + # (See src/proto/grpc/testing/control.proto) + self._pool.submit(self._dispatch_request) - def send_request(self): - # Send requests in separate threads to support multiple outstanding rpcs - # (See src/proto/grpc/testing/control.proto) - self._pool.submit(self._dispatch_request) + def stop(self): + self._pool.shutdown(wait=True) + self._stub = None - def stop(self): - self._pool.shutdown(wait=True) - self._stub = None - - def _dispatch_request(self): - start_time = time.time() - self._stub.UnaryCall(self._request, _TIMEOUT) - end_time = time.time() - self._handle_response(self, end_time - start_time) + def _dispatch_request(self): + start_time = time.time() + self._stub.UnaryCall(self._request, _TIMEOUT) + end_time = time.time() + self._handle_response(self, end_time - start_time) class UnaryAsyncBenchmarkClient(BenchmarkClient): + def send_request(self): + # Use the Future callback api to support multiple outstanding rpcs + start_time = time.time() + response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT) + response_future.add_done_callback( + lambda resp: self._response_received(start_time, resp) + ) - def send_request(self): - # Use the Future callback api to support multiple outstanding rpcs - start_time = time.time() - response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT) - response_future.add_done_callback( - lambda resp: self._response_received(start_time, resp) - ) + def _response_received(self, start_time, resp): + resp.result() + end_time = time.time() + self._handle_response(self, end_time - start_time) - def _response_received(self, start_time, resp): - resp.result() - end_time = time.time() - self._handle_response(self, end_time - start_time) - - def stop(self): - self._stub = None + def stop(self): + self._stub = None class _SyncStream(object): + def __init__(self, stub, generic, request, handle_response): + self._stub = stub + self._generic = generic + self._request = request + self._handle_response = handle_response + self._is_streaming = False + self._request_queue = queue.Queue() + self._send_time_queue = queue.Queue() - def __init__(self, stub, generic, request, handle_response): - self._stub = stub - self._generic = generic - self._request = request - self._handle_response = handle_response - self._is_streaming = False - self._request_queue = queue.Queue() - self._send_time_queue = queue.Queue() + def send_request(self): + self._send_time_queue.put(time.time()) + self._request_queue.put(self._request) - def send_request(self): - self._send_time_queue.put(time.time()) - self._request_queue.put(self._request) + def start(self): + self._is_streaming = True + response_stream = self._stub.StreamingCall( + self._request_generator(), _TIMEOUT + ) + for _ in response_stream: + self._handle_response( + self, time.time() - self._send_time_queue.get_nowait() + ) - def start(self): - self._is_streaming = True - response_stream = self._stub.StreamingCall( - self._request_generator(), _TIMEOUT - ) - for _ in response_stream: - self._handle_response( - self, time.time() - self._send_time_queue.get_nowait() - ) + def stop(self): + self._is_streaming = False - def stop(self): - self._is_streaming = False - - def _request_generator(self): - while self._is_streaming: - try: - request = self._request_queue.get(block=True, timeout=1.0) - yield request - except queue.Empty: - pass + def _request_generator(self): + while self._is_streaming: + try: + request = self._request_queue.get(block=True, timeout=1.0) + yield request + except queue.Empty: + pass class StreamingSyncBenchmarkClient(BenchmarkClient): - - def __init__(self, server, config, hist): - super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist) - self._pool = futures.ThreadPoolExecutor( - max_workers=config.outstanding_rpcs_per_channel - ) - self._streams = [ - _SyncStream( - self._stub, self._generic, self._request, self._handle_response + def __init__(self, server, config, hist): + super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist) + self._pool = futures.ThreadPoolExecutor( + max_workers=config.outstanding_rpcs_per_channel ) - for _ in range(config.outstanding_rpcs_per_channel) - ] - self._curr_stream = 0 + self._streams = [ + _SyncStream( + self._stub, self._generic, self._request, self._handle_response + ) + for _ in range(config.outstanding_rpcs_per_channel) + ] + self._curr_stream = 0 - def send_request(self): - # Use a round_robin scheduler to determine what stream to send on - self._streams[self._curr_stream].send_request() - self._curr_stream = (self._curr_stream + 1) % len(self._streams) + def send_request(self): + # Use a round_robin scheduler to determine what stream to send on + self._streams[self._curr_stream].send_request() + self._curr_stream = (self._curr_stream + 1) % len(self._streams) - def start(self): - for stream in self._streams: - self._pool.submit(stream.start) + def start(self): + for stream in self._streams: + self._pool.submit(stream.start) - def stop(self): - for stream in self._streams: - stream.stop() - self._pool.shutdown(wait=True) - self._stub = None + def stop(self): + for stream in self._streams: + stream.stop() + self._pool.shutdown(wait=True) + self._stub = None class ServerStreamingSyncBenchmarkClient(BenchmarkClient): + def __init__(self, server, config, hist): + super(ServerStreamingSyncBenchmarkClient, self).__init__( + server, config, hist + ) + if config.outstanding_rpcs_per_channel == 1: + self._pool = None + else: + self._pool = futures.ThreadPoolExecutor( + max_workers=config.outstanding_rpcs_per_channel + ) + self._rpcs = [] + self._sender = None - def __init__(self, server, config, hist): - super(ServerStreamingSyncBenchmarkClient, self).__init__( - server, config, hist - ) - if config.outstanding_rpcs_per_channel == 1: - self._pool = None - else: - self._pool = futures.ThreadPoolExecutor( - max_workers=config.outstanding_rpcs_per_channel - ) - self._rpcs = [] - self._sender = None + def send_request(self): + if self._pool is None: + self._sender = threading.Thread( + target=self._one_stream_streaming_rpc, daemon=True + ) + self._sender.start() + else: + self._pool.submit(self._one_stream_streaming_rpc) - def send_request(self): - if self._pool is None: - self._sender = threading.Thread( - target=self._one_stream_streaming_rpc, daemon=True - ) - self._sender.start() - else: - self._pool.submit(self._one_stream_streaming_rpc) + def _one_stream_streaming_rpc(self): + response_stream = self._stub.StreamingFromServer( + self._request, _TIMEOUT + ) + self._rpcs.append(response_stream) + start_time = time.time() + for _ in response_stream: + self._handle_response(self, time.time() - start_time) + start_time = time.time() - def _one_stream_streaming_rpc(self): - response_stream = self._stub.StreamingFromServer(self._request, _TIMEOUT) - self._rpcs.append(response_stream) - start_time = time.time() - for _ in response_stream: - self._handle_response(self, time.time() - start_time) - start_time = time.time() - - def stop(self): - for call in self._rpcs: - call.cancel() - if self._sender is not None: - self._sender.join() - if self._pool is not None: - self._pool.shutdown(wait=False) - self._stub = None + def stop(self): + for call in self._rpcs: + call.cancel() + if self._sender is not None: + self._sender.join() + if self._pool is not None: + self._pool.shutdown(wait=False) + self._stub = None
diff --git a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py index 8882155..1b7e2e5 100644 --- a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py +++ b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py
@@ -19,6 +19,7 @@ import time import grpc + from tests.unit.framework.common import test_constants WAIT_TIME = 1000 @@ -57,198 +58,195 @@ def hang_unary_unary(request, servicer_context): - time.sleep(WAIT_TIME) + time.sleep(WAIT_TIME) def hang_unary_stream(request, servicer_context): - time.sleep(WAIT_TIME) + time.sleep(WAIT_TIME) def hang_partial_unary_stream(request, servicer_context): - for _ in range(test_constants.STREAM_LENGTH // 2): - yield request - time.sleep(WAIT_TIME) + for _ in range(test_constants.STREAM_LENGTH // 2): + yield request + time.sleep(WAIT_TIME) def hang_stream_unary(request_iterator, servicer_context): - time.sleep(WAIT_TIME) + time.sleep(WAIT_TIME) def hang_partial_stream_unary(request_iterator, servicer_context): - for _ in range(test_constants.STREAM_LENGTH // 2): - next(request_iterator) - time.sleep(WAIT_TIME) + for _ in range(test_constants.STREAM_LENGTH // 2): + next(request_iterator) + time.sleep(WAIT_TIME) def hang_stream_stream(request_iterator, servicer_context): - time.sleep(WAIT_TIME) + time.sleep(WAIT_TIME) def hang_partial_stream_stream(request_iterator, servicer_context): - for _ in range(test_constants.STREAM_LENGTH // 2): - yield next(request_iterator) # pylint: disable=stop-iteration-return - time.sleep(WAIT_TIME) + for _ in range(test_constants.STREAM_LENGTH // 2): + yield next(request_iterator) # pylint: disable=stop-iteration-return + time.sleep(WAIT_TIME) class MethodHandler(grpc.RpcMethodHandler): - - def __init__(self, request_streaming, response_streaming, partial_hang): - self.request_streaming = request_streaming - self.response_streaming = response_streaming - self.request_deserializer = None - self.response_serializer = None - self.unary_unary = None - self.unary_stream = None - self.stream_unary = None - self.stream_stream = None - if self.request_streaming and self.response_streaming: - if partial_hang: - self.stream_stream = hang_partial_stream_stream - else: - self.stream_stream = hang_stream_stream - elif self.request_streaming: - if partial_hang: - self.stream_unary = hang_partial_stream_unary - else: - self.stream_unary = hang_stream_unary - elif self.response_streaming: - if partial_hang: - self.unary_stream = hang_partial_unary_stream - else: - self.unary_stream = hang_unary_stream - else: - self.unary_unary = hang_unary_unary + def __init__(self, request_streaming, response_streaming, partial_hang): + self.request_streaming = request_streaming + self.response_streaming = response_streaming + self.request_deserializer = None + self.response_serializer = None + self.unary_unary = None + self.unary_stream = None + self.stream_unary = None + self.stream_stream = None + if self.request_streaming and self.response_streaming: + if partial_hang: + self.stream_stream = hang_partial_stream_stream + else: + self.stream_stream = hang_stream_stream + elif self.request_streaming: + if partial_hang: + self.stream_unary = hang_partial_stream_unary + else: + self.stream_unary = hang_stream_unary + elif self.response_streaming: + if partial_hang: + self.unary_stream = hang_partial_unary_stream + else: + self.unary_stream = hang_unary_stream + else: + self.unary_unary = hang_unary_unary class GenericHandler(grpc.GenericRpcHandler): - - def service(self, handler_call_details): - if handler_call_details.method == UNARY_UNARY: - return MethodHandler(False, False, False) - elif handler_call_details.method == UNARY_STREAM: - return MethodHandler(False, True, False) - elif handler_call_details.method == STREAM_UNARY: - return MethodHandler(True, False, False) - elif handler_call_details.method == STREAM_STREAM: - return MethodHandler(True, True, False) - elif handler_call_details.method == PARTIAL_UNARY_STREAM: - return MethodHandler(False, True, True) - elif handler_call_details.method == PARTIAL_STREAM_UNARY: - return MethodHandler(True, False, True) - elif handler_call_details.method == PARTIAL_STREAM_STREAM: - return MethodHandler(True, True, True) - else: - return None + def service(self, handler_call_details): + if handler_call_details.method == UNARY_UNARY: + return MethodHandler(False, False, False) + elif handler_call_details.method == UNARY_STREAM: + return MethodHandler(False, True, False) + elif handler_call_details.method == STREAM_UNARY: + return MethodHandler(True, False, False) + elif handler_call_details.method == STREAM_STREAM: + return MethodHandler(True, True, False) + elif handler_call_details.method == PARTIAL_UNARY_STREAM: + return MethodHandler(False, True, True) + elif handler_call_details.method == PARTIAL_STREAM_UNARY: + return MethodHandler(True, False, True) + elif handler_call_details.method == PARTIAL_STREAM_STREAM: + return MethodHandler(True, True, True) + else: + return None # Traditional executors will not exit until all their # current jobs complete. Because we submit jobs that will # never finish, we don't want to block exit on these jobs. class DaemonPool(object): + def submit(self, fn, *args, **kwargs): + thread = threading.Thread(target=fn, args=args, kwargs=kwargs) + thread.daemon = True + thread.start() - def submit(self, fn, *args, **kwargs): - thread = threading.Thread(target=fn, args=args, kwargs=kwargs) - thread.daemon = True - thread.start() - - def shutdown(self, wait=True): - pass + def shutdown(self, wait=True): + pass def infinite_request_iterator(): - while True: - yield REQUEST + while True: + yield REQUEST if __name__ == "__main__": - logging.basicConfig() - parser = argparse.ArgumentParser() - parser.add_argument("scenario", type=str) - parser.add_argument( - "--wait_for_interrupt", dest="wait_for_interrupt", action="store_true" - ) - args = parser.parse_args() + logging.basicConfig() + parser = argparse.ArgumentParser() + parser.add_argument("scenario", type=str) + parser.add_argument( + "--wait_for_interrupt", dest="wait_for_interrupt", action="store_true" + ) + args = parser.parse_args() - if args.scenario == UNSTARTED_SERVER: - server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) - if args.wait_for_interrupt: - time.sleep(WAIT_TIME) - elif args.scenario == RUNNING_SERVER: - server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) - port = server.add_insecure_port("[::]:0") - server.start() - if args.wait_for_interrupt: - time.sleep(WAIT_TIME) - elif args.scenario == POLL_CONNECTIVITY_NO_SERVER: - channel = grpc.insecure_channel("localhost:12345") + if args.scenario == UNSTARTED_SERVER: + server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == RUNNING_SERVER: + server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) + port = server.add_insecure_port("[::]:0") + server.start() + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == POLL_CONNECTIVITY_NO_SERVER: + channel = grpc.insecure_channel("localhost:12345") - def connectivity_callback(connectivity): - pass + def connectivity_callback(connectivity): + pass - channel.subscribe(connectivity_callback, try_to_connect=True) - if args.wait_for_interrupt: - time.sleep(WAIT_TIME) - elif args.scenario == POLL_CONNECTIVITY: - server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) - port = server.add_insecure_port("[::]:0") - server.start() - channel = grpc.insecure_channel("localhost:%d" % port) + channel.subscribe(connectivity_callback, try_to_connect=True) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) + elif args.scenario == POLL_CONNECTIVITY: + server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) + port = server.add_insecure_port("[::]:0") + server.start() + channel = grpc.insecure_channel("localhost:%d" % port) - def connectivity_callback(connectivity): - pass + def connectivity_callback(connectivity): + pass - channel.subscribe(connectivity_callback, try_to_connect=True) - if args.wait_for_interrupt: - time.sleep(WAIT_TIME) + channel.subscribe(connectivity_callback, try_to_connect=True) + if args.wait_for_interrupt: + time.sleep(WAIT_TIME) - else: - handler = GenericHandler() - server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) - port = server.add_insecure_port("[::]:0") - server.add_generic_rpc_handlers((handler,)) - server.start() - channel = grpc.insecure_channel("localhost:%d" % port) + else: + handler = GenericHandler() + server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) + port = server.add_insecure_port("[::]:0") + server.add_generic_rpc_handlers((handler,)) + server.start() + channel = grpc.insecure_channel("localhost:%d" % port) - method = TEST_TO_METHOD[args.scenario] + method = TEST_TO_METHOD[args.scenario] - if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: - multi_callable = channel.unary_unary( - method, - _registered_method=True, - ) - future = multi_callable.future(REQUEST) - result, call = multi_callable.with_call(REQUEST) - elif ( - args.scenario == IN_FLIGHT_UNARY_STREAM_CALL - or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL - ): - multi_callable = channel.unary_stream( - method, - _registered_method=True, - ) - response_iterator = multi_callable(REQUEST) - for response in response_iterator: - pass - elif ( - args.scenario == IN_FLIGHT_STREAM_UNARY_CALL - or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL - ): - multi_callable = channel.stream_unary( - method, - _registered_method=True, - ) - future = multi_callable.future(infinite_request_iterator()) - result, call = multi_callable.with_call( - iter([REQUEST] * test_constants.STREAM_LENGTH) - ) - elif ( - args.scenario == IN_FLIGHT_STREAM_STREAM_CALL - or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL - ): - multi_callable = channel.stream_stream( - method, - _registered_method=True, - ) - response_iterator = multi_callable(infinite_request_iterator()) - for response in response_iterator: - pass + if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: + multi_callable = channel.unary_unary( + method, + _registered_method=True, + ) + future = multi_callable.future(REQUEST) + result, call = multi_callable.with_call(REQUEST) + elif ( + args.scenario == IN_FLIGHT_UNARY_STREAM_CALL + or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL + ): + multi_callable = channel.unary_stream( + method, + _registered_method=True, + ) + response_iterator = multi_callable(REQUEST) + for response in response_iterator: + pass + elif ( + args.scenario == IN_FLIGHT_STREAM_UNARY_CALL + or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL + ): + multi_callable = channel.stream_unary( + method, + _registered_method=True, + ) + future = multi_callable.future(infinite_request_iterator()) + result, call = multi_callable.with_call( + iter([REQUEST] * test_constants.STREAM_LENGTH) + ) + elif ( + args.scenario == IN_FLIGHT_STREAM_STREAM_CALL + or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL + ): + multi_callable = channel.stream_stream( + method, + _registered_method=True, + ) + response_iterator = multi_callable(infinite_request_iterator()) + for response in response_iterator: + pass
diff --git a/test/cpp/end2end/rls_server.h b/test/cpp/end2end/rls_server.h index 576db11..a169c15 100644 --- a/test/cpp/end2end/rls_server.h +++ b/test/cpp/end2end/rls_server.h
@@ -18,6 +18,7 @@ #define GRPC_TEST_CPP_END2END_RLS_SERVER_H #include "absl/types/optional.h" + #include "src/core/lib/gprpp/time.h" #include "src/proto/grpc/lookup/v1/rls.grpc.pb.h" #include "src/proto/grpc/lookup/v1/rls.pb.h"
diff --git a/test/http2_test/http2_server_health_check.py b/test/http2_test/http2_server_health_check.py index 988212c..c91e4aa 100644 --- a/test/http2_test/http2_server_health_check.py +++ b/test/http2_test/http2_server_health_check.py
@@ -20,16 +20,16 @@ # Utility to healthcheck the http2 server. Used when starting the server to # verify that the server is live before tests begin. if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--server_host", type=str, default="localhost") - parser.add_argument("--server_port", type=int, default=8080) - args = parser.parse_args() - server_host = args.server_host - server_port = args.server_port - conn = hyper.HTTP20Connection("%s:%d" % (server_host, server_port)) - conn.request("POST", "/grpc.testing.TestService/UnaryCall") - resp = conn.get_response() - if resp.headers.get("grpc-encoding") is None: - sys.exit(1) - else: - sys.exit(0) + parser = argparse.ArgumentParser() + parser.add_argument("--server_host", type=str, default="localhost") + parser.add_argument("--server_port", type=int, default=8080) + args = parser.parse_args() + server_host = args.server_host + server_port = args.server_port + conn = hyper.HTTP20Connection("%s:%d" % (server_host, server_port)) + conn.request("POST", "/grpc.testing.TestService/UnaryCall") + resp = conn.get_response() + if resp.headers.get("grpc-encoding") is None: + sys.exit(1) + else: + sys.exit(0)
diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py index 6a90657..1c3cca0 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
@@ -48,475 +48,483 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp): - """Represents RPC services implemented in Client component of the xds test app. + """Represents RPC services implemented in Client component of the xds test app. - https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#client - """ - - # A unique string identifying each client replica. Used in logging. - hostname: str - - def __init__( - self, - *, - ip: str, - rpc_port: int, - server_target: str, - hostname: str, - rpc_host: Optional[str] = None, - maintenance_port: Optional[int] = None, - ): - super().__init__(rpc_host=(rpc_host or ip)) - self.ip = ip - self.rpc_port = rpc_port - self.server_target = server_target - self.maintenance_port = maintenance_port or rpc_port - self.hostname = hostname - - @property - @functools.lru_cache(None) - def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient: - return _LoadBalancerStatsServiceClient( - self._make_channel(self.rpc_port), - log_target=f"{self.hostname}:{self.rpc_port}", - ) - - @property - @functools.lru_cache(None) - def update_config(self): - return _XdsUpdateClientConfigureServiceClient( - self._make_channel(self.rpc_port), - log_target=f"{self.hostname}:{self.rpc_port}", - ) - - @property - @functools.lru_cache(None) - def channelz(self) -> _ChannelzServiceClient: - return _ChannelzServiceClient( - self._make_channel(self.maintenance_port), - log_target=f"{self.hostname}:{self.maintenance_port}", - ) - - @property - @functools.lru_cache(None) - def csds(self) -> _CsdsClient: - return _CsdsClient( - self._make_channel(self.maintenance_port), - log_target=f"{self.hostname}:{self.maintenance_port}", - ) - - def get_load_balancer_stats( - self, - *, - num_rpcs: int, - metadata_keys: Optional[tuple[str, ...]] = None, - timeout_sec: Optional[int] = None, - ) -> grpc_testing.LoadBalancerStatsResponse: - """Shortcut to LoadBalancerStatsServiceClient.get_client_stats()""" - return self.load_balancer_stats.get_client_stats( - num_rpcs=num_rpcs, - timeout_sec=timeout_sec, - metadata_keys=metadata_keys, - ) - - def get_load_balancer_accumulated_stats( - self, - *, - timeout_sec: Optional[int] = None, - ) -> grpc_testing.LoadBalancerAccumulatedStatsResponse: - """Shortcut to LoadBalancerStatsServiceClient.get_client_accumulated_stats()""" - return self.load_balancer_stats.get_client_accumulated_stats( - timeout_sec=timeout_sec - ) - - def wait_for_server_channel_ready( - self, - *, - timeout: Optional[_timedelta] = None, - rpc_deadline: Optional[_timedelta] = None, - ) -> _ChannelzChannel: - """Wait for the channel to the server to transition to READY. - - Raises: - GrpcApp.NotFound: If the channel never transitioned to READY. + https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#client """ - try: - return self.wait_for_server_channel_state( - _ChannelzChannelState.READY, - timeout=timeout, - rpc_deadline=rpc_deadline, - ) - except retryers.RetryError as retry_err: - if isinstance(retry_err.exception(), self.ChannelNotFound): - retry_err.add_note( - framework.errors.FrameworkError.note_blanket_error( - "The client couldn't connect to the server." + + # A unique string identifying each client replica. Used in logging. + hostname: str + + def __init__( + self, + *, + ip: str, + rpc_port: int, + server_target: str, + hostname: str, + rpc_host: Optional[str] = None, + maintenance_port: Optional[int] = None, + ): + super().__init__(rpc_host=(rpc_host or ip)) + self.ip = ip + self.rpc_port = rpc_port + self.server_target = server_target + self.maintenance_port = maintenance_port or rpc_port + self.hostname = hostname + + @property + @functools.lru_cache(None) + def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient: + return _LoadBalancerStatsServiceClient( + self._make_channel(self.rpc_port), + log_target=f"{self.hostname}:{self.rpc_port}", + ) + + @property + @functools.lru_cache(None) + def update_config(self): + return _XdsUpdateClientConfigureServiceClient( + self._make_channel(self.rpc_port), + log_target=f"{self.hostname}:{self.rpc_port}", + ) + + @property + @functools.lru_cache(None) + def channelz(self) -> _ChannelzServiceClient: + return _ChannelzServiceClient( + self._make_channel(self.maintenance_port), + log_target=f"{self.hostname}:{self.maintenance_port}", + ) + + @property + @functools.lru_cache(None) + def csds(self) -> _CsdsClient: + return _CsdsClient( + self._make_channel(self.maintenance_port), + log_target=f"{self.hostname}:{self.maintenance_port}", + ) + + def get_load_balancer_stats( + self, + *, + num_rpcs: int, + metadata_keys: Optional[tuple[str, ...]] = None, + timeout_sec: Optional[int] = None, + ) -> grpc_testing.LoadBalancerStatsResponse: + """Shortcut to LoadBalancerStatsServiceClient.get_client_stats()""" + return self.load_balancer_stats.get_client_stats( + num_rpcs=num_rpcs, + timeout_sec=timeout_sec, + metadata_keys=metadata_keys, + ) + + def get_load_balancer_accumulated_stats( + self, + *, + timeout_sec: Optional[int] = None, + ) -> grpc_testing.LoadBalancerAccumulatedStatsResponse: + """Shortcut to LoadBalancerStatsServiceClient.get_client_accumulated_stats()""" + return self.load_balancer_stats.get_client_accumulated_stats( + timeout_sec=timeout_sec + ) + + def wait_for_server_channel_ready( + self, + *, + timeout: Optional[_timedelta] = None, + rpc_deadline: Optional[_timedelta] = None, + ) -> _ChannelzChannel: + """Wait for the channel to the server to transition to READY. + + Raises: + GrpcApp.NotFound: If the channel never transitioned to READY. + """ + try: + return self.wait_for_server_channel_state( + _ChannelzChannelState.READY, + timeout=timeout, + rpc_deadline=rpc_deadline, ) - ) - raise + except retryers.RetryError as retry_err: + if isinstance(retry_err.exception(), self.ChannelNotFound): + retry_err.add_note( + framework.errors.FrameworkError.note_blanket_error( + "The client couldn't connect to the server." + ) + ) + raise - def wait_for_active_xds_channel( - self, - *, - xds_server_uri: Optional[str] = None, - timeout: Optional[_timedelta] = None, - rpc_deadline: Optional[_timedelta] = None, - ) -> _ChannelzChannel: - """Wait until the xds channel is active or timeout. + def wait_for_active_xds_channel( + self, + *, + xds_server_uri: Optional[str] = None, + timeout: Optional[_timedelta] = None, + rpc_deadline: Optional[_timedelta] = None, + ) -> _ChannelzChannel: + """Wait until the xds channel is active or timeout. - Raises: - GrpcApp.NotFound: If the channel to xds never transitioned to active. - """ - try: - return self.wait_for_xds_channel_active( - xds_server_uri=xds_server_uri, - timeout=timeout, - rpc_deadline=rpc_deadline, - ) - except retryers.RetryError as retry_err: - if isinstance(retry_err.exception(), self.ChannelNotFound): - retry_err.add_note( - framework.errors.FrameworkError.note_blanket_error( - "The client couldn't connect to the xDS control plane." + Raises: + GrpcApp.NotFound: If the channel to xds never transitioned to active. + """ + try: + return self.wait_for_xds_channel_active( + xds_server_uri=xds_server_uri, + timeout=timeout, + rpc_deadline=rpc_deadline, ) + except retryers.RetryError as retry_err: + if isinstance(retry_err.exception(), self.ChannelNotFound): + retry_err.add_note( + framework.errors.FrameworkError.note_blanket_error( + "The client couldn't connect to the xDS control plane." + ) + ) + raise + + def get_active_server_channel_socket(self) -> _ChannelzSocket: + channel = self.find_server_channel_with_state( + _ChannelzChannelState.READY ) - raise - - def get_active_server_channel_socket(self) -> _ChannelzSocket: - channel = self.find_server_channel_with_state(_ChannelzChannelState.READY) - # Get the first subchannel of the active channel to the server. - logger.debug( - ( - "[%s] Retrieving client -> server socket, " - "channel_id: %s, subchannel: %s" - ), - self.hostname, - channel.ref.channel_id, - channel.subchannel_ref[0].name, - ) - subchannel, *subchannels = list( - self.channelz.list_channel_subchannels(channel) - ) - if subchannels: - logger.warning( - "[%s] Unexpected subchannels: %r", self.hostname, subchannels - ) - # Get the first socket of the subchannel - socket, *sockets = list(self.channelz.list_subchannels_sockets(subchannel)) - if sockets: - logger.warning("[%s] Unexpected sockets: %r", self.hostname, subchannels) - logger.debug( - "[%s] Found client -> server socket: %s", - self.hostname, - socket.ref.name, - ) - return socket - - def wait_for_server_channel_state( - self, - state: _ChannelzChannelState, - *, - timeout: Optional[_timedelta] = None, - rpc_deadline: Optional[_timedelta] = None, - ) -> _ChannelzChannel: - # When polling for a state, prefer smaller wait times to avoid - # exhausting all allowed time on a single long RPC. - if rpc_deadline is None: - rpc_deadline = _timedelta(seconds=30) - - # Fine-tuned to wait for the channel to the server. - retryer = retryers.exponential_retryer_with_timeout( - wait_min=_timedelta(seconds=10), - wait_max=_timedelta(seconds=25), - timeout=_timedelta(minutes=5) if timeout is None else timeout, - ) - - logger.info( - "[%s] Waiting to report a %s channel to %s", - self.hostname, - _ChannelzChannelState.Name(state), - self.server_target, - ) - channel = retryer( - self.find_server_channel_with_state, - state, - rpc_deadline=rpc_deadline, - ) - logger.info( - "[%s] Channel to %s transitioned to state %s: %s", - self.hostname, - self.server_target, - _ChannelzChannelState.Name(state), - _ChannelzServiceClient.channel_repr(channel), - ) - return channel - - def wait_for_xds_channel_active( - self, - *, - xds_server_uri: Optional[str] = None, - timeout: Optional[_timedelta] = None, - rpc_deadline: Optional[_timedelta] = None, - ) -> _ChannelzChannel: - if not xds_server_uri: - xds_server_uri = DEFAULT_TD_XDS_URI - # When polling for a state, prefer smaller wait times to avoid - # exhausting all allowed time on a single long RPC. - if rpc_deadline is None: - rpc_deadline = _timedelta(seconds=30) - - retryer = retryers.exponential_retryer_with_timeout( - wait_min=_timedelta(seconds=10), - wait_max=_timedelta(seconds=25), - timeout=_timedelta(minutes=5) if timeout is None else timeout, - ) - - logger.info( - "[%s] ADS: Waiting for active calls to xDS control plane to %s", - self.hostname, - xds_server_uri, - ) - channel = retryer( - self.find_active_xds_channel, - xds_server_uri=xds_server_uri, - rpc_deadline=rpc_deadline, - ) - logger.info( - "[%s] ADS: Detected active calls to xDS control plane %s", - self.hostname, - xds_server_uri, - ) - return channel - - def find_active_xds_channel( - self, - xds_server_uri: str, - *, - rpc_deadline: Optional[_timedelta] = None, - ) -> _ChannelzChannel: - rpc_params = {} - if rpc_deadline is not None: - rpc_params["deadline_sec"] = rpc_deadline.total_seconds() - - for channel in self.find_channels(xds_server_uri, **rpc_params): - logger.info( - "[%s] xDS control plane channel: %s", - self.hostname, - _ChannelzServiceClient.channel_repr(channel), - ) - - try: - channel_upd = self.check_channel_in_flight_calls(channel, **rpc_params) - logger.info( - "[%s] Detected active calls to xDS control plane %s, channel: %s", + # Get the first subchannel of the active channel to the server. + logger.debug( + ( + "[%s] Retrieving client -> server socket, " + "channel_id: %s, subchannel: %s" + ), self.hostname, - xds_server_uri, - _ChannelzServiceClient.channel_repr(channel_upd), + channel.ref.channel_id, + channel.subchannel_ref[0].name, ) - return channel_upd - except self.NotFound: - # Continue checking other channels to the same target on - # not found. - continue - except framework.rpc.grpc.RpcError as err: - # Logged at 'info' and not at 'warning' because this method is - # expected to be called in a retryer. If this error eventually - # causes the retryer to fail, it will be logged fully at 'error' - logger.info( - "[%s] Unexpected error while checking xDS control plane" - " channel %s: %r", + subchannel, *subchannels = list( + self.channelz.list_channel_subchannels(channel) + ) + if subchannels: + logger.warning( + "[%s] Unexpected subchannels: %r", self.hostname, subchannels + ) + # Get the first socket of the subchannel + socket, *sockets = list( + self.channelz.list_subchannels_sockets(subchannel) + ) + if sockets: + logger.warning( + "[%s] Unexpected sockets: %r", self.hostname, subchannels + ) + logger.debug( + "[%s] Found client -> server socket: %s", self.hostname, + socket.ref.name, + ) + return socket + + def wait_for_server_channel_state( + self, + state: _ChannelzChannelState, + *, + timeout: Optional[_timedelta] = None, + rpc_deadline: Optional[_timedelta] = None, + ) -> _ChannelzChannel: + # When polling for a state, prefer smaller wait times to avoid + # exhausting all allowed time on a single long RPC. + if rpc_deadline is None: + rpc_deadline = _timedelta(seconds=30) + + # Fine-tuned to wait for the channel to the server. + retryer = retryers.exponential_retryer_with_timeout( + wait_min=_timedelta(seconds=10), + wait_max=_timedelta(seconds=25), + timeout=_timedelta(minutes=5) if timeout is None else timeout, + ) + + logger.info( + "[%s] Waiting to report a %s channel to %s", + self.hostname, + _ChannelzChannelState.Name(state), + self.server_target, + ) + channel = retryer( + self.find_server_channel_with_state, + state, + rpc_deadline=rpc_deadline, + ) + logger.info( + "[%s] Channel to %s transitioned to state %s: %s", + self.hostname, + self.server_target, + _ChannelzChannelState.Name(state), _ChannelzServiceClient.channel_repr(channel), - err, ) - raise - - raise self.ChannelNotActive( - f"[{self.hostname}] Client has no" - f" active channel with xDS control plane {xds_server_uri}", - src=self.hostname, - dst=xds_server_uri, - ) - - def find_server_channel_with_state( - self, - expected_state: _ChannelzChannelState, - *, - rpc_deadline: Optional[_timedelta] = None, - check_subchannel=True, - ) -> _ChannelzChannel: - rpc_params = {} - if rpc_deadline is not None: - rpc_params["deadline_sec"] = rpc_deadline.total_seconds() - - expected_state_name: str = _ChannelzChannelState.Name(expected_state) - target: str = self.server_target - - for channel in self.find_channels(target, **rpc_params): - channel_state: _ChannelzChannelState = channel.data.state.state - logger.info( - "[%s] Server channel: %s", - self.hostname, - _ChannelzServiceClient.channel_repr(channel), - ) - if channel_state is expected_state: - if check_subchannel: - # When requested, check if the channel has at least - # one subchannel in the requested state. - try: - subchannel = self.find_subchannel_with_state( - channel, expected_state, **rpc_params - ) - logger.info( - "[%s] Found subchannel in state %s: %s", - self.hostname, - expected_state_name, - _ChannelzServiceClient.subchannel_repr(subchannel), - ) - except self.NotFound as e: - # Otherwise, keep searching. - logger.info(e.message) - continue return channel - raise self.ChannelNotFound( - f"[{self.hostname}] Client has no" - f" {expected_state_name} channel with server {target}", - src=self.hostname, - dst=target, - expected_state=expected_state, - ) - - def find_channels( - self, - target: str, - **rpc_params, - ) -> Iterable[_ChannelzChannel]: - return self.channelz.find_channels_for_target(target, **rpc_params) - - def find_subchannel_with_state( - self, channel: _ChannelzChannel, state: _ChannelzChannelState, **kwargs - ) -> _ChannelzSubchannel: - subchannels = self.channelz.list_channel_subchannels(channel, **kwargs) - for subchannel in subchannels: - if subchannel.data.state.state is state: - return subchannel - - raise self.NotFound( - f"[{self.hostname}] Not found " - f"a {_ChannelzChannelState.Name(state)} subchannel " - f"for channel_id {channel.ref.channel_id}" - ) - - def find_subchannels_with_state( - self, state: _ChannelzChannelState, **kwargs - ) -> List[_ChannelzSubchannel]: - subchannels = [] - for channel in self.channelz.find_channels_for_target( - self.server_target, **kwargs - ): - for subchannel in self.channelz.list_channel_subchannels( - channel, **kwargs - ): - if subchannel.data.state.state is state: - subchannels.append(subchannel) - return subchannels - - def check_channel_in_flight_calls( - self, - channel: _ChannelzChannel, - *, - wait_between_checks: Optional[_timedelta] = None, - **rpc_params, - ) -> Optional[_ChannelzChannel]: - """Checks if the channel has calls that started, but didn't complete. - - We consider the channel is active if channel is in READY state and - calls_started is greater than calls_failed. - - This method address race where a call to the xDS control plane server - has just started and a channelz request comes in before the call has - had a chance to fail. - - With channels to the xDS control plane, the channel can be READY but the - calls could be failing to initialize, f.e. due to a failure to fetch - OAUTH2 token. To increase the confidence that we have a valid channel - with working OAUTH2 tokens, we check whether the channel is in a READY - state with active calls twice with an interval of 2 seconds between the - two attempts. If the OAUTH2 token is not valid, the call would fail and - be caught in either the first attempt, or the second attempt. It is - possible that between the two attempts, a call fails and a new call is - started, so we also test for equality between the started calls of the - two channelz results. - - There still exists a possibility that a call fails on fetching OAUTH2 - token after 2 seconds (maybe because there is a slowdown in the - system.) If such a case is observed, consider increasing the interval - from 2 seconds to 5 seconds. - - Returns updated channel on success, or None on failure. - """ - if not self.calc_calls_in_flight(channel): - return None - - if not wait_between_checks: - wait_between_checks = _timedelta(seconds=2) - - # Load the channel second time after the timeout. - time.sleep(wait_between_checks.total_seconds()) - channel_upd: _ChannelzChannel = self.channelz.get_channel( - channel.ref.channel_id, **rpc_params - ) - if ( - not self.calc_calls_in_flight(channel_upd) - or channel.data.calls_started != channel_upd.data.calls_started - ): - return None - return channel_upd - - @classmethod - def calc_calls_in_flight(cls, channel: _ChannelzChannel) -> int: - cdata: _ChannelzChannelData = channel.data - if cdata.state.state is not _ChannelzChannelState.READY: - return 0 - - return cdata.calls_started - cdata.calls_succeeded - cdata.calls_failed - - class ChannelNotFound(framework.rpc.grpc.GrpcApp.NotFound): - """Channel with expected status not found""" - - src: str - dst: str - expected_state: object - - def __init__( + def wait_for_xds_channel_active( self, - message: str, *, - src: str, - dst: str, + xds_server_uri: Optional[str] = None, + timeout: Optional[_timedelta] = None, + rpc_deadline: Optional[_timedelta] = None, + ) -> _ChannelzChannel: + if not xds_server_uri: + xds_server_uri = DEFAULT_TD_XDS_URI + # When polling for a state, prefer smaller wait times to avoid + # exhausting all allowed time on a single long RPC. + if rpc_deadline is None: + rpc_deadline = _timedelta(seconds=30) + + retryer = retryers.exponential_retryer_with_timeout( + wait_min=_timedelta(seconds=10), + wait_max=_timedelta(seconds=25), + timeout=_timedelta(minutes=5) if timeout is None else timeout, + ) + + logger.info( + "[%s] ADS: Waiting for active calls to xDS control plane to %s", + self.hostname, + xds_server_uri, + ) + channel = retryer( + self.find_active_xds_channel, + xds_server_uri=xds_server_uri, + rpc_deadline=rpc_deadline, + ) + logger.info( + "[%s] ADS: Detected active calls to xDS control plane %s", + self.hostname, + xds_server_uri, + ) + return channel + + def find_active_xds_channel( + self, + xds_server_uri: str, + *, + rpc_deadline: Optional[_timedelta] = None, + ) -> _ChannelzChannel: + rpc_params = {} + if rpc_deadline is not None: + rpc_params["deadline_sec"] = rpc_deadline.total_seconds() + + for channel in self.find_channels(xds_server_uri, **rpc_params): + logger.info( + "[%s] xDS control plane channel: %s", + self.hostname, + _ChannelzServiceClient.channel_repr(channel), + ) + + try: + channel_upd = self.check_channel_in_flight_calls( + channel, **rpc_params + ) + logger.info( + "[%s] Detected active calls to xDS control plane %s, channel: %s", + self.hostname, + xds_server_uri, + _ChannelzServiceClient.channel_repr(channel_upd), + ) + return channel_upd + except self.NotFound: + # Continue checking other channels to the same target on + # not found. + continue + except framework.rpc.grpc.RpcError as err: + # Logged at 'info' and not at 'warning' because this method is + # expected to be called in a retryer. If this error eventually + # causes the retryer to fail, it will be logged fully at 'error' + logger.info( + "[%s] Unexpected error while checking xDS control plane" + " channel %s: %r", + self.hostname, + _ChannelzServiceClient.channel_repr(channel), + err, + ) + raise + + raise self.ChannelNotActive( + f"[{self.hostname}] Client has no" + f" active channel with xDS control plane {xds_server_uri}", + src=self.hostname, + dst=xds_server_uri, + ) + + def find_server_channel_with_state( + self, expected_state: _ChannelzChannelState, - **kwargs, - ): - self.src = src - self.dst = dst - self.expected_state = expected_state - super().__init__(message, src, dst, expected_state, **kwargs) - - class ChannelNotActive(framework.rpc.grpc.GrpcApp.NotFound): - """No active channel was found""" - - src: str - dst: str - - def __init__( - self, - message: str, *, - src: str, - dst: str, - **kwargs, - ): - self.src = src - self.dst = dst - super().__init__(message, src, dst, **kwargs) + rpc_deadline: Optional[_timedelta] = None, + check_subchannel=True, + ) -> _ChannelzChannel: + rpc_params = {} + if rpc_deadline is not None: + rpc_params["deadline_sec"] = rpc_deadline.total_seconds() + + expected_state_name: str = _ChannelzChannelState.Name(expected_state) + target: str = self.server_target + + for channel in self.find_channels(target, **rpc_params): + channel_state: _ChannelzChannelState = channel.data.state.state + logger.info( + "[%s] Server channel: %s", + self.hostname, + _ChannelzServiceClient.channel_repr(channel), + ) + if channel_state is expected_state: + if check_subchannel: + # When requested, check if the channel has at least + # one subchannel in the requested state. + try: + subchannel = self.find_subchannel_with_state( + channel, expected_state, **rpc_params + ) + logger.info( + "[%s] Found subchannel in state %s: %s", + self.hostname, + expected_state_name, + _ChannelzServiceClient.subchannel_repr(subchannel), + ) + except self.NotFound as e: + # Otherwise, keep searching. + logger.info(e.message) + continue + return channel + + raise self.ChannelNotFound( + f"[{self.hostname}] Client has no" + f" {expected_state_name} channel with server {target}", + src=self.hostname, + dst=target, + expected_state=expected_state, + ) + + def find_channels( + self, + target: str, + **rpc_params, + ) -> Iterable[_ChannelzChannel]: + return self.channelz.find_channels_for_target(target, **rpc_params) + + def find_subchannel_with_state( + self, channel: _ChannelzChannel, state: _ChannelzChannelState, **kwargs + ) -> _ChannelzSubchannel: + subchannels = self.channelz.list_channel_subchannels(channel, **kwargs) + for subchannel in subchannels: + if subchannel.data.state.state is state: + return subchannel + + raise self.NotFound( + f"[{self.hostname}] Not found " + f"a {_ChannelzChannelState.Name(state)} subchannel " + f"for channel_id {channel.ref.channel_id}" + ) + + def find_subchannels_with_state( + self, state: _ChannelzChannelState, **kwargs + ) -> List[_ChannelzSubchannel]: + subchannels = [] + for channel in self.channelz.find_channels_for_target( + self.server_target, **kwargs + ): + for subchannel in self.channelz.list_channel_subchannels( + channel, **kwargs + ): + if subchannel.data.state.state is state: + subchannels.append(subchannel) + return subchannels + + def check_channel_in_flight_calls( + self, + channel: _ChannelzChannel, + *, + wait_between_checks: Optional[_timedelta] = None, + **rpc_params, + ) -> Optional[_ChannelzChannel]: + """Checks if the channel has calls that started, but didn't complete. + + We consider the channel is active if channel is in READY state and + calls_started is greater than calls_failed. + + This method address race where a call to the xDS control plane server + has just started and a channelz request comes in before the call has + had a chance to fail. + + With channels to the xDS control plane, the channel can be READY but the + calls could be failing to initialize, f.e. due to a failure to fetch + OAUTH2 token. To increase the confidence that we have a valid channel + with working OAUTH2 tokens, we check whether the channel is in a READY + state with active calls twice with an interval of 2 seconds between the + two attempts. If the OAUTH2 token is not valid, the call would fail and + be caught in either the first attempt, or the second attempt. It is + possible that between the two attempts, a call fails and a new call is + started, so we also test for equality between the started calls of the + two channelz results. + + There still exists a possibility that a call fails on fetching OAUTH2 + token after 2 seconds (maybe because there is a slowdown in the + system.) If such a case is observed, consider increasing the interval + from 2 seconds to 5 seconds. + + Returns updated channel on success, or None on failure. + """ + if not self.calc_calls_in_flight(channel): + return None + + if not wait_between_checks: + wait_between_checks = _timedelta(seconds=2) + + # Load the channel second time after the timeout. + time.sleep(wait_between_checks.total_seconds()) + channel_upd: _ChannelzChannel = self.channelz.get_channel( + channel.ref.channel_id, **rpc_params + ) + if ( + not self.calc_calls_in_flight(channel_upd) + or channel.data.calls_started != channel_upd.data.calls_started + ): + return None + return channel_upd + + @classmethod + def calc_calls_in_flight(cls, channel: _ChannelzChannel) -> int: + cdata: _ChannelzChannelData = channel.data + if cdata.state.state is not _ChannelzChannelState.READY: + return 0 + + return cdata.calls_started - cdata.calls_succeeded - cdata.calls_failed + + class ChannelNotFound(framework.rpc.grpc.GrpcApp.NotFound): + """Channel with expected status not found""" + + src: str + dst: str + expected_state: object + + def __init__( + self, + message: str, + *, + src: str, + dst: str, + expected_state: _ChannelzChannelState, + **kwargs, + ): + self.src = src + self.dst = dst + self.expected_state = expected_state + super().__init__(message, src, dst, expected_state, **kwargs) + + class ChannelNotActive(framework.rpc.grpc.GrpcApp.NotFound): + """No active channel was found""" + + src: str + dst: str + + def __init__( + self, + message: str, + *, + src: str, + dst: str, + **kwargs, + ): + self.src = src + self.dst = dst + super().__init__(message, src, dst, **kwargs)