Merge pull request #24782 from yashykt/xdssecenvvarbackport

Protect xds security code with the environment variable "GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT"
diff --git a/src/core/ext/xds/xds_api.cc b/src/core/ext/xds/xds_api.cc
index 31d862d..8ec9dfe 100644
--- a/src/core/ext/xds/xds_api.cc
+++ b/src/core/ext/xds/xds_api.cc
@@ -99,6 +99,17 @@
   return parse_succeeded && parsed_value;
 }
 
+// TODO(yashykt): Check to see if xDS security is enabled. This will be
+// removed once this feature is fully integration-tested and enabled by
+// default.
+bool XdsSecurityEnabled() {
+  char* value = gpr_getenv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT");
+  bool parsed_value;
+  bool parse_succeeded = gpr_parse_bool_value(value, &parsed_value);
+  gpr_free(value);
+  return parse_succeeded && parsed_value;
+}
+
 //
 // XdsApi::Route::Matchers::PathMatcher
 //
@@ -1566,33 +1577,36 @@
       return GRPC_ERROR_CREATE_FROM_STATIC_STRING(
           "LB policy is not ROUND_ROBIN.");
     }
-    // Record Upstream tls context
-    auto* transport_socket =
-        envoy_config_cluster_v3_Cluster_transport_socket(cluster);
-    if (transport_socket != nullptr) {
-      absl::string_view name = UpbStringToAbsl(
-          envoy_config_core_v3_TransportSocket_name(transport_socket));
-      if (name == "envoy.transport_sockets.tls") {
-        auto* typed_config =
-            envoy_config_core_v3_TransportSocket_typed_config(transport_socket);
-        if (typed_config != nullptr) {
-          const upb_strview encoded_upstream_tls_context =
-              google_protobuf_Any_value(typed_config);
-          auto* upstream_tls_context =
-              envoy_extensions_transport_sockets_tls_v3_UpstreamTlsContext_parse(
-                  encoded_upstream_tls_context.data,
-                  encoded_upstream_tls_context.size, arena);
-          if (upstream_tls_context == nullptr) {
-            return GRPC_ERROR_CREATE_FROM_STATIC_STRING(
-                "Can't decode upstream tls context.");
-          }
-          auto* common_tls_context =
-              envoy_extensions_transport_sockets_tls_v3_UpstreamTlsContext_common_tls_context(
-                  upstream_tls_context);
-          if (common_tls_context != nullptr) {
-            grpc_error* error = CommonTlsContextParse(
-                common_tls_context, &cds_update.common_tls_context);
-            if (error != GRPC_ERROR_NONE) return error;
+    if (XdsSecurityEnabled()) {
+      // Record Upstream tls context
+      auto* transport_socket =
+          envoy_config_cluster_v3_Cluster_transport_socket(cluster);
+      if (transport_socket != nullptr) {
+        absl::string_view name = UpbStringToAbsl(
+            envoy_config_core_v3_TransportSocket_name(transport_socket));
+        if (name == "envoy.transport_sockets.tls") {
+          auto* typed_config =
+              envoy_config_core_v3_TransportSocket_typed_config(
+                  transport_socket);
+          if (typed_config != nullptr) {
+            const upb_strview encoded_upstream_tls_context =
+                google_protobuf_Any_value(typed_config);
+            auto* upstream_tls_context =
+                envoy_extensions_transport_sockets_tls_v3_UpstreamTlsContext_parse(
+                    encoded_upstream_tls_context.data,
+                    encoded_upstream_tls_context.size, arena);
+            if (upstream_tls_context == nullptr) {
+              return GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+                  "Can't decode upstream tls context.");
+            }
+            auto* common_tls_context =
+                envoy_extensions_transport_sockets_tls_v3_UpstreamTlsContext_common_tls_context(
+                    upstream_tls_context);
+            if (common_tls_context != nullptr) {
+              grpc_error* error = CommonTlsContextParse(
+                  common_tls_context, &cds_update.common_tls_context);
+              if (error != GRPC_ERROR_NONE) return error;
+            }
           }
         }
       }
diff --git a/src/core/ext/xds/xds_api.h b/src/core/ext/xds/xds_api.h
index 8e7b1f2..885dd2c 100644
--- a/src/core/ext/xds/xds_api.h
+++ b/src/core/ext/xds/xds_api.h
@@ -39,6 +39,11 @@
 
 namespace grpc_core {
 
+// TODO(yashykt): Check to see if xDS security is enabled. This will be
+// removed once this feature is fully integration-tested and enabled by
+// default.
+bool XdsSecurityEnabled();
+
 class XdsClient;
 
 class XdsApi {
diff --git a/src/core/ext/xds/xds_bootstrap.cc b/src/core/ext/xds/xds_bootstrap.cc
index 3daeb73..e48d982 100644
--- a/src/core/ext/xds/xds_bootstrap.cc
+++ b/src/core/ext/xds/xds_bootstrap.cc
@@ -29,6 +29,7 @@
 #include "absl/strings/string_view.h"
 
 #include "src/core/ext/xds/certificate_provider_registry.h"
+#include "src/core/ext/xds/xds_api.h"
 #include "src/core/lib/gpr/env.h"
 #include "src/core/lib/gpr/string.h"
 #include "src/core/lib/iomgr/load_file.h"
@@ -204,14 +205,16 @@
       if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error);
     }
   }
-  it = json.mutable_object()->find("certificate_providers");
-  if (it != json.mutable_object()->end()) {
-    if (it->second.type() != Json::Type::OBJECT) {
-      error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
-          "\"certificate_providers\" field is not an object"));
-    } else {
-      grpc_error* parse_error = ParseCertificateProviders(&it->second);
-      if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error);
+  if (XdsSecurityEnabled()) {
+    it = json.mutable_object()->find("certificate_providers");
+    if (it != json.mutable_object()->end()) {
+      if (it->second.type() != Json::Type::OBJECT) {
+        error_list.push_back(GRPC_ERROR_CREATE_FROM_STATIC_STRING(
+            "\"certificate_providers\" field is not an object"));
+      } else {
+        grpc_error* parse_error = ParseCertificateProviders(&it->second);
+        if (parse_error != GRPC_ERROR_NONE) error_list.push_back(parse_error);
+      }
     }
   }
   *error = GRPC_ERROR_CREATE_FROM_VECTOR("errors parsing xds bootstrap file",
diff --git a/test/core/xds/xds_bootstrap_test.cc b/test/core/xds/xds_bootstrap_test.cc
index 596d823..7a7545d 100644
--- a/test/core/xds/xds_bootstrap_test.cc
+++ b/test/core/xds/xds_bootstrap_test.cc
@@ -34,14 +34,40 @@
 namespace grpc_core {
 namespace testing {
 
-class XdsBootstrapTest : public ::testing::Test {
+class TestType {
  public:
-  XdsBootstrapTest() { grpc_init(); }
+  TestType(bool parse_xds_certificate_providers)
+      : parse_xds_certificate_providers_(parse_xds_certificate_providers) {}
+
+  bool parse_xds_certificate_providers() const {
+    return parse_xds_certificate_providers_;
+  }
+
+  std::string AsString() const {
+    return parse_xds_certificate_providers_
+               ? "WithCertificateProvidersParsing"
+               : "WithoutCertificateProvidersParsing";
+  }
+
+ private:
+  const bool parse_xds_certificate_providers_;
+};
+
+class XdsBootstrapTest : public ::testing::TestWithParam<TestType> {
+ public:
+  XdsBootstrapTest() {
+    if (GetParam().parse_xds_certificate_providers()) {
+      gpr_setenv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT", "true");
+    } else {
+      gpr_unsetenv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT");
+    }
+    grpc_init();
+  }
 
   ~XdsBootstrapTest() override { grpc_shutdown_blocking(); }
 };
 
-TEST_F(XdsBootstrapTest, Basic) {
+TEST_P(XdsBootstrapTest, Basic) {
   const char* json_str =
       "{"
       "  \"xds_servers\": ["
@@ -116,7 +142,7 @@
                           ::testing::Property(&Json::string_value, "1")))));
 }
 
-TEST_F(XdsBootstrapTest, ValidWithoutNode) {
+TEST_P(XdsBootstrapTest, ValidWithoutNode) {
   const char* json_str =
       "{"
       "  \"xds_servers\": ["
@@ -136,7 +162,7 @@
   EXPECT_EQ(bootstrap.node(), nullptr);
 }
 
-TEST_F(XdsBootstrapTest, InsecureCreds) {
+TEST_P(XdsBootstrapTest, InsecureCreds) {
   const char* json_str =
       "{"
       "  \"xds_servers\": ["
@@ -156,7 +182,7 @@
   EXPECT_EQ(bootstrap.node(), nullptr);
 }
 
-TEST_F(XdsBootstrapTest, GoogleDefaultCreds) {
+TEST_P(XdsBootstrapTest, GoogleDefaultCreds) {
   // Generate call creds file needed by GoogleDefaultCreds.
   const char token_str[] =
       "{ \"client_id\": \"32555999999.apps.googleusercontent.com\","
@@ -192,7 +218,7 @@
   EXPECT_EQ(bootstrap.node(), nullptr);
 }
 
-TEST_F(XdsBootstrapTest, MissingChannelCreds) {
+TEST_P(XdsBootstrapTest, MissingChannelCreds) {
   const char* json_str =
       "{"
       "  \"xds_servers\": ["
@@ -210,7 +236,7 @@
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, NoKnownChannelCreds) {
+TEST_P(XdsBootstrapTest, NoKnownChannelCreds) {
   const char* json_str =
       "{"
       "  \"xds_servers\": ["
@@ -230,7 +256,7 @@
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, MissingXdsServers) {
+TEST_P(XdsBootstrapTest, MissingXdsServers) {
   grpc_error* error = GRPC_ERROR_NONE;
   Json json = Json::Parse("{}", &error);
   ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
@@ -240,7 +266,7 @@
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, TopFieldsWrongTypes) {
+TEST_P(XdsBootstrapTest, TopFieldsWrongTypes) {
   const char* json_str =
       "{"
       "  \"xds_servers\":1,"
@@ -252,14 +278,17 @@
   ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
   XdsBootstrap bootstrap(std::move(json), &error);
   EXPECT_THAT(grpc_error_string(error),
-              ::testing::ContainsRegex(
-                  "\"xds_servers\" field is not an array.*"
-                  "\"node\" field is not an object.*"
-                  "\"certificate_providers\" field is not an object"));
+              ::testing::ContainsRegex("\"xds_servers\" field is not an array.*"
+                                       "\"node\" field is not an object.*"));
+  if (GetParam().parse_xds_certificate_providers()) {
+    EXPECT_THAT(grpc_error_string(error),
+                ::testing::ContainsRegex(
+                    "\"certificate_providers\" field is not an object"));
+  }
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, XdsServerMissingServerUri) {
+TEST_P(XdsBootstrapTest, XdsServerMissingServerUri) {
   const char* json_str =
       "{"
       "  \"xds_servers\":[{}]"
@@ -275,7 +304,7 @@
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, XdsServerUriAndCredsWrongTypes) {
+TEST_P(XdsBootstrapTest, XdsServerUriAndCredsWrongTypes) {
   const char* json_str =
       "{"
       "  \"xds_servers\":["
@@ -298,7 +327,7 @@
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, ChannelCredsFieldsWrongTypes) {
+TEST_P(XdsBootstrapTest, ChannelCredsFieldsWrongTypes) {
   const char* json_str =
       "{"
       "  \"xds_servers\":["
@@ -328,7 +357,7 @@
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, NodeFieldsWrongTypes) {
+TEST_P(XdsBootstrapTest, NodeFieldsWrongTypes) {
   const char* json_str =
       "{"
       "  \"node\":{"
@@ -351,7 +380,7 @@
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, LocalityFieldsWrongType) {
+TEST_P(XdsBootstrapTest, LocalityFieldsWrongType) {
   const char* json_str =
       "{"
       "  \"node\":{"
@@ -375,12 +404,13 @@
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, CertificateProvidersElementWrongType) {
+TEST_P(XdsBootstrapTest, CertificateProvidersElementWrongType) {
   const char* json_str =
       "{"
       "  \"xds_servers\": ["
       "    {"
-      "      \"server_uri\": \"fake:///lb\""
+      "      \"server_uri\": \"fake:///lb\","
+      "      \"channel_creds\": [{\"type\": \"fake\"}]"
       "    }"
       "  ],"
       "  \"certificate_providers\": {"
@@ -391,19 +421,24 @@
   Json json = Json::Parse(json_str, &error);
   ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
   XdsBootstrap bootstrap(std::move(json), &error);
-  EXPECT_THAT(grpc_error_string(error),
-              ::testing::ContainsRegex(
-                  "errors parsing \"certificate_providers\" object.*"
-                  "element \"plugin\" is not an object"));
+  if (GetParam().parse_xds_certificate_providers()) {
+    EXPECT_THAT(grpc_error_string(error),
+                ::testing::ContainsRegex(
+                    "errors parsing \"certificate_providers\" object.*"
+                    "element \"plugin\" is not an object"));
+  } else {
+    EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
+  }
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, CertificateProvidersPluginNameWrongType) {
+TEST_P(XdsBootstrapTest, CertificateProvidersPluginNameWrongType) {
   const char* json_str =
       "{"
       "  \"xds_servers\": ["
       "    {"
-      "      \"server_uri\": \"fake:///lb\""
+      "      \"server_uri\": \"fake:///lb\","
+      "      \"channel_creds\": [{\"type\": \"fake\"}]"
       "    }"
       "  ],"
       "  \"certificate_providers\": {"
@@ -416,11 +451,15 @@
   Json json = Json::Parse(json_str, &error);
   ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
   XdsBootstrap bootstrap(std::move(json), &error);
-  EXPECT_THAT(grpc_error_string(error),
-              ::testing::ContainsRegex(
-                  "errors parsing \"certificate_providers\" object.*"
-                  "errors parsing element \"plugin\".*"
-                  "\"plugin_name\" field is not a string"));
+  if (GetParam().parse_xds_certificate_providers()) {
+    EXPECT_THAT(grpc_error_string(error),
+                ::testing::ContainsRegex(
+                    "errors parsing \"certificate_providers\" object.*"
+                    "errors parsing element \"plugin\".*"
+                    "\"plugin_name\" field is not a string"));
+  } else {
+    EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
+  }
   GRPC_ERROR_UNREF(error);
 }
 
@@ -473,14 +512,15 @@
   }
 };
 
-TEST_F(XdsBootstrapTest, CertificateProvidersFakePluginParsingError) {
+TEST_P(XdsBootstrapTest, CertificateProvidersFakePluginParsingError) {
   CertificateProviderRegistry::RegisterCertificateProviderFactory(
       absl::make_unique<FakeCertificateProviderFactory>());
   const char* json_str =
       "{"
       "  \"xds_servers\": ["
       "    {"
-      "      \"server_uri\": \"fake:///lb\""
+      "      \"server_uri\": \"fake:///lb\","
+      "      \"channel_creds\": [{\"type\": \"fake\"}]"
       "    }"
       "  ],"
       "  \"certificate_providers\": {"
@@ -496,15 +536,19 @@
   Json json = Json::Parse(json_str, &error);
   ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
   XdsBootstrap bootstrap(std::move(json), &error);
-  EXPECT_THAT(grpc_error_string(error),
-              ::testing::ContainsRegex(
-                  "errors parsing \"certificate_providers\" object.*"
-                  "errors parsing element \"fake_plugin\".*"
-                  "field:config field:value not of type number"));
+  if (GetParam().parse_xds_certificate_providers()) {
+    EXPECT_THAT(grpc_error_string(error),
+                ::testing::ContainsRegex(
+                    "errors parsing \"certificate_providers\" object.*"
+                    "errors parsing element \"fake_plugin\".*"
+                    "field:config field:value not of type number"));
+  } else {
+    EXPECT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
+  }
   GRPC_ERROR_UNREF(error);
 }
 
-TEST_F(XdsBootstrapTest, CertificateProvidersFakePluginParsingSuccess) {
+TEST_P(XdsBootstrapTest, CertificateProvidersFakePluginParsingSuccess) {
   CertificateProviderRegistry::RegisterCertificateProviderFactory(
       absl::make_unique<FakeCertificateProviderFactory>());
   const char* json_str =
@@ -529,17 +573,22 @@
   ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
   XdsBootstrap bootstrap(std::move(json), &error);
   ASSERT_TRUE(error == GRPC_ERROR_NONE) << grpc_error_string(error);
-  const CertificateProviderStore::PluginDefinition& fake_plugin =
-      bootstrap.certificate_providers().at("fake_plugin");
-  ASSERT_EQ(fake_plugin.plugin_name, "fake");
-  ASSERT_STREQ(fake_plugin.config->name(), "fake");
-  ASSERT_EQ(static_cast<RefCountedPtr<FakeCertificateProviderFactory::Config>>(
-                fake_plugin.config)
-                ->value(),
-            10);
+  if (GetParam().parse_xds_certificate_providers()) {
+    const CertificateProviderStore::PluginDefinition& fake_plugin =
+        bootstrap.certificate_providers().at("fake_plugin");
+    ASSERT_EQ(fake_plugin.plugin_name, "fake");
+    ASSERT_STREQ(fake_plugin.config->name(), "fake");
+    ASSERT_EQ(
+        static_cast<RefCountedPtr<FakeCertificateProviderFactory::Config>>(
+            fake_plugin.config)
+            ->value(),
+        10);
+  } else {
+    EXPECT_TRUE(bootstrap.certificate_providers().empty());
+  }
 }
 
-TEST_F(XdsBootstrapTest, CertificateProvidersFakePluginEmptyConfig) {
+TEST_P(XdsBootstrapTest, CertificateProvidersFakePluginEmptyConfig) {
   CertificateProviderRegistry::RegisterCertificateProviderFactory(
       absl::make_unique<FakeCertificateProviderFactory>());
   const char* json_str =
@@ -561,16 +610,29 @@
   ASSERT_EQ(error, GRPC_ERROR_NONE) << grpc_error_string(error);
   XdsBootstrap bootstrap(std::move(json), &error);
   ASSERT_TRUE(error == GRPC_ERROR_NONE) << grpc_error_string(error);
-  const CertificateProviderStore::PluginDefinition& fake_plugin =
-      bootstrap.certificate_providers().at("fake_plugin");
-  ASSERT_EQ(fake_plugin.plugin_name, "fake");
-  ASSERT_STREQ(fake_plugin.config->name(), "fake");
-  ASSERT_EQ(static_cast<RefCountedPtr<FakeCertificateProviderFactory::Config>>(
-                fake_plugin.config)
-                ->value(),
-            0);
+  if (GetParam().parse_xds_certificate_providers()) {
+    const CertificateProviderStore::PluginDefinition& fake_plugin =
+        bootstrap.certificate_providers().at("fake_plugin");
+    ASSERT_EQ(fake_plugin.plugin_name, "fake");
+    ASSERT_STREQ(fake_plugin.config->name(), "fake");
+    ASSERT_EQ(
+        static_cast<RefCountedPtr<FakeCertificateProviderFactory::Config>>(
+            fake_plugin.config)
+            ->value(),
+        0);
+  } else {
+    EXPECT_TRUE(bootstrap.certificate_providers().empty());
+  }
 }
 
+std::string TestTypeName(const ::testing::TestParamInfo<TestType>& info) {
+  return info.param.AsString();
+}
+
+INSTANTIATE_TEST_SUITE_P(XdsBootstrap, XdsBootstrapTest,
+                         ::testing::Values(TestType(false), TestType(true)),
+                         &TestTypeName);
+
 }  // namespace testing
 }  // namespace grpc_core