meshca: CertificateProvider plugin implementation. (#3871)
diff --git a/credentials/sts/sts.go b/credentials/sts/sts.go
index a97b20a..f4d5801 100644
--- a/credentials/sts/sts.go
+++ b/credentials/sts/sts.go
@@ -111,6 +111,10 @@
ActorTokenType string // Optional.
}
+func (o Options) String() string {
+ return fmt.Sprintf("%s:%s:%s:%s:%s:%s:%s:%s:%s", o.TokenExchangeServiceURI, o.Resource, o.Audience, o.Scope, o.RequestedTokenType, o.SubjectTokenPath, o.SubjectTokenType, o.ActorTokenPath, o.ActorTokenType)
+}
+
// NewCredentials returns a new PerRPCCredentials implementation, configured
// using opts, which performs token exchange using STS.
func NewCredentials(opts Options) (credentials.PerRPCCredentials, error) {
@@ -213,7 +217,7 @@
return err
}
if u.Scheme != "http" && u.Scheme != "https" {
- return fmt.Errorf("scheme is not supported: %s. Only http(s) is supported", u.Scheme)
+ return fmt.Errorf("scheme is not supported: %q. Only http(s) is supported", u.Scheme)
}
if opts.SubjectTokenPath == "" {
diff --git a/credentials/sts/sts_test.go b/credentials/sts/sts_test.go
index cc8b08e..9cfa120 100644
--- a/credentials/sts/sts_test.go
+++ b/credentials/sts/sts_test.go
@@ -43,19 +43,20 @@
)
const (
- requestedTokenType = "urn:ietf:params:oauth:token-type:access-token"
- actorTokenPath = "/var/run/secrets/token.jwt"
- actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
- actorTokenContents = "actorToken.jwt.contents"
- accessTokenContents = "access_token"
- subjectTokenPath = "/var/run/secrets/token.jwt"
- subjectTokenType = "urn:ietf:params:oauth:token-type:id_token"
- subjectTokenContents = "subjectToken.jwt.contents"
- serviceURI = "http://localhost"
- exampleResource = "https://backend.example.com/api"
- exampleAudience = "example-backend-service"
- testScope = "https://www.googleapis.com/auth/monitoring"
- defaultTestTimeout = 1 * time.Second
+ requestedTokenType = "urn:ietf:params:oauth:token-type:access-token"
+ actorTokenPath = "/var/run/secrets/token.jwt"
+ actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
+ actorTokenContents = "actorToken.jwt.contents"
+ accessTokenContents = "access_token"
+ subjectTokenPath = "/var/run/secrets/token.jwt"
+ subjectTokenType = "urn:ietf:params:oauth:token-type:id_token"
+ subjectTokenContents = "subjectToken.jwt.contents"
+ serviceURI = "http://localhost"
+ exampleResource = "https://backend.example.com/api"
+ exampleAudience = "example-backend-service"
+ testScope = "https://www.googleapis.com/auth/monitoring"
+ defaultTestTimeout = 1 * time.Second
+ defaultTestShortTimeout = 10 * time.Millisecond
)
var (
@@ -132,35 +133,13 @@
}
}
-// fakeHTTPDoer helps mock out the http.Client.Do calls made by the credentials
-// code under test. It makes the http.Request made by the credentials available
-// through a channel, and makes it possible to inject various responses.
-type fakeHTTPDoer struct {
- reqCh *testutils.Channel
- respCh *testutils.Channel
- err error
-}
-
-func (fc *fakeHTTPDoer) Do(req *http.Request) (*http.Response, error) {
- fc.reqCh.Send(req)
-
- ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
- defer cancel()
-
- val, err := fc.respCh.Receive(ctx)
- if err != nil {
- return nil, err
- }
- return val.(*http.Response), fc.err
-}
-
// Overrides the http.Client with a fakeClient which sends a good response.
-func overrideHTTPClientGood() (*fakeHTTPDoer, func()) {
- fc := &fakeHTTPDoer{
- reqCh: testutils.NewChannel(),
- respCh: testutils.NewChannel(),
+func overrideHTTPClientGood() (*testutils.FakeHTTPClient, func()) {
+ fc := &testutils.FakeHTTPClient{
+ ReqChan: testutils.NewChannel(),
+ RespChan: testutils.NewChannel(),
}
- fc.respCh.Send(makeGoodResponse())
+ fc.RespChan.Send(makeGoodResponse())
origMakeHTTPDoer := makeHTTPDoer
makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
@@ -168,7 +147,7 @@
}
// Overrides the http.Client with the provided fakeClient.
-func overrideHTTPClient(fc *fakeHTTPDoer) func() {
+func overrideHTTPClient(fc *testutils.FakeHTTPClient) func() {
origMakeHTTPDoer := makeHTTPDoer
makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
return func() { makeHTTPDoer = origMakeHTTPDoer }
@@ -244,11 +223,11 @@
// expected goodRequest. This is expected to be called in a separate goroutine
// by the tests. So, any errors encountered are pushed to an error channel
// which is monitored by the test.
-func receiveAndCompareRequest(reqCh *testutils.Channel, errCh chan error) {
+func receiveAndCompareRequest(ReqChan *testutils.Channel, errCh chan error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
- val, err := reqCh.Receive(ctx)
+ val, err := ReqChan.Receive(ctx)
if err != nil {
errCh <- err
return
@@ -274,7 +253,7 @@
}
errCh := make(chan error, 1)
- go receiveAndCompareRequest(fc.reqCh, errCh)
+ go receiveAndCompareRequest(fc.ReqChan, errCh)
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
if err != nil {
@@ -323,9 +302,9 @@
func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) {
const expiresInSecs = 1
defer overrideSubjectTokenGood()()
- fc := &fakeHTTPDoer{
- reqCh: testutils.NewChannel(),
- respCh: testutils.NewChannel(),
+ fc := &testutils.FakeHTTPClient{
+ ReqChan: testutils.NewChannel(),
+ RespChan: testutils.NewChannel(),
}
defer overrideHTTPClient(fc)()
@@ -340,7 +319,7 @@
// out a fresh request.
for i := 0; i < 2; i++ {
errCh := make(chan error, 1)
- go receiveAndCompareRequest(fc.reqCh, errCh)
+ go receiveAndCompareRequest(fc.ReqChan, errCh)
respJSON, _ := json.Marshal(responseParameters{
AccessToken: accessTokenContents,
@@ -354,7 +333,7 @@
StatusCode: http.StatusOK,
Body: respBody,
}
- fc.respCh.Send(resp)
+ fc.RespChan.Send(resp)
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
if err != nil {
@@ -399,9 +378,9 @@
t.Run(test.name, func(t *testing.T) {
defer overrideSubjectTokenGood()()
- fc := &fakeHTTPDoer{
- reqCh: testutils.NewChannel(),
- respCh: testutils.NewChannel(),
+ fc := &testutils.FakeHTTPClient{
+ ReqChan: testutils.NewChannel(),
+ RespChan: testutils.NewChannel(),
}
defer overrideHTTPClient(fc)()
@@ -411,9 +390,9 @@
}
errCh := make(chan error, 1)
- go receiveAndCompareRequest(fc.reqCh, errCh)
+ go receiveAndCompareRequest(fc.ReqChan, errCh)
- fc.respCh.Send(test.response)
+ fc.RespChan.Send(test.response)
if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil {
t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
}
@@ -438,10 +417,9 @@
errCh := make(chan error, 1)
go func() {
- ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer cancel()
-
- if _, err := fc.reqCh.Receive(ctx); err != context.DeadlineExceeded {
+ if _, err := fc.ReqChan.Receive(ctx); err != context.DeadlineExceeded {
errCh <- err
return
}
@@ -698,12 +676,12 @@
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- client := &fakeHTTPDoer{
- reqCh: testutils.NewChannel(),
- respCh: testutils.NewChannel(),
- err: test.respErr,
+ client := &testutils.FakeHTTPClient{
+ ReqChan: testutils.NewChannel(),
+ RespChan: testutils.NewChannel(),
+ Err: test.respErr,
}
- client.respCh.Send(test.resp)
+ client.RespChan.Send(test.resp)
_, err := sendRequest(client, req)
if (err != nil) != test.wantErr {
t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr)
diff --git a/credentials/tls/certprovider/meshca/builder.go b/credentials/tls/certprovider/meshca/builder.go
new file mode 100644
index 0000000..3544a16
--- /dev/null
+++ b/credentials/tls/certprovider/meshca/builder.go
@@ -0,0 +1,169 @@
+// +build go1.13
+
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package meshca
+
+import (
+ "crypto/x509"
+ "encoding/json"
+ "fmt"
+ "sync"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/credentials/sts"
+ "google.golang.org/grpc/credentials/tls/certprovider"
+ "google.golang.org/grpc/internal/backoff"
+)
+
+const pluginName = "mesh_ca"
+
+// For overriding in unit tests.
+var (
+ grpcDialFunc = grpc.Dial
+ backoffFunc = backoff.DefaultExponential.Backoff
+)
+
+func init() {
+ certprovider.Register(newPluginBuilder())
+}
+
+func newPluginBuilder() *pluginBuilder {
+ return &pluginBuilder{clients: make(map[ccMapKey]*refCountedCC)}
+}
+
+// Key for the map containing ClientConns to the MeshCA server. Only the server
+// name and the STS options (which is used to create call creds) from the plugin
+// configuration determine if two configs can share the same ClientConn. Hence
+// only those form the key to this map.
+type ccMapKey struct {
+ name string
+ stsOpts sts.Options
+}
+
+// refCountedCC wraps a grpc.ClientConn to MeshCA along with a reference count.
+type refCountedCC struct {
+ cc *grpc.ClientConn
+ refCnt int
+}
+
+// pluginBuilder is an implementation of the certprovider.Builder interface,
+// which build certificate provider instances which get certificates signed from
+// the MeshCA.
+type pluginBuilder struct {
+ // A collection of ClientConns to the MeshCA server along with a reference
+ // count. Provider instances whose config point to the same server name will
+ // end up sharing the ClientConn.
+ mu sync.Mutex
+ clients map[ccMapKey]*refCountedCC
+}
+
+// Build returns a MeshCA certificate provider for the passed in configuration
+// and options.
+//
+// This builder takes care of sharing the ClientConn to the MeshCA server among
+// different plugin instantiations.
+func (b *pluginBuilder) Build(c certprovider.StableConfig, opts certprovider.Options) certprovider.Provider {
+ cfg, ok := c.(*pluginConfig)
+ if !ok {
+ // This is not expected when passing config returned by ParseConfig().
+ // This could indicate a bug in the certprovider.Store implementation or
+ // in cases where the user is directly using these APIs, could be a user
+ // error.
+ logger.Errorf("unsupported config type: %T", c)
+ return nil
+ }
+
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ ccmk := ccMapKey{
+ name: cfg.serverURI,
+ stsOpts: cfg.stsOpts,
+ }
+ rcc, ok := b.clients[ccmk]
+ if !ok {
+ // STS call credentials take care of exchanging a locally provisioned
+ // JWT token for an access token which will be accepted by the MeshCA.
+ callCreds, err := sts.NewCredentials(cfg.stsOpts)
+ if err != nil {
+ logger.Errorf("sts.NewCredentials() failed: %v", err)
+ return nil
+ }
+
+ // MeshCA is a public endpoint whose certificate is Web-PKI compliant.
+ // So, we just need to use the system roots to authenticate the MeshCA.
+ cp, err := x509.SystemCertPool()
+ if err != nil {
+ logger.Errorf("x509.SystemCertPool() failed: %v", err)
+ return nil
+ }
+ transportCreds := credentials.NewClientTLSFromCert(cp, "")
+
+ cc, err := grpcDialFunc(cfg.serverURI, grpc.WithTransportCredentials(transportCreds), grpc.WithPerRPCCredentials(callCreds))
+ if err != nil {
+ logger.Errorf("grpc.Dial(%s) failed: %v", cfg.serverURI, err)
+ return nil
+ }
+
+ rcc = &refCountedCC{cc: cc}
+ b.clients[ccmk] = rcc
+ }
+ rcc.refCnt++
+
+ p := newProviderPlugin(providerParams{
+ cc: rcc.cc,
+ cfg: cfg,
+ opts: opts,
+ backoff: backoffFunc,
+ doneFunc: func() {
+ // The plugin implementation will invoke this function when it is
+ // being closed, and here we take care of closing the ClientConn
+ // when there are no more plugins using it. We need to acquire the
+ // lock before accessing the rcc from the enclosing function.
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ rcc.refCnt--
+ if rcc.refCnt == 0 {
+ logger.Infof("Closing grpc.ClientConn to %s", ccmk.name)
+ rcc.cc.Close()
+ delete(b.clients, ccmk)
+ }
+ },
+ })
+ return p
+}
+
+// ParseConfig parses the configuration to be passed to the MeshCA plugin
+// implementation. Expects the config to be a json.RawMessage which contains a
+// serialized JSON representation of the meshca_experimental.GoogleMeshCaConfig
+// proto message.
+func (b *pluginBuilder) ParseConfig(c interface{}) (certprovider.StableConfig, error) {
+ data, ok := c.(json.RawMessage)
+ if !ok {
+ return nil, fmt.Errorf("meshca: unsupported config type: %T", c)
+ }
+ return pluginConfigFromJSON(data)
+}
+
+// Name returns the MeshCA plugin name.
+func (b *pluginBuilder) Name() string {
+ return pluginName
+}
diff --git a/credentials/tls/certprovider/meshca/builder_test.go b/credentials/tls/certprovider/meshca/builder_test.go
new file mode 100644
index 0000000..b395f4f
--- /dev/null
+++ b/credentials/tls/certprovider/meshca/builder_test.go
@@ -0,0 +1,182 @@
+// +build go1.13
+
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package meshca
+
+import (
+ "context"
+ "fmt"
+ "testing"
+
+ "github.com/golang/protobuf/proto"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/connectivity"
+ "google.golang.org/grpc/credentials/tls/certprovider"
+ configpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/meshca_experimental"
+ "google.golang.org/grpc/internal/testutils"
+)
+
+func overrideHTTPFuncs() func() {
+ // Directly override the functions which are used to read the zone and
+ // audience instead of overriding the http.Client.
+ origReadZone := readZoneFunc
+ readZoneFunc = func(httpDoer) string { return "test-zone" }
+ origReadAudience := readAudienceFunc
+ readAudienceFunc = func(httpDoer) string { return "test-audience" }
+ return func() {
+ readZoneFunc = origReadZone
+ readAudienceFunc = origReadAudience
+ }
+}
+
+func (s) TestBuildSameConfig(t *testing.T) {
+ defer overrideHTTPFuncs()()
+
+ // We will attempt to create `cnt` number of providers. So we create a
+ // channel of the same size here, even though we expect only one ClientConn
+ // to be pushed into this channel. This makes sure that even if more than
+ // one ClientConn ends up being created, the Build() call does not block.
+ const cnt = 5
+ ccChan := testutils.NewChannelWithSize(cnt)
+
+ // Override the dial func to dial a dummy MeshCA endpoint, and also push the
+ // returned ClientConn on a channel to be inspected by the test.
+ origDialFunc := grpcDialFunc
+ grpcDialFunc = func(string, ...grpc.DialOption) (*grpc.ClientConn, error) {
+ cc, err := grpc.Dial("dummy-meshca-endpoint", grpc.WithInsecure())
+ ccChan.Send(cc)
+ return cc, err
+ }
+ defer func() { grpcDialFunc = origDialFunc }()
+
+ // Parse a good config to generate a stable config which will be passed to
+ // invocations of Build().
+ inputConfig := makeJSONConfig(t, goodConfigFullySpecified)
+ builder := newPluginBuilder()
+ stableConfig, err := builder.ParseConfig(inputConfig)
+ if err != nil {
+ t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err)
+ }
+
+ // Create multiple providers with the same config. All these providers must
+ // end up sharing the same ClientConn.
+ providers := []certprovider.Provider{}
+ for i := 0; i < cnt; i++ {
+ p := builder.Build(stableConfig, certprovider.Options{})
+ if p == nil {
+ t.Fatalf("builder.Build(%s) failed: %v", string(stableConfig.Canonical()), err)
+ }
+ providers = append(providers, p)
+ }
+
+ // Make sure only one ClientConn is created.
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ val, err := ccChan.Receive(ctx)
+ if err != nil {
+ t.Fatalf("Failed to create ClientConn: %v", err)
+ }
+ testCC := val.(*grpc.ClientConn)
+
+ // Attempt to read the second ClientConn should timeout.
+ ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
+ defer cancel()
+ if _, err := ccChan.Receive(ctx); err != context.DeadlineExceeded {
+ t.Fatal("Builder created more than one ClientConn")
+ }
+
+ for _, p := range providers {
+ p.Close()
+ }
+
+ for {
+ state := testCC.GetState()
+ if state == connectivity.Shutdown {
+ break
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ if !testCC.WaitForStateChange(ctx, state) {
+ t.Fatalf("timeout waiting for clientConn state to change from %s", state)
+ }
+ }
+}
+
+func (s) TestBuildDifferentConfig(t *testing.T) {
+ defer overrideHTTPFuncs()()
+
+ // We will attempt to create two providers with different configs. So we
+ // expect two ClientConns to be pushed on to this channel.
+ const cnt = 2
+ ccChan := testutils.NewChannelWithSize(cnt)
+
+ // Override the dial func to dial a dummy MeshCA endpoint, and also push the
+ // returned ClientConn on a channel to be inspected by the test.
+ origDialFunc := grpcDialFunc
+ grpcDialFunc = func(string, ...grpc.DialOption) (*grpc.ClientConn, error) {
+ cc, err := grpc.Dial("dummy-meshca-endpoint", grpc.WithInsecure())
+ ccChan.Send(cc)
+ return cc, err
+ }
+ defer func() { grpcDialFunc = origDialFunc }()
+
+ builder := newPluginBuilder()
+ providers := []certprovider.Provider{}
+ for i := 0; i < cnt; i++ {
+ // Copy the good test config and modify the serverURI to make sure that
+ // a new provider is created for the config.
+ cfg := proto.Clone(goodConfigFullySpecified).(*configpb.GoogleMeshCaConfig)
+ cfg.Server.GrpcServices[0].GetGoogleGrpc().TargetUri = fmt.Sprintf("test-mesh-ca:%d", i)
+ inputConfig := makeJSONConfig(t, cfg)
+ stableConfig, err := builder.ParseConfig(inputConfig)
+ if err != nil {
+ t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err)
+ }
+
+ p := builder.Build(stableConfig, certprovider.Options{})
+ if p == nil {
+ t.Fatalf("builder.Build(%s) failed: %v", string(stableConfig.Canonical()), err)
+ }
+ providers = append(providers, p)
+ }
+
+ // Make sure two ClientConns are created.
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ for i := 0; i < cnt; i++ {
+ if _, err := ccChan.Receive(ctx); err != nil {
+ t.Fatalf("Failed to create ClientConn: %v", err)
+ }
+ }
+
+ // Close the first provider, and attempt to read key material from the
+ // second provider. The call to read key material should timeout, but it
+ // should not return certprovider.errProviderClosed.
+ providers[0].Close()
+ ctx, cancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
+ defer cancel()
+ if _, err := providers[1].KeyMaterial(ctx); err != context.DeadlineExceeded {
+ t.Fatalf("provider.KeyMaterial(ctx) = %v, want contextDeadlineExceeded", err)
+ }
+
+ // Close the second provider to make sure that the leakchecker is happy.
+ providers[1].Close()
+}
diff --git a/credentials/tls/certprovider/meshca/config.go b/credentials/tls/certprovider/meshca/config.go
new file mode 100644
index 0000000..38186fa
--- /dev/null
+++ b/credentials/tls/certprovider/meshca/config.go
@@ -0,0 +1,281 @@
+// +build go1.13
+
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package meshca
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httputil"
+ "path"
+ "strings"
+ "time"
+
+ v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
+ "github.com/golang/protobuf/jsonpb"
+ "github.com/golang/protobuf/ptypes"
+
+ "google.golang.org/grpc/credentials/sts"
+ configpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/meshca_experimental"
+)
+
+const (
+ // GKE metadata server endpoint.
+ mdsBaseURI = "http://metadata.google.internal/"
+ mdsRequestTimeout = 5 * time.Second
+
+ // The following are default values used in the interaction with MeshCA.
+ defaultMeshCaEndpoint = "meshca.googleapis.com"
+ defaultCallTimeout = 10 * time.Second
+ defaultCertLifetime = 24 * time.Hour
+ defaultCertGraceTime = 12 * time.Hour
+ defaultKeyTypeRSA = "RSA"
+ defaultKeySize = 2048
+
+ // The following are default values used in the interaction with STS or
+ // Secure Token Service, which is used to exchange the JWT token for an
+ // access token.
+ defaultSTSEndpoint = "securetoken.googleapis.com"
+ defaultCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
+ defaultRequestedTokenType = "urn:ietf:params:oauth:token-type:access_token"
+ defaultSubjectTokenType = "urn:ietf:params:oauth:token-type:jwt"
+)
+
+// For overriding in unit tests.
+var (
+ makeHTTPDoer = makeHTTPClient
+ readZoneFunc = readZone
+ readAudienceFunc = readAudience
+)
+
+// Implements the certprovider.StableConfig interface.
+type pluginConfig struct {
+ serverURI string
+ stsOpts sts.Options
+ callTimeout time.Duration
+ certLifetime time.Duration
+ certGraceTime time.Duration
+ keyType string
+ keySize int
+ location string
+}
+
+// pluginConfigFromJSON parses the provided config in JSON.
+//
+// For certain values missing in the config, we use default values defined at
+// the top of this file.
+//
+// If the location field or STS audience field is missing, we try talking to the
+// GKE Metadata server and try to infer these values. If this attempt does not
+// succeed, we let those fields have empty values.
+func pluginConfigFromJSON(data json.RawMessage) (*pluginConfig, error) {
+ cfgProto := &configpb.GoogleMeshCaConfig{}
+ m := jsonpb.Unmarshaler{AllowUnknownFields: true}
+ if err := m.Unmarshal(bytes.NewReader(data), cfgProto); err != nil {
+ return nil, fmt.Errorf("meshca: failed to unmarshal config: %v", err)
+ }
+
+ if api := cfgProto.GetServer().GetApiType(); api != v3corepb.ApiConfigSource_GRPC {
+ return nil, fmt.Errorf("meshca: server has apiType %s, want %s", api, v3corepb.ApiConfigSource_GRPC)
+ }
+
+ pc := &pluginConfig{}
+ gs := cfgProto.GetServer().GetGrpcServices()
+ if l := len(gs); l != 1 {
+ return nil, fmt.Errorf("meshca: number of gRPC services in config is %d, expected 1", l)
+ }
+ grpcService := gs[0]
+ googGRPC := grpcService.GetGoogleGrpc()
+ if googGRPC == nil {
+ return nil, errors.New("meshca: missing google gRPC service in config")
+ }
+ pc.serverURI = googGRPC.GetTargetUri()
+ if pc.serverURI == "" {
+ pc.serverURI = defaultMeshCaEndpoint
+ }
+
+ callCreds := googGRPC.GetCallCredentials()
+ if len(callCreds) == 0 {
+ return nil, errors.New("meshca: missing call credentials in config")
+ }
+ var stsCallCreds *v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService
+ for _, cc := range callCreds {
+ if stsCallCreds = cc.GetStsService(); stsCallCreds != nil {
+ break
+ }
+ }
+ if stsCallCreds == nil {
+ return nil, errors.New("meshca: missing STS call credentials in config")
+ }
+ if stsCallCreds.GetSubjectTokenPath() == "" {
+ return nil, errors.New("meshca: missing subjectTokenPath in STS call credentials config")
+ }
+ pc.stsOpts = makeStsOptsWithDefaults(stsCallCreds)
+
+ var err error
+ if pc.callTimeout, err = ptypes.Duration(grpcService.GetTimeout()); err != nil {
+ pc.callTimeout = defaultCallTimeout
+ }
+ if pc.certLifetime, err = ptypes.Duration(cfgProto.GetCertificateLifetime()); err != nil {
+ pc.certLifetime = defaultCertLifetime
+ }
+ if pc.certGraceTime, err = ptypes.Duration(cfgProto.GetRenewalGracePeriod()); err != nil {
+ pc.certGraceTime = defaultCertGraceTime
+ }
+ switch cfgProto.GetKeyType() {
+ case configpb.GoogleMeshCaConfig_KEY_TYPE_UNKNOWN, configpb.GoogleMeshCaConfig_KEY_TYPE_RSA:
+ pc.keyType = defaultKeyTypeRSA
+ default:
+ return nil, fmt.Errorf("meshca: unsupported key type: %s, only support RSA keys", pc.keyType)
+ }
+ pc.keySize = int(cfgProto.GetKeySize())
+ if pc.keySize == 0 {
+ pc.keySize = defaultKeySize
+ }
+ pc.location = cfgProto.GetLocation()
+ if pc.location == "" {
+ pc.location = readZoneFunc(makeHTTPDoer())
+ }
+
+ return pc, nil
+}
+
+func (pc *pluginConfig) Canonical() []byte {
+ return []byte(fmt.Sprintf("%s:%s:%s:%s:%s:%s:%d:%s", pc.serverURI, pc.stsOpts, pc.callTimeout, pc.certLifetime, pc.certGraceTime, pc.keyType, pc.keySize, pc.location))
+}
+
+func makeStsOptsWithDefaults(stsCallCreds *v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService) sts.Options {
+ opts := sts.Options{
+ TokenExchangeServiceURI: stsCallCreds.GetTokenExchangeServiceUri(),
+ Resource: stsCallCreds.GetResource(),
+ Audience: stsCallCreds.GetAudience(),
+ Scope: stsCallCreds.GetScope(),
+ RequestedTokenType: stsCallCreds.GetRequestedTokenType(),
+ SubjectTokenPath: stsCallCreds.GetSubjectTokenPath(),
+ SubjectTokenType: stsCallCreds.GetSubjectTokenType(),
+ ActorTokenPath: stsCallCreds.GetActorTokenPath(),
+ ActorTokenType: stsCallCreds.GetActorTokenType(),
+ }
+
+ // Use sane defaults for unspecified fields.
+ if opts.TokenExchangeServiceURI == "" {
+ opts.TokenExchangeServiceURI = defaultSTSEndpoint
+ }
+ if opts.Audience == "" {
+ opts.Audience = readAudienceFunc(makeHTTPDoer())
+ }
+ if opts.Scope == "" {
+ opts.Scope = defaultCloudPlatformScope
+ }
+ if opts.RequestedTokenType == "" {
+ opts.RequestedTokenType = defaultRequestedTokenType
+ }
+ if opts.SubjectTokenType == "" {
+ opts.SubjectTokenType = defaultSubjectTokenType
+ }
+ return opts
+}
+
+// httpDoer wraps the single method on the http.Client type that we use. This
+// helps with overriding in unit tests.
+type httpDoer interface {
+ Do(req *http.Request) (*http.Response, error)
+}
+
+func makeHTTPClient() httpDoer {
+ return &http.Client{Timeout: mdsRequestTimeout}
+}
+
+func readMetadata(client httpDoer, uriPath string) (string, error) {
+ req, err := http.NewRequest("GET", mdsBaseURI+uriPath, nil)
+ if err != nil {
+ return "", err
+ }
+ req.Header.Add("Metadata-Flavor", "Google")
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return "", err
+ }
+ if resp.StatusCode != http.StatusOK {
+ dump, err := httputil.DumpRequestOut(req, false)
+ if err != nil {
+ logger.Warningf("Failed to dump HTTP request: %v", err)
+ }
+ logger.Warningf("Request %q returned status %v", dump, resp.StatusCode)
+ }
+ return string(body), err
+}
+
+func readZone(client httpDoer) string {
+ zoneURI := "computeMetadata/v1/instance/zone"
+ data, err := readMetadata(client, zoneURI)
+ if err != nil {
+ logger.Warningf("GET %s failed: %v", path.Join(mdsBaseURI, zoneURI), err)
+ return ""
+ }
+
+ // The output returned by the metadata server looks like this:
+ // projects/<PROJECT-NUMBER>/zones/<ZONE>
+ parts := strings.Split(data, "/")
+ if len(parts) == 0 {
+ logger.Warningf("GET %s returned {%s}, does not match expected format {projects/<PROJECT-NUMBER>/zones/<ZONE>}", path.Join(mdsBaseURI, zoneURI))
+ return ""
+ }
+ return parts[len(parts)-1]
+}
+
+// readAudience constructs the audience field to be used in the STS request, if
+// it is not specified in the plugin configuration.
+//
+// "identitynamespace:{TRUST_DOMAIN}:{GKE_CLUSTER_URL}" is the format of the
+// audience field. When workload identity is enabled on a GCP project, a default
+// trust domain is created whose value is "{PROJECT_ID}.svc.id.goog". The format
+// of the GKE_CLUSTER_URL is:
+// https://container.googleapis.com/v1/projects/{PROJECT_ID}/zones/{ZONE}/clusters/{CLUSTER_NAME}.
+func readAudience(client httpDoer) string {
+ projURI := "computeMetadata/v1/project/project-id"
+ project, err := readMetadata(client, projURI)
+ if err != nil {
+ logger.Warningf("GET %s failed: %v", path.Join(mdsBaseURI, projURI), err)
+ return ""
+ }
+ trustDomain := fmt.Sprintf("%s.svc.id.goog", project)
+
+ clusterURI := "computeMetadata/v1/instance/attributes/cluster-name"
+ cluster, err := readMetadata(client, clusterURI)
+ if err != nil {
+ logger.Warningf("GET %s failed: %v", path.Join(mdsBaseURI, clusterURI), err)
+ return ""
+ }
+ zone := readZoneFunc(client)
+ clusterURL := fmt.Sprintf("https://container.googleapis.com/v1/projects/%s/zones/%s/clusters/%s", project, zone, cluster)
+ audience := fmt.Sprintf("identitynamespace:%s:%s", trustDomain, clusterURL)
+ return audience
+}
diff --git a/credentials/tls/certprovider/meshca/config_test.go b/credentials/tls/certprovider/meshca/config_test.go
new file mode 100644
index 0000000..34dd9f7
--- /dev/null
+++ b/credentials/tls/certprovider/meshca/config_test.go
@@ -0,0 +1,410 @@
+// +build go1.13
+
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package meshca
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "strings"
+ "testing"
+
+ v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
+ "github.com/golang/protobuf/jsonpb"
+ durationpb "github.com/golang/protobuf/ptypes/duration"
+ "github.com/google/go-cmp/cmp"
+
+ configpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/meshca_experimental"
+ "google.golang.org/grpc/internal/grpctest"
+ "google.golang.org/grpc/internal/testutils"
+)
+
+const (
+ testProjectID = "test-project-id"
+ testGKECluster = "test-gke-cluster"
+ testGCEZone = "test-zone"
+)
+
+type s struct {
+ grpctest.Tester
+}
+
+func Test(t *testing.T) {
+ grpctest.RunSubTests(t, s{})
+}
+
+var (
+ goodConfigFullySpecified = &configpb.GoogleMeshCaConfig{
+ Server: &v3corepb.ApiConfigSource{
+ ApiType: v3corepb.ApiConfigSource_GRPC,
+ GrpcServices: []*v3corepb.GrpcService{
+ {
+ TargetSpecifier: &v3corepb.GrpcService_GoogleGrpc_{
+ GoogleGrpc: &v3corepb.GrpcService_GoogleGrpc{
+ TargetUri: "test-meshca",
+ CallCredentials: []*v3corepb.GrpcService_GoogleGrpc_CallCredentials{
+ // This call creds should be ignored.
+ {
+ CredentialSpecifier: &v3corepb.GrpcService_GoogleGrpc_CallCredentials_AccessToken{},
+ },
+ {
+ CredentialSpecifier: &v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService_{
+ StsService: &v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService{
+ TokenExchangeServiceUri: "http://test-sts",
+ Resource: "test-resource",
+ Audience: "test-audience",
+ Scope: "test-scope",
+ RequestedTokenType: "test-requested-token-type",
+ SubjectTokenPath: "test-subject-token-path",
+ SubjectTokenType: "test-subject-token-type",
+ ActorTokenPath: "test-actor-token-path",
+ ActorTokenType: "test-actor-token-type",
+ },
+ },
+ },
+ },
+ },
+ },
+ Timeout: &durationpb.Duration{Seconds: 10}, // 10s
+ },
+ },
+ },
+ CertificateLifetime: &durationpb.Duration{Seconds: 86400}, // 1d
+ RenewalGracePeriod: &durationpb.Duration{Seconds: 43200}, //12h
+ KeyType: configpb.GoogleMeshCaConfig_KEY_TYPE_RSA,
+ KeySize: uint32(2048),
+ Location: "us-west1-b",
+ }
+ goodConfigWithDefaults = &configpb.GoogleMeshCaConfig{
+ Server: &v3corepb.ApiConfigSource{
+ ApiType: v3corepb.ApiConfigSource_GRPC,
+ GrpcServices: []*v3corepb.GrpcService{
+ {
+ TargetSpecifier: &v3corepb.GrpcService_GoogleGrpc_{
+ GoogleGrpc: &v3corepb.GrpcService_GoogleGrpc{
+ CallCredentials: []*v3corepb.GrpcService_GoogleGrpc_CallCredentials{
+ {
+ CredentialSpecifier: &v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService_{
+ StsService: &v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService{
+ SubjectTokenPath: "test-subject-token-path",
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+)
+
+// makeJSONConfig marshals the provided config proto into JSON. This makes it
+// possible for tests to specify the config in proto form, which is much easier
+// than specifying the config in JSON form.
+func makeJSONConfig(t *testing.T, cfg *configpb.GoogleMeshCaConfig) json.RawMessage {
+ t.Helper()
+
+ b := &bytes.Buffer{}
+ m := &jsonpb.Marshaler{EnumsAsInts: true}
+ if err := m.Marshal(b, cfg); err != nil {
+ t.Fatalf("jsonpb.Marshal(%+v) failed: %v", cfg, err)
+ }
+ return json.RawMessage(b.Bytes())
+}
+
+// verifyReceivedRequest reads the HTTP request received by the fake client
+// (exposed through a channel), and verifies that it matches the expected
+// request.
+func verifyReceivedRequest(fc *testutils.FakeHTTPClient, wantURI string) error {
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ val, err := fc.ReqChan.Receive(ctx)
+ if err != nil {
+ return err
+ }
+ gotReq := val.(*http.Request)
+ if gotURI := gotReq.URL.String(); gotURI != wantURI {
+ return fmt.Errorf("request contains URL %q want %q", gotURI, wantURI)
+ }
+ if got, want := gotReq.Header.Get("Metadata-Flavor"), "Google"; got != want {
+ return fmt.Errorf("request contains flavor %q want %q", got, want)
+ }
+ return nil
+}
+
+// TestParseConfigSuccessFullySpecified tests the case where the config is fully
+// specified and no defaults are required.
+func (s) TestParseConfigSuccessFullySpecified(t *testing.T) {
+ inputConfig := makeJSONConfig(t, goodConfigFullySpecified)
+ wantConfig := "test-meshca:http://test-sts:test-resource:test-audience:test-scope:test-requested-token-type:test-subject-token-path:test-subject-token-type:test-actor-token-path:test-actor-token-type:10s:24h0m0s:12h0m0s:RSA:2048:us-west1-b"
+
+ builder := newPluginBuilder()
+ gotConfig, err := builder.ParseConfig(inputConfig)
+ if err != nil {
+ t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err)
+ }
+ if diff := cmp.Diff(wantConfig, string(gotConfig.Canonical())); diff != "" {
+ t.Errorf("builder.ParseConfig(%q) returned config does not match expected (-want +got):\n%s", inputConfig, diff)
+ }
+}
+
+// TestParseConfigSuccessWithDefaults tests cases where the config is not fully
+// specified, and we end up using some sane defaults.
+func (s) TestParseConfigSuccessWithDefaults(t *testing.T) {
+ inputConfig := makeJSONConfig(t, goodConfigWithDefaults)
+ wantConfig := fmt.Sprintf("%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s:%s",
+ "meshca.googleapis.com", // Mesh CA Server URI.
+ "securetoken.googleapis.com", // STS Server URI.
+ "", // STS Resource Name.
+ "identitynamespace:test-project-id.svc.id.goog:https://container.googleapis.com/v1/projects/test-project-id/zones/test-zone/clusters/test-gke-cluster", // STS Audience.
+ "https://www.googleapis.com/auth/cloud-platform", // STS Scope.
+ "urn:ietf:params:oauth:token-type:access_token", // STS requested token type.
+ "test-subject-token-path", // STS subject token path.
+ "urn:ietf:params:oauth:token-type:jwt", // STS subject token type.
+ "", // STS actor token path.
+ "", // STS actor token type.
+ "10s", // Call timeout.
+ "24h0m0s", // Cert life time.
+ "12h0m0s", // Cert grace time.
+ "RSA", // Key type
+ "2048", // Key size
+ "test-zone", // Zone
+ )
+
+ // We expect the config parser to make four HTTP requests and receive four
+ // responses. Hence we setup the request and response channels in the fake
+ // client with appropriate buffer size.
+ fc := &testutils.FakeHTTPClient{
+ ReqChan: testutils.NewChannelWithSize(4),
+ RespChan: testutils.NewChannelWithSize(4),
+ }
+ // Set up the responses to be delivered to the config parser by the fake
+ // client. The config parser expects responses with project_id,
+ // gke_cluster_id and gce_zone. The zone is read twice, once as part of
+ // reading the STS audience and once to get location metadata.
+ fc.RespChan.Send(&http.Response{
+ Status: "200 OK",
+ StatusCode: http.StatusOK,
+ Body: ioutil.NopCloser(bytes.NewReader([]byte(testProjectID))),
+ })
+ fc.RespChan.Send(&http.Response{
+ Status: "200 OK",
+ StatusCode: http.StatusOK,
+ Body: ioutil.NopCloser(bytes.NewReader([]byte(testGKECluster))),
+ })
+ fc.RespChan.Send(&http.Response{
+ Status: "200 OK",
+ StatusCode: http.StatusOK,
+ Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf("projects/%s/zones/%s", testProjectID, testGCEZone)))),
+ })
+ fc.RespChan.Send(&http.Response{
+ Status: "200 OK",
+ StatusCode: http.StatusOK,
+ Body: ioutil.NopCloser(bytes.NewReader([]byte(fmt.Sprintf("projects/%s/zones/%s", testProjectID, testGCEZone)))),
+ })
+ // Override the http.Client with our fakeClient.
+ origMakeHTTPDoer := makeHTTPDoer
+ makeHTTPDoer = func() httpDoer { return fc }
+ defer func() { makeHTTPDoer = origMakeHTTPDoer }()
+
+ // Spawn a goroutine to verify the HTTP requests sent out as part of the
+ // config parsing.
+ errCh := make(chan error, 1)
+ go func() {
+ if err := verifyReceivedRequest(fc, "http://metadata.google.internal/computeMetadata/v1/project/project-id"); err != nil {
+ errCh <- err
+ return
+ }
+ if err := verifyReceivedRequest(fc, "http://metadata.google.internal/computeMetadata/v1/instance/attributes/cluster-name"); err != nil {
+ errCh <- err
+ return
+ }
+ if err := verifyReceivedRequest(fc, "http://metadata.google.internal/computeMetadata/v1/instance/zone"); err != nil {
+ errCh <- err
+ return
+ }
+ errCh <- nil
+ }()
+
+ builder := newPluginBuilder()
+ gotConfig, err := builder.ParseConfig(inputConfig)
+ if err != nil {
+ t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err)
+
+ }
+ if diff := cmp.Diff(wantConfig, string(gotConfig.Canonical())); diff != "" {
+ t.Errorf("builder.ParseConfig(%q) returned config does not match expected (-want +got):\n%s", inputConfig, diff)
+ }
+
+ if err := <-errCh; err != nil {
+ t.Fatal(err)
+ }
+}
+
+// TestParseConfigFailureCases tests several invalid configs which all result in
+// config parsing failures.
+func (s) TestParseConfigFailureCases(t *testing.T) {
+ tests := []struct {
+ desc string
+ inputConfig interface{}
+ wantErr string
+ }{
+ {
+ desc: "bad config type",
+ inputConfig: struct{ foo string }{foo: "bar"},
+ wantErr: "unsupported config type",
+ },
+ {
+ desc: "invalid JSON",
+ inputConfig: json.RawMessage(`bad bad json`),
+ wantErr: "failed to unmarshal config",
+ },
+ {
+ desc: "bad apiType",
+ inputConfig: makeJSONConfig(t, &configpb.GoogleMeshCaConfig{
+ Server: &v3corepb.ApiConfigSource{
+ ApiType: v3corepb.ApiConfigSource_REST,
+ },
+ }),
+ wantErr: "server has apiType REST, want GRPC",
+ },
+ {
+ desc: "no grpc services",
+ inputConfig: makeJSONConfig(t, &configpb.GoogleMeshCaConfig{
+ Server: &v3corepb.ApiConfigSource{
+ ApiType: v3corepb.ApiConfigSource_GRPC,
+ },
+ }),
+ wantErr: "number of gRPC services in config is 0, expected 1",
+ },
+ {
+ desc: "too many grpc services",
+ inputConfig: makeJSONConfig(t, &configpb.GoogleMeshCaConfig{
+ Server: &v3corepb.ApiConfigSource{
+ ApiType: v3corepb.ApiConfigSource_GRPC,
+ GrpcServices: []*v3corepb.GrpcService{nil, nil},
+ },
+ }),
+ wantErr: "number of gRPC services in config is 2, expected 1",
+ },
+ {
+ desc: "missing google grpc service",
+ inputConfig: makeJSONConfig(t, &configpb.GoogleMeshCaConfig{
+ Server: &v3corepb.ApiConfigSource{
+ ApiType: v3corepb.ApiConfigSource_GRPC,
+ GrpcServices: []*v3corepb.GrpcService{
+ {
+ TargetSpecifier: &v3corepb.GrpcService_EnvoyGrpc_{
+ EnvoyGrpc: &v3corepb.GrpcService_EnvoyGrpc{
+ ClusterName: "foo",
+ },
+ },
+ },
+ },
+ },
+ }),
+ wantErr: "missing google gRPC service in config",
+ },
+ {
+ desc: "missing call credentials",
+ inputConfig: makeJSONConfig(t, &configpb.GoogleMeshCaConfig{
+ Server: &v3corepb.ApiConfigSource{
+ ApiType: v3corepb.ApiConfigSource_GRPC,
+ GrpcServices: []*v3corepb.GrpcService{
+ {
+ TargetSpecifier: &v3corepb.GrpcService_GoogleGrpc_{
+ GoogleGrpc: &v3corepb.GrpcService_GoogleGrpc{
+ TargetUri: "foo",
+ },
+ },
+ },
+ },
+ },
+ }),
+ wantErr: "missing call credentials in config",
+ },
+ {
+ desc: "missing STS call credentials",
+ inputConfig: makeJSONConfig(t, &configpb.GoogleMeshCaConfig{
+ Server: &v3corepb.ApiConfigSource{
+ ApiType: v3corepb.ApiConfigSource_GRPC,
+ GrpcServices: []*v3corepb.GrpcService{
+ {
+ TargetSpecifier: &v3corepb.GrpcService_GoogleGrpc_{
+ GoogleGrpc: &v3corepb.GrpcService_GoogleGrpc{
+ TargetUri: "foo",
+ CallCredentials: []*v3corepb.GrpcService_GoogleGrpc_CallCredentials{
+ {
+ CredentialSpecifier: &v3corepb.GrpcService_GoogleGrpc_CallCredentials_AccessToken{},
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ }),
+ wantErr: "missing STS call credentials in config",
+ },
+ {
+ desc: "with no defaults",
+ inputConfig: makeJSONConfig(t, &configpb.GoogleMeshCaConfig{
+ Server: &v3corepb.ApiConfigSource{
+ ApiType: v3corepb.ApiConfigSource_GRPC,
+ GrpcServices: []*v3corepb.GrpcService{
+ {
+ TargetSpecifier: &v3corepb.GrpcService_GoogleGrpc_{
+ GoogleGrpc: &v3corepb.GrpcService_GoogleGrpc{
+ CallCredentials: []*v3corepb.GrpcService_GoogleGrpc_CallCredentials{
+ {
+ CredentialSpecifier: &v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService_{
+ StsService: &v3corepb.GrpcService_GoogleGrpc_CallCredentials_StsService{},
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ }),
+ wantErr: "missing subjectTokenPath in STS call credentials config",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ builder := newPluginBuilder()
+ sc, err := builder.ParseConfig(test.inputConfig)
+ if err == nil {
+ t.Fatalf("builder.ParseConfig(%q) = %v, expected to return error (%v)", test.inputConfig, string(sc.Canonical()), test.wantErr)
+
+ }
+ if !strings.Contains(err.Error(), test.wantErr) {
+ t.Fatalf("builder.ParseConfig(%q) = (%v), want error (%v)", test.inputConfig, err, test.wantErr)
+ }
+ })
+ }
+}
diff --git a/credentials/tls/certprovider/meshca/logging.go b/credentials/tls/certprovider/meshca/logging.go
new file mode 100644
index 0000000..ae20059
--- /dev/null
+++ b/credentials/tls/certprovider/meshca/logging.go
@@ -0,0 +1,36 @@
+// +build go1.13
+
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package meshca
+
+import (
+ "fmt"
+
+ "google.golang.org/grpc/grpclog"
+ internalgrpclog "google.golang.org/grpc/internal/grpclog"
+)
+
+const prefix = "[%p] "
+
+var logger = grpclog.Component("meshca")
+
+func prefixLogger(p *providerPlugin) *internalgrpclog.PrefixLogger {
+ return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p))
+}
diff --git a/credentials/tls/certprovider/meshca/plugin.go b/credentials/tls/certprovider/meshca/plugin.go
new file mode 100644
index 0000000..b00ad5f
--- /dev/null
+++ b/credentials/tls/certprovider/meshca/plugin.go
@@ -0,0 +1,288 @@
+// +build go1.13
+
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+// Package meshca provides an implementation of the Provider interface which
+// communicates with MeshCA to get certificates signed.
+package meshca
+
+import (
+ "context"
+ "crypto"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/pem"
+ "fmt"
+ "time"
+
+ durationpb "github.com/golang/protobuf/ptypes/duration"
+ "github.com/google/uuid"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/tls/certprovider"
+ meshpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1"
+ "google.golang.org/grpc/internal/grpclog"
+ "google.golang.org/grpc/metadata"
+)
+
+// In requests sent to the MeshCA, we add a metadata header with this key and
+// the value being the GCE zone in which the workload is running in.
+const locationMetadataKey = "x-goog-request-params"
+
+// For overriding from unit tests.
+var newDistributorFunc = func() distributor { return certprovider.NewDistributor() }
+
+// distributor wraps the methods on certprovider.Distributor which are used by
+// the plugin. This is very useful in tests which need to know exactly when the
+// plugin updates its key material.
+type distributor interface {
+ KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error)
+ Set(km *certprovider.KeyMaterial, err error)
+ Stop()
+}
+
+// providerPlugin is an implementation of the certprovider.Provider interface,
+// which gets certificates signed by communicating with the MeshCA.
+type providerPlugin struct {
+ distributor // Holds the key material.
+ cancel context.CancelFunc
+ cc *grpc.ClientConn // Connection to MeshCA server.
+ cfg *pluginConfig // Plugin configuration.
+ opts certprovider.Options // Key material options.
+ logger *grpclog.PrefixLogger // Plugin instance specific prefix.
+ backoff func(int) time.Duration // Exponential backoff.
+ doneFunc func() // Notify the builder when done.
+}
+
+// providerParams wraps params passed to the provider plugin at creation time.
+type providerParams struct {
+ // This ClientConn to the MeshCA server is owned by the builder.
+ cc *grpc.ClientConn
+ cfg *pluginConfig
+ opts certprovider.Options
+ backoff func(int) time.Duration
+ doneFunc func()
+}
+
+func newProviderPlugin(params providerParams) *providerPlugin {
+ ctx, cancel := context.WithCancel(context.Background())
+ p := &providerPlugin{
+ cancel: cancel,
+ cc: params.cc,
+ cfg: params.cfg,
+ opts: params.opts,
+ backoff: params.backoff,
+ doneFunc: params.doneFunc,
+ distributor: newDistributorFunc(),
+ }
+ p.logger = prefixLogger((p))
+ p.logger.Infof("plugin created")
+ go p.run(ctx)
+ return p
+}
+
+func (p *providerPlugin) Close() {
+ p.logger.Infof("plugin closed")
+ p.Stop() // Stop the embedded distributor.
+ p.cancel()
+ p.doneFunc()
+}
+
+// run is a long running goroutine which periodically sends out CSRs to the
+// MeshCA, and updates the underlying Distributor with the new key material.
+func (p *providerPlugin) run(ctx context.Context) {
+ // We need to start fetching key material right away. The next attempt will
+ // be triggered by the timer firing.
+ for {
+ certValidity, err := p.updateKeyMaterial(ctx)
+ if err != nil {
+ return
+ }
+
+ // We request a certificate with the configured validity duration (which
+ // is usually twice as much as the grace period). But the server is free
+ // to return a certificate with whatever validity time it deems right.
+ refreshAfter := p.cfg.certGraceTime
+ if refreshAfter > certValidity {
+ // The default value of cert grace time is half that of the default
+ // cert validity time. So here, when we have to use a non-default
+ // cert life time, we will set the grace time again to half that of
+ // the validity time.
+ refreshAfter = certValidity / 2
+ }
+ timer := time.NewTimer(refreshAfter)
+ select {
+ case <-ctx.Done():
+ return
+ case <-timer.C:
+ }
+ }
+}
+
+// updateKeyMaterial generates a CSR and attempts to get it signed from the
+// MeshCA. It retries with an exponential backoff till it succeeds or the
+// deadline specified in ctx expires. Once it gets the CSR signed from the
+// MeshCA, it updates the Distributor with the new key material.
+//
+// It returns the amount of time the new certificate is valid for.
+func (p *providerPlugin) updateKeyMaterial(ctx context.Context) (time.Duration, error) {
+ client := meshpb.NewMeshCertificateServiceClient(p.cc)
+ retries := 0
+ for {
+ if ctx.Err() != nil {
+ return 0, ctx.Err()
+ }
+
+ if retries != 0 {
+ bi := p.backoff(retries)
+ p.logger.Warningf("Backing off for %s before attempting the next CreateCertificate() request", bi)
+ timer := time.NewTimer(bi)
+ select {
+ case <-timer.C:
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ }
+ }
+ retries++
+
+ privKey, err := rsa.GenerateKey(rand.Reader, p.cfg.keySize)
+ if err != nil {
+ p.logger.Warningf("RSA key generation failed: %v", err)
+ continue
+ }
+ // We do not set any fields in the CSR (we use an empty
+ // x509.CertificateRequest as the template) because the MeshCA discards
+ // them anyways, and uses the workload identity from the access token
+ // that we present (as part of the STS call creds).
+ csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{}, crypto.PrivateKey(privKey))
+ if err != nil {
+ p.logger.Warningf("CSR creation failed: %v", err)
+ continue
+ }
+ csrPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrBytes})
+
+ // Send out the CSR with a call timeout and location metadata, as
+ // specified in the plugin configuration.
+ req := &meshpb.MeshCertificateRequest{
+ RequestId: uuid.New().String(),
+ Csr: string(csrPEM),
+ Validity: &durationpb.Duration{Seconds: int64(p.cfg.certLifetime / time.Second)},
+ }
+ p.logger.Debugf("Sending CreateCertificate() request: %v", req)
+
+ callCtx, ctxCancel := context.WithTimeout(context.Background(), p.cfg.callTimeout)
+ callCtx = metadata.NewOutgoingContext(callCtx, metadata.Pairs(locationMetadataKey, p.cfg.location))
+ resp, err := client.CreateCertificate(callCtx, req)
+ if err != nil {
+ p.logger.Warningf("CreateCertificate request failed: %v", err)
+ ctxCancel()
+ continue
+ }
+ ctxCancel()
+
+ // The returned cert chain must contain more than one cert. Leaf cert is
+ // element '0', while root cert is element 'n', and the intermediate
+ // entries form the chain from the root to the leaf.
+ certChain := resp.GetCertChain()
+ if l := len(certChain); l <= 1 {
+ p.logger.Errorf("Received certificate chain contains %d certificates, need more than one", l)
+ continue
+ }
+
+ // We need to explicitly parse the PEM cert contents as an
+ // x509.Certificate to read the certificate validity period. We use this
+ // to decide when to refresh the cert. Even though the call to
+ // tls.X509KeyPair actually parses the PEM contents into an
+ // x509.Certificate, it does not store that in the `Leaf` field. See:
+ // https://golang.org/pkg/crypto/tls/#X509KeyPair.
+ identity, intermediates, roots, err := parseCertChain(certChain)
+ if err != nil {
+ p.logger.Errorf(err.Error())
+ continue
+ }
+ _, err = identity.Verify(x509.VerifyOptions{
+ Intermediates: intermediates,
+ Roots: roots,
+ })
+ if err != nil {
+ p.logger.Errorf("Certificate verification failed for return certChain: %v", err)
+ continue
+ }
+
+ key := x509.MarshalPKCS1PrivateKey(privKey)
+ keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: key})
+ certPair, err := tls.X509KeyPair([]byte(certChain[0]), keyPEM)
+ if err != nil {
+ p.logger.Errorf("Failed to create x509 key pair: %v", err)
+ continue
+ }
+
+ // At this point, the received response has been deemed good.
+ retries = 0
+
+ // All certs signed by the MeshCA roll up to the same root. And treating
+ // the last element of the returned chain as the root is the only
+ // supported option to get the root certificate. So, we ignore the
+ // options specified in the call to Build(), which contain certificate
+ // name and whether the caller is interested in identity or root cert.
+ p.Set(&certprovider.KeyMaterial{Certs: []tls.Certificate{certPair}, Roots: roots}, nil)
+ return time.Until(identity.NotAfter), nil
+ }
+}
+
+// ParseCertChain parses the result returned by the MeshCA which consists of a
+// list of PEM encoded certs. The first element in the list is the leaf or
+// identity cert, while the last element is the root, and everything in between
+// form the chain of trust.
+//
+// Caller needs to make sure that certChain has at least two elements.
+func parseCertChain(certChain []string) (*x509.Certificate, *x509.CertPool, *x509.CertPool, error) {
+ identity, err := parseCert([]byte(certChain[0]))
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ intermediates := x509.NewCertPool()
+ for _, cert := range certChain[1 : len(certChain)-1] {
+ i, err := parseCert([]byte(cert))
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ intermediates.AddCert(i)
+ }
+
+ roots := x509.NewCertPool()
+ root, err := parseCert([]byte(certChain[len(certChain)-1]))
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ roots.AddCert(root)
+
+ return identity, intermediates, roots, nil
+}
+
+func parseCert(certPEM []byte) (*x509.Certificate, error) {
+ block, _ := pem.Decode(certPEM)
+ if block == nil {
+ return nil, fmt.Errorf("failed to decode received PEM data: %v", certPEM)
+ }
+ return x509.ParseCertificate(block.Bytes)
+}
diff --git a/credentials/tls/certprovider/meshca/plugin_test.go b/credentials/tls/certprovider/meshca/plugin_test.go
new file mode 100644
index 0000000..48740c1
--- /dev/null
+++ b/credentials/tls/certprovider/meshca/plugin_test.go
@@ -0,0 +1,464 @@
+// +build go1.13
+
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package meshca
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "math/big"
+ "net"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/golang/protobuf/proto"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/tls/certprovider"
+ configpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/meshca_experimental"
+ meshgrpc "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1"
+ meshpb "google.golang.org/grpc/credentials/tls/certprovider/meshca/internal/v1"
+ "google.golang.org/grpc/internal/testutils"
+)
+
+const (
+ // Used when waiting for something that is expected to *not* happen.
+ defaultTestShortTimeout = 10 * time.Millisecond
+ defaultTestTimeout = 5 * time.Second
+ defaultTestCertLife = time.Hour
+ shortTestCertLife = 2 * time.Second
+ maxErrCount = 2
+)
+
+// fakeCA provides a very simple fake implementation of the certificate signing
+// service as exported by the MeshCA.
+type fakeCA struct {
+ meshgrpc.UnimplementedMeshCertificateServiceServer
+
+ withErrors bool // Whether the CA returns errors to begin with.
+ withShortLife bool // Whether to create certs with short lifetime
+
+ ccChan *testutils.Channel // Channel to get notified about CreateCertificate calls.
+ errors int // Error count.
+ key *rsa.PrivateKey // Private key of CA.
+ cert *x509.Certificate // Signing certificate.
+ certPEM []byte // PEM encoding of signing certificate.
+}
+
+// Returns a new instance of the fake Mesh CA. It generates a new RSA key and a
+// self-signed certificate which will be used to sign CSRs received in incoming
+// requests.
+// withErrors controls whether the fake returns errors before succeeding, while
+// withShortLife controls whether the fake returns certs with very small
+// lifetimes (to test plugin refresh behavior). Every time a CreateCertificate()
+// call succeeds, an event is pushed on the ccChan.
+func newFakeMeshCA(ccChan *testutils.Channel, withErrors, withShortLife bool) (*fakeCA, error) {
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return nil, fmt.Errorf("RSA key generation failed: %v", err)
+ }
+
+ now := time.Now()
+ tmpl := &x509.Certificate{
+ Subject: pkix.Name{CommonName: "my-fake-ca"},
+ SerialNumber: big.NewInt(10),
+ NotBefore: now.Add(-time.Hour),
+ NotAfter: now.Add(time.Hour),
+ KeyUsage: x509.KeyUsageCertSign,
+ IsCA: true,
+ BasicConstraintsValid: true,
+ }
+ certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
+ if err != nil {
+ return nil, fmt.Errorf("x509.CreateCertificate(%v) failed: %v", tmpl, err)
+ }
+ // The PEM encoding of the self-signed certificate is stored because we need
+ // to return a chain of certificates in the response, starting with the
+ // client certificate and ending in the root.
+ certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
+ cert, err := x509.ParseCertificate(certDER)
+ if err != nil {
+ return nil, fmt.Errorf("x509.ParseCertificate(%v) failed: %v", certDER, err)
+ }
+
+ return &fakeCA{
+ withErrors: withErrors,
+ withShortLife: withShortLife,
+ ccChan: ccChan,
+ key: key,
+ cert: cert,
+ certPEM: certPEM,
+ }, nil
+}
+
+// CreateCertificate helps implement the MeshCA service.
+//
+// If the fakeMeshCA was created with `withErrors` set to true, the first
+// `maxErrCount` number of RPC return errors. Subsequent requests are signed and
+// returned without error.
+func (f *fakeCA) CreateCertificate(ctx context.Context, req *meshpb.MeshCertificateRequest) (*meshpb.MeshCertificateResponse, error) {
+ if f.withErrors {
+ if f.errors < maxErrCount {
+ f.errors++
+ return nil, errors.New("fake Mesh CA error")
+
+ }
+ }
+
+ csrPEM := []byte(req.GetCsr())
+ block, _ := pem.Decode(csrPEM)
+ if block == nil {
+ return nil, fmt.Errorf("failed to decode received CSR: %v", csrPEM)
+ }
+ csr, err := x509.ParseCertificateRequest(block.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse received CSR: %v", csrPEM)
+ }
+
+ // By default, we create certs which are valid for an hour. But if
+ // `withShortLife` is set, we create certs which are valid only for a couple
+ // of seconds.
+ now := time.Now()
+ notBefore, notAfter := now.Add(-defaultTestCertLife), now.Add(defaultTestCertLife)
+ if f.withShortLife {
+ notBefore, notAfter = now.Add(-shortTestCertLife), now.Add(shortTestCertLife)
+ }
+ tmpl := &x509.Certificate{
+ Subject: pkix.Name{CommonName: "signed-cert"},
+ SerialNumber: big.NewInt(10),
+ NotBefore: notBefore,
+ NotAfter: notAfter,
+ KeyUsage: x509.KeyUsageDigitalSignature,
+ }
+ certDER, err := x509.CreateCertificate(rand.Reader, tmpl, f.cert, csr.PublicKey, f.key)
+ if err != nil {
+ return nil, fmt.Errorf("x509.CreateCertificate(%v) failed: %v", tmpl, err)
+ }
+ certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
+
+ // Push to ccChan to indicate that the RPC is processed.
+ f.ccChan.Send(nil)
+
+ certChain := []string{
+ string(certPEM), // Signed certificate corresponding to CSR
+ string(f.certPEM), // Root certificate
+ }
+ return &meshpb.MeshCertificateResponse{CertChain: certChain}, nil
+}
+
+// opts wraps the options to be passed to setup.
+type opts struct {
+ // Whether the CA returns certs with short lifetime. Used to test client refresh.
+ withShortLife bool
+ // Whether the CA returns errors to begin with. Used to test client backoff.
+ withbackoff bool
+}
+
+// events wraps channels which indicate different events.
+type events struct {
+ // Pushed to when the plugin dials the MeshCA.
+ dialDone *testutils.Channel
+ // Pushed to when CreateCertifcate() succeeds on the MeshCA.
+ createCertDone *testutils.Channel
+ // Pushed to when the plugin updates the distributor with new key material.
+ keyMaterialDone *testutils.Channel
+ // Pushed to when the client backs off after a failed CreateCertificate().
+ backoffDone *testutils.Channel
+}
+
+// setup performs tasks common to all tests in this file.
+func setup(t *testing.T, o opts) (events, string, func()) {
+ t.Helper()
+
+ // Create a fake MeshCA which pushes events on the passed channel for
+ // successful RPCs.
+ createCertDone := testutils.NewChannel()
+ fs, err := newFakeMeshCA(createCertDone, o.withbackoff, o.withShortLife)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a gRPC server and register the fake MeshCA on it.
+ server := grpc.NewServer()
+ meshgrpc.RegisterMeshCertificateServiceServer(server, fs)
+
+ // Start a net.Listener on a local port, and pass it to the gRPC server
+ // created above and start serving.
+ lis, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ addr := lis.Addr().String()
+ go server.Serve(lis)
+
+ // Override the plugin's dial function and perform a blocking dial. Also
+ // push on dialDone once the dial is complete so that test can block on this
+ // event before verifying other things.
+ dialDone := testutils.NewChannel()
+ origDialFunc := grpcDialFunc
+ grpcDialFunc = func(uri string, _ ...grpc.DialOption) (*grpc.ClientConn, error) {
+ if uri != addr {
+ t.Fatalf("plugin dialing MeshCA at %s, want %s", uri, addr)
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ cc, err := grpc.DialContext(ctx, uri, grpc.WithInsecure(), grpc.WithBlock())
+ if err != nil {
+ t.Fatalf("grpc.DialContext(%s) failed: %v", addr, err)
+ }
+ dialDone.Send(nil)
+ return cc, nil
+ }
+
+ // Override the plugin's newDistributorFunc and return a wrappedDistributor
+ // which allows the test to be notified whenever the plugin pushes new key
+ // material into the distributor.
+ origDistributorFunc := newDistributorFunc
+ keyMaterialDone := testutils.NewChannel()
+ d := newWrappedDistributor(keyMaterialDone)
+ newDistributorFunc = func() distributor { return d }
+
+ // Override the plugin's backoff function to perform no real backoff, but
+ // push on a channel so that the test can verifiy that backoff actually
+ // happened.
+ backoffDone := testutils.NewChannelWithSize(maxErrCount)
+ origBackoffFunc := backoffFunc
+ if o.withbackoff {
+ // Override the plugin's backoff function with this, so that we can verify
+ // that a backoff actually was triggered.
+ backoffFunc = func(v int) time.Duration {
+ backoffDone.Send(v)
+ return 0
+ }
+ }
+
+ // Return all the channels, and a cancel function to undo all the overrides.
+ e := events{
+ dialDone: dialDone,
+ createCertDone: createCertDone,
+ keyMaterialDone: keyMaterialDone,
+ backoffDone: backoffDone,
+ }
+ done := func() {
+ server.Stop()
+ grpcDialFunc = origDialFunc
+ newDistributorFunc = origDistributorFunc
+ backoffFunc = origBackoffFunc
+ }
+ return e, addr, done
+}
+
+// wrappedDistributor wraps a distributor and pushes on a channel whenever new
+// key material is pushed to the distributor.
+type wrappedDistributor struct {
+ *certprovider.Distributor
+ kmChan *testutils.Channel
+}
+
+func newWrappedDistributor(kmChan *testutils.Channel) *wrappedDistributor {
+ return &wrappedDistributor{
+ kmChan: kmChan,
+ Distributor: certprovider.NewDistributor(),
+ }
+}
+
+func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) {
+ wd.Distributor.Set(km, err)
+ wd.kmChan.Send(nil)
+}
+
+// TestCreateCertificate verifies the simple case where the MeshCA server
+// returns a good certificate.
+func (s) TestCreateCertificate(t *testing.T) {
+ e, addr, cancel := setup(t, opts{})
+ defer cancel()
+
+ // Set the MeshCA targetURI in the plugin configuration to point to our fake
+ // MeshCA.
+ cfg := proto.Clone(goodConfigFullySpecified).(*configpb.GoogleMeshCaConfig)
+ cfg.Server.GrpcServices[0].GetGoogleGrpc().TargetUri = addr
+ inputConfig := makeJSONConfig(t, cfg)
+ prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.Options{})
+ if err != nil {
+ t.Fatalf("certprovider.GetProvider(%s, %s) failed: %v", pluginName, cfg, err)
+ }
+ defer prov.Close()
+
+ // Wait till the plugin dials the MeshCA server.
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ if _, err := e.dialDone.Receive(ctx); err != nil {
+ t.Fatal("timeout waiting for plugin to dial MeshCA")
+ }
+
+ // Wait till the plugin makes a CreateCertificate() call.
+ ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ if _, err := e.createCertDone.Receive(ctx); err != nil {
+ t.Fatal("timeout waiting for plugin to make CreateCertificate RPC")
+ }
+
+ // We don't really care about the exact key material returned here. All we
+ // care about is whether we get any key material at all, and that we don't
+ // get any errors.
+ ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ if _, err = prov.KeyMaterial(ctx); err != nil {
+ t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err)
+ }
+}
+
+// TestCreateCertificateWithBackoff verifies the case where the MeshCA server
+// returns errors initially and then returns a good certificate. The test makes
+// sure that the client backs off when the server returns errors.
+func (s) TestCreateCertificateWithBackoff(t *testing.T) {
+ e, addr, cancel := setup(t, opts{withbackoff: true})
+ defer cancel()
+
+ // Set the MeshCA targetURI in the plugin configuration to point to our fake
+ // MeshCA.
+ cfg := proto.Clone(goodConfigFullySpecified).(*configpb.GoogleMeshCaConfig)
+ cfg.Server.GrpcServices[0].GetGoogleGrpc().TargetUri = addr
+ inputConfig := makeJSONConfig(t, cfg)
+ prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.Options{})
+ if err != nil {
+ t.Fatalf("certprovider.GetProvider(%s, %s) failed: %v", pluginName, cfg, err)
+ }
+ defer prov.Close()
+
+ // Wait till the plugin dials the MeshCA server.
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ if _, err := e.dialDone.Receive(ctx); err != nil {
+ t.Fatal("timeout waiting for plugin to dial MeshCA")
+ }
+
+ // Making the CreateCertificateRPC involves generating the keys, creating
+ // the CSR etc which seem to take reasonable amount of time. And in this
+ // test, the first two attempts will fail. Hence we give it a reasonable
+ // deadline here.
+ ctx, cancel = context.WithTimeout(context.Background(), 3*defaultTestTimeout)
+ defer cancel()
+ if _, err := e.createCertDone.Receive(ctx); err != nil {
+ t.Fatal("timeout waiting for plugin to make CreateCertificate RPC")
+ }
+
+ // The first `maxErrCount` calls to CreateCertificate end in failure, and
+ // should lead to a backoff.
+ for i := 0; i < maxErrCount; i++ {
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ if _, err := e.backoffDone.Receive(ctx); err != nil {
+ t.Fatalf("plugin failed to backoff after error from fake server: %v", err)
+ }
+ }
+
+ // We don't really care about the exact key material returned here. All we
+ // care about is whether we get any key material at all, and that we don't
+ // get any errors.
+ ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ if _, err = prov.KeyMaterial(ctx); err != nil {
+ t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err)
+ }
+}
+
+// TestCreateCertificateWithRefresh verifies the case where the MeshCA returns a
+// certificate with a really short lifetime, and makes sure that the plugin
+// refreshes the cert in time.
+func (s) TestCreateCertificateWithRefresh(t *testing.T) {
+ e, addr, cancel := setup(t, opts{withShortLife: true})
+ defer cancel()
+
+ // Set the MeshCA targetURI in the plugin configuration to point to our fake
+ // MeshCA.
+ cfg := proto.Clone(goodConfigFullySpecified).(*configpb.GoogleMeshCaConfig)
+ cfg.Server.GrpcServices[0].GetGoogleGrpc().TargetUri = addr
+ inputConfig := makeJSONConfig(t, cfg)
+ prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.Options{})
+ if err != nil {
+ t.Fatalf("certprovider.GetProvider(%s, %s) failed: %v", pluginName, cfg, err)
+ }
+ defer prov.Close()
+
+ // Wait till the plugin dials the MeshCA server.
+ ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ if _, err := e.dialDone.Receive(ctx); err != nil {
+ t.Fatal("timeout waiting for plugin to dial MeshCA")
+ }
+
+ // Wait till the plugin makes a CreateCertificate() call.
+ ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ if _, err := e.createCertDone.Receive(ctx); err != nil {
+ t.Fatal("timeout waiting for plugin to make CreateCertificate RPC")
+ }
+
+ ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ km1, err := prov.KeyMaterial(ctx)
+ if err != nil {
+ t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err)
+ }
+
+ // At this point, we have read the first key material, and since the
+ // returned key material has a really short validity period, we expect the
+ // key material to be refreshed quite soon. We drain the channel on which
+ // the event corresponding to setting of new key material is pushed. This
+ // enables us to block on the same channel, waiting for refreshed key
+ // material.
+ // Since we do not expect this call to block, it is OK to pass the
+ // background context.
+ e.keyMaterialDone.Receive(context.Background())
+
+ // Wait for the next call to CreateCertificate() to refresh the certificate
+ // returned earlier.
+ ctx, cancel = context.WithTimeout(context.Background(), 2*shortTestCertLife)
+ defer cancel()
+ if _, err := e.keyMaterialDone.Receive(ctx); err != nil {
+ t.Fatalf("CreateCertificate() RPC not made: %v", err)
+ }
+
+ ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
+ defer cancel()
+ km2, err := prov.KeyMaterial(ctx)
+ if err != nil {
+ t.Fatalf("provider.KeyMaterial(ctx) failed: %v", err)
+ }
+
+ // TODO(easwars): Remove all references to reflect.DeepEqual and use
+ // cmp.Equal instead. Currently, the later panics because x509.Certificate
+ // type defines an Equal method, but does not check for nil. This has been
+ // fixed in
+ // https://github.com/golang/go/commit/89865f8ba64ccb27f439cce6daaa37c9aa38f351,
+ // but this is only available starting go1.14. So, once we remove support
+ // for go1.13, we can make the switch.
+ if reflect.DeepEqual(km1, km2) {
+ t.Error("certificate refresh did not happen in the background")
+ }
+}
diff --git a/examples/go.sum b/examples/go.sum
index 1ef6201..fae9a83 100644
--- a/examples/go.sum
+++ b/examples/go.sum
@@ -93,6 +93,7 @@
github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
+github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
diff --git a/go.mod b/go.mod
index 31f2b01..0bcae73 100644
--- a/go.mod
+++ b/go.mod
@@ -8,6 +8,7 @@
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b
github.com/golang/protobuf v1.3.3
github.com/google/go-cmp v0.4.0
+ github.com/google/uuid v1.1.2
golang.org/x/net v0.0.0-20190311183353-d8887717615a
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a
diff --git a/go.sum b/go.sum
index be8078e..bab616e 100644
--- a/go.sum
+++ b/go.sum
@@ -25,6 +25,8 @@
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
+github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
diff --git a/internal/grpctest/tlogger.go b/internal/grpctest/tlogger.go
index a074fbf..95c3598 100644
--- a/internal/grpctest/tlogger.go
+++ b/internal/grpctest/tlogger.go
@@ -26,7 +26,6 @@
"regexp"
"runtime"
"strconv"
- "strings"
"sync"
"testing"
"time"
@@ -51,6 +50,7 @@
type tLogger struct {
v int
t *testing.T
+ start time.Time
initialized bool
m sync.Mutex // protects errors
@@ -74,14 +74,6 @@
return fmt.Sprintf("%s:%d", path.Base(file), line), nil
}
-// Returns the last component of the stringified current time, which is of the
-// format "m=±<value>", where value is the monotonic clock reading formatted as
-// a decimal number of seconds.
-func getTimeSuffix() string {
- parts := strings.Split(time.Now().String(), " ")
- return fmt.Sprintf(" (%s)", parts[len(parts)-1])
-}
-
// log logs the message with the specified parameters to the tLogger.
func (g *tLogger) log(ltype logType, depth int, format string, args ...interface{}) {
prefix, err := getCallingPrefix(callingFrame + depth)
@@ -90,7 +82,7 @@
return
}
args = append([]interface{}{prefix}, args...)
- args = append(args, getTimeSuffix())
+ args = append(args, fmt.Sprintf(" (t=+%s)", time.Since(g.start)))
if format == "" {
switch ltype {
@@ -107,7 +99,8 @@
g.t.Log(args...)
}
} else {
- format = "%v " + format
+ // Add formatting directives for the callingPrefix and timeSuffix.
+ format = "%v " + format + "%s"
switch ltype {
case errorLog:
if g.expected(fmt.Sprintf(format, args...)) {
@@ -131,6 +124,7 @@
g.initialized = true
}
g.t = t
+ g.start = time.Now()
g.m.Lock()
defer g.m.Unlock()
g.errors = map[*regexp.Regexp]int{}
diff --git a/internal/leakcheck/leakcheck.go b/internal/leakcheck/leakcheck.go
index 946c575..1d4fcef 100644
--- a/internal/leakcheck/leakcheck.go
+++ b/internal/leakcheck/leakcheck.go
@@ -42,6 +42,7 @@
"runtime_mcall",
"(*loggingT).flushDaemon",
"goroutine in C code",
+ "httputil.DumpRequestOut", // TODO: Remove this once Go1.13 support is removed. https://github.com/golang/go/issues/37669.
}
// RegisterIgnoreGoroutine appends s into the ignore goroutine list. The
diff --git a/internal/testutils/http_client.go b/internal/testutils/http_client.go
new file mode 100644
index 0000000..9832bf3
--- /dev/null
+++ b/internal/testutils/http_client.go
@@ -0,0 +1,63 @@
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package testutils
+
+import (
+ "context"
+ "net/http"
+ "time"
+)
+
+// DefaultHTTPRequestTimeout is the default timeout value for the amount of time
+// this client waits for a response to be pushed on RespChan before it fails the
+// Do() call.
+const DefaultHTTPRequestTimeout = 1 * time.Second
+
+// FakeHTTPClient helps mock out HTTP calls made by the code under test. It
+// makes HTTP requests made by the code under test available through a channel,
+// and makes it possible to inject various responses.
+type FakeHTTPClient struct {
+ // ReqChan exposes the HTTP.Request made by the code under test.
+ ReqChan *Channel
+ // RespChan is a channel on which this fake client accepts responses to be
+ // sent to the code under test.
+ RespChan *Channel
+ // Err, if set, is returned by Do().
+ Err error
+ // RecvTimeout is the amount of the time this client waits for a response to
+ // be pushed on RespChan before it fails the Do() call. If this field is
+ // left unspecified, DefaultHTTPRequestTimeout is used.
+ RecvTimeout time.Duration
+}
+
+// Do pushes req on ReqChan and returns the response available on RespChan.
+func (fc *FakeHTTPClient) Do(req *http.Request) (*http.Response, error) {
+ fc.ReqChan.Send(req)
+
+ timeout := fc.RecvTimeout
+ if timeout == 0 {
+ timeout = DefaultHTTPRequestTimeout
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ val, err := fc.RespChan.Receive(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return val.(*http.Response), fc.Err
+}
diff --git a/security/advancedtls/go.sum b/security/advancedtls/go.sum
index b1759c0..f2ab78d 100644
--- a/security/advancedtls/go.sum
+++ b/security/advancedtls/go.sum
@@ -87,6 +87,7 @@
github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
+github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=