xds: Add bootstrap support for certificate providers. (#3901)

diff --git a/credentials/tls/certprovider/provider.go b/credentials/tls/certprovider/provider.go
index 204e556..8d8ae80 100644
--- a/credentials/tls/certprovider/provider.go
+++ b/credentials/tls/certprovider/provider.go
@@ -29,8 +29,14 @@
 	"crypto/tls"
 	"crypto/x509"
 	"errors"
+
+	"google.golang.org/grpc/internal"
 )
 
+func init() {
+	internal.GetCertificateProviderBuilder = getBuilder
+}
+
 var (
 	// errProviderClosed is returned by Distributor.KeyMaterial when it is
 	// closed.
diff --git a/internal/internal.go b/internal/internal.go
index 818ca85..716d928 100644
--- a/internal/internal.go
+++ b/internal/internal.go
@@ -52,6 +52,11 @@
 	// This function compares the config without rawJSON stripped, in case the
 	// there's difference in white space.
 	EqualServiceConfigForTesting func(a, b serviceconfig.Config) bool
+	// GetCertificateProviderBuilder returns the registered builder for the
+	// given name. This is set by package certprovider for use from xDS
+	// bootstrap code while parsing certificate provider configs in the
+	// bootstrap file.
+	GetCertificateProviderBuilder interface{} // func(string) certprovider.Builder
 )
 
 // HealthChecker defines the signature of the client-side LB channel health checking function.
diff --git a/xds/internal/client/bootstrap/bootstrap.go b/xds/internal/client/bootstrap/bootstrap.go
index e3f2ce5..93e6d3e 100644
--- a/xds/internal/client/bootstrap/bootstrap.go
+++ b/xds/internal/client/bootstrap/bootstrap.go
@@ -33,6 +33,8 @@
 	"github.com/golang/protobuf/proto"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/credentials/google"
+	"google.golang.org/grpc/credentials/tls/certprovider"
+	"google.golang.org/grpc/internal"
 	"google.golang.org/grpc/xds/internal/version"
 )
 
@@ -77,6 +79,18 @@
 	// NodeProto contains the Node proto to be used in xDS requests. The actual
 	// type depends on the transport protocol version used.
 	NodeProto proto.Message
+	// CertProviderConfigs contain parsed configs for supported certificate
+	// provider plugins found in the bootstrap file.
+	CertProviderConfigs map[string]CertProviderConfig
+}
+
+// CertProviderConfig wraps the certificate provider plugin name and config
+// (corresponding to one plugin instance) found in the bootstrap file.
+type CertProviderConfig struct {
+	// Name is the registered name of the certificate provider.
+	Name string
+	// Config is the parsed config to be passed to the certificate provider.
+	Config certprovider.StableConfig
 }
 
 type channelCreds struct {
@@ -103,6 +117,10 @@
 //        }
 //      ],
 //      "server_features": [ ... ]
+//		"certificate_providers" : {
+//			"default": { default cert provider config },
+//			"foo": { config for provider foo }
+//		}
 //    },
 //    "node": <JSON form of Node proto>
 // }
@@ -182,6 +200,35 @@
 					serverSupportsV3 = true
 				}
 			}
+		case "certificate_providers":
+			var providerInstances map[string]json.RawMessage
+			if err := json.Unmarshal(v, &providerInstances); err != nil {
+				return nil, fmt.Errorf("xds: json.Unmarshal(%v) for field %q failed during bootstrap: %v", string(v), k, err)
+			}
+			configs := make(map[string]CertProviderConfig)
+			getBuilder := internal.GetCertificateProviderBuilder.(func(string) certprovider.Builder)
+			for instance, data := range providerInstances {
+				var providerConfigs map[string]json.RawMessage
+				if err := json.Unmarshal(data, &providerConfigs); err != nil {
+					return nil, fmt.Errorf("xds: json.Unmarshal(%v) for field %q failed during bootstrap: %v", string(v), instance, err)
+				}
+				for name, cfg := range providerConfigs {
+					parser := getBuilder(name)
+					if parser == nil {
+						// We ignore plugins that we do not know about.
+						continue
+					}
+					c, err := parser.ParseConfig(cfg)
+					if err != nil {
+						return nil, fmt.Errorf("xds: Config parsing for plugin %q failed: %v", name, err)
+					}
+					configs[instance] = CertProviderConfig{
+						Name:   name,
+						Config: c,
+					}
+				}
+			}
+			config.CertProviderConfigs = configs
 		}
 		// Do not fail the xDS bootstrap when an unknown field is seen. This can
 		// happen when an older version client reads a newer version bootstrap
diff --git a/xds/internal/client/bootstrap/bootstrap_test.go b/xds/internal/client/bootstrap/bootstrap_test.go
index a7734f0..353bcd9 100644
--- a/xds/internal/client/bootstrap/bootstrap_test.go
+++ b/xds/internal/client/bootstrap/bootstrap_test.go
@@ -19,6 +19,8 @@
 package bootstrap
 
 import (
+	"encoding/json"
+	"errors"
 	"fmt"
 	"os"
 	"testing"
@@ -28,8 +30,11 @@
 	"github.com/golang/protobuf/proto"
 	structpb "github.com/golang/protobuf/ptypes/struct"
 	"github.com/google/go-cmp/cmp"
+
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/credentials/google"
+	"google.golang.org/grpc/credentials/tls/certprovider"
+	"google.golang.org/grpc/internal"
 	"google.golang.org/grpc/xds/internal/version"
 )
 
@@ -233,6 +238,24 @@
 	if diff := cmp.Diff(want.NodeProto, c.NodeProto, cmp.Comparer(proto.Equal)); diff != "" {
 		return fmt.Errorf("config.NodeProto diff (-want, +got):\n%s", diff)
 	}
+
+	// A vanilla cmp.Equal or cmp.Diff will not produce useful error message
+	// here. So, we iterate through the list of configs and compare them one at
+	// a time.
+	gotCfgs := c.CertProviderConfigs
+	wantCfgs := want.CertProviderConfigs
+	if len(gotCfgs) != len(wantCfgs) {
+		return fmt.Errorf("config.CertProviderConfigs is %d entries, want %d", len(gotCfgs), len(wantCfgs))
+	}
+	for instance, gotCfg := range gotCfgs {
+		wantCfg, ok := wantCfgs[instance]
+		if !ok {
+			return fmt.Errorf("config.CertProviderConfigs has unexpected plugin instance %q with config %q", instance, string(gotCfg.Config.Canonical()))
+		}
+		if gotCfg.Name != wantCfg.Name || !cmp.Equal(gotCfg.Config.Canonical(), wantCfg.Config.Canonical()) {
+			return fmt.Errorf("config.CertProviderConfigs for plugin instance %q has config {%s, %s, want {%s, %s}", instance, gotCfg.Name, string(gotCfg.Config.Canonical()), wantCfg.Name, string(wantCfg.Config.Canonical()))
+		}
+	}
 	return nil
 }
 
@@ -452,3 +475,231 @@
 		t.Errorf("NewConfig() returned nil error, expected to fail")
 	}
 }
+
+func init() {
+	certprovider.Register(&fakeCertProviderBuilder{})
+}
+
+const fakeCertProviderName = "fake-certificate-provider"
+
+// fakeCertProviderBuilder builds new instances of fakeCertProvider and
+// interprets the config provided to it as JSON with a single key and value.
+type fakeCertProviderBuilder struct{}
+
+func (b *fakeCertProviderBuilder) Build(certprovider.StableConfig, certprovider.Options) certprovider.Provider {
+	return &fakeCertProvider{}
+}
+
+// ParseConfig expects input in JSON format containing a map from string to
+// string, with a single entry and mapKey being "configKey".
+func (b *fakeCertProviderBuilder) ParseConfig(cfg interface{}) (certprovider.StableConfig, error) {
+	config, ok := cfg.(json.RawMessage)
+	if !ok {
+		return nil, fmt.Errorf("fakeCertProviderBuilder received config of type %T, want []byte", config)
+	}
+	var cfgData map[string]string
+	if err := json.Unmarshal(config, &cfgData); err != nil {
+		return nil, fmt.Errorf("fakeCertProviderBuilder config parsing failed: %v", err)
+	}
+	if len(cfgData) != 1 || cfgData["configKey"] == "" {
+		return nil, errors.New("fakeCertProviderBuilder received invalid config")
+	}
+	return &fakeStableConfig{config: cfgData}, nil
+}
+
+func (b *fakeCertProviderBuilder) Name() string {
+	return fakeCertProviderName
+}
+
+type fakeStableConfig struct {
+	config map[string]string
+}
+
+func (c *fakeStableConfig) Canonical() []byte {
+	var cfg string
+	for k, v := range c.config {
+		cfg = fmt.Sprintf("%s:%s", k, v)
+	}
+	return []byte(cfg)
+}
+
+// fakeCertProvider is an empty implementation of the Provider interface.
+type fakeCertProvider struct {
+	certprovider.Provider
+}
+
+func TestNewConfigWithCertificateProviders(t *testing.T) {
+	bootstrapFileMap := map[string]string{
+		"badJSONCertProviderConfig": `
+		{
+			"node": {
+				"id": "ENVOY_NODE_ID",
+				"metadata": {
+				    "TRAFFICDIRECTOR_GRPC_HOSTNAME": "trafficdirector"
+			    }
+			},
+			"xds_servers" : [{
+				"server_uri": "trafficdirector.googleapis.com:443",
+				"channel_creds": [
+					{ "type": "google_default" }
+				]
+			}],
+			"server_features" : ["foo", "bar", "xds_v3"],
+			"certificate_providers": "bad JSON"
+		}`,
+		"allUnknownCertProviders": `
+		{
+			"node": {
+				"id": "ENVOY_NODE_ID",
+				"metadata": {
+				    "TRAFFICDIRECTOR_GRPC_HOSTNAME": "trafficdirector"
+			    }
+			},
+			"xds_servers" : [{
+				"server_uri": "trafficdirector.googleapis.com:443",
+				"channel_creds": [
+					{ "type": "google_default" }
+				]
+			}],
+			"server_features" : ["foo", "bar", "xds_v3"],
+			"certificate_providers": {
+				"unknownProviderInstance1": {
+					"foo1": "bar1"
+				},
+				"unknownProviderInstance2": {
+					"foo2": "bar2"
+				}
+			}
+		}`,
+		"badCertProviderConfig": `
+		{
+			"node": {
+				"id": "ENVOY_NODE_ID",
+				"metadata": {
+				    "TRAFFICDIRECTOR_GRPC_HOSTNAME": "trafficdirector"
+			    }
+			},
+			"xds_servers" : [{
+				"server_uri": "trafficdirector.googleapis.com:443",
+				"channel_creds": [
+					{ "type": "google_default" }
+				]
+			}],
+			"server_features" : ["foo", "bar", "xds_v3"],
+			"certificate_providers": {
+				"unknownProviderInstance": {
+					"foo": "bar"
+				},
+				"fakeProviderInstance": {
+					"fake-certificate-provider": {
+						"configKey": "configValue"
+					}
+				},
+				"fakeProviderInstanceBad": {
+					"fake-certificate-provider": {
+						"configKey": 666
+					}
+				}
+			}
+		}`,
+		"goodCertProviderConfig": `
+		{
+			"node": {
+				"id": "ENVOY_NODE_ID",
+				"metadata": {
+				    "TRAFFICDIRECTOR_GRPC_HOSTNAME": "trafficdirector"
+			    }
+			},
+			"xds_servers" : [{
+				"server_uri": "trafficdirector.googleapis.com:443",
+				"channel_creds": [
+					{ "type": "google_default" }
+				]
+			}],
+			"server_features" : ["foo", "bar", "xds_v3"],
+			"certificate_providers": {
+				"unknownProviderInstance": {
+					"foo": "bar"
+				},
+				"fakeProviderInstance": {
+					"fake-certificate-provider": {
+						"configKey": "configValue"
+					}
+				}
+			}
+		}`,
+	}
+
+	getBuilder := internal.GetCertificateProviderBuilder.(func(string) certprovider.Builder)
+	parser := getBuilder(fakeCertProviderName)
+	if parser == nil {
+		t.Fatalf("missing certprovider plugin %q", fakeCertProviderName)
+	}
+	wantCfg, err := parser.ParseConfig(json.RawMessage(`{"configKey": "configValue"}`))
+	if err != nil {
+		t.Fatalf("config parsing for plugin %q failed: %v", fakeCertProviderName, err)
+	}
+
+	if err := os.Setenv(v3SupportEnv, "true"); err != nil {
+		t.Fatalf("os.Setenv(%s, %s) failed with error: %v", v3SupportEnv, "true", err)
+	}
+	defer os.Unsetenv(v3SupportEnv)
+
+	cancel := setupBootstrapOverride(bootstrapFileMap)
+	defer cancel()
+
+	goodConfig := &Config{
+		BalancerName: "trafficdirector.googleapis.com:443",
+		Creds:        grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()),
+		TransportAPI: version.TransportV3,
+		NodeProto:    v3NodeProto,
+		CertProviderConfigs: map[string]CertProviderConfig{
+			"fakeProviderInstance": {
+				Name:   fakeCertProviderName,
+				Config: wantCfg,
+			},
+		},
+	}
+	tests := []struct {
+		name       string
+		wantConfig *Config
+		wantErr    bool
+	}{
+		{
+			name:    "badJSONCertProviderConfig",
+			wantErr: true,
+		},
+		{
+
+			name:    "badCertProviderConfig",
+			wantErr: true,
+		},
+		{
+
+			name:       "allUnknownCertProviders",
+			wantConfig: nonNilCredsConfigV3,
+		},
+		{
+			name:       "goodCertProviderConfig",
+			wantConfig: goodConfig,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			if err := os.Setenv(bootstrapFileEnv, test.name); err != nil {
+				t.Fatalf("os.Setenv(%s, %s) failed with error: %v", bootstrapFileEnv, test.name, err)
+			}
+			c, err := NewConfig()
+			if (err != nil) != test.wantErr {
+				t.Fatalf("NewConfig() returned: %v, wantErr: %v", err, test.wantErr)
+			}
+			if test.wantErr {
+				return
+			}
+			if err := c.compare(test.wantConfig); err != nil {
+				t.Fatal(err)
+			}
+		})
+	}
+}