advancedtls: clean up test files and shared code (#3897)
* advancedtls: clean up test files and shared code
diff --git a/security/advancedtls/advancedtls.go b/security/advancedtls/advancedtls.go
index 7456463..ea93d64 100644
--- a/security/advancedtls/advancedtls.go
+++ b/security/advancedtls/advancedtls.go
@@ -28,7 +28,6 @@
"fmt"
"net"
"reflect"
- "syscall"
"time"
"google.golang.org/grpc/credentials"
@@ -374,7 +373,7 @@
func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
// Use local cfg to avoid clobbering ServerName if using multiple endpoints.
- cfg := cloneTLSConfig(c.config)
+ cfg := credinternal.CloneTLSConfig(c.config)
// We return the full authority name to users if ServerName is empty without
// stripping the trailing port.
if cfg.ServerName == "" {
@@ -404,11 +403,11 @@
},
}
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
- return WrapSyscallConn(rawConn, conn), info, nil
+ return credinternal.WrapSyscallConn(rawConn, conn), info, nil
}
func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
- cfg := cloneTLSConfig(c.config)
+ cfg := credinternal.CloneTLSConfig(c.config)
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
conn := tls.Server(rawConn, cfg)
if err := conn.Handshake(); err != nil {
@@ -422,12 +421,12 @@
},
}
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
- return WrapSyscallConn(rawConn, conn), info, nil
+ return credinternal.WrapSyscallConn(rawConn, conn), info, nil
}
func (c *advancedTLSCreds) Clone() credentials.TransportCredentials {
return &advancedTLSCreds{
- config: cloneTLSConfig(c.config),
+ config: credinternal.CloneTLSConfig(c.config),
verifyFunc: c.verifyFunc,
getRootCAs: c.getRootCAs,
isClient: c.isClient,
@@ -530,7 +529,7 @@
verifyFunc: o.VerifyPeer,
vType: o.VType,
}
- tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
+ tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil
}
@@ -548,64 +547,6 @@
verifyFunc: o.VerifyPeer,
vType: o.VType,
}
- tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
+ tc.config.NextProtos = credinternal.AppendH2ToNextProtos(tc.config.NextProtos)
return tc, nil
}
-
-// TODO(ZhenLian): The code below are duplicates with gRPC-Go under
-// credentials/internal. Consider refactoring in the future.
-const alpnProtoStrH2 = "h2"
-
-func appendH2ToNextProtos(ps []string) []string {
- for _, p := range ps {
- if p == alpnProtoStrH2 {
- return ps
- }
- }
- ret := make([]string, 0, len(ps)+1)
- ret = append(ret, ps...)
- return append(ret, alpnProtoStrH2)
-}
-
-// We give syscall.Conn a new name here since syscall.Conn and net.Conn used
-// below have the same names.
-type sysConn = syscall.Conn
-
-// syscallConn keeps reference of rawConn to support syscall.Conn for channelz.
-// SyscallConn() (the method in interface syscall.Conn) is explicitly
-// implemented on this type,
-//
-// Interface syscall.Conn is implemented by most net.Conn implementations (e.g.
-// TCPConn, UnixConn), but is not part of net.Conn interface. So wrapper conns
-// that embed net.Conn don't implement syscall.Conn. (Side note: tls.Conn
-// doesn't embed net.Conn, so even if syscall.Conn is part of net.Conn, it won't
-// help here).
-type syscallConn struct {
- net.Conn
- // sysConn is a type alias of syscall.Conn. It's necessary because the name
- // `Conn` collides with `net.Conn`.
- sysConn
-}
-
-// WrapSyscallConn tries to wrap rawConn and newConn into a net.Conn that
-// implements syscall.Conn. rawConn will be used to support syscall, and newConn
-// will be used for read/write.
-//
-// This function returns newConn if rawConn doesn't implement syscall.Conn.
-func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn {
- sysConn, ok := rawConn.(syscall.Conn)
- if !ok {
- return newConn
- }
- return &syscallConn{
- Conn: newConn,
- sysConn: sysConn,
- }
-}
-
-func cloneTLSConfig(cfg *tls.Config) *tls.Config {
- if cfg == nil {
- return &tls.Config{}
- }
- return cfg.Clone()
-}
diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go
index a95fa56..3f4e705 100644
--- a/security/advancedtls/advancedtls_integration_test.go
+++ b/security/advancedtls/advancedtls_integration_test.go
@@ -31,7 +31,7 @@
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
pb "google.golang.org/grpc/examples/helloworld/helloworld"
- "google.golang.org/grpc/security/advancedtls/testdata"
+ "google.golang.org/grpc/security/advancedtls/internal/testutils"
)
var (
@@ -67,69 +67,6 @@
s.stage = 0
}
-// certStore contains all the certificates used in the integration tests.
-type certStore struct {
- // clientPeer1 is the certificate sent by client to prove its identity.
- // It is trusted by serverTrust1.
- clientPeer1 tls.Certificate
- // clientPeer2 is the certificate sent by client to prove its identity.
- // It is trusted by serverTrust2.
- clientPeer2 tls.Certificate
- // serverPeer1 is the certificate sent by server to prove its identity.
- // It is trusted by clientTrust1.
- serverPeer1 tls.Certificate
- // serverPeer2 is the certificate sent by server to prove its identity.
- // It is trusted by clientTrust2.
- serverPeer2 tls.Certificate
- clientTrust1 *x509.CertPool
- clientTrust2 *x509.CertPool
- serverTrust1 *x509.CertPool
- serverTrust2 *x509.CertPool
-}
-
-// loadCerts function is used to load test certificates at the beginning of
-// each integration test.
-func (cs *certStore) loadCerts() error {
- var err error
- cs.clientPeer1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"),
- testdata.Path("client_key_1.pem"))
- if err != nil {
- return err
- }
- cs.clientPeer2, err = tls.LoadX509KeyPair(testdata.Path("client_cert_2.pem"),
- testdata.Path("client_key_2.pem"))
- if err != nil {
- return err
- }
- cs.serverPeer1, err = tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
- testdata.Path("server_key_1.pem"))
- if err != nil {
- return err
- }
- cs.serverPeer2, err = tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"),
- testdata.Path("server_key_2.pem"))
- if err != nil {
- return err
- }
- cs.clientTrust1, err = readTrustCert(testdata.Path("client_trust_cert_1.pem"))
- if err != nil {
- return err
- }
- cs.clientTrust2, err = readTrustCert(testdata.Path("client_trust_cert_2.pem"))
- if err != nil {
- return err
- }
- cs.serverTrust1, err = readTrustCert(testdata.Path("server_trust_cert_1.pem"))
- if err != nil {
- return err
- }
- cs.serverTrust2, err = readTrustCert(testdata.Path("server_trust_cert_2.pem"))
- if err != nil {
- return err
- }
- return nil
-}
-
type greeterServer struct {
pb.UnimplementedGreeterServer
}
@@ -183,10 +120,9 @@
// (could be change the client's trust certificate, or change custom
// verification function, etc)
func (s) TestEnd2End(t *testing.T) {
- cs := &certStore{}
- err := cs.loadCerts()
- if err != nil {
- t.Fatalf("failed to load certs: %v", err)
+ cs := &testutils.CertStore{}
+ if err := cs.LoadCerts(); err != nil {
+ t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
stage := &stageInfo{}
for _, test := range []struct {
@@ -206,38 +142,38 @@
}{
// Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert
- // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
- // The mutual authentication works at the beginning, since clientPeer1 is
- // trusted by serverTrust1, and serverPeer1 by clientTrust1.
- // At stage 1, client changes clientPeer1 to clientPeer2. Since clientPeer2
- // is not trusted by serverTrust1, following rpc calls are expected to
+ // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
+ // The mutual authentication works at the beginning, since ClientCert1 is
+ // trusted by ServerTrust1, and ServerCert1 by ClientTrust1.
+ // At stage 1, client changes ClientCert1 to ClientCert2. Since ClientCert2
+ // is not trusted by ServerTrust1, following rpc calls are expected to
// fail, while the previous rpc calls are still good because those are
// already authenticated.
- // At stage 2, the server changes serverTrust1 to serverTrust2, and we
- // should see it again accepts the connection, since clientPeer2 is trusted
- // by serverTrust2.
+ // At stage 2, the server changes ServerTrust1 to ServerTrust2, and we
+ // should see it again accepts the connection, since ClientCert2 is trusted
+ // by ServerTrust2.
{
desc: "TestClientPeerCertReloadServerTrustCertReload",
clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
switch stage.read() {
case 0:
- return &cs.clientPeer1, nil
+ return &cs.ClientCert1, nil
default:
- return &cs.clientPeer2, nil
+ return &cs.ClientCert2, nil
}
},
- clientRoot: cs.clientTrust1,
+ clientRoot: cs.ClientTrust1,
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
clientVType: CertVerification,
- serverCert: []tls.Certificate{cs.serverPeer1},
+ serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
switch stage.read() {
case 0, 1:
- return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil
+ return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
default:
- return &GetRootCAsResults{TrustCerts: cs.serverTrust2}, nil
+ return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil
}
},
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
@@ -247,25 +183,25 @@
},
// Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert
- // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
- // The mutual authentication works at the beginning, since clientPeer1 is
- // trusted by serverTrust1, and serverPeer1 by clientTrust1.
- // At stage 1, server changes serverPeer1 to serverPeer2. Since serverPeer2
- // is not trusted by clientTrust1, following rpc calls are expected to
+ // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
+ // The mutual authentication works at the beginning, since ClientCert1 is
+ // trusted by ServerTrust1, and ServerCert1 by ClientTrust1.
+ // At stage 1, server changes ServerCert1 to ServerCert2. Since ServerCert2
+ // is not trusted by ClientTrust1, following rpc calls are expected to
// fail, while the previous rpc calls are still good because those are
// already authenticated.
- // At stage 2, the client changes clientTrust1 to clientTrust2, and we
- // should see it again accepts the connection, since serverPeer2 is trusted
- // by clientTrust2.
+ // At stage 2, the client changes ClientTrust1 to ClientTrust2, and we
+ // should see it again accepts the connection, since ServerCert2 is trusted
+ // by ClientTrust2.
{
desc: "TestServerPeerCertReloadClientTrustCertReload",
- clientCert: []tls.Certificate{cs.clientPeer1},
+ clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
switch stage.read() {
case 0, 1:
- return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
+ return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
default:
- return &GetRootCAsResults{TrustCerts: cs.clientTrust2}, nil
+ return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
}
},
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
@@ -275,12 +211,12 @@
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
switch stage.read() {
case 0:
- return []*tls.Certificate{&cs.serverPeer1}, nil
+ return []*tls.Certificate{&cs.ServerCert1}, nil
default:
- return []*tls.Certificate{&cs.serverPeer2}, nil
+ return []*tls.Certificate{&cs.ServerCert2}, nil
}
},
- serverRoot: cs.serverTrust1,
+ serverRoot: cs.ServerTrust1,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
@@ -288,26 +224,26 @@
},
// Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert
- // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
- // The mutual authentication works at the beginning, since clientPeer1
- // trusted by serverTrust1, serverPeer1 by clientTrust1, and also the
- // custom verification check allows the CommonName on serverPeer1.
- // At stage 1, server changes serverPeer1 to serverPeer2, and client
- // changes clientTrust1 to clientTrust2. Although serverPeer2 is trusted by
- // clientTrust2, our authorization check only accepts serverPeer1, and
+ // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
+ // The mutual authentication works at the beginning, since ClientCert1
+ // trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the
+ // custom verification check allows the CommonName on ServerCert1.
+ // At stage 1, server changes ServerCert1 to ServerCert2, and client
+ // changes ClientTrust1 to ClientTrust2. Although ServerCert2 is trusted by
+ // ClientTrust2, our authorization check only accepts ServerCert1, and
// hence the following calls should fail. Previous connections should
// not be affected.
// At stage 2, the client changes authorization check to only accept
- // serverPeer2. Now we should see the connection becomes normal again.
+ // ServerCert2. Now we should see the connection becomes normal again.
{
desc: "TestClientCustomVerification",
- clientCert: []tls.Certificate{cs.clientPeer1},
+ clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
switch stage.read() {
case 0:
- return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
+ return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
default:
- return &GetRootCAsResults{TrustCerts: cs.clientTrust2}, nil
+ return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
}
},
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
@@ -321,12 +257,12 @@
authzCheck := false
switch stage.read() {
case 0, 1:
- // foo.bar.com is the common name on serverPeer1
+ // foo.bar.com is the common name on ServerCert1
if cert.Subject.CommonName == "foo.bar.com" {
authzCheck = true
}
default:
- // foo.bar.server2.com is the common name on serverPeer2
+ // foo.bar.server2.com is the common name on ServerCert2
if cert.Subject.CommonName == "foo.bar.server2.com" {
authzCheck = true
}
@@ -340,12 +276,12 @@
serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
switch stage.read() {
case 0:
- return []*tls.Certificate{&cs.serverPeer1}, nil
+ return []*tls.Certificate{&cs.ServerCert1}, nil
default:
- return []*tls.Certificate{&cs.serverPeer2}, nil
+ return []*tls.Certificate{&cs.ServerCert2}, nil
}
},
- serverRoot: cs.serverTrust1,
+ serverRoot: cs.ServerTrust1,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
@@ -353,9 +289,9 @@
},
// Test Scenarios:
// At initialization(stage = 0), client will be initialized with cert
- // clientPeer1 and clientTrust1, server with serverPeer1 and serverTrust1.
- // The mutual authentication works at the beginning, since clientPeer1
- // trusted by serverTrust1, serverPeer1 by clientTrust1, and also the
+ // ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
+ // The mutual authentication works at the beginning, since ClientCert1
+ // trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the
// custom verification check on server side allows all connections.
// At stage 1, server disallows the the connections by setting custom
// verification check. The following calls should fail. Previous
@@ -364,14 +300,14 @@
// authentications should go back to normal.
{
desc: "TestServerCustomVerification",
- clientCert: []tls.Certificate{cs.clientPeer1},
- clientRoot: cs.clientTrust1,
+ clientCert: []tls.Certificate{cs.ClientCert1},
+ clientRoot: cs.ClientTrust1,
clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
},
clientVType: CertVerification,
- serverCert: []tls.Certificate{cs.serverPeer1},
- serverRoot: cs.serverTrust1,
+ serverCert: []tls.Certificate{cs.ServerCert1},
+ serverRoot: cs.ServerTrust1,
serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
switch stage.read() {
case 0, 2:
diff --git a/security/advancedtls/advancedtls_test.go b/security/advancedtls/advancedtls_test.go
index a631ee4..a7ecf27 100644
--- a/security/advancedtls/advancedtls_test.go
+++ b/security/advancedtls/advancedtls_test.go
@@ -22,21 +22,17 @@
"context"
"crypto/tls"
"crypto/x509"
- "encoding/pem"
"errors"
"fmt"
- "io/ioutil"
"math/big"
"net"
- "reflect"
- "syscall"
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/internal/grpctest"
- "google.golang.org/grpc/security/advancedtls/testdata"
+ "google.golang.org/grpc/security/advancedtls/internal/testutils"
)
type s struct {
@@ -65,27 +61,26 @@
if f.wantError {
return nil, fmt.Errorf("bad fakeProvider")
}
- cs := &certStore{}
- err := cs.loadCerts()
- if err != nil {
- return nil, fmt.Errorf("failed to load certs: %v", err)
+ cs := &testutils.CertStore{}
+ if err := cs.LoadCerts(); err != nil {
+ return nil, fmt.Errorf("cs.LoadCerts() failed, err: %v", err)
}
if f.pt == provTypeRoot && f.isClient {
- return &certprovider.KeyMaterial{Roots: cs.clientTrust1}, nil
+ return &certprovider.KeyMaterial{Roots: cs.ClientTrust1}, nil
}
if f.pt == provTypeRoot && !f.isClient {
- return &certprovider.KeyMaterial{Roots: cs.serverTrust1}, nil
+ return &certprovider.KeyMaterial{Roots: cs.ServerTrust1}, nil
}
if f.pt == provTypeIdentity && f.isClient {
if f.wantMultiCert {
- return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1, cs.clientPeer2}}, nil
+ return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1, cs.ClientCert2}}, nil
}
- return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}}, nil
+ return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, nil
}
if f.wantMultiCert {
- return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1, cs.serverPeer2}}, nil
+ return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1, cs.ServerCert2}}, nil
}
- return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.serverPeer1}}, nil
+ return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1}}, nil
}
func (f fakeProvider) Close() {}
@@ -308,13 +303,12 @@
}
func (s) TestClientServerHandshake(t *testing.T) {
- cs := &certStore{}
- err := cs.loadCerts()
- if err != nil {
- t.Fatalf("Failed to load certs: %v", err)
+ cs := &testutils.CertStore{}
+ if err := cs.LoadCerts(); err != nil {
+ t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
- return &GetRootCAsResults{TrustCerts: cs.clientTrust1}, nil
+ return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
}
clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
if params.ServerName == "" {
@@ -331,7 +325,7 @@
return nil, fmt.Errorf("custom verification function failed")
}
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
- return &GetRootCAsResults{TrustCerts: cs.serverTrust1}, nil
+ return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
}
serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) {
if params.ServerName != "" {
@@ -378,7 +372,7 @@
desc: "Client has no trust cert with verifyFuncGood; server sends peer cert",
clientVerifyFunc: clientVerifyFuncGood,
clientVType: SkipVerification,
- serverCert: []tls.Certificate{cs.serverPeer1},
+ serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertAndHostVerification,
},
// Client: only set clientRoot
@@ -389,10 +383,10 @@
// this test suites.
{
desc: "Client has root cert; server sends peer cert",
- clientRoot: cs.clientTrust1,
+ clientRoot: cs.ClientTrust1,
clientVType: CertAndHostVerification,
clientExpectHandshakeError: true,
- serverCert: []tls.Certificate{cs.serverPeer1},
+ serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertAndHostVerification,
serverExpectError: true,
},
@@ -407,7 +401,7 @@
clientGetRoot: getRootCAsForClient,
clientVType: CertAndHostVerification,
clientExpectHandshakeError: true,
- serverCert: []tls.Certificate{cs.serverPeer1},
+ serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertAndHostVerification,
serverExpectError: true,
},
@@ -419,7 +413,7 @@
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
- serverCert: []tls.Certificate{cs.serverPeer1},
+ serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertAndHostVerification,
},
// Client: set clientGetRoot and bad clientVerifyFunc function
@@ -432,7 +426,7 @@
clientVerifyFunc: verifyFuncBad,
clientVType: CertVerification,
clientExpectHandshakeError: true,
- serverCert: []tls.Certificate{cs.serverPeer1},
+ serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertVerification,
serverExpectError: true,
},
@@ -441,13 +435,13 @@
// Expected Behavior: success
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS",
- clientCert: []tls.Certificate{cs.clientPeer1},
+ clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
- serverCert: []tls.Certificate{cs.serverPeer1},
- serverRoot: cs.serverTrust1,
+ serverCert: []tls.Certificate{cs.ServerCert1},
+ serverRoot: cs.ServerTrust1,
serverVType: CertVerification,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
@@ -455,12 +449,12 @@
// Expected Behavior: success
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
- clientCert: []tls.Certificate{cs.clientPeer1},
+ clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
- serverCert: []tls.Certificate{cs.serverPeer1},
+ serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: getRootCAsForServer,
serverVType: CertVerification,
},
@@ -471,12 +465,12 @@
// Reason: server side reloading returns failure
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS",
- clientCert: []tls.Certificate{cs.clientPeer1},
+ clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
- serverCert: []tls.Certificate{cs.serverPeer1},
+ serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: getRootCAsForServerBad,
serverVType: CertVerification,
serverExpectError: true,
@@ -487,14 +481,14 @@
{
desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- return &cs.clientPeer1, nil
+ return &cs.ClientCert1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
- return []*tls.Certificate{&cs.serverPeer1}, nil
+ return []*tls.Certificate{&cs.ServerCert1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
@@ -508,14 +502,14 @@
{
desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- return &cs.serverPeer1, nil
+ return &cs.ServerCert1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
- return []*tls.Certificate{&cs.serverPeer1}, nil
+ return []*tls.Certificate{&cs.ServerCert1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
@@ -529,7 +523,7 @@
{
desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- return &cs.clientPeer1, nil
+ return &cs.ClientCert1, nil
},
clientGetRoot: getRootCAsForServer,
clientVerifyFunc: clientVerifyFuncGood,
@@ -537,7 +531,7 @@
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
- return []*tls.Certificate{&cs.serverPeer1}, nil
+ return []*tls.Certificate{&cs.ServerCert1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
@@ -552,14 +546,14 @@
{
desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- return &cs.clientPeer1, nil
+ return &cs.ClientCert1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
- return []*tls.Certificate{&cs.clientPeer1}, nil
+ return []*tls.Certificate{&cs.ClientCert1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
@@ -573,7 +567,7 @@
{
desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
- return &cs.clientPeer1, nil
+ return &cs.ClientCert1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
@@ -581,7 +575,7 @@
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
- return []*tls.Certificate{&cs.serverPeer1}, nil
+ return []*tls.Certificate{&cs.ServerCert1}, nil
},
serverGetRoot: getRootCAsForClient,
serverVerifyFunc: serverVerifyFunc,
@@ -594,13 +588,13 @@
// server custom check fails
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS",
- clientCert: []tls.Certificate{cs.clientPeer1},
+ clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverMutualTLS: true,
- serverCert: []tls.Certificate{cs.serverPeer1},
+ serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: verifyFuncBad,
serverVType: CertVerification,
@@ -776,24 +770,6 @@
}
}
-func readTrustCert(fileName string) (*x509.CertPool, error) {
- trustData, err := ioutil.ReadFile(fileName)
- if err != nil {
- return nil, err
- }
- trustBlock, _ := pem.Decode(trustData)
- if trustBlock == nil {
- return nil, err
- }
- trustCert, err := x509.ParseCertificate(trustBlock.Bytes)
- if err != nil {
- return nil, err
- }
- trustPool := x509.NewCertPool()
- trustPool.AddCert(trustCert)
- return trustPool, nil
-}
-
func compare(a1, a2 credentials.AuthInfo) bool {
if a1.AuthType() != a2.AuthType() {
return false
@@ -816,13 +792,13 @@
func (s) TestAdvancedTLSOverrideServerName(t *testing.T) {
expectedServerName := "server.name"
- clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
- if err != nil {
- t.Fatalf("Client is unable to load trust certs. Error: %v", err)
+ cs := &testutils.CertStore{}
+ if err := cs.LoadCerts(); err != nil {
+ t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
clientOptions := &ClientOptions{
RootOptions: RootCertificateOptions{
- RootCACerts: clientTrustPool,
+ RootCACerts: cs.ClientTrust1,
},
ServerNameOverride: expectedServerName,
}
@@ -836,122 +812,33 @@
}
}
-func (s) TestTLSClone(t *testing.T) {
- expectedServerName := "server.name"
- clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
- if err != nil {
- t.Fatalf("Client is unable to load trust certs. Error: %v", err)
- }
- clientOptions := &ClientOptions{
- RootOptions: RootCertificateOptions{
- RootCACerts: clientTrustPool,
- },
- ServerNameOverride: expectedServerName,
- }
- c, err := NewClientCreds(clientOptions)
- if err != nil {
- t.Fatalf("Failed to create new client: %v", err)
- }
- cc := c.Clone()
- if cc.Info().ServerName != expectedServerName {
- t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
- }
- cc.OverrideServerName("")
- if c.Info().ServerName != expectedServerName {
- t.Fatalf("Change in clone should not affect the original, "+
- "c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
- }
-
-}
-
-func (s) TestAppendH2ToNextProtos(t *testing.T) {
- tests := []struct {
- name string
- ps []string
- want []string
- }{
- {
- name: "empty",
- ps: nil,
- want: []string{"h2"},
- },
- {
- name: "only h2",
- ps: []string{"h2"},
- want: []string{"h2"},
- },
- {
- name: "with h2",
- ps: []string{"alpn", "h2"},
- want: []string{"alpn", "h2"},
- },
- {
- name: "no h2",
- ps: []string{"alpn"},
- want: []string{"alpn", "h2"},
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := appendH2ToNextProtos(tt.ps); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("appendH2ToNextProtos() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-type nonSyscallConn struct {
- net.Conn
-}
-
-func (s) TestWrapSyscallConn(t *testing.T) {
- sc := &syscallConn{}
- nsc := &nonSyscallConn{}
-
- wrapConn := WrapSyscallConn(sc, nsc)
- if _, ok := wrapConn.(syscall.Conn); !ok {
- t.Errorf("returned conn (type %T) doesn't implement syscall.Conn, want implement",
- wrapConn)
- }
-}
-
func (s) TestGetCertificatesSNI(t *testing.T) {
- // Load server certificates for setting the serverGetCert callback function.
- serverCert1, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem"))
- if err != nil {
- t.Fatalf("tls.LoadX509KeyPair(server_cert_1.pem, server_key_1.pem) failed: %v", err)
+ cs := &testutils.CertStore{}
+ if err := cs.LoadCerts(); err != nil {
+ t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
- serverCert2, err := tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem"))
- if err != nil {
- t.Fatalf("tls.LoadX509KeyPair(server_cert_2.pem, server_key_2.pem) failed: %v", err)
- }
- serverCert3, err := tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem"))
- if err != nil {
- t.Fatalf("tls.LoadX509KeyPair(server_cert_3.pem, server_key_3.pem) failed: %v", err)
- }
-
tests := []struct {
desc string
serverName string
wantCert tls.Certificate
}{
{
- desc: "Select serverCert1",
+ desc: "Select ServerCert1",
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
serverName: "foo.bar.com",
- wantCert: serverCert1,
+ wantCert: cs.ServerCert1,
},
{
- desc: "Select serverCert2",
+ desc: "Select ServerCert2",
// "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem.
serverName: "foo.bar.server2.com",
- wantCert: serverCert2,
+ wantCert: cs.ServerCert2,
},
{
desc: "Select serverCert3",
// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
serverName: "google.com",
- wantCert: serverCert3,
+ wantCert: cs.ServerPeer3,
},
}
for _, test := range tests {
@@ -960,7 +847,7 @@
serverOptions := &ServerOptions{
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
- return []*tls.Certificate{&serverCert1, &serverCert2, &serverCert3}, nil
+ return []*tls.Certificate{&cs.ServerCert1, &cs.ServerCert2, &cs.ServerPeer3}, nil
},
},
}
diff --git a/security/advancedtls/internal/testutils/testutils.go b/security/advancedtls/internal/testutils/testutils.go
new file mode 100644
index 0000000..665cc60
--- /dev/null
+++ b/security/advancedtls/internal/testutils/testutils.go
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+// Package testutils contains helper functions for advancedtls.
+package testutils
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "io/ioutil"
+
+ "google.golang.org/grpc/security/advancedtls/testdata"
+)
+
+// CertStore contains all the certificates used in the integration tests.
+type CertStore struct {
+ // ClientCert1 is the certificate sent by client to prove its identity.
+ // It is trusted by ServerTrust1.
+ ClientCert1 tls.Certificate
+ // ClientCert2 is the certificate sent by client to prove its identity.
+ // It is trusted by ServerTrust2.
+ ClientCert2 tls.Certificate
+ // ServerCert1 is the certificate sent by server to prove its identity.
+ // It is trusted by ClientTrust1.
+ ServerCert1 tls.Certificate
+ // ServerCert2 is the certificate sent by server to prove its identity.
+ // It is trusted by ClientTrust2.
+ ServerCert2 tls.Certificate
+ // ServerPeer3 is the certificate sent by server to prove its identity.
+ ServerPeer3 tls.Certificate
+ // ClientTrust1 is the root certificate used on the client side.
+ ClientTrust1 *x509.CertPool
+ // ClientTrust2 is the root certificate used on the client side.
+ ClientTrust2 *x509.CertPool
+ // ServerTrust1 is the root certificate used on the server side.
+ ServerTrust1 *x509.CertPool
+ // ServerTrust2 is the root certificate used on the server side.
+ ServerTrust2 *x509.CertPool
+}
+
+func readTrustCert(fileName string) (*x509.CertPool, error) {
+ trustData, err := ioutil.ReadFile(fileName)
+ if err != nil {
+ return nil, err
+ }
+ trustPool := x509.NewCertPool()
+ if !trustPool.AppendCertsFromPEM(trustData) {
+ return nil, fmt.Errorf("error loading trust certificates")
+ }
+ return trustPool, nil
+}
+
+// LoadCerts function is used to load test certificates at the beginning of
+// each integration test.
+func (cs *CertStore) LoadCerts() error {
+ var err error
+ if cs.ClientCert1, err = tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"), testdata.Path("client_key_1.pem")); err != nil {
+ return err
+ }
+ if cs.ClientCert2, err = tls.LoadX509KeyPair(testdata.Path("client_cert_2.pem"), testdata.Path("client_key_2.pem")); err != nil {
+ return err
+ }
+ if cs.ServerCert1, err = tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), testdata.Path("server_key_1.pem")); err != nil {
+ return err
+ }
+ if cs.ServerCert2, err = tls.LoadX509KeyPair(testdata.Path("server_cert_2.pem"), testdata.Path("server_key_2.pem")); err != nil {
+ return err
+ }
+ if cs.ServerPeer3, err = tls.LoadX509KeyPair(testdata.Path("server_cert_3.pem"), testdata.Path("server_key_3.pem")); err != nil {
+ return err
+ }
+ if cs.ClientTrust1, err = readTrustCert(testdata.Path("client_trust_cert_1.pem")); err != nil {
+ return err
+ }
+ if cs.ClientTrust2, err = readTrustCert(testdata.Path("client_trust_cert_2.pem")); err != nil {
+ return err
+ }
+ if cs.ServerTrust1, err = readTrustCert(testdata.Path("server_trust_cert_1.pem")); err != nil {
+ return err
+ }
+ if cs.ServerTrust2, err = readTrustCert(testdata.Path("server_trust_cert_2.pem")); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/security/advancedtls/pemfile_provider_test.go b/security/advancedtls/pemfile_provider_test.go
index abc494b..48e0bd2 100644
--- a/security/advancedtls/pemfile_provider_test.go
+++ b/security/advancedtls/pemfile_provider_test.go
@@ -29,6 +29,7 @@
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials/tls/certprovider"
+ "google.golang.org/grpc/security/advancedtls/internal/testutils"
"google.golang.org/grpc/security/advancedtls/testdata"
)
@@ -95,17 +96,17 @@
// This test overwrites the credential reading function used by the watching
// goroutine. It is tested under different stages:
-// At stage 0, we force reading function to load clientPeer1 and serverTrust1,
+// At stage 0, we force reading function to load ClientCert1 and ServerTrust1,
// and see if the credentials are picked up by the watching go routine.
// At stage 1, we force reading function to cause an error. The watching go
// routine should log the error while leaving the credentials unchanged.
-// At stage 2, we force reading function to load clientPeer2 and serverTrust2,
+// At stage 2, we force reading function to load ClientCert2 and ServerTrust2,
// and see if the new credentials are picked up.
func (s) TestWatchingRoutineUpdates(t *testing.T) {
// Load certificates.
- cs := &certStore{}
- if err := cs.loadCerts(); err != nil {
- t.Fatalf("cs.loadCerts() failed: %v", err)
+ cs := &testutils.CertStore{}
+ if err := cs.LoadCerts(); err != nil {
+ t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
tests := []struct {
desc string
@@ -121,9 +122,9 @@
KeyFile: "not_empty_key_file",
TrustFile: "not_empty_trust_file",
},
- wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}, Roots: cs.serverTrust1},
- wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}, Roots: cs.serverTrust1},
- wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer2}, Roots: cs.serverTrust2},
+ wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
+ wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
+ wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}, Roots: cs.ServerTrust2},
},
{
desc: "use identity certs only",
@@ -131,18 +132,18 @@
CertFile: "not_empty_cert_file",
KeyFile: "not_empty_key_file",
},
- wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}},
- wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}},
- wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer2}},
+ wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
+ wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
+ wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}},
},
{
desc: "use trust certs only",
options: PEMFileProviderOptions{
TrustFile: "not_empty_trust_file",
},
- wantKmStage0: certprovider.KeyMaterial{Roots: cs.serverTrust1},
- wantKmStage1: certprovider.KeyMaterial{Roots: cs.serverTrust1},
- wantKmStage2: certprovider.KeyMaterial{Roots: cs.serverTrust2},
+ wantKmStage0: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
+ wantKmStage1: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
+ wantKmStage2: certprovider.KeyMaterial{Roots: cs.ServerTrust2},
},
}
for _, test := range tests {
@@ -155,11 +156,11 @@
readKeyCertPairFunc = func(certFile, keyFile string) (tls.Certificate, error) {
switch stage.read() {
case 0:
- return cs.clientPeer1, nil
+ return cs.ClientCert1, nil
case 1:
return tls.Certificate{}, fmt.Errorf("error occurred while reloading")
case 2:
- return cs.clientPeer2, nil
+ return cs.ClientCert2, nil
default:
return tls.Certificate{}, fmt.Errorf("test stage not supported")
}
@@ -171,11 +172,11 @@
readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) {
switch stage.read() {
case 0:
- return cs.serverTrust1, nil
+ return cs.ServerTrust1, nil
case 1:
return nil, fmt.Errorf("error occurred while reloading")
case 2:
- return cs.serverTrust2, nil
+ return cs.ServerTrust2, nil
default:
return nil, fmt.Errorf("test stage not supported")
}