Change secure endpoint write code to use max_frame_size to control encrypted frame sizes at the sender. (#29990)

* use max_frame_size to control encrypted frame sizes on the sender

* Add comment

* adding logic to set max_frame_size in chttp2 transport and protecting it under a flag

* fix typo

* fix review comments

* set max frame size usage in endpoint_tests

* update endpoint_tests

* adding an interception layer to secure_endpoint_test

* add comments

* reverting some mistaken changes

* Automated change: Fix sanity tests

* try increasing deadline to check if msan passes

* Automated change: Fix sanity tests

Co-authored-by: Vignesh2208 <Vignesh2208@users.noreply.github.com>
diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
index 2fe922c..d8349e7db 100644
--- a/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
+++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.cc
@@ -62,6 +62,7 @@
 #include "src/core/lib/gpr/useful.h"
 #include "src/core/lib/gprpp/bitset.h"
 #include "src/core/lib/gprpp/debug_location.h"
+#include "src/core/lib/gprpp/global_config_env.h"
 #include "src/core/lib/gprpp/ref_counted.h"
 #include "src/core/lib/gprpp/status_helper.h"
 #include "src/core/lib/gprpp/time.h"
@@ -92,6 +93,12 @@
 #include "src/core/lib/transport/transport.h"
 #include "src/core/lib/transport/transport_impl.h"
 
+GPR_GLOBAL_CONFIG_DEFINE_BOOL(
+    grpc_experimental_enable_peer_state_based_framing, false,
+    "If set, the max sizes of frames sent to lower layers is controlled based "
+    "on the peer's memory pressure which is reflected in its max http2 frame "
+    "size.");
+
 #define DEFAULT_CONNECTION_WINDOW_TARGET (1024 * 1024)
 #define MAX_WINDOW 0x7fffffffu
 #define MAX_WRITE_BUFFER_SIZE (64 * 1024 * 1024)
@@ -979,14 +986,26 @@
 
 static void write_action(void* gt, grpc_error_handle /*error*/) {
   GPR_TIMER_SCOPE("write_action", 0);
+  static bool kEnablePeerStateBasedFraming =
+      GPR_GLOBAL_CONFIG_GET(grpc_experimental_enable_peer_state_based_framing);
   grpc_chttp2_transport* t = static_cast<grpc_chttp2_transport*>(gt);
   void* cl = t->cl;
   t->cl = nullptr;
+  // If grpc_experimental_enable_peer_state_based_framing is set to true,
+  // choose max_frame_size as 2 * max http2 frame size of peer. If peer is under
+  // high memory pressure, then it would advertise a smaller max http2 frame
+  // size. With this logic, the sender would automatically reduce the sending
+  // frame size as well.
+  int max_frame_size =
+      kEnablePeerStateBasedFraming
+          ? 2 * t->settings[GRPC_PEER_SETTINGS]
+                           [GRPC_CHTTP2_SETTINGS_MAX_FRAME_SIZE]
+          : INT_MAX;
   grpc_endpoint_write(
       t->ep, &t->outbuf,
       GRPC_CLOSURE_INIT(&t->write_action_end_locked, write_action_end, t,
                         grpc_schedule_on_exec_ctx),
-      cl, /*max_frame_size=*/INT_MAX);
+      cl, max_frame_size);
 }
 
 static void write_action_end(void* tp, grpc_error_handle error) {
diff --git a/src/core/lib/security/transport/secure_endpoint.cc b/src/core/lib/security/transport/secure_endpoint.cc
index b0c8a35..303a0dc 100644
--- a/src/core/lib/security/transport/secure_endpoint.cc
+++ b/src/core/lib/security/transport/secure_endpoint.cc
@@ -21,7 +21,6 @@
 #include "src/core/lib/security/transport/secure_endpoint.h"
 
 #include <inttypes.h>
-#include <limits.h>
 
 #include <algorithm>
 #include <atomic>
@@ -105,6 +104,7 @@
     }
     has_posted_reclaimer.store(false, std::memory_order_relaxed);
     min_progress_size = 1;
+    grpc_slice_buffer_init(&protector_staging_buffer);
     gpr_ref_init(&ref, 1);
   }
 
@@ -117,6 +117,7 @@
     grpc_slice_unref_internal(read_staging_buffer);
     grpc_slice_unref_internal(write_staging_buffer);
     grpc_slice_buffer_destroy_internal(&output_buffer);
+    grpc_slice_buffer_destroy_internal(&protector_staging_buffer);
     gpr_mu_destroy(&protector_mu);
   }
 
@@ -143,7 +144,7 @@
   grpc_core::MemoryAllocator::Reservation self_reservation;
   std::atomic<bool> has_posted_reclaimer;
   int min_progress_size;
-
+  grpc_slice_buffer protector_staging_buffer;
   gpr_refcount ref;
 };
 }  // namespace
@@ -384,8 +385,7 @@
 }
 
 static void endpoint_write(grpc_endpoint* secure_ep, grpc_slice_buffer* slices,
-                           grpc_closure* cb, void* arg,
-                           int /*max_frame_size*/) {
+                           grpc_closure* cb, void* arg, int max_frame_size) {
   GPR_TIMER_SCOPE("secure_endpoint.endpoint_write", 0);
 
   unsigned i;
@@ -410,8 +410,25 @@
 
     if (ep->zero_copy_protector != nullptr) {
       // Use zero-copy grpc protector to protect.
-      result = tsi_zero_copy_grpc_protector_protect(ep->zero_copy_protector,
-                                                    slices, &ep->output_buffer);
+      result = TSI_OK;
+      // Break the input slices into chunks of size = max_frame_size and call
+      // tsi_zero_copy_grpc_protector_protect on each chunk. This ensures that
+      // the protector cannot create frames larger than the specified
+      // max_frame_size.
+      while (slices->length > static_cast<size_t>(max_frame_size) &&
+             result == TSI_OK) {
+        grpc_slice_buffer_move_first(slices,
+                                     static_cast<size_t>(max_frame_size),
+                                     &ep->protector_staging_buffer);
+        result = tsi_zero_copy_grpc_protector_protect(
+            ep->zero_copy_protector, &ep->protector_staging_buffer,
+            &ep->output_buffer);
+      }
+      if (result == TSI_OK && slices->length > 0) {
+        result = tsi_zero_copy_grpc_protector_protect(
+            ep->zero_copy_protector, slices, &ep->output_buffer);
+      }
+      grpc_slice_buffer_reset_and_unref_internal(&ep->protector_staging_buffer);
     } else {
       // Use frame protector to protect.
       for (i = 0; i < slices->count; i++) {
@@ -479,7 +496,7 @@
   }
 
   grpc_endpoint_write(ep->wrapped_ep, &ep->output_buffer, cb, arg,
-                      /*max_frame_size=*/INT_MAX);
+                      max_frame_size);
 }
 
 static void endpoint_shutdown(grpc_endpoint* secure_ep, grpc_error_handle why) {
diff --git a/src/core/tsi/fake_transport_security.cc b/src/core/tsi/fake_transport_security.cc
index 18a42dd..2143fb5 100644
--- a/src/core/tsi/fake_transport_security.cc
+++ b/src/core/tsi/fake_transport_security.cc
@@ -143,6 +143,11 @@
   return load32_little_endian(frame_size_buffer);
 }
 
+uint32_t tsi_fake_zero_copy_grpc_protector_next_frame_size(
+    const grpc_slice_buffer* protected_slices) {
+  return read_frame_size(protected_slices);
+}
+
 static void tsi_fake_frame_reset(tsi_fake_frame* frame, int needs_draining) {
   frame->offset = 0;
   frame->needs_draining = needs_draining;
diff --git a/src/core/tsi/fake_transport_security.h b/src/core/tsi/fake_transport_security.h
index 704d9fb..b2cf5ff 100644
--- a/src/core/tsi/fake_transport_security.h
+++ b/src/core/tsi/fake_transport_security.h
@@ -21,6 +21,7 @@
 
 #include <grpc/support/port_platform.h>
 
+#include "src/core/lib/slice/slice_internal.h"
 #include "src/core/tsi/transport_security_interface.h"
 
 /* Value for the TSI_CERTIFICATE_TYPE_PEER_PROPERTY property for FAKE certs. */
@@ -44,4 +45,9 @@
 tsi_zero_copy_grpc_protector* tsi_create_fake_zero_copy_grpc_protector(
     size_t* max_protected_frame_size);
 
+/* Given a buffer containing slices encrypted by a fake_zero_copy_protector
+ * it parses these protected slices to return the total frame size of the first
+ * contained frame */
+uint32_t tsi_fake_zero_copy_grpc_protector_next_frame_size(
+    const grpc_slice_buffer* protected_slices);
 #endif /* GRPC_CORE_TSI_FAKE_TRANSPORT_SECURITY_H */
diff --git a/test/core/iomgr/endpoint_tests.cc b/test/core/iomgr/endpoint_tests.cc
index ad2488a..4fb9490 100644
--- a/test/core/iomgr/endpoint_tests.cc
+++ b/test/core/iomgr/endpoint_tests.cc
@@ -112,6 +112,7 @@
   uint8_t current_write_data;
   int read_done;
   int write_done;
+  int max_write_frame_size;
   grpc_slice_buffer incoming;
   grpc_slice_buffer outgoing;
   grpc_closure done_read;
@@ -153,7 +154,7 @@
   struct read_and_write_test_state* state =
       static_cast<struct read_and_write_test_state*>(data);
   grpc_endpoint_write(state->write_ep, &state->outgoing, &state->done_write,
-                      nullptr, /*max_frame_size=*/INT_MAX);
+                      nullptr, /*max_frame_size=*/state->max_write_frame_size);
 }
 
 static void read_and_write_test_write_handler(void* data,
@@ -197,13 +198,14 @@
  */
 static void read_and_write_test(grpc_endpoint_test_config config,
                                 size_t num_bytes, size_t write_size,
-                                size_t slice_size, bool shutdown) {
+                                size_t slice_size, int max_write_frame_size,
+                                bool shutdown) {
   struct read_and_write_test_state state;
   grpc_endpoint_test_fixture f =
       begin_test(config, "read_and_write_test", slice_size);
   grpc_core::ExecCtx exec_ctx;
   auto deadline = grpc_core::Timestamp::FromTimespecRoundUp(
-      grpc_timeout_seconds_to_deadline(20));
+      grpc_timeout_seconds_to_deadline(60));
   gpr_log(GPR_DEBUG,
           "num_bytes=%" PRIuPTR " write_size=%" PRIuPTR " slice_size=%" PRIuPTR
           " shutdown=%d",
@@ -223,6 +225,7 @@
   state.target_bytes = num_bytes;
   state.bytes_read = 0;
   state.current_write_size = write_size;
+  state.max_write_frame_size = max_write_frame_size;
   state.bytes_written = 0;
   state.read_done = 0;
   state.write_done = 0;
@@ -305,7 +308,6 @@
   grpc_endpoint_test_fixture f =
       begin_test(config, "multiple_shutdown_test", 128);
   int fail_count = 0;
-
   grpc_slice_buffer slice_buffer;
   grpc_slice_buffer_init(&slice_buffer);
 
@@ -346,11 +348,13 @@
   g_pollset = pollset;
   g_mu = mu;
   multiple_shutdown_test(config);
-  read_and_write_test(config, 10000000, 100000, 8192, false);
-  read_and_write_test(config, 1000000, 100000, 1, false);
-  read_and_write_test(config, 100000000, 100000, 1, true);
+  for (int i = 1; i <= 8192; i = i * 2) {
+    read_and_write_test(config, 10000000, 100000, 8192, i, false);
+    read_and_write_test(config, 1000000, 100000, 1, i, false);
+    read_and_write_test(config, 100000000, 100000, 1, i, true);
+  }
   for (i = 1; i < 1000; i = std::max(i + 1, i * 5 / 4)) {
-    read_and_write_test(config, 40320, i, i, false);
+    read_and_write_test(config, 40320, i, i, i, false);
   }
   g_pollset = nullptr;
   g_mu = nullptr;
diff --git a/test/core/security/secure_endpoint_test.cc b/test/core/security/secure_endpoint_test.cc
index 03118d8..708a499 100644
--- a/test/core/security/secure_endpoint_test.cc
+++ b/test/core/security/secure_endpoint_test.cc
@@ -36,6 +36,93 @@
 static gpr_mu* g_mu;
 static grpc_pollset* g_pollset;
 
+#define TSI_FAKE_FRAME_HEADER_SIZE 4
+
+typedef struct intercept_endpoint {
+  grpc_endpoint base;
+  grpc_endpoint* wrapped_ep;
+  grpc_slice_buffer staging_buffer;
+} intercept_endpoint;
+
+static void me_read(grpc_endpoint* ep, grpc_slice_buffer* slices,
+                    grpc_closure* cb, bool urgent, int min_progress_size) {
+  intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
+  grpc_endpoint_read(m->wrapped_ep, slices, cb, urgent, min_progress_size);
+}
+
+static void me_write(grpc_endpoint* ep, grpc_slice_buffer* slices,
+                     grpc_closure* cb, void* arg, int max_frame_size) {
+  intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
+  int remaining = slices->length;
+  while (remaining > 0) {
+    // Estimate the frame size of the next frame.
+    int next_frame_size =
+        tsi_fake_zero_copy_grpc_protector_next_frame_size(slices);
+    GPR_ASSERT(next_frame_size > TSI_FAKE_FRAME_HEADER_SIZE);
+    // Ensure the protected data size does not exceed the max_frame_size.
+    GPR_ASSERT(next_frame_size - TSI_FAKE_FRAME_HEADER_SIZE <= max_frame_size);
+    // Move this frame into a staging buffer and repeat.
+    grpc_slice_buffer_move_first(slices, next_frame_size, &m->staging_buffer);
+    remaining -= next_frame_size;
+  }
+  grpc_slice_buffer_swap(&m->staging_buffer, slices);
+  grpc_endpoint_write(m->wrapped_ep, slices, cb, arg, max_frame_size);
+}
+
+static void me_add_to_pollset(grpc_endpoint* /*ep*/,
+                              grpc_pollset* /*pollset*/) {}
+
+static void me_add_to_pollset_set(grpc_endpoint* /*ep*/,
+                                  grpc_pollset_set* /*pollset*/) {}
+
+static void me_delete_from_pollset_set(grpc_endpoint* /*ep*/,
+                                       grpc_pollset_set* /*pollset*/) {}
+
+static void me_shutdown(grpc_endpoint* ep, grpc_error_handle why) {
+  intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
+  grpc_endpoint_shutdown(m->wrapped_ep, why);
+}
+
+static void me_destroy(grpc_endpoint* ep) {
+  intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
+  grpc_endpoint_destroy(m->wrapped_ep);
+  grpc_slice_buffer_destroy(&m->staging_buffer);
+  gpr_free(m);
+}
+
+static absl::string_view me_get_peer(grpc_endpoint* /*ep*/) {
+  return "fake:intercept-endpoint";
+}
+
+static absl::string_view me_get_local_address(grpc_endpoint* /*ep*/) {
+  return "fake:intercept-endpoint";
+}
+
+static int me_get_fd(grpc_endpoint* /*ep*/) { return -1; }
+
+static bool me_can_track_err(grpc_endpoint* /*ep*/) { return false; }
+
+static const grpc_endpoint_vtable vtable = {me_read,
+                                            me_write,
+                                            me_add_to_pollset,
+                                            me_add_to_pollset_set,
+                                            me_delete_from_pollset_set,
+                                            me_shutdown,
+                                            me_destroy,
+                                            me_get_peer,
+                                            me_get_local_address,
+                                            me_get_fd,
+                                            me_can_track_err};
+
+grpc_endpoint* wrap_with_intercept_endpoint(grpc_endpoint* wrapped_ep) {
+  intercept_endpoint* m =
+      static_cast<intercept_endpoint*>(gpr_malloc(sizeof(*m)));
+  m->base.vtable = &vtable;
+  m->wrapped_ep = wrapped_ep;
+  grpc_slice_buffer_init(&m->staging_buffer);
+  return &m->base;
+}
+
 static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair(
     size_t slice_size, grpc_slice* leftover_slices, size_t leftover_nslices,
     bool use_zero_copy_protector) {
@@ -68,6 +155,13 @@
   grpc_endpoint_add_to_pollset(tcp.client, g_pollset);
   grpc_endpoint_add_to_pollset(tcp.server, g_pollset);
 
+  // TODO(vigneshbabu): Extend the intercept endpoint logic to cover non-zero
+  // copy based frame protectors as well.
+  if (use_zero_copy_protector && leftover_nslices == 0) {
+    tcp.client = wrap_with_intercept_endpoint(tcp.client);
+    tcp.server = wrap_with_intercept_endpoint(tcp.server);
+  }
+
   if (leftover_nslices == 0) {
     f.client_ep = grpc_secure_endpoint_create(fake_read_protector,
                                               fake_read_zero_copy_protector,
@@ -125,7 +219,6 @@
                                             tcp.server, nullptr, &args, 0);
   grpc_resource_quota_unref(
       static_cast<grpc_resource_quota*>(a[1].value.pointer.p));
-
   return f;
 }