Automated change: Fix sanity tests
diff --git a/src/core/ext/transport/chaotic_good/server_transport.h b/src/core/ext/transport/chaotic_good/server_transport.h
index b5d0cf9..9ce9292 100644
--- a/src/core/ext/transport/chaotic_good/server_transport.h
+++ b/src/core/ext/transport/chaotic_good/server_transport.h
@@ -15,11 +15,8 @@
#ifndef GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H
#define GRPC_SRC_CORE_EXT_TRANSPORT_CHAOTIC_GOOD_SERVER_TRANSPORT_H
-#include <grpc/event_engine/event_engine.h>
-#include <grpc/event_engine/memory_allocator.h>
-#include <grpc/slice.h>
-#include <grpc/support/log.h>
#include <grpc/support/port_platform.h>
+
#include <stdint.h>
#include <stdio.h>
@@ -42,6 +39,12 @@
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
+
+#include <grpc/event_engine/event_engine.h>
+#include <grpc/event_engine/memory_allocator.h>
+#include <grpc/slice.h>
+#include <grpc/support/log.h>
+
#include "src/core/ext/transport/chaotic_good/chaotic_good_transport.h"
#include "src/core/ext/transport/chaotic_good/frame.h"
#include "src/core/ext/transport/chaotic_good/frame_header.h"
diff --git a/src/core/lib/event_engine/trace.cc b/src/core/lib/event_engine/trace.cc
index 33b2fa7..20ae3ae 100644
--- a/src/core/lib/event_engine/trace.cc
+++ b/src/core/lib/event_engine/trace.cc
@@ -11,10 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "src/core/lib/debug/trace.h"
-
#include <grpc/support/port_platform.h>
+#include "src/core/lib/debug/trace.h"
+
grpc_core::TraceFlag grpc_event_engine_trace(false, "event_engine");
grpc_core::TraceFlag grpc_event_engine_dns_trace(false, "event_engine_dns");
grpc_core::TraceFlag grpc_event_engine_endpoint_trace(false,
diff --git a/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc b/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc
index ddf1a12..75498eb 100644
--- a/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc
+++ b/src/core/lib/security/authorization/authorization_policy_provider_vtable.cc
@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <grpc/support/port_platform.h>
+
#include <grpc/grpc.h>
#include <grpc/grpc_security.h>
-#include <grpc/support/port_platform.h>
#include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/ref_counted_ptr.h"
diff --git a/src/core/lib/security/certificate_provider/certificate_provider_registry.h b/src/core/lib/security/certificate_provider/certificate_provider_registry.h
index 6368845..5cc05ae 100644
--- a/src/core/lib/security/certificate_provider/certificate_provider_registry.h
+++ b/src/core/lib/security/certificate_provider/certificate_provider_registry.h
@@ -26,6 +26,7 @@
#include <utility>
#include "absl/strings/string_view.h"
+
#include "src/core/lib/security/certificate_provider/certificate_provider_factory.h"
namespace grpc_core {
diff --git a/src/core/lib/security/credentials/insecure/insecure_credentials.cc b/src/core/lib/security/credentials/insecure/insecure_credentials.cc
index e996072..cce66f6 100644
--- a/src/core/lib/security/credentials/insecure/insecure_credentials.cc
+++ b/src/core/lib/security/credentials/insecure/insecure_credentials.cc
@@ -16,10 +16,10 @@
//
//
-#include "src/core/lib/security/credentials/insecure/insecure_credentials.h"
-
#include <grpc/support/port_platform.h>
+#include "src/core/lib/security/credentials/insecure/insecure_credentials.h"
+
#include <utility>
#include "src/core/lib/channel/channel_args.h"
diff --git a/src/core/lib/transport/promise_endpoint.h b/src/core/lib/transport/promise_endpoint.h
index 36c3713..fbdc467 100644
--- a/src/core/lib/transport/promise_endpoint.h
+++ b/src/core/lib/transport/promise_endpoint.h
@@ -15,12 +15,8 @@
#ifndef GRPC_SRC_CORE_LIB_TRANSPORT_PROMISE_ENDPOINT_H
#define GRPC_SRC_CORE_LIB_TRANSPORT_PROMISE_ENDPOINT_H
-#include <grpc/event_engine/event_engine.h>
-#include <grpc/event_engine/slice.h>
-#include <grpc/event_engine/slice_buffer.h>
-#include <grpc/slice_buffer.h>
-#include <grpc/support/log.h>
#include <grpc/support/port_platform.h>
+
#include <stddef.h>
#include <stdint.h>
@@ -33,6 +29,13 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
+
+#include <grpc/event_engine/event_engine.h>
+#include <grpc/event_engine/slice.h>
+#include <grpc/event_engine/slice_buffer.h>
+#include <grpc/slice_buffer.h>
+#include <grpc/support/log.h>
+
#include "src/core/lib/gprpp/sync.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/if.h"
diff --git a/src/core/tsi/transport_security_grpc.cc b/src/core/tsi/transport_security_grpc.cc
index 4d120d9..3f04ce9 100644
--- a/src/core/tsi/transport_security_grpc.cc
+++ b/src/core/tsi/transport_security_grpc.cc
@@ -16,10 +16,10 @@
//
//
-#include "src/core/tsi/transport_security_grpc.h"
-
#include <grpc/support/port_platform.h>
+#include "src/core/tsi/transport_security_grpc.h"
+
// This method creates a tsi_zero_copy_grpc_protector object.
tsi_result tsi_handshaker_result_create_zero_copy_grpc_protector(
const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py
index 1a44fc7..9310f8a 100644
--- a/src/python/grpcio_tests/tests/qps/benchmark_client.py
+++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py
@@ -20,6 +20,7 @@
import time
import grpc
+
from src.proto.grpc.testing import benchmark_service_pb2_grpc
from src.proto.grpc.testing import messages_pb2
from tests.unit import resources
@@ -29,229 +30,229 @@
class GenericStub(object):
-
- def __init__(self, channel):
- self.UnaryCall = channel.unary_unary(
- "/grpc.testing.BenchmarkService/UnaryCall",
- _registered_method=True,
- )
- self.StreamingFromServer = channel.unary_stream(
- "/grpc.testing.BenchmarkService/StreamingFromServer",
- _registered_method=True,
- )
- self.StreamingCall = channel.stream_stream(
- "/grpc.testing.BenchmarkService/StreamingCall",
- _registered_method=True,
- )
+ def __init__(self, channel):
+ self.UnaryCall = channel.unary_unary(
+ "/grpc.testing.BenchmarkService/UnaryCall",
+ _registered_method=True,
+ )
+ self.StreamingFromServer = channel.unary_stream(
+ "/grpc.testing.BenchmarkService/StreamingFromServer",
+ _registered_method=True,
+ )
+ self.StreamingCall = channel.stream_stream(
+ "/grpc.testing.BenchmarkService/StreamingCall",
+ _registered_method=True,
+ )
class BenchmarkClient:
- """Benchmark client interface that exposes a non-blocking send_request()."""
+ """Benchmark client interface that exposes a non-blocking send_request()."""
- __metaclass__ = abc.ABCMeta
+ __metaclass__ = abc.ABCMeta
- def __init__(self, server, config, hist):
- # Create the stub
- if config.HasField("security_params"):
- creds = grpc.ssl_channel_credentials(resources.test_root_certificates())
- channel = test_common.test_secure_channel(
- server, creds, config.security_params.server_host_override
- )
- else:
- channel = grpc.insecure_channel(server)
+ def __init__(self, server, config, hist):
+ # Create the stub
+ if config.HasField("security_params"):
+ creds = grpc.ssl_channel_credentials(
+ resources.test_root_certificates()
+ )
+ channel = test_common.test_secure_channel(
+ server, creds, config.security_params.server_host_override
+ )
+ else:
+ channel = grpc.insecure_channel(server)
- # waits for the channel to be ready before we start sending messages
- grpc.channel_ready_future(channel).result()
+ # waits for the channel to be ready before we start sending messages
+ grpc.channel_ready_future(channel).result()
- if config.payload_config.WhichOneof("payload") == "simple_params":
- self._generic = False
- self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(channel)
- payload = messages_pb2.Payload(
- body=bytes(b"\0" * config.payload_config.simple_params.req_size)
- )
- self._request = messages_pb2.SimpleRequest(
- payload=payload,
- response_size=config.payload_config.simple_params.resp_size,
- )
- else:
- self._generic = True
- self._stub = GenericStub(channel)
- self._request = bytes(
- b"\0" * config.payload_config.bytebuf_params.req_size
- )
+ if config.payload_config.WhichOneof("payload") == "simple_params":
+ self._generic = False
+ self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub(
+ channel
+ )
+ payload = messages_pb2.Payload(
+ body=bytes(b"\0" * config.payload_config.simple_params.req_size)
+ )
+ self._request = messages_pb2.SimpleRequest(
+ payload=payload,
+ response_size=config.payload_config.simple_params.resp_size,
+ )
+ else:
+ self._generic = True
+ self._stub = GenericStub(channel)
+ self._request = bytes(
+ b"\0" * config.payload_config.bytebuf_params.req_size
+ )
- self._hist = hist
- self._response_callbacks = []
+ self._hist = hist
+ self._response_callbacks = []
- def add_response_callback(self, callback):
- """callback will be invoked as callback(client, query_time)"""
- self._response_callbacks.append(callback)
+ def add_response_callback(self, callback):
+ """callback will be invoked as callback(client, query_time)"""
+ self._response_callbacks.append(callback)
- @abc.abstractmethod
- def send_request(self):
- """Non-blocking wrapper for a client's request operation."""
- raise NotImplementedError()
+ @abc.abstractmethod
+ def send_request(self):
+ """Non-blocking wrapper for a client's request operation."""
+ raise NotImplementedError()
- def start(self):
- pass
+ def start(self):
+ pass
- def stop(self):
- pass
+ def stop(self):
+ pass
- def _handle_response(self, client, query_time):
- self._hist.add(query_time * 1e9) # Report times in nanoseconds
- for callback in self._response_callbacks:
- callback(client, query_time)
+ def _handle_response(self, client, query_time):
+ self._hist.add(query_time * 1e9) # Report times in nanoseconds
+ for callback in self._response_callbacks:
+ callback(client, query_time)
class UnarySyncBenchmarkClient(BenchmarkClient):
+ def __init__(self, server, config, hist):
+ super(UnarySyncBenchmarkClient, self).__init__(server, config, hist)
+ self._pool = futures.ThreadPoolExecutor(
+ max_workers=config.outstanding_rpcs_per_channel
+ )
- def __init__(self, server, config, hist):
- super(UnarySyncBenchmarkClient, self).__init__(server, config, hist)
- self._pool = futures.ThreadPoolExecutor(
- max_workers=config.outstanding_rpcs_per_channel
- )
+ def send_request(self):
+ # Send requests in separate threads to support multiple outstanding rpcs
+ # (See src/proto/grpc/testing/control.proto)
+ self._pool.submit(self._dispatch_request)
- def send_request(self):
- # Send requests in separate threads to support multiple outstanding rpcs
- # (See src/proto/grpc/testing/control.proto)
- self._pool.submit(self._dispatch_request)
+ def stop(self):
+ self._pool.shutdown(wait=True)
+ self._stub = None
- def stop(self):
- self._pool.shutdown(wait=True)
- self._stub = None
-
- def _dispatch_request(self):
- start_time = time.time()
- self._stub.UnaryCall(self._request, _TIMEOUT)
- end_time = time.time()
- self._handle_response(self, end_time - start_time)
+ def _dispatch_request(self):
+ start_time = time.time()
+ self._stub.UnaryCall(self._request, _TIMEOUT)
+ end_time = time.time()
+ self._handle_response(self, end_time - start_time)
class UnaryAsyncBenchmarkClient(BenchmarkClient):
+ def send_request(self):
+ # Use the Future callback api to support multiple outstanding rpcs
+ start_time = time.time()
+ response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT)
+ response_future.add_done_callback(
+ lambda resp: self._response_received(start_time, resp)
+ )
- def send_request(self):
- # Use the Future callback api to support multiple outstanding rpcs
- start_time = time.time()
- response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT)
- response_future.add_done_callback(
- lambda resp: self._response_received(start_time, resp)
- )
+ def _response_received(self, start_time, resp):
+ resp.result()
+ end_time = time.time()
+ self._handle_response(self, end_time - start_time)
- def _response_received(self, start_time, resp):
- resp.result()
- end_time = time.time()
- self._handle_response(self, end_time - start_time)
-
- def stop(self):
- self._stub = None
+ def stop(self):
+ self._stub = None
class _SyncStream(object):
+ def __init__(self, stub, generic, request, handle_response):
+ self._stub = stub
+ self._generic = generic
+ self._request = request
+ self._handle_response = handle_response
+ self._is_streaming = False
+ self._request_queue = queue.Queue()
+ self._send_time_queue = queue.Queue()
- def __init__(self, stub, generic, request, handle_response):
- self._stub = stub
- self._generic = generic
- self._request = request
- self._handle_response = handle_response
- self._is_streaming = False
- self._request_queue = queue.Queue()
- self._send_time_queue = queue.Queue()
+ def send_request(self):
+ self._send_time_queue.put(time.time())
+ self._request_queue.put(self._request)
- def send_request(self):
- self._send_time_queue.put(time.time())
- self._request_queue.put(self._request)
+ def start(self):
+ self._is_streaming = True
+ response_stream = self._stub.StreamingCall(
+ self._request_generator(), _TIMEOUT
+ )
+ for _ in response_stream:
+ self._handle_response(
+ self, time.time() - self._send_time_queue.get_nowait()
+ )
- def start(self):
- self._is_streaming = True
- response_stream = self._stub.StreamingCall(
- self._request_generator(), _TIMEOUT
- )
- for _ in response_stream:
- self._handle_response(
- self, time.time() - self._send_time_queue.get_nowait()
- )
+ def stop(self):
+ self._is_streaming = False
- def stop(self):
- self._is_streaming = False
-
- def _request_generator(self):
- while self._is_streaming:
- try:
- request = self._request_queue.get(block=True, timeout=1.0)
- yield request
- except queue.Empty:
- pass
+ def _request_generator(self):
+ while self._is_streaming:
+ try:
+ request = self._request_queue.get(block=True, timeout=1.0)
+ yield request
+ except queue.Empty:
+ pass
class StreamingSyncBenchmarkClient(BenchmarkClient):
-
- def __init__(self, server, config, hist):
- super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
- self._pool = futures.ThreadPoolExecutor(
- max_workers=config.outstanding_rpcs_per_channel
- )
- self._streams = [
- _SyncStream(
- self._stub, self._generic, self._request, self._handle_response
+ def __init__(self, server, config, hist):
+ super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
+ self._pool = futures.ThreadPoolExecutor(
+ max_workers=config.outstanding_rpcs_per_channel
)
- for _ in range(config.outstanding_rpcs_per_channel)
- ]
- self._curr_stream = 0
+ self._streams = [
+ _SyncStream(
+ self._stub, self._generic, self._request, self._handle_response
+ )
+ for _ in range(config.outstanding_rpcs_per_channel)
+ ]
+ self._curr_stream = 0
- def send_request(self):
- # Use a round_robin scheduler to determine what stream to send on
- self._streams[self._curr_stream].send_request()
- self._curr_stream = (self._curr_stream + 1) % len(self._streams)
+ def send_request(self):
+ # Use a round_robin scheduler to determine what stream to send on
+ self._streams[self._curr_stream].send_request()
+ self._curr_stream = (self._curr_stream + 1) % len(self._streams)
- def start(self):
- for stream in self._streams:
- self._pool.submit(stream.start)
+ def start(self):
+ for stream in self._streams:
+ self._pool.submit(stream.start)
- def stop(self):
- for stream in self._streams:
- stream.stop()
- self._pool.shutdown(wait=True)
- self._stub = None
+ def stop(self):
+ for stream in self._streams:
+ stream.stop()
+ self._pool.shutdown(wait=True)
+ self._stub = None
class ServerStreamingSyncBenchmarkClient(BenchmarkClient):
+ def __init__(self, server, config, hist):
+ super(ServerStreamingSyncBenchmarkClient, self).__init__(
+ server, config, hist
+ )
+ if config.outstanding_rpcs_per_channel == 1:
+ self._pool = None
+ else:
+ self._pool = futures.ThreadPoolExecutor(
+ max_workers=config.outstanding_rpcs_per_channel
+ )
+ self._rpcs = []
+ self._sender = None
- def __init__(self, server, config, hist):
- super(ServerStreamingSyncBenchmarkClient, self).__init__(
- server, config, hist
- )
- if config.outstanding_rpcs_per_channel == 1:
- self._pool = None
- else:
- self._pool = futures.ThreadPoolExecutor(
- max_workers=config.outstanding_rpcs_per_channel
- )
- self._rpcs = []
- self._sender = None
+ def send_request(self):
+ if self._pool is None:
+ self._sender = threading.Thread(
+ target=self._one_stream_streaming_rpc, daemon=True
+ )
+ self._sender.start()
+ else:
+ self._pool.submit(self._one_stream_streaming_rpc)
- def send_request(self):
- if self._pool is None:
- self._sender = threading.Thread(
- target=self._one_stream_streaming_rpc, daemon=True
- )
- self._sender.start()
- else:
- self._pool.submit(self._one_stream_streaming_rpc)
+ def _one_stream_streaming_rpc(self):
+ response_stream = self._stub.StreamingFromServer(
+ self._request, _TIMEOUT
+ )
+ self._rpcs.append(response_stream)
+ start_time = time.time()
+ for _ in response_stream:
+ self._handle_response(self, time.time() - start_time)
+ start_time = time.time()
- def _one_stream_streaming_rpc(self):
- response_stream = self._stub.StreamingFromServer(self._request, _TIMEOUT)
- self._rpcs.append(response_stream)
- start_time = time.time()
- for _ in response_stream:
- self._handle_response(self, time.time() - start_time)
- start_time = time.time()
-
- def stop(self):
- for call in self._rpcs:
- call.cancel()
- if self._sender is not None:
- self._sender.join()
- if self._pool is not None:
- self._pool.shutdown(wait=False)
- self._stub = None
+ def stop(self):
+ for call in self._rpcs:
+ call.cancel()
+ if self._sender is not None:
+ self._sender.join()
+ if self._pool is not None:
+ self._pool.shutdown(wait=False)
+ self._stub = None
diff --git a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py
index 8882155..1b7e2e5 100644
--- a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py
+++ b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py
@@ -19,6 +19,7 @@
import time
import grpc
+
from tests.unit.framework.common import test_constants
WAIT_TIME = 1000
@@ -57,198 +58,195 @@
def hang_unary_unary(request, servicer_context):
- time.sleep(WAIT_TIME)
+ time.sleep(WAIT_TIME)
def hang_unary_stream(request, servicer_context):
- time.sleep(WAIT_TIME)
+ time.sleep(WAIT_TIME)
def hang_partial_unary_stream(request, servicer_context):
- for _ in range(test_constants.STREAM_LENGTH // 2):
- yield request
- time.sleep(WAIT_TIME)
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ yield request
+ time.sleep(WAIT_TIME)
def hang_stream_unary(request_iterator, servicer_context):
- time.sleep(WAIT_TIME)
+ time.sleep(WAIT_TIME)
def hang_partial_stream_unary(request_iterator, servicer_context):
- for _ in range(test_constants.STREAM_LENGTH // 2):
- next(request_iterator)
- time.sleep(WAIT_TIME)
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ next(request_iterator)
+ time.sleep(WAIT_TIME)
def hang_stream_stream(request_iterator, servicer_context):
- time.sleep(WAIT_TIME)
+ time.sleep(WAIT_TIME)
def hang_partial_stream_stream(request_iterator, servicer_context):
- for _ in range(test_constants.STREAM_LENGTH // 2):
- yield next(request_iterator) # pylint: disable=stop-iteration-return
- time.sleep(WAIT_TIME)
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ yield next(request_iterator) # pylint: disable=stop-iteration-return
+ time.sleep(WAIT_TIME)
class MethodHandler(grpc.RpcMethodHandler):
-
- def __init__(self, request_streaming, response_streaming, partial_hang):
- self.request_streaming = request_streaming
- self.response_streaming = response_streaming
- self.request_deserializer = None
- self.response_serializer = None
- self.unary_unary = None
- self.unary_stream = None
- self.stream_unary = None
- self.stream_stream = None
- if self.request_streaming and self.response_streaming:
- if partial_hang:
- self.stream_stream = hang_partial_stream_stream
- else:
- self.stream_stream = hang_stream_stream
- elif self.request_streaming:
- if partial_hang:
- self.stream_unary = hang_partial_stream_unary
- else:
- self.stream_unary = hang_stream_unary
- elif self.response_streaming:
- if partial_hang:
- self.unary_stream = hang_partial_unary_stream
- else:
- self.unary_stream = hang_unary_stream
- else:
- self.unary_unary = hang_unary_unary
+ def __init__(self, request_streaming, response_streaming, partial_hang):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ if partial_hang:
+ self.stream_stream = hang_partial_stream_stream
+ else:
+ self.stream_stream = hang_stream_stream
+ elif self.request_streaming:
+ if partial_hang:
+ self.stream_unary = hang_partial_stream_unary
+ else:
+ self.stream_unary = hang_stream_unary
+ elif self.response_streaming:
+ if partial_hang:
+ self.unary_stream = hang_partial_unary_stream
+ else:
+ self.unary_stream = hang_unary_stream
+ else:
+ self.unary_unary = hang_unary_unary
class GenericHandler(grpc.GenericRpcHandler):
-
- def service(self, handler_call_details):
- if handler_call_details.method == UNARY_UNARY:
- return MethodHandler(False, False, False)
- elif handler_call_details.method == UNARY_STREAM:
- return MethodHandler(False, True, False)
- elif handler_call_details.method == STREAM_UNARY:
- return MethodHandler(True, False, False)
- elif handler_call_details.method == STREAM_STREAM:
- return MethodHandler(True, True, False)
- elif handler_call_details.method == PARTIAL_UNARY_STREAM:
- return MethodHandler(False, True, True)
- elif handler_call_details.method == PARTIAL_STREAM_UNARY:
- return MethodHandler(True, False, True)
- elif handler_call_details.method == PARTIAL_STREAM_STREAM:
- return MethodHandler(True, True, True)
- else:
- return None
+ def service(self, handler_call_details):
+ if handler_call_details.method == UNARY_UNARY:
+ return MethodHandler(False, False, False)
+ elif handler_call_details.method == UNARY_STREAM:
+ return MethodHandler(False, True, False)
+ elif handler_call_details.method == STREAM_UNARY:
+ return MethodHandler(True, False, False)
+ elif handler_call_details.method == STREAM_STREAM:
+ return MethodHandler(True, True, False)
+ elif handler_call_details.method == PARTIAL_UNARY_STREAM:
+ return MethodHandler(False, True, True)
+ elif handler_call_details.method == PARTIAL_STREAM_UNARY:
+ return MethodHandler(True, False, True)
+ elif handler_call_details.method == PARTIAL_STREAM_STREAM:
+ return MethodHandler(True, True, True)
+ else:
+ return None
# Traditional executors will not exit until all their
# current jobs complete. Because we submit jobs that will
# never finish, we don't want to block exit on these jobs.
class DaemonPool(object):
+ def submit(self, fn, *args, **kwargs):
+ thread = threading.Thread(target=fn, args=args, kwargs=kwargs)
+ thread.daemon = True
+ thread.start()
- def submit(self, fn, *args, **kwargs):
- thread = threading.Thread(target=fn, args=args, kwargs=kwargs)
- thread.daemon = True
- thread.start()
-
- def shutdown(self, wait=True):
- pass
+ def shutdown(self, wait=True):
+ pass
def infinite_request_iterator():
- while True:
- yield REQUEST
+ while True:
+ yield REQUEST
if __name__ == "__main__":
- logging.basicConfig()
- parser = argparse.ArgumentParser()
- parser.add_argument("scenario", type=str)
- parser.add_argument(
- "--wait_for_interrupt", dest="wait_for_interrupt", action="store_true"
- )
- args = parser.parse_args()
+ logging.basicConfig()
+ parser = argparse.ArgumentParser()
+ parser.add_argument("scenario", type=str)
+ parser.add_argument(
+ "--wait_for_interrupt", dest="wait_for_interrupt", action="store_true"
+ )
+ args = parser.parse_args()
- if args.scenario == UNSTARTED_SERVER:
- server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
- if args.wait_for_interrupt:
- time.sleep(WAIT_TIME)
- elif args.scenario == RUNNING_SERVER:
- server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
- port = server.add_insecure_port("[::]:0")
- server.start()
- if args.wait_for_interrupt:
- time.sleep(WAIT_TIME)
- elif args.scenario == POLL_CONNECTIVITY_NO_SERVER:
- channel = grpc.insecure_channel("localhost:12345")
+ if args.scenario == UNSTARTED_SERVER:
+ server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+ elif args.scenario == RUNNING_SERVER:
+ server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
+ port = server.add_insecure_port("[::]:0")
+ server.start()
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+ elif args.scenario == POLL_CONNECTIVITY_NO_SERVER:
+ channel = grpc.insecure_channel("localhost:12345")
- def connectivity_callback(connectivity):
- pass
+ def connectivity_callback(connectivity):
+ pass
- channel.subscribe(connectivity_callback, try_to_connect=True)
- if args.wait_for_interrupt:
- time.sleep(WAIT_TIME)
- elif args.scenario == POLL_CONNECTIVITY:
- server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
- port = server.add_insecure_port("[::]:0")
- server.start()
- channel = grpc.insecure_channel("localhost:%d" % port)
+ channel.subscribe(connectivity_callback, try_to_connect=True)
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+ elif args.scenario == POLL_CONNECTIVITY:
+ server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
+ port = server.add_insecure_port("[::]:0")
+ server.start()
+ channel = grpc.insecure_channel("localhost:%d" % port)
- def connectivity_callback(connectivity):
- pass
+ def connectivity_callback(connectivity):
+ pass
- channel.subscribe(connectivity_callback, try_to_connect=True)
- if args.wait_for_interrupt:
- time.sleep(WAIT_TIME)
+ channel.subscribe(connectivity_callback, try_to_connect=True)
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
- else:
- handler = GenericHandler()
- server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
- port = server.add_insecure_port("[::]:0")
- server.add_generic_rpc_handlers((handler,))
- server.start()
- channel = grpc.insecure_channel("localhost:%d" % port)
+ else:
+ handler = GenericHandler()
+ server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),))
+ port = server.add_insecure_port("[::]:0")
+ server.add_generic_rpc_handlers((handler,))
+ server.start()
+ channel = grpc.insecure_channel("localhost:%d" % port)
- method = TEST_TO_METHOD[args.scenario]
+ method = TEST_TO_METHOD[args.scenario]
- if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
- multi_callable = channel.unary_unary(
- method,
- _registered_method=True,
- )
- future = multi_callable.future(REQUEST)
- result, call = multi_callable.with_call(REQUEST)
- elif (
- args.scenario == IN_FLIGHT_UNARY_STREAM_CALL
- or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL
- ):
- multi_callable = channel.unary_stream(
- method,
- _registered_method=True,
- )
- response_iterator = multi_callable(REQUEST)
- for response in response_iterator:
- pass
- elif (
- args.scenario == IN_FLIGHT_STREAM_UNARY_CALL
- or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL
- ):
- multi_callable = channel.stream_unary(
- method,
- _registered_method=True,
- )
- future = multi_callable.future(infinite_request_iterator())
- result, call = multi_callable.with_call(
- iter([REQUEST] * test_constants.STREAM_LENGTH)
- )
- elif (
- args.scenario == IN_FLIGHT_STREAM_STREAM_CALL
- or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL
- ):
- multi_callable = channel.stream_stream(
- method,
- _registered_method=True,
- )
- response_iterator = multi_callable(infinite_request_iterator())
- for response in response_iterator:
- pass
+ if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
+ multi_callable = channel.unary_unary(
+ method,
+ _registered_method=True,
+ )
+ future = multi_callable.future(REQUEST)
+ result, call = multi_callable.with_call(REQUEST)
+ elif (
+ args.scenario == IN_FLIGHT_UNARY_STREAM_CALL
+ or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL
+ ):
+ multi_callable = channel.unary_stream(
+ method,
+ _registered_method=True,
+ )
+ response_iterator = multi_callable(REQUEST)
+ for response in response_iterator:
+ pass
+ elif (
+ args.scenario == IN_FLIGHT_STREAM_UNARY_CALL
+ or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL
+ ):
+ multi_callable = channel.stream_unary(
+ method,
+ _registered_method=True,
+ )
+ future = multi_callable.future(infinite_request_iterator())
+ result, call = multi_callable.with_call(
+ iter([REQUEST] * test_constants.STREAM_LENGTH)
+ )
+ elif (
+ args.scenario == IN_FLIGHT_STREAM_STREAM_CALL
+ or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL
+ ):
+ multi_callable = channel.stream_stream(
+ method,
+ _registered_method=True,
+ )
+ response_iterator = multi_callable(infinite_request_iterator())
+ for response in response_iterator:
+ pass
diff --git a/test/cpp/end2end/rls_server.h b/test/cpp/end2end/rls_server.h
index 576db11..a169c15 100644
--- a/test/cpp/end2end/rls_server.h
+++ b/test/cpp/end2end/rls_server.h
@@ -18,6 +18,7 @@
#define GRPC_TEST_CPP_END2END_RLS_SERVER_H
#include "absl/types/optional.h"
+
#include "src/core/lib/gprpp/time.h"
#include "src/proto/grpc/lookup/v1/rls.grpc.pb.h"
#include "src/proto/grpc/lookup/v1/rls.pb.h"
diff --git a/test/http2_test/http2_server_health_check.py b/test/http2_test/http2_server_health_check.py
index 988212c..c91e4aa 100644
--- a/test/http2_test/http2_server_health_check.py
+++ b/test/http2_test/http2_server_health_check.py
@@ -20,16 +20,16 @@
# Utility to healthcheck the http2 server. Used when starting the server to
# verify that the server is live before tests begin.
if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--server_host", type=str, default="localhost")
- parser.add_argument("--server_port", type=int, default=8080)
- args = parser.parse_args()
- server_host = args.server_host
- server_port = args.server_port
- conn = hyper.HTTP20Connection("%s:%d" % (server_host, server_port))
- conn.request("POST", "/grpc.testing.TestService/UnaryCall")
- resp = conn.get_response()
- if resp.headers.get("grpc-encoding") is None:
- sys.exit(1)
- else:
- sys.exit(0)
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--server_host", type=str, default="localhost")
+ parser.add_argument("--server_port", type=int, default=8080)
+ args = parser.parse_args()
+ server_host = args.server_host
+ server_port = args.server_port
+ conn = hyper.HTTP20Connection("%s:%d" % (server_host, server_port))
+ conn.request("POST", "/grpc.testing.TestService/UnaryCall")
+ resp = conn.get_response()
+ if resp.headers.get("grpc-encoding") is None:
+ sys.exit(1)
+ else:
+ sys.exit(0)
diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
index 6a90657..1c3cca0 100644
--- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
+++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
@@ -48,475 +48,483 @@
class XdsTestClient(framework.rpc.grpc.GrpcApp):
- """Represents RPC services implemented in Client component of the xds test app.
+ """Represents RPC services implemented in Client component of the xds test app.
- https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#client
- """
-
- # A unique string identifying each client replica. Used in logging.
- hostname: str
-
- def __init__(
- self,
- *,
- ip: str,
- rpc_port: int,
- server_target: str,
- hostname: str,
- rpc_host: Optional[str] = None,
- maintenance_port: Optional[int] = None,
- ):
- super().__init__(rpc_host=(rpc_host or ip))
- self.ip = ip
- self.rpc_port = rpc_port
- self.server_target = server_target
- self.maintenance_port = maintenance_port or rpc_port
- self.hostname = hostname
-
- @property
- @functools.lru_cache(None)
- def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient:
- return _LoadBalancerStatsServiceClient(
- self._make_channel(self.rpc_port),
- log_target=f"{self.hostname}:{self.rpc_port}",
- )
-
- @property
- @functools.lru_cache(None)
- def update_config(self):
- return _XdsUpdateClientConfigureServiceClient(
- self._make_channel(self.rpc_port),
- log_target=f"{self.hostname}:{self.rpc_port}",
- )
-
- @property
- @functools.lru_cache(None)
- def channelz(self) -> _ChannelzServiceClient:
- return _ChannelzServiceClient(
- self._make_channel(self.maintenance_port),
- log_target=f"{self.hostname}:{self.maintenance_port}",
- )
-
- @property
- @functools.lru_cache(None)
- def csds(self) -> _CsdsClient:
- return _CsdsClient(
- self._make_channel(self.maintenance_port),
- log_target=f"{self.hostname}:{self.maintenance_port}",
- )
-
- def get_load_balancer_stats(
- self,
- *,
- num_rpcs: int,
- metadata_keys: Optional[tuple[str, ...]] = None,
- timeout_sec: Optional[int] = None,
- ) -> grpc_testing.LoadBalancerStatsResponse:
- """Shortcut to LoadBalancerStatsServiceClient.get_client_stats()"""
- return self.load_balancer_stats.get_client_stats(
- num_rpcs=num_rpcs,
- timeout_sec=timeout_sec,
- metadata_keys=metadata_keys,
- )
-
- def get_load_balancer_accumulated_stats(
- self,
- *,
- timeout_sec: Optional[int] = None,
- ) -> grpc_testing.LoadBalancerAccumulatedStatsResponse:
- """Shortcut to LoadBalancerStatsServiceClient.get_client_accumulated_stats()"""
- return self.load_balancer_stats.get_client_accumulated_stats(
- timeout_sec=timeout_sec
- )
-
- def wait_for_server_channel_ready(
- self,
- *,
- timeout: Optional[_timedelta] = None,
- rpc_deadline: Optional[_timedelta] = None,
- ) -> _ChannelzChannel:
- """Wait for the channel to the server to transition to READY.
-
- Raises:
- GrpcApp.NotFound: If the channel never transitioned to READY.
+ https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#client
"""
- try:
- return self.wait_for_server_channel_state(
- _ChannelzChannelState.READY,
- timeout=timeout,
- rpc_deadline=rpc_deadline,
- )
- except retryers.RetryError as retry_err:
- if isinstance(retry_err.exception(), self.ChannelNotFound):
- retry_err.add_note(
- framework.errors.FrameworkError.note_blanket_error(
- "The client couldn't connect to the server."
+
+ # A unique string identifying each client replica. Used in logging.
+ hostname: str
+
+ def __init__(
+ self,
+ *,
+ ip: str,
+ rpc_port: int,
+ server_target: str,
+ hostname: str,
+ rpc_host: Optional[str] = None,
+ maintenance_port: Optional[int] = None,
+ ):
+ super().__init__(rpc_host=(rpc_host or ip))
+ self.ip = ip
+ self.rpc_port = rpc_port
+ self.server_target = server_target
+ self.maintenance_port = maintenance_port or rpc_port
+ self.hostname = hostname
+
+ @property
+ @functools.lru_cache(None)
+ def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient:
+ return _LoadBalancerStatsServiceClient(
+ self._make_channel(self.rpc_port),
+ log_target=f"{self.hostname}:{self.rpc_port}",
+ )
+
+ @property
+ @functools.lru_cache(None)
+ def update_config(self):
+ return _XdsUpdateClientConfigureServiceClient(
+ self._make_channel(self.rpc_port),
+ log_target=f"{self.hostname}:{self.rpc_port}",
+ )
+
+ @property
+ @functools.lru_cache(None)
+ def channelz(self) -> _ChannelzServiceClient:
+ return _ChannelzServiceClient(
+ self._make_channel(self.maintenance_port),
+ log_target=f"{self.hostname}:{self.maintenance_port}",
+ )
+
+ @property
+ @functools.lru_cache(None)
+ def csds(self) -> _CsdsClient:
+ return _CsdsClient(
+ self._make_channel(self.maintenance_port),
+ log_target=f"{self.hostname}:{self.maintenance_port}",
+ )
+
+ def get_load_balancer_stats(
+ self,
+ *,
+ num_rpcs: int,
+ metadata_keys: Optional[tuple[str, ...]] = None,
+ timeout_sec: Optional[int] = None,
+ ) -> grpc_testing.LoadBalancerStatsResponse:
+ """Shortcut to LoadBalancerStatsServiceClient.get_client_stats()"""
+ return self.load_balancer_stats.get_client_stats(
+ num_rpcs=num_rpcs,
+ timeout_sec=timeout_sec,
+ metadata_keys=metadata_keys,
+ )
+
+ def get_load_balancer_accumulated_stats(
+ self,
+ *,
+ timeout_sec: Optional[int] = None,
+ ) -> grpc_testing.LoadBalancerAccumulatedStatsResponse:
+ """Shortcut to LoadBalancerStatsServiceClient.get_client_accumulated_stats()"""
+ return self.load_balancer_stats.get_client_accumulated_stats(
+ timeout_sec=timeout_sec
+ )
+
+ def wait_for_server_channel_ready(
+ self,
+ *,
+ timeout: Optional[_timedelta] = None,
+ rpc_deadline: Optional[_timedelta] = None,
+ ) -> _ChannelzChannel:
+ """Wait for the channel to the server to transition to READY.
+
+ Raises:
+ GrpcApp.NotFound: If the channel never transitioned to READY.
+ """
+ try:
+ return self.wait_for_server_channel_state(
+ _ChannelzChannelState.READY,
+ timeout=timeout,
+ rpc_deadline=rpc_deadline,
)
- )
- raise
+ except retryers.RetryError as retry_err:
+ if isinstance(retry_err.exception(), self.ChannelNotFound):
+ retry_err.add_note(
+ framework.errors.FrameworkError.note_blanket_error(
+ "The client couldn't connect to the server."
+ )
+ )
+ raise
- def wait_for_active_xds_channel(
- self,
- *,
- xds_server_uri: Optional[str] = None,
- timeout: Optional[_timedelta] = None,
- rpc_deadline: Optional[_timedelta] = None,
- ) -> _ChannelzChannel:
- """Wait until the xds channel is active or timeout.
+ def wait_for_active_xds_channel(
+ self,
+ *,
+ xds_server_uri: Optional[str] = None,
+ timeout: Optional[_timedelta] = None,
+ rpc_deadline: Optional[_timedelta] = None,
+ ) -> _ChannelzChannel:
+ """Wait until the xds channel is active or timeout.
- Raises:
- GrpcApp.NotFound: If the channel to xds never transitioned to active.
- """
- try:
- return self.wait_for_xds_channel_active(
- xds_server_uri=xds_server_uri,
- timeout=timeout,
- rpc_deadline=rpc_deadline,
- )
- except retryers.RetryError as retry_err:
- if isinstance(retry_err.exception(), self.ChannelNotFound):
- retry_err.add_note(
- framework.errors.FrameworkError.note_blanket_error(
- "The client couldn't connect to the xDS control plane."
+ Raises:
+ GrpcApp.NotFound: If the channel to xds never transitioned to active.
+ """
+ try:
+ return self.wait_for_xds_channel_active(
+ xds_server_uri=xds_server_uri,
+ timeout=timeout,
+ rpc_deadline=rpc_deadline,
)
+ except retryers.RetryError as retry_err:
+ if isinstance(retry_err.exception(), self.ChannelNotFound):
+ retry_err.add_note(
+ framework.errors.FrameworkError.note_blanket_error(
+ "The client couldn't connect to the xDS control plane."
+ )
+ )
+ raise
+
+ def get_active_server_channel_socket(self) -> _ChannelzSocket:
+ channel = self.find_server_channel_with_state(
+ _ChannelzChannelState.READY
)
- raise
-
- def get_active_server_channel_socket(self) -> _ChannelzSocket:
- channel = self.find_server_channel_with_state(_ChannelzChannelState.READY)
- # Get the first subchannel of the active channel to the server.
- logger.debug(
- (
- "[%s] Retrieving client -> server socket, "
- "channel_id: %s, subchannel: %s"
- ),
- self.hostname,
- channel.ref.channel_id,
- channel.subchannel_ref[0].name,
- )
- subchannel, *subchannels = list(
- self.channelz.list_channel_subchannels(channel)
- )
- if subchannels:
- logger.warning(
- "[%s] Unexpected subchannels: %r", self.hostname, subchannels
- )
- # Get the first socket of the subchannel
- socket, *sockets = list(self.channelz.list_subchannels_sockets(subchannel))
- if sockets:
- logger.warning("[%s] Unexpected sockets: %r", self.hostname, subchannels)
- logger.debug(
- "[%s] Found client -> server socket: %s",
- self.hostname,
- socket.ref.name,
- )
- return socket
-
- def wait_for_server_channel_state(
- self,
- state: _ChannelzChannelState,
- *,
- timeout: Optional[_timedelta] = None,
- rpc_deadline: Optional[_timedelta] = None,
- ) -> _ChannelzChannel:
- # When polling for a state, prefer smaller wait times to avoid
- # exhausting all allowed time on a single long RPC.
- if rpc_deadline is None:
- rpc_deadline = _timedelta(seconds=30)
-
- # Fine-tuned to wait for the channel to the server.
- retryer = retryers.exponential_retryer_with_timeout(
- wait_min=_timedelta(seconds=10),
- wait_max=_timedelta(seconds=25),
- timeout=_timedelta(minutes=5) if timeout is None else timeout,
- )
-
- logger.info(
- "[%s] Waiting to report a %s channel to %s",
- self.hostname,
- _ChannelzChannelState.Name(state),
- self.server_target,
- )
- channel = retryer(
- self.find_server_channel_with_state,
- state,
- rpc_deadline=rpc_deadline,
- )
- logger.info(
- "[%s] Channel to %s transitioned to state %s: %s",
- self.hostname,
- self.server_target,
- _ChannelzChannelState.Name(state),
- _ChannelzServiceClient.channel_repr(channel),
- )
- return channel
-
- def wait_for_xds_channel_active(
- self,
- *,
- xds_server_uri: Optional[str] = None,
- timeout: Optional[_timedelta] = None,
- rpc_deadline: Optional[_timedelta] = None,
- ) -> _ChannelzChannel:
- if not xds_server_uri:
- xds_server_uri = DEFAULT_TD_XDS_URI
- # When polling for a state, prefer smaller wait times to avoid
- # exhausting all allowed time on a single long RPC.
- if rpc_deadline is None:
- rpc_deadline = _timedelta(seconds=30)
-
- retryer = retryers.exponential_retryer_with_timeout(
- wait_min=_timedelta(seconds=10),
- wait_max=_timedelta(seconds=25),
- timeout=_timedelta(minutes=5) if timeout is None else timeout,
- )
-
- logger.info(
- "[%s] ADS: Waiting for active calls to xDS control plane to %s",
- self.hostname,
- xds_server_uri,
- )
- channel = retryer(
- self.find_active_xds_channel,
- xds_server_uri=xds_server_uri,
- rpc_deadline=rpc_deadline,
- )
- logger.info(
- "[%s] ADS: Detected active calls to xDS control plane %s",
- self.hostname,
- xds_server_uri,
- )
- return channel
-
- def find_active_xds_channel(
- self,
- xds_server_uri: str,
- *,
- rpc_deadline: Optional[_timedelta] = None,
- ) -> _ChannelzChannel:
- rpc_params = {}
- if rpc_deadline is not None:
- rpc_params["deadline_sec"] = rpc_deadline.total_seconds()
-
- for channel in self.find_channels(xds_server_uri, **rpc_params):
- logger.info(
- "[%s] xDS control plane channel: %s",
- self.hostname,
- _ChannelzServiceClient.channel_repr(channel),
- )
-
- try:
- channel_upd = self.check_channel_in_flight_calls(channel, **rpc_params)
- logger.info(
- "[%s] Detected active calls to xDS control plane %s, channel: %s",
+ # Get the first subchannel of the active channel to the server.
+ logger.debug(
+ (
+ "[%s] Retrieving client -> server socket, "
+ "channel_id: %s, subchannel: %s"
+ ),
self.hostname,
- xds_server_uri,
- _ChannelzServiceClient.channel_repr(channel_upd),
+ channel.ref.channel_id,
+ channel.subchannel_ref[0].name,
)
- return channel_upd
- except self.NotFound:
- # Continue checking other channels to the same target on
- # not found.
- continue
- except framework.rpc.grpc.RpcError as err:
- # Logged at 'info' and not at 'warning' because this method is
- # expected to be called in a retryer. If this error eventually
- # causes the retryer to fail, it will be logged fully at 'error'
- logger.info(
- "[%s] Unexpected error while checking xDS control plane"
- " channel %s: %r",
+ subchannel, *subchannels = list(
+ self.channelz.list_channel_subchannels(channel)
+ )
+ if subchannels:
+ logger.warning(
+ "[%s] Unexpected subchannels: %r", self.hostname, subchannels
+ )
+ # Get the first socket of the subchannel
+ socket, *sockets = list(
+ self.channelz.list_subchannels_sockets(subchannel)
+ )
+ if sockets:
+ logger.warning(
+ "[%s] Unexpected sockets: %r", self.hostname, subchannels
+ )
+ logger.debug(
+ "[%s] Found client -> server socket: %s",
self.hostname,
+ socket.ref.name,
+ )
+ return socket
+
+ def wait_for_server_channel_state(
+ self,
+ state: _ChannelzChannelState,
+ *,
+ timeout: Optional[_timedelta] = None,
+ rpc_deadline: Optional[_timedelta] = None,
+ ) -> _ChannelzChannel:
+ # When polling for a state, prefer smaller wait times to avoid
+ # exhausting all allowed time on a single long RPC.
+ if rpc_deadline is None:
+ rpc_deadline = _timedelta(seconds=30)
+
+ # Fine-tuned to wait for the channel to the server.
+ retryer = retryers.exponential_retryer_with_timeout(
+ wait_min=_timedelta(seconds=10),
+ wait_max=_timedelta(seconds=25),
+ timeout=_timedelta(minutes=5) if timeout is None else timeout,
+ )
+
+ logger.info(
+ "[%s] Waiting to report a %s channel to %s",
+ self.hostname,
+ _ChannelzChannelState.Name(state),
+ self.server_target,
+ )
+ channel = retryer(
+ self.find_server_channel_with_state,
+ state,
+ rpc_deadline=rpc_deadline,
+ )
+ logger.info(
+ "[%s] Channel to %s transitioned to state %s: %s",
+ self.hostname,
+ self.server_target,
+ _ChannelzChannelState.Name(state),
_ChannelzServiceClient.channel_repr(channel),
- err,
)
- raise
-
- raise self.ChannelNotActive(
- f"[{self.hostname}] Client has no"
- f" active channel with xDS control plane {xds_server_uri}",
- src=self.hostname,
- dst=xds_server_uri,
- )
-
- def find_server_channel_with_state(
- self,
- expected_state: _ChannelzChannelState,
- *,
- rpc_deadline: Optional[_timedelta] = None,
- check_subchannel=True,
- ) -> _ChannelzChannel:
- rpc_params = {}
- if rpc_deadline is not None:
- rpc_params["deadline_sec"] = rpc_deadline.total_seconds()
-
- expected_state_name: str = _ChannelzChannelState.Name(expected_state)
- target: str = self.server_target
-
- for channel in self.find_channels(target, **rpc_params):
- channel_state: _ChannelzChannelState = channel.data.state.state
- logger.info(
- "[%s] Server channel: %s",
- self.hostname,
- _ChannelzServiceClient.channel_repr(channel),
- )
- if channel_state is expected_state:
- if check_subchannel:
- # When requested, check if the channel has at least
- # one subchannel in the requested state.
- try:
- subchannel = self.find_subchannel_with_state(
- channel, expected_state, **rpc_params
- )
- logger.info(
- "[%s] Found subchannel in state %s: %s",
- self.hostname,
- expected_state_name,
- _ChannelzServiceClient.subchannel_repr(subchannel),
- )
- except self.NotFound as e:
- # Otherwise, keep searching.
- logger.info(e.message)
- continue
return channel
- raise self.ChannelNotFound(
- f"[{self.hostname}] Client has no"
- f" {expected_state_name} channel with server {target}",
- src=self.hostname,
- dst=target,
- expected_state=expected_state,
- )
-
- def find_channels(
- self,
- target: str,
- **rpc_params,
- ) -> Iterable[_ChannelzChannel]:
- return self.channelz.find_channels_for_target(target, **rpc_params)
-
- def find_subchannel_with_state(
- self, channel: _ChannelzChannel, state: _ChannelzChannelState, **kwargs
- ) -> _ChannelzSubchannel:
- subchannels = self.channelz.list_channel_subchannels(channel, **kwargs)
- for subchannel in subchannels:
- if subchannel.data.state.state is state:
- return subchannel
-
- raise self.NotFound(
- f"[{self.hostname}] Not found "
- f"a {_ChannelzChannelState.Name(state)} subchannel "
- f"for channel_id {channel.ref.channel_id}"
- )
-
- def find_subchannels_with_state(
- self, state: _ChannelzChannelState, **kwargs
- ) -> List[_ChannelzSubchannel]:
- subchannels = []
- for channel in self.channelz.find_channels_for_target(
- self.server_target, **kwargs
- ):
- for subchannel in self.channelz.list_channel_subchannels(
- channel, **kwargs
- ):
- if subchannel.data.state.state is state:
- subchannels.append(subchannel)
- return subchannels
-
- def check_channel_in_flight_calls(
- self,
- channel: _ChannelzChannel,
- *,
- wait_between_checks: Optional[_timedelta] = None,
- **rpc_params,
- ) -> Optional[_ChannelzChannel]:
- """Checks if the channel has calls that started, but didn't complete.
-
- We consider the channel is active if channel is in READY state and
- calls_started is greater than calls_failed.
-
- This method address race where a call to the xDS control plane server
- has just started and a channelz request comes in before the call has
- had a chance to fail.
-
- With channels to the xDS control plane, the channel can be READY but the
- calls could be failing to initialize, f.e. due to a failure to fetch
- OAUTH2 token. To increase the confidence that we have a valid channel
- with working OAUTH2 tokens, we check whether the channel is in a READY
- state with active calls twice with an interval of 2 seconds between the
- two attempts. If the OAUTH2 token is not valid, the call would fail and
- be caught in either the first attempt, or the second attempt. It is
- possible that between the two attempts, a call fails and a new call is
- started, so we also test for equality between the started calls of the
- two channelz results.
-
- There still exists a possibility that a call fails on fetching OAUTH2
- token after 2 seconds (maybe because there is a slowdown in the
- system.) If such a case is observed, consider increasing the interval
- from 2 seconds to 5 seconds.
-
- Returns updated channel on success, or None on failure.
- """
- if not self.calc_calls_in_flight(channel):
- return None
-
- if not wait_between_checks:
- wait_between_checks = _timedelta(seconds=2)
-
- # Load the channel second time after the timeout.
- time.sleep(wait_between_checks.total_seconds())
- channel_upd: _ChannelzChannel = self.channelz.get_channel(
- channel.ref.channel_id, **rpc_params
- )
- if (
- not self.calc_calls_in_flight(channel_upd)
- or channel.data.calls_started != channel_upd.data.calls_started
- ):
- return None
- return channel_upd
-
- @classmethod
- def calc_calls_in_flight(cls, channel: _ChannelzChannel) -> int:
- cdata: _ChannelzChannelData = channel.data
- if cdata.state.state is not _ChannelzChannelState.READY:
- return 0
-
- return cdata.calls_started - cdata.calls_succeeded - cdata.calls_failed
-
- class ChannelNotFound(framework.rpc.grpc.GrpcApp.NotFound):
- """Channel with expected status not found"""
-
- src: str
- dst: str
- expected_state: object
-
- def __init__(
+ def wait_for_xds_channel_active(
self,
- message: str,
*,
- src: str,
- dst: str,
+ xds_server_uri: Optional[str] = None,
+ timeout: Optional[_timedelta] = None,
+ rpc_deadline: Optional[_timedelta] = None,
+ ) -> _ChannelzChannel:
+ if not xds_server_uri:
+ xds_server_uri = DEFAULT_TD_XDS_URI
+ # When polling for a state, prefer smaller wait times to avoid
+ # exhausting all allowed time on a single long RPC.
+ if rpc_deadline is None:
+ rpc_deadline = _timedelta(seconds=30)
+
+ retryer = retryers.exponential_retryer_with_timeout(
+ wait_min=_timedelta(seconds=10),
+ wait_max=_timedelta(seconds=25),
+ timeout=_timedelta(minutes=5) if timeout is None else timeout,
+ )
+
+ logger.info(
+ "[%s] ADS: Waiting for active calls to xDS control plane to %s",
+ self.hostname,
+ xds_server_uri,
+ )
+ channel = retryer(
+ self.find_active_xds_channel,
+ xds_server_uri=xds_server_uri,
+ rpc_deadline=rpc_deadline,
+ )
+ logger.info(
+ "[%s] ADS: Detected active calls to xDS control plane %s",
+ self.hostname,
+ xds_server_uri,
+ )
+ return channel
+
+ def find_active_xds_channel(
+ self,
+ xds_server_uri: str,
+ *,
+ rpc_deadline: Optional[_timedelta] = None,
+ ) -> _ChannelzChannel:
+ rpc_params = {}
+ if rpc_deadline is not None:
+ rpc_params["deadline_sec"] = rpc_deadline.total_seconds()
+
+ for channel in self.find_channels(xds_server_uri, **rpc_params):
+ logger.info(
+ "[%s] xDS control plane channel: %s",
+ self.hostname,
+ _ChannelzServiceClient.channel_repr(channel),
+ )
+
+ try:
+ channel_upd = self.check_channel_in_flight_calls(
+ channel, **rpc_params
+ )
+ logger.info(
+ "[%s] Detected active calls to xDS control plane %s, channel: %s",
+ self.hostname,
+ xds_server_uri,
+ _ChannelzServiceClient.channel_repr(channel_upd),
+ )
+ return channel_upd
+ except self.NotFound:
+ # Continue checking other channels to the same target on
+ # not found.
+ continue
+ except framework.rpc.grpc.RpcError as err:
+ # Logged at 'info' and not at 'warning' because this method is
+ # expected to be called in a retryer. If this error eventually
+ # causes the retryer to fail, it will be logged fully at 'error'
+ logger.info(
+ "[%s] Unexpected error while checking xDS control plane"
+ " channel %s: %r",
+ self.hostname,
+ _ChannelzServiceClient.channel_repr(channel),
+ err,
+ )
+ raise
+
+ raise self.ChannelNotActive(
+ f"[{self.hostname}] Client has no"
+ f" active channel with xDS control plane {xds_server_uri}",
+ src=self.hostname,
+ dst=xds_server_uri,
+ )
+
+ def find_server_channel_with_state(
+ self,
expected_state: _ChannelzChannelState,
- **kwargs,
- ):
- self.src = src
- self.dst = dst
- self.expected_state = expected_state
- super().__init__(message, src, dst, expected_state, **kwargs)
-
- class ChannelNotActive(framework.rpc.grpc.GrpcApp.NotFound):
- """No active channel was found"""
-
- src: str
- dst: str
-
- def __init__(
- self,
- message: str,
*,
- src: str,
- dst: str,
- **kwargs,
- ):
- self.src = src
- self.dst = dst
- super().__init__(message, src, dst, **kwargs)
+ rpc_deadline: Optional[_timedelta] = None,
+ check_subchannel=True,
+ ) -> _ChannelzChannel:
+ rpc_params = {}
+ if rpc_deadline is not None:
+ rpc_params["deadline_sec"] = rpc_deadline.total_seconds()
+
+ expected_state_name: str = _ChannelzChannelState.Name(expected_state)
+ target: str = self.server_target
+
+ for channel in self.find_channels(target, **rpc_params):
+ channel_state: _ChannelzChannelState = channel.data.state.state
+ logger.info(
+ "[%s] Server channel: %s",
+ self.hostname,
+ _ChannelzServiceClient.channel_repr(channel),
+ )
+ if channel_state is expected_state:
+ if check_subchannel:
+ # When requested, check if the channel has at least
+ # one subchannel in the requested state.
+ try:
+ subchannel = self.find_subchannel_with_state(
+ channel, expected_state, **rpc_params
+ )
+ logger.info(
+ "[%s] Found subchannel in state %s: %s",
+ self.hostname,
+ expected_state_name,
+ _ChannelzServiceClient.subchannel_repr(subchannel),
+ )
+ except self.NotFound as e:
+ # Otherwise, keep searching.
+ logger.info(e.message)
+ continue
+ return channel
+
+ raise self.ChannelNotFound(
+ f"[{self.hostname}] Client has no"
+ f" {expected_state_name} channel with server {target}",
+ src=self.hostname,
+ dst=target,
+ expected_state=expected_state,
+ )
+
+ def find_channels(
+ self,
+ target: str,
+ **rpc_params,
+ ) -> Iterable[_ChannelzChannel]:
+ return self.channelz.find_channels_for_target(target, **rpc_params)
+
+ def find_subchannel_with_state(
+ self, channel: _ChannelzChannel, state: _ChannelzChannelState, **kwargs
+ ) -> _ChannelzSubchannel:
+ subchannels = self.channelz.list_channel_subchannels(channel, **kwargs)
+ for subchannel in subchannels:
+ if subchannel.data.state.state is state:
+ return subchannel
+
+ raise self.NotFound(
+ f"[{self.hostname}] Not found "
+ f"a {_ChannelzChannelState.Name(state)} subchannel "
+ f"for channel_id {channel.ref.channel_id}"
+ )
+
+ def find_subchannels_with_state(
+ self, state: _ChannelzChannelState, **kwargs
+ ) -> List[_ChannelzSubchannel]:
+ subchannels = []
+ for channel in self.channelz.find_channels_for_target(
+ self.server_target, **kwargs
+ ):
+ for subchannel in self.channelz.list_channel_subchannels(
+ channel, **kwargs
+ ):
+ if subchannel.data.state.state is state:
+ subchannels.append(subchannel)
+ return subchannels
+
+ def check_channel_in_flight_calls(
+ self,
+ channel: _ChannelzChannel,
+ *,
+ wait_between_checks: Optional[_timedelta] = None,
+ **rpc_params,
+ ) -> Optional[_ChannelzChannel]:
+ """Checks if the channel has calls that started, but didn't complete.
+
+ We consider the channel is active if channel is in READY state and
+ calls_started is greater than calls_failed.
+
+ This method address race where a call to the xDS control plane server
+ has just started and a channelz request comes in before the call has
+ had a chance to fail.
+
+ With channels to the xDS control plane, the channel can be READY but the
+ calls could be failing to initialize, f.e. due to a failure to fetch
+ OAUTH2 token. To increase the confidence that we have a valid channel
+ with working OAUTH2 tokens, we check whether the channel is in a READY
+ state with active calls twice with an interval of 2 seconds between the
+ two attempts. If the OAUTH2 token is not valid, the call would fail and
+ be caught in either the first attempt, or the second attempt. It is
+ possible that between the two attempts, a call fails and a new call is
+ started, so we also test for equality between the started calls of the
+ two channelz results.
+
+ There still exists a possibility that a call fails on fetching OAUTH2
+ token after 2 seconds (maybe because there is a slowdown in the
+ system.) If such a case is observed, consider increasing the interval
+ from 2 seconds to 5 seconds.
+
+ Returns updated channel on success, or None on failure.
+ """
+ if not self.calc_calls_in_flight(channel):
+ return None
+
+ if not wait_between_checks:
+ wait_between_checks = _timedelta(seconds=2)
+
+ # Load the channel second time after the timeout.
+ time.sleep(wait_between_checks.total_seconds())
+ channel_upd: _ChannelzChannel = self.channelz.get_channel(
+ channel.ref.channel_id, **rpc_params
+ )
+ if (
+ not self.calc_calls_in_flight(channel_upd)
+ or channel.data.calls_started != channel_upd.data.calls_started
+ ):
+ return None
+ return channel_upd
+
+ @classmethod
+ def calc_calls_in_flight(cls, channel: _ChannelzChannel) -> int:
+ cdata: _ChannelzChannelData = channel.data
+ if cdata.state.state is not _ChannelzChannelState.READY:
+ return 0
+
+ return cdata.calls_started - cdata.calls_succeeded - cdata.calls_failed
+
+ class ChannelNotFound(framework.rpc.grpc.GrpcApp.NotFound):
+ """Channel with expected status not found"""
+
+ src: str
+ dst: str
+ expected_state: object
+
+ def __init__(
+ self,
+ message: str,
+ *,
+ src: str,
+ dst: str,
+ expected_state: _ChannelzChannelState,
+ **kwargs,
+ ):
+ self.src = src
+ self.dst = dst
+ self.expected_state = expected_state
+ super().__init__(message, src, dst, expected_state, **kwargs)
+
+ class ChannelNotActive(framework.rpc.grpc.GrpcApp.NotFound):
+ """No active channel was found"""
+
+ src: str
+ dst: str
+
+ def __init__(
+ self,
+ message: str,
+ *,
+ src: str,
+ dst: str,
+ **kwargs,
+ ):
+ self.src = src
+ self.dst = dst
+ super().__init__(message, src, dst, **kwargs)