Change secure endpoint write code to use max_frame_size to control encrypted frame sizes at the sender. (#29990)
* use max_frame_size to control encrypted frame sizes on the sender
* Add comment
* adding logic to set max_frame_size in chttp2 transport and protecting it under a flag
* fix typo
* fix review comments
* set max frame size usage in endpoint_tests
* update endpoint_tests
* adding an interception layer to secure_endpoint_test
* add comments
* reverting some mistaken changes
* Automated change: Fix sanity tests
* try increasing deadline to check if msan passes
* Automated change: Fix sanity tests
Co-authored-by: Vignesh2208 <Vignesh2208@users.noreply.github.com>
diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
index 2fe922c..d8349e7db 100644
--- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
+++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
@@ -62,6 +62,7 @@
#include "src/core/lib/gpr/useful.h"
#include "src/core/lib/gprpp/bitset.h"
#include "src/core/lib/gprpp/debug_location.h"
+#include "src/core/lib/gprpp/global_config_env.h"
#include "src/core/lib/gprpp/ref_counted.h"
#include "src/core/lib/gprpp/status_helper.h"
#include "src/core/lib/gprpp/time.h"
@@ -92,6 +93,12 @@
#include "src/core/lib/transport/transport.h"
#include "src/core/lib/transport/transport_impl.h"
+GPR_GLOBAL_CONFIG_DEFINE_BOOL(
+ grpc_experimental_enable_peer_state_based_framing, false,
+ "If set, the max sizes of frames sent to lower layers is controlled based "
+ "on the peer's memory pressure which is reflected in its max http2 frame "
+ "size.");
+
#define DEFAULT_CONNECTION_WINDOW_TARGET (1024 * 1024)
#define MAX_WINDOW 0x7fffffffu
#define MAX_WRITE_BUFFER_SIZE (64 * 1024 * 1024)
@@ -979,14 +986,26 @@
static void write_action(void* gt, grpc_error_handle /*error*/) {
GPR_TIMER_SCOPE("write_action", 0);
+ static bool kEnablePeerStateBasedFraming =
+ GPR_GLOBAL_CONFIG_GET(grpc_experimental_enable_peer_state_based_framing);
grpc_chttp2_transport* t = static_cast<grpc_chttp2_transport*>(gt);
void* cl = t->cl;
t->cl = nullptr;
+ // If grpc_experimental_enable_peer_state_based_framing is set to true,
+ // choose max_frame_size as 2 * max http2 frame size of peer. If peer is under
+ // high memory pressure, then it would advertise a smaller max http2 frame
+ // size. With this logic, the sender would automatically reduce the sending
+ // frame size as well.
+ int max_frame_size =
+ kEnablePeerStateBasedFraming
+ ? 2 * t->settings[GRPC_PEER_SETTINGS]
+ [GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE]
+ : INT_MAX;
grpc_endpoint_write(
t->ep, &t->outbuf,
GRPC_CLOSURE_INIT(&t->write_action_end_locked, write_action_end, t,
grpc_schedule_on_exec_ctx),
- cl, /*max_frame_size=*/INT_MAX);
+ cl, max_frame_size);
}
static void write_action_end(void* tp, grpc_error_handle error) {
diff --git a/src/core/lib/security/transport/secure_endpoint.cc b/src/core/lib/security/transport/secure_endpoint.cc
index b0c8a35..303a0dc 100644
--- a/src/core/lib/security/transport/secure_endpoint.cc
+++ b/src/core/lib/security/transport/secure_endpoint.cc
@@ -21,7 +21,6 @@
#include "src/core/lib/security/transport/secure_endpoint.h"
#include <inttypes.h>
-#include <limits.h>
#include <algorithm>
#include <atomic>
@@ -105,6 +104,7 @@
}
has_posted_reclaimer.store(false, std::memory_order_relaxed);
min_progress_size = 1;
+ grpc_slice_buffer_init(&protector_staging_buffer);
gpr_ref_init(&ref, 1);
}
@@ -117,6 +117,7 @@
grpc_slice_unref_internal(read_staging_buffer);
grpc_slice_unref_internal(write_staging_buffer);
grpc_slice_buffer_destroy_internal(&output_buffer);
+ grpc_slice_buffer_destroy_internal(&protector_staging_buffer);
gpr_mu_destroy(&protector_mu);
}
@@ -143,7 +144,7 @@
grpc_core::MemoryAllocator::Reservation self_reservation;
std::atomic<bool> has_posted_reclaimer;
int min_progress_size;
-
+ grpc_slice_buffer protector_staging_buffer;
gpr_refcount ref;
};
} // namespace
@@ -384,8 +385,7 @@
}
static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices,
- grpc_closure* cb, void* arg,
- int /*max_frame_size*/) {
+ grpc_closure* cb, void* arg, int max_frame_size) {
GPR_TIMER_SCOPE("secure_endpoint.endpoint_write", 0);
unsigned i;
@@ -410,8 +410,25 @@
if (ep->zero_copy_protector != nullptr) {
// Use zero-copy grpc protector to protect.
- result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector,
- slices, &ep->output_buffer);
+ result = TSI_OK;
+ // Break the input slices into chunks of size = max_frame_size and call
+ // tsi_zero_copy_grpc_protector_protect on each chunk. This ensures that
+ // the protector cannot create frames larger than the specified
+ // max_frame_size.
+ while (slices->length > static_cast<size_t>(max_frame_size) &&
+ result == TSI_OK) {
+ grpc_slice_buffer_move_first(slices,
+ static_cast<size_t>(max_frame_size),
+ &ep->protector_staging_buffer);
+ result = tsi_zero_copy_grpc_protector_protect(
+ ep->zero_copy_protector, &ep->protector_staging_buffer,
+ &ep->output_buffer);
+ }
+ if (result == TSI_OK && slices->length > 0) {
+ result = tsi_zero_copy_grpc_protector_protect(
+ ep->zero_copy_protector, slices, &ep->output_buffer);
+ }
+ grpc_slice_buffer_reset_and_unref_internal(&ep->protector_staging_buffer);
} else {
// Use frame protector to protect.
for (i = 0; i < slices->count; i++) {
@@ -479,7 +496,7 @@
}
grpc_endpoint_write(ep->wrapped_ep, &ep->output_buffer, cb, arg,
- /*max_frame_size=*/INT_MAX);
+ max_frame_size);
}
static void endpoint_shutdown(grpc_endpoint* secure_ep, grpc_error_handle why) {
diff --git a/src/core/tsi/fake_transport_security.cc b/src/core/tsi/fake_transport_security.cc
index 18a42dd..2143fb5 100644
--- a/src/core/tsi/fake_transport_security.cc
+++ b/src/core/tsi/fake_transport_security.cc
@@ -143,6 +143,11 @@
return load32_little_endian(frame_size_buffer);
}
+uint32_t tsi_fake_zero_copy_grpc_protector_next_frame_size(
+ const grpc_slice_buffer* protected_slices) {
+ return read_frame_size(protected_slices);
+}
+
static void tsi_fake_frame_reset(tsi_fake_frame* frame, int needs_draining) {
frame->offset = 0;
frame->needs_draining = needs_draining;
diff --git a/src/core/tsi/fake_transport_security.h b/src/core/tsi/fake_transport_security.h
index 704d9fb..b2cf5ff 100644
--- a/src/core/tsi/fake_transport_security.h
+++ b/src/core/tsi/fake_transport_security.h
@@ -21,6 +21,7 @@
#include <grpc/support/port_platform.h>
+#include "src/core/lib/slice/slice_internal.h"
#include "src/core/tsi/transport_security_interface.h"
/* Value for the TSI_CERTIFICATE_TYPE_PEER_PROPERTY property for FAKE certs. */
@@ -44,4 +45,9 @@
tsi_zero_copy_grpc_protector* tsi_create_fake_zero_copy_grpc_protector(
size_t* max_protected_frame_size);
+/* Given a buffer containing slices encrypted by a fake_zero_copy_protector
+ * it parses these protected slices to return the total frame size of the first
+ * contained frame */
+uint32_t tsi_fake_zero_copy_grpc_protector_next_frame_size(
+ const grpc_slice_buffer* protected_slices);
#endif /* GRPC_CORE_TSI_FAKE_TRANSPORT_SECURITY_H */
diff --git a/test/core/iomgr/endpoint_tests.cc b/test/core/iomgr/endpoint_tests.cc
index ad2488a..4fb9490 100644
--- a/test/core/iomgr/endpoint_tests.cc
+++ b/test/core/iomgr/endpoint_tests.cc
@@ -112,6 +112,7 @@
uint8_t current_write_data;
int read_done;
int write_done;
+ int max_write_frame_size;
grpc_slice_buffer incoming;
grpc_slice_buffer outgoing;
grpc_closure done_read;
@@ -153,7 +154,7 @@
struct read_and_write_test_state* state =
static_cast<struct read_and_write_test_state*>(data);
grpc_endpoint_write(state->write_ep, &state->outgoing, &state->done_write,
- nullptr, /*max_frame_size=*/INT_MAX);
+ nullptr, /*max_frame_size=*/state->max_write_frame_size);
}
static void read_and_write_test_write_handler(void* data,
@@ -197,13 +198,14 @@
*/
static void read_and_write_test(grpc_endpoint_test_config config,
size_t num_bytes, size_t write_size,
- size_t slice_size, bool shutdown) {
+ size_t slice_size, int max_write_frame_size,
+ bool shutdown) {
struct read_and_write_test_state state;
grpc_endpoint_test_fixture f =
begin_test(config, "read_and_write_test", slice_size);
grpc_core::ExecCtx exec_ctx;
auto deadline = grpc_core::Timestamp::FromTimespecRoundUp(
- grpc_timeout_seconds_to_deadline(20));
+ grpc_timeout_seconds_to_deadline(60));
gpr_log(GPR_DEBUG,
"num_bytes=%" PRIuPTR " write_size=%" PRIuPTR " slice_size=%" PRIuPTR
" shutdown=%d",
@@ -223,6 +225,7 @@
state.target_bytes = num_bytes;
state.bytes_read = 0;
state.current_write_size = write_size;
+ state.max_write_frame_size = max_write_frame_size;
state.bytes_written = 0;
state.read_done = 0;
state.write_done = 0;
@@ -305,7 +308,6 @@
grpc_endpoint_test_fixture f =
begin_test(config, "multiple_shutdown_test", 128);
int fail_count = 0;
-
grpc_slice_buffer slice_buffer;
grpc_slice_buffer_init(&slice_buffer);
@@ -346,11 +348,13 @@
g_pollset = pollset;
g_mu = mu;
multiple_shutdown_test(config);
- read_and_write_test(config, 10000000, 100000, 8192, false);
- read_and_write_test(config, 1000000, 100000, 1, false);
- read_and_write_test(config, 100000000, 100000, 1, true);
+ for (int i = 1; i <= 8192; i = i * 2) {
+ read_and_write_test(config, 10000000, 100000, 8192, i, false);
+ read_and_write_test(config, 1000000, 100000, 1, i, false);
+ read_and_write_test(config, 100000000, 100000, 1, i, true);
+ }
for (i = 1; i < 1000; i = std::max(i + 1, i * 5 / 4)) {
- read_and_write_test(config, 40320, i, i, false);
+ read_and_write_test(config, 40320, i, i, i, false);
}
g_pollset = nullptr;
g_mu = nullptr;
diff --git a/test/core/security/secure_endpoint_test.cc b/test/core/security/secure_endpoint_test.cc
index 03118d8..708a499 100644
--- a/test/core/security/secure_endpoint_test.cc
+++ b/test/core/security/secure_endpoint_test.cc
@@ -36,6 +36,93 @@
static gpr_mu* g_mu;
static grpc_pollset* g_pollset;
+#define TSI_FAKE_FRAME_HEADER_SIZE 4
+
+typedef struct intercept_endpoint {
+ grpc_endpoint base;
+ grpc_endpoint* wrapped_ep;
+ grpc_slice_buffer staging_buffer;
+} intercept_endpoint;
+
+static void me_read(grpc_endpoint* ep, grpc_slice_buffer* slices,
+ grpc_closure* cb, bool urgent, int min_progress_size) {
+ intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
+ grpc_endpoint_read(m->wrapped_ep, slices, cb, urgent, min_progress_size);
+}
+
+static void me_write(grpc_endpoint* ep, grpc_slice_buffer* slices,
+ grpc_closure* cb, void* arg, int max_frame_size) {
+ intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
+ int remaining = slices->length;
+ while (remaining > 0) {
+ // Estimate the frame size of the next frame.
+ int next_frame_size =
+ tsi_fake_zero_copy_grpc_protector_next_frame_size(slices);
+ GPR_ASSERT(next_frame_size > TSI_FAKE_FRAME_HEADER_SIZE);
+ // Ensure the protected data size does not exceed the max_frame_size.
+ GPR_ASSERT(next_frame_size - TSI_FAKE_FRAME_HEADER_SIZE <= max_frame_size);
+ // Move this frame into a staging buffer and repeat.
+ grpc_slice_buffer_move_first(slices, next_frame_size, &m->staging_buffer);
+ remaining -= next_frame_size;
+ }
+ grpc_slice_buffer_swap(&m->staging_buffer, slices);
+ grpc_endpoint_write(m->wrapped_ep, slices, cb, arg, max_frame_size);
+}
+
+static void me_add_to_pollset(grpc_endpoint* /*ep*/,
+ grpc_pollset* /*pollset*/) {}
+
+static void me_add_to_pollset_set(grpc_endpoint* /*ep*/,
+ grpc_pollset_set* /*pollset*/) {}
+
+static void me_delete_from_pollset_set(grpc_endpoint* /*ep*/,
+ grpc_pollset_set* /*pollset*/) {}
+
+static void me_shutdown(grpc_endpoint* ep, grpc_error_handle why) {
+ intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
+ grpc_endpoint_shutdown(m->wrapped_ep, why);
+}
+
+static void me_destroy(grpc_endpoint* ep) {
+ intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
+ grpc_endpoint_destroy(m->wrapped_ep);
+ grpc_slice_buffer_destroy(&m->staging_buffer);
+ gpr_free(m);
+}
+
+static absl::string_view me_get_peer(grpc_endpoint* /*ep*/) {
+ return "fake:intercept-endpoint";
+}
+
+static absl::string_view me_get_local_address(grpc_endpoint* /*ep*/) {
+ return "fake:intercept-endpoint";
+}
+
+static int me_get_fd(grpc_endpoint* /*ep*/) { return -1; }
+
+static bool me_can_track_err(grpc_endpoint* /*ep*/) { return false; }
+
+static const grpc_endpoint_vtable vtable = {me_read,
+ me_write,
+ me_add_to_pollset,
+ me_add_to_pollset_set,
+ me_delete_from_pollset_set,
+ me_shutdown,
+ me_destroy,
+ me_get_peer,
+ me_get_local_address,
+ me_get_fd,
+ me_can_track_err};
+
+grpc_endpoint* wrap_with_intercept_endpoint(grpc_endpoint* wrapped_ep) {
+ intercept_endpoint* m =
+ static_cast<intercept_endpoint*>(gpr_malloc(sizeof(*m)));
+ m->base.vtable = &vtable;
+ m->wrapped_ep = wrapped_ep;
+ grpc_slice_buffer_init(&m->staging_buffer);
+ return &m->base;
+}
+
static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair(
size_t slice_size, grpc_slice* leftover_slices, size_t leftover_nslices,
bool use_zero_copy_protector) {
@@ -68,6 +155,13 @@
grpc_endpoint_add_to_pollset(tcp.client, g_pollset);
grpc_endpoint_add_to_pollset(tcp.server, g_pollset);
+ // TODO(vigneshbabu): Extend the intercept endpoint logic to cover non-zero
+ // copy based frame protectors as well.
+ if (use_zero_copy_protector && leftover_nslices == 0) {
+ tcp.client = wrap_with_intercept_endpoint(tcp.client);
+ tcp.server = wrap_with_intercept_endpoint(tcp.server);
+ }
+
if (leftover_nslices == 0) {
f.client_ep = grpc_secure_endpoint_create(fake_read_protector,
fake_read_zero_copy_protector,
@@ -125,7 +219,6 @@
tcp.server, nullptr, &args, 0);
grpc_resource_quota_unref(
static_cast<grpc_resource_quota*>(a[1].value.pointer.p));
-
return f;
}