| /* |
| * |
| * Copyright 2020 gRPC authors. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| package xds |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "crypto/x509" |
| "errors" |
| "fmt" |
| "io/ioutil" |
| "net" |
| "strings" |
| "testing" |
| "time" |
| |
| "google.golang.org/grpc/credentials" |
| "google.golang.org/grpc/credentials/tls/certprovider" |
| "google.golang.org/grpc/internal" |
| xdsinternal "google.golang.org/grpc/internal/credentials/xds" |
| "google.golang.org/grpc/internal/grpctest" |
| "google.golang.org/grpc/internal/testutils" |
| "google.golang.org/grpc/resolver" |
| "google.golang.org/grpc/testdata" |
| ) |
| |
| const ( |
| defaultTestTimeout = 10 * time.Second |
| defaultTestShortTimeout = 10 * time.Millisecond |
| defaultTestCertSAN = "*.test.example.com" |
| authority = "authority" |
| ) |
| |
| type s struct { |
| grpctest.Tester |
| } |
| |
| func Test(t *testing.T) { |
| grpctest.RunSubTests(t, s{}) |
| } |
| |
| // Helper function to create a real TLS client credentials which is used as |
| // fallback credentials from multiple tests. |
| func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials { |
| creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") |
| if err != nil { |
| t.Fatal(err) |
| } |
| return creds |
| } |
| |
| // testServer is a no-op server which listens on a local TCP port for incoming |
| // connections, and performs a manual TLS handshake on the received raw |
| // connection using a user specified handshake function. It then makes the |
| // result of the handshake operation available through a channel for tests to |
| // inspect. Tests should stop the testServer as part of their cleanup. |
| type testServer struct { |
| lis net.Listener |
| address string // Listening address of the test server. |
| handshakeFunc testHandshakeFunc // Test specified handshake function. |
| hsResult *testutils.Channel // Channel to deliver handshake results. |
| } |
| |
| // handshakeResult wraps the result of the handshake operation on the test |
| // server. It consists of TLS connection state and an error, if the handshake |
| // failed. This result is delivered on the `hsResult` channel on the testServer. |
| type handshakeResult struct { |
| connState tls.ConnectionState |
| err error |
| } |
| |
| // Configurable handshake function for the testServer. Tests can set this to |
| // simulate different conditions like handshake success, failure, timeout etc. |
| type testHandshakeFunc func(net.Conn) handshakeResult |
| |
| // newTestServerWithHandshakeFunc starts a new testServer which listens for |
| // connections on a local TCP port, and uses the provided custom handshake |
| // function to perform TLS handshake. |
| func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer { |
| ts := &testServer{ |
| handshakeFunc: f, |
| hsResult: testutils.NewChannel(), |
| } |
| ts.start() |
| return ts |
| } |
| |
| // starts actually starts listening on a local TCP port, and spawns a goroutine |
| // to handle new connections. |
| func (ts *testServer) start() error { |
| lis, err := net.Listen("tcp", "localhost:0") |
| if err != nil { |
| return err |
| } |
| ts.lis = lis |
| ts.address = lis.Addr().String() |
| go ts.handleConn() |
| return nil |
| } |
| |
| // handleconn accepts a new raw connection, and invokes the test provided |
| // handshake function to perform TLS handshake, and returns the result on the |
| // `hsResult` channel. |
| func (ts *testServer) handleConn() { |
| for { |
| rawConn, err := ts.lis.Accept() |
| if err != nil { |
| // Once the listeners closed, Accept() will return with an error. |
| return |
| } |
| hsr := ts.handshakeFunc(rawConn) |
| ts.hsResult.Send(hsr) |
| } |
| } |
| |
| // stop closes the associated listener which causes the connection handling |
| // goroutine to exit. |
| func (ts *testServer) stop() { |
| ts.lis.Close() |
| } |
| |
| // A handshake function which simulates a successful handshake without client |
| // authentication (server does not request for client certificate during the |
| // handshake here). |
| func testServerTLSHandshake(rawConn net.Conn) handshakeResult { |
| cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) |
| if err != nil { |
| return handshakeResult{err: err} |
| } |
| cfg := &tls.Config{Certificates: []tls.Certificate{cert}} |
| conn := tls.Server(rawConn, cfg) |
| if err := conn.Handshake(); err != nil { |
| return handshakeResult{err: err} |
| } |
| return handshakeResult{connState: conn.ConnectionState()} |
| } |
| |
| // A handshake function which simulates a successful handshake with mutual |
| // authentication. |
| func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult { |
| cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) |
| if err != nil { |
| return handshakeResult{err: err} |
| } |
| pemData, err := ioutil.ReadFile(testdata.Path("x509/client_ca_cert.pem")) |
| if err != nil { |
| return handshakeResult{err: err} |
| } |
| roots := x509.NewCertPool() |
| roots.AppendCertsFromPEM(pemData) |
| cfg := &tls.Config{ |
| Certificates: []tls.Certificate{cert}, |
| ClientCAs: roots, |
| } |
| conn := tls.Server(rawConn, cfg) |
| if err := conn.Handshake(); err != nil { |
| return handshakeResult{err: err} |
| } |
| return handshakeResult{connState: conn.ConnectionState()} |
| } |
| |
| // fakeProvider is an implementation of the certprovider.Provider interface |
| // which returns the configured key material and error in calls to |
| // KeyMaterial(). |
| type fakeProvider struct { |
| km *certprovider.KeyMaterial |
| err error |
| } |
| |
| func (f *fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { |
| return f.km, f.err |
| } |
| |
| func (f *fakeProvider) Close() {} |
| |
| // makeIdentityProvider creates a new instance of the fakeProvider returning the |
| // identity key material specified in the provider file paths. |
| func makeIdentityProvider(t *testing.T, certPath, keyPath string) certprovider.Provider { |
| t.Helper() |
| cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| return &fakeProvider{km: &certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}} |
| } |
| |
| // makeRootProvider creates a new instance of the fakeProvider returning the |
| // root key material specified in the provider file paths. |
| func makeRootProvider(t *testing.T, caPath string) *fakeProvider { |
| pemData, err := ioutil.ReadFile(testdata.Path(caPath)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| roots := x509.NewCertPool() |
| roots.AppendCertsFromPEM(pemData) |
| return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}} |
| } |
| |
| // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo |
| // context value added to it. |
| func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sans ...string) context.Context { |
| // Creating the HandshakeInfo and adding it to the attributes is very |
| // similar to what the CDS balancer would do when it intercepts calls to |
| // NewSubConn(). |
| info := xdsinternal.NewHandshakeInfo(root, identity, sans...) |
| addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info) |
| |
| // Moving the attributes from the resolver.Address to the context passed to |
| // the handshaker is done in the transport layer. Since we directly call the |
| // handshaker in these tests, we need to do the same here. |
| contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) |
| return contextWithHandshakeInfo(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) |
| } |
| |
| // compareAuthInfo compares the AuthInfo received on the client side after a |
| // successful handshake with the authInfo available on the testServer. |
| func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error { |
| if ai.AuthType() != "tls" { |
| return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls") |
| } |
| info, ok := ai.(credentials.TLSInfo) |
| if !ok { |
| return fmt.Errorf("ClientHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{}) |
| } |
| gotState := info.State |
| |
| // Read the handshake result from the testServer which contains the TLS |
| // connection state and compare it with the one received on the client-side. |
| val, err := ts.hsResult.Receive(ctx) |
| if err != nil { |
| return fmt.Errorf("testServer failed to return handshake result: %v", err) |
| } |
| hsr := val.(handshakeResult) |
| if hsr.err != nil { |
| return fmt.Errorf("testServer handshake failure: %v", hsr.err) |
| } |
| // AuthInfo contains a variety of information. We only verify a subset here. |
| // This is the same subset which is verified in TLS credentials tests. |
| if err := compareConnState(gotState, hsr.connState); err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| func compareConnState(got, want tls.ConnectionState) error { |
| switch { |
| case got.Version != want.Version: |
| return fmt.Errorf("TLS.ConnectionState got Version: %v, want: %v", got.Version, want.Version) |
| case got.HandshakeComplete != want.HandshakeComplete: |
| return fmt.Errorf("TLS.ConnectionState got HandshakeComplete: %v, want: %v", got.HandshakeComplete, want.HandshakeComplete) |
| case got.CipherSuite != want.CipherSuite: |
| return fmt.Errorf("TLS.ConnectionState got CipherSuite: %v, want: %v", got.CipherSuite, want.CipherSuite) |
| case got.NegotiatedProtocol != want.NegotiatedProtocol: |
| return fmt.Errorf("TLS.ConnectionState got NegotiatedProtocol: %v, want: %v", got.NegotiatedProtocol, want.NegotiatedProtocol) |
| } |
| return nil |
| } |
| |
| // TestClientCredsWithoutFallback verifies that the call to |
| // NewClientCredentials() fails when no fallback is specified. |
| func (s) TestClientCredsWithoutFallback(t *testing.T) { |
| if _, err := NewClientCredentials(ClientOptions{}); err == nil { |
| t.Fatal("NewClientCredentials() succeeded without specifying fallback") |
| } |
| } |
| |
| // TestClientCredsInvalidHandshakeInfo verifies scenarios where the passed in |
| // HandshakeInfo is invalid because it does not contain the expected certificate |
| // providers. |
| func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) { |
| opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} |
| creds, err := NewClientCredentials(opts) |
| if err != nil { |
| t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) |
| } |
| |
| pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) |
| defer cancel() |
| ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}) |
| if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil { |
| t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo") |
| } |
| } |
| |
| // TestClientCredsProviderFailure verifies the cases where an expected |
| // certificate provider is missing in the HandshakeInfo value in the context. |
| func (s) TestClientCredsProviderFailure(t *testing.T) { |
| opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} |
| creds, err := NewClientCredentials(opts) |
| if err != nil { |
| t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) |
| } |
| |
| tests := []struct { |
| desc string |
| rootProvider certprovider.Provider |
| identityProvider certprovider.Provider |
| wantErr string |
| }{ |
| { |
| desc: "erroring root provider", |
| rootProvider: &fakeProvider{err: errors.New("root provider error")}, |
| wantErr: "root provider error", |
| }, |
| { |
| desc: "erroring identity provider", |
| rootProvider: &fakeProvider{km: &certprovider.KeyMaterial{}}, |
| identityProvider: &fakeProvider{err: errors.New("identity provider error")}, |
| wantErr: "identity provider error", |
| }, |
| } |
| for _, test := range tests { |
| t.Run(test.desc, func(t *testing.T) { |
| ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) |
| defer cancel() |
| ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider) |
| if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) { |
| t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr) |
| } |
| }) |
| } |
| } |
| |
| // TestClientCredsSuccess verifies successful client handshake cases. |
| func (s) TestClientCredsSuccess(t *testing.T) { |
| tests := []struct { |
| desc string |
| handshakeFunc testHandshakeFunc |
| handshakeInfoCtx func(ctx context.Context) context.Context |
| }{ |
| { |
| desc: "fallback", |
| handshakeFunc: testServerTLSHandshake, |
| handshakeInfoCtx: func(ctx context.Context) context.Context { |
| // Since we don't add a HandshakeInfo to the context, the |
| // ClientHandshake() method will delegate to the fallback. |
| return ctx |
| }, |
| }, |
| { |
| desc: "TLS", |
| handshakeFunc: testServerTLSHandshake, |
| handshakeInfoCtx: func(ctx context.Context) context.Context { |
| return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN) |
| }, |
| }, |
| { |
| desc: "mTLS", |
| handshakeFunc: testServerMutualTLSHandshake, |
| handshakeInfoCtx: func(ctx context.Context) context.Context { |
| return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), defaultTestCertSAN) |
| }, |
| }, |
| { |
| desc: "mTLS with no acceptedSANs specified", |
| handshakeFunc: testServerMutualTLSHandshake, |
| handshakeInfoCtx: func(ctx context.Context) context.Context { |
| return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem")) |
| }, |
| }, |
| } |
| |
| for _, test := range tests { |
| t.Run(test.desc, func(t *testing.T) { |
| ts := newTestServerWithHandshakeFunc(test.handshakeFunc) |
| defer ts.stop() |
| |
| opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} |
| creds, err := NewClientCredentials(opts) |
| if err != nil { |
| t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) |
| } |
| |
| conn, err := net.Dial("tcp", ts.address) |
| if err != nil { |
| t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) |
| } |
| defer conn.Close() |
| |
| ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) |
| defer cancel() |
| _, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn) |
| if err != nil { |
| t.Fatalf("ClientHandshake() returned failed: %q", err) |
| } |
| if err := compareAuthInfo(ctx, ts, ai); err != nil { |
| t.Fatal(err) |
| } |
| }) |
| } |
| } |
| |
| func (s) TestClientCredsHandshakeTimeout(t *testing.T) { |
| clientDone := make(chan struct{}) |
| // A handshake function which simulates a handshake timeout from the |
| // server-side by simply blocking on the client-side handshake to timeout |
| // and not writing any handshake data. |
| hErr := errors.New("server handshake error") |
| ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { |
| <-clientDone |
| return handshakeResult{err: hErr} |
| }) |
| defer ts.stop() |
| |
| opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} |
| creds, err := NewClientCredentials(opts) |
| if err != nil { |
| t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) |
| } |
| |
| conn, err := net.Dial("tcp", ts.address) |
| if err != nil { |
| t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) |
| } |
| defer conn.Close() |
| |
| sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) |
| defer sCancel() |
| ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN) |
| if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { |
| t.Fatal("ClientHandshake() succeeded when expected to timeout") |
| } |
| close(clientDone) |
| |
| // Read the handshake result from the testServer and make sure the expected |
| // error is returned. |
| ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) |
| defer cancel() |
| val, err := ts.hsResult.Receive(ctx) |
| if err != nil { |
| t.Fatalf("testServer failed to return handshake result: %v", err) |
| } |
| hsr := val.(handshakeResult) |
| if hsr.err != hErr { |
| t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr) |
| } |
| } |
| |
| // TestClientCredsHandshakeFailure verifies different handshake failure cases. |
| func (s) TestClientCredsHandshakeFailure(t *testing.T) { |
| tests := []struct { |
| desc string |
| handshakeFunc testHandshakeFunc |
| rootProvider certprovider.Provider |
| san string |
| wantErr string |
| }{ |
| { |
| desc: "cert validation failure", |
| handshakeFunc: testServerTLSHandshake, |
| rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"), |
| san: defaultTestCertSAN, |
| wantErr: "x509: certificate signed by unknown authority", |
| }, |
| { |
| desc: "SAN mismatch", |
| handshakeFunc: testServerTLSHandshake, |
| rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"), |
| san: "bad-san", |
| wantErr: "does not match any of the accepted SANs", |
| }, |
| } |
| |
| for _, test := range tests { |
| t.Run(test.desc, func(t *testing.T) { |
| ts := newTestServerWithHandshakeFunc(test.handshakeFunc) |
| defer ts.stop() |
| |
| opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} |
| creds, err := NewClientCredentials(opts) |
| if err != nil { |
| t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) |
| } |
| |
| conn, err := net.Dial("tcp", ts.address) |
| if err != nil { |
| t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) |
| } |
| defer conn.Close() |
| |
| ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) |
| defer cancel() |
| ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san) |
| if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) { |
| t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr) |
| } |
| }) |
| } |
| } |
| |
| // TestClientCredsProviderSwitch verifies the case where the first attempt of |
| // ClientHandshake fails because of a handshake failure. Then we update the |
| // certificate provider and the second attempt succeeds. This is an |
| // approximation of the flow of events when the control plane specifies new |
| // security config which results in new certificate providers being used. |
| func (s) TestClientCredsProviderSwitch(t *testing.T) { |
| ts := newTestServerWithHandshakeFunc(testServerTLSHandshake) |
| defer ts.stop() |
| |
| opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} |
| creds, err := NewClientCredentials(opts) |
| if err != nil { |
| t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) |
| } |
| |
| conn, err := net.Dial("tcp", ts.address) |
| if err != nil { |
| t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) |
| } |
| defer conn.Close() |
| |
| ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) |
| defer cancel() |
| // Create a root provider which will fail the handshake because it does not |
| // use the correct trust roots. |
| root1 := makeRootProvider(t, "x509/client_ca_cert.pem") |
| handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, defaultTestCertSAN) |
| |
| // We need to repeat most of what newTestContextWithHandshakeInfo() does |
| // here because we need access to the underlying HandshakeInfo so that we |
| // can update it before the next call to ClientHandshake(). |
| addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo) |
| contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) |
| ctx = contextWithHandshakeInfo(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) |
| if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil { |
| t.Fatal("ClientHandshake() succeeded when expected to fail") |
| } |
| // Drain the result channel on the test server so that we can inspect the |
| // result for the next handshake. |
| _, err = ts.hsResult.Receive(ctx) |
| if err != nil { |
| t.Errorf("testServer failed to return handshake result: %v", err) |
| } |
| |
| conn, err = net.Dial("tcp", ts.address) |
| if err != nil { |
| t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) |
| } |
| defer conn.Close() |
| |
| // Create a new root provider which uses the correct trust roots. And update |
| // the HandshakeInfo with the new provider. |
| root2 := makeRootProvider(t, "x509/server_ca_cert.pem") |
| handshakeInfo.SetRootCertProvider(root2) |
| _, ai, err := creds.ClientHandshake(ctx, authority, conn) |
| if err != nil { |
| t.Fatalf("ClientHandshake() returned failed: %q", err) |
| } |
| if err := compareAuthInfo(ctx, ts, ai); err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| // TestClientClone verifies the Clone() method on client credentials. |
| func (s) TestClientClone(t *testing.T) { |
| opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)} |
| orig, err := NewClientCredentials(opts) |
| if err != nil { |
| t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err) |
| } |
| |
| // The credsImpl does not have any exported fields, and it does not make |
| // sense to use any cmp options to look deep into. So, all we make sure here |
| // is that the cloned object points to a different location in memory. |
| if clone := orig.Clone(); clone == orig { |
| t.Fatal("return value from Clone() doesn't point to new credentials instance") |
| } |
| } |