blob: 67c0b9dba42aedcd9f3063c55149894731242280 [file] [log] [blame]
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package advancedtls
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"math/big"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/security/advancedtls/internal/testutils"
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
type provType int
const (
provTypeRoot provType = iota
provTypeIdentity
)
type fakeProvider struct {
pt provType
isClient bool
wantMultiCert bool
wantError bool
}
func (f fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
if f.wantError {
return nil, fmt.Errorf("bad fakeProvider")
}
cs := &testutils.CertStore{}
if err := cs.LoadCerts(); err != nil {
return nil, fmt.Errorf("cs.LoadCerts() failed, err: %v", err)
}
if f.pt == provTypeRoot && f.isClient {
return &certprovider.KeyMaterial{Roots: cs.ClientTrust1}, nil
}
if f.pt == provTypeRoot && !f.isClient {
return &certprovider.KeyMaterial{Roots: cs.ServerTrust1}, nil
}
if f.pt == provTypeIdentity && f.isClient {
if f.wantMultiCert {
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1, cs.ClientCert2}}, nil
}
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}}, nil
}
if f.wantMultiCert {
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1, cs.ServerCert2}}, nil
}
return &certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ServerCert1}}, nil
}
func (f fakeProvider) Close() {}
func (s) TestClientOptionsConfigErrorCases(t *testing.T) {
tests := []struct {
desc string
clientVType VerificationType
IdentityOptions IdentityCertificateOptions
RootOptions RootCertificateOptions
}{
{
desc: "Skip default verification and provide no root credentials",
clientVType: SkipVerification,
},
{
desc: "More than one fields in RootCertificateOptions is specified",
clientVType: CertVerification,
RootOptions: RootCertificateOptions{
RootCACerts: x509.NewCertPool(),
RootProvider: fakeProvider{},
},
},
{
desc: "More than one fields in IdentityCertificateOptions is specified",
clientVType: CertVerification,
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return nil, nil
},
IdentityProvider: fakeProvider{pt: provTypeIdentity},
},
},
{
desc: "Specify GetIdentityCertificatesForServer",
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return nil, nil
},
},
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
clientOptions := &ClientOptions{
VType: test.clientVType,
IdentityOptions: test.IdentityOptions,
RootOptions: test.RootOptions,
}
_, err := clientOptions.config()
if err == nil {
t.Fatalf("ClientOptions{%v}.config() returns no err, wantErr != nil", clientOptions)
}
})
}
}
func (s) TestClientOptionsConfigSuccessCases(t *testing.T) {
tests := []struct {
desc string
clientVType VerificationType
IdentityOptions IdentityCertificateOptions
RootOptions RootCertificateOptions
}{
{
desc: "Use system default if no fields in RootCertificateOptions is specified",
clientVType: CertVerification,
},
{
desc: "Good case with mutual TLS",
clientVType: CertVerification,
RootOptions: RootCertificateOptions{
RootProvider: fakeProvider{},
},
IdentityOptions: IdentityCertificateOptions{
IdentityProvider: fakeProvider{pt: provTypeIdentity},
},
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
clientOptions := &ClientOptions{
VType: test.clientVType,
IdentityOptions: test.IdentityOptions,
RootOptions: test.RootOptions,
}
clientConfig, err := clientOptions.config()
if err != nil {
t.Fatalf("ClientOptions{%v}.config() = %v, wantErr == nil", clientOptions, err)
}
// Verify that the system-provided certificates would be used
// when no verification method was set in clientOptions.
if clientOptions.RootOptions.RootCACerts == nil &&
clientOptions.RootOptions.GetRootCertificates == nil && clientOptions.RootOptions.RootProvider == nil {
if clientConfig.RootCAs == nil {
t.Fatalf("Failed to assign system-provided certificates on the client side.")
}
}
})
}
}
func (s) TestServerOptionsConfigErrorCases(t *testing.T) {
tests := []struct {
desc string
requireClientCert bool
serverVType VerificationType
IdentityOptions IdentityCertificateOptions
RootOptions RootCertificateOptions
}{
{
desc: "Skip default verification and provide no root credentials",
requireClientCert: true,
serverVType: SkipVerification,
},
{
desc: "More than one fields in RootCertificateOptions is specified",
requireClientCert: true,
serverVType: CertVerification,
RootOptions: RootCertificateOptions{
RootCACerts: x509.NewCertPool(),
GetRootCertificates: func(*GetRootCAsParams) (*GetRootCAsResults, error) {
return nil, nil
},
},
},
{
desc: "More than one fields in IdentityCertificateOptions is specified",
serverVType: CertVerification,
IdentityOptions: IdentityCertificateOptions{
Certificates: []tls.Certificate{},
IdentityProvider: fakeProvider{pt: provTypeIdentity},
},
},
{
desc: "no field in IdentityCertificateOptions is specified",
serverVType: CertVerification,
},
{
desc: "Specify GetIdentityCertificatesForClient",
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForClient: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return nil, nil
},
},
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
VType: test.serverVType,
RequireClientCert: test.requireClientCert,
IdentityOptions: test.IdentityOptions,
RootOptions: test.RootOptions,
}
_, err := serverOptions.config()
if err == nil {
t.Fatalf("ServerOptions{%v}.config() returns no err, wantErr != nil", serverOptions)
}
})
}
}
func (s) TestServerOptionsConfigSuccessCases(t *testing.T) {
tests := []struct {
desc string
requireClientCert bool
serverVType VerificationType
IdentityOptions IdentityCertificateOptions
RootOptions RootCertificateOptions
}{
{
desc: "Use system default if no fields in RootCertificateOptions is specified",
requireClientCert: true,
serverVType: CertVerification,
IdentityOptions: IdentityCertificateOptions{
Certificates: []tls.Certificate{},
},
},
{
desc: "Good case with mutual TLS",
requireClientCert: true,
serverVType: CertVerification,
RootOptions: RootCertificateOptions{
RootProvider: fakeProvider{},
},
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForServer: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return nil, nil
},
},
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
VType: test.serverVType,
RequireClientCert: test.requireClientCert,
IdentityOptions: test.IdentityOptions,
RootOptions: test.RootOptions,
}
serverConfig, err := serverOptions.config()
if err != nil {
t.Fatalf("ServerOptions{%v}.config() = %v, wantErr == nil", serverOptions, err)
}
// Verify that the system-provided certificates would be used
// when no verification method was set in serverOptions.
if serverOptions.RootOptions.RootCACerts == nil &&
serverOptions.RootOptions.GetRootCertificates == nil && serverOptions.RootOptions.RootProvider == nil {
if serverConfig.ClientCAs == nil {
t.Fatalf("Failed to assign system-provided certificates on the server side.")
}
}
})
}
}
func (s) TestClientServerHandshake(t *testing.T) {
cs := &testutils.CertStore{}
if err := cs.LoadCerts(); err != nil {
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
}
clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
if params.ServerName == "" {
return nil, errors.New("client side server name should have a value")
}
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.com") {
return nil, errors.New("client side params parsing error")
}
return &VerificationResults{}, nil
}
verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) {
return nil, fmt.Errorf("custom verification function failed")
}
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
}
serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) {
if params.ServerName != "" {
return nil, errors.New("server side server name should not have a value")
}
// "foo.bar.hoo.com" is the common name on client certificate client_cert_1.pem.
if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.hoo.com") {
return nil, errors.New("server side params parsing error")
}
return &VerificationResults{}, nil
}
getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return nil, fmt.Errorf("bad root certificate reloading")
}
for _, test := range []struct {
desc string
clientCert []tls.Certificate
clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientVerifyFunc CustomVerificationFunc
clientVType VerificationType
clientRootProvider certprovider.Provider
clientIdentityProvider certprovider.Provider
clientExpectHandshakeError bool
serverMutualTLS bool
serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverVerifyFunc CustomVerificationFunc
serverVType VerificationType
serverRootProvider certprovider.Provider
serverIdentityProvider certprovider.Provider
serverExpectError bool
}{
// Client: nil setting except verifyFuncGood
// Server: only set serverCert with mutual TLS off
// Expected Behavior: success
// Reason: we will use verifyFuncGood to verify the server,
// if either clientCert or clientGetCert is not set
{
desc: "Client has no trust cert with verifyFuncGood; server sends peer cert",
clientVerifyFunc: clientVerifyFuncGood,
clientVType: SkipVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertAndHostVerification,
},
// Client: only set clientRoot
// Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure and client handshake failure
// Reason: client side sets vType to CertAndHostVerification, and will do
// default hostname check. All the default hostname checks will fail in
// this test suites.
{
desc: "Client has root cert; server sends peer cert",
clientRoot: cs.ClientTrust1,
clientVType: CertAndHostVerification,
clientExpectHandshakeError: true,
serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertAndHostVerification,
serverExpectError: true,
},
// Client: only set clientGetRoot
// Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure and client handshake failure
// Reason: client side sets vType to CertAndHostVerification, and will do
// default hostname check. All the default hostname checks will fail in
// this test suites.
{
desc: "Client sets reload root function; server sends peer cert",
clientGetRoot: getRootCAsForClient,
clientVType: CertAndHostVerification,
clientExpectHandshakeError: true,
serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertAndHostVerification,
serverExpectError: true,
},
// Client: set clientGetRoot and clientVerifyFunc
// Server: only set serverCert with mutual TLS off
// Expected Behavior: success
{
desc: "Client sets reload root function with verifyFuncGood; server sends peer cert",
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertAndHostVerification,
},
// Client: set clientGetRoot and bad clientVerifyFunc function
// Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure and client handshake failure
// Reason: custom verification function is bad
{
desc: "Client sets reload root function with verifyFuncBad; server sends peer cert",
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: verifyFuncBad,
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverCert: []tls.Certificate{cs.ServerCert1},
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverRoot and serverCert with mutual TLS on
// Expected Behavior: success
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverCert: []tls.Certificate{cs.ServerCert1},
serverRoot: cs.ServerTrust1,
serverVType: CertVerification,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverGetRoot and serverCert with mutual TLS on
// Expected Behavior: success
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: getRootCAsForServer,
serverVType: CertVerification,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverGetRoot returning error and serverCert with mutual
// TLS on
// Expected Behavior: server side failure
// Reason: server side reloading returns failure
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: getRootCAsForServerBad,
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientGetCert
// Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: success
{
desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cs.ClientCert1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&cs.ServerCert1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
serverVType: CertVerification,
},
// Client: set everything but with the wrong peer cert not trusted by
// server
// Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: server side returns failure because of
// certificate mismatch
{
desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cs.ServerCert1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&cs.ServerCert1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set everything but with the wrong trust cert not trusting server
// Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure
{
desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cs.ClientCert1, nil
},
clientGetRoot: getRootCAsForServer,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&cs.ServerCert1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set everything but with the wrong peer cert not trusted by
// client
// Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure
{
desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cs.ClientCert1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&cs.ClientCert1}, nil
},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: serverVerifyFunc,
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set everything but with the wrong trust cert not trusting client
// Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure
{
desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS",
clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cs.ClientCert1, nil
},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverGetCert: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&cs.ServerCert1}, nil
},
serverGetRoot: getRootCAsForClient,
serverVerifyFunc: serverVerifyFunc,
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverGetRoot and serverCert, but with bad verifyFunc
// Expected Behavior: server side and client side return failure due to
// server custom check fails
{
desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS",
clientCert: []tls.Certificate{cs.ClientCert1},
clientGetRoot: getRootCAsForClient,
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
clientExpectHandshakeError: true,
serverMutualTLS: true,
serverCert: []tls.Certificate{cs.ServerCert1},
serverGetRoot: getRootCAsForServer,
serverVerifyFunc: verifyFuncBad,
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set a clientIdentityProvider which will get multiple cert chains
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
// Expected Behavior: server side failure due to multiple cert chains in
// clientIdentityProvider
{
desc: "Client sets multiple certs in clientIdentityProvider; Server sets root and identity provider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantMultiCert: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set a bad clientIdentityProvider
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
// Expected Behavior: server side failure due to bad clientIdentityProvider
{
desc: "Client sets bad clientIdentityProvider; Server sets root and identity provider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true, wantError: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientIdentityProvider and clientRootProvider
// Server: set bad serverRootProvider with mutual TLS on
// Expected Behavior: server side failure due to bad serverRootProvider
{
desc: "Client sets root and identity provider; Server sets bad root provider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
serverRootProvider: fakeProvider{isClient: false, wantError: true},
serverVType: CertVerification,
serverExpectError: true,
},
// Client: set clientIdentityProvider and clientRootProvider
// Server: set serverIdentityProvider and serverRootProvider with mutual TLS on
// Expected Behavior: success
{
desc: "Client sets root and identity provider; Server sets root and identity provider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false},
serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification,
},
// Client: set clientIdentityProvider and clientRootProvider
// Server: set serverIdentityProvider getting multiple cert chains and serverRootProvider with mutual TLS on
// Expected Behavior: success, because server side has SNI
{
desc: "Client sets root and identity provider; Server sets multiple certs in serverIdentityProvider; mutualTLS",
clientIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: true},
clientRootProvider: fakeProvider{isClient: true},
clientVerifyFunc: clientVerifyFuncGood,
clientVType: CertVerification,
serverMutualTLS: true,
serverIdentityProvider: fakeProvider{pt: provTypeIdentity, isClient: false, wantMultiCert: true},
serverRootProvider: fakeProvider{isClient: false},
serverVType: CertVerification,
},
} {
test := test
t.Run(test.desc, func(t *testing.T) {
done := make(chan credentials.AuthInfo, 1)
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
// Start a server using ServerOptions in another goroutine.
serverOptions := &ServerOptions{
IdentityOptions: IdentityCertificateOptions{
Certificates: test.serverCert,
GetIdentityCertificatesForServer: test.serverGetCert,
IdentityProvider: test.serverIdentityProvider,
},
RootOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCertificates: test.serverGetRoot,
RootProvider: test.serverRootProvider,
},
RequireClientCert: test.serverMutualTLS,
VerifyPeer: test.serverVerifyFunc,
VType: test.serverVType,
}
go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) {
serverRawConn, err := lis.Accept()
if err != nil {
close(done)
return
}
serverTLS, err := NewServerCreds(serverOptions)
if err != nil {
serverRawConn.Close()
close(done)
return
}
_, serverAuthInfo, err := serverTLS.ServerHandshake(serverRawConn)
if err != nil {
serverRawConn.Close()
close(done)
return
}
done <- serverAuthInfo
}(done, lis, serverOptions)
defer lis.Close()
// Start a client using ClientOptions and connects to the server.
lisAddr := lis.Addr().String()
conn, err := net.Dial("tcp", lisAddr)
if err != nil {
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
}
defer conn.Close()
clientOptions := &ClientOptions{
IdentityOptions: IdentityCertificateOptions{
Certificates: test.clientCert,
GetIdentityCertificatesForClient: test.clientGetCert,
IdentityProvider: test.clientIdentityProvider,
},
VerifyPeer: test.clientVerifyFunc,
RootOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
GetRootCertificates: test.clientGetRoot,
RootProvider: test.clientRootProvider,
},
VType: test.clientVType,
}
clientTLS, err := NewClientCreds(clientOptions)
if err != nil {
t.Fatalf("NewClientCreds failed: %v", err)
}
_, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(),
lisAddr, conn)
// wait until server sends serverAuthInfo or fails.
serverAuthInfo, ok := <-done
if !ok && test.serverExpectError {
return
}
if ok && test.serverExpectError || !ok && !test.serverExpectError {
t.Fatalf("Server side error mismatch, got %v, want %v", !ok, test.serverExpectError)
}
if handshakeErr != nil && test.clientExpectHandshakeError {
return
}
if handshakeErr != nil && !test.clientExpectHandshakeError ||
handshakeErr == nil && test.clientExpectHandshakeError {
t.Fatalf("Expect error: %v, but err is %v",
test.clientExpectHandshakeError, handshakeErr)
}
if !compare(clientAuthInfo, serverAuthInfo) {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr,
clientAuthInfo, serverAuthInfo)
}
})
}
}
func compare(a1, a2 credentials.AuthInfo) bool {
if a1.AuthType() != a2.AuthType() {
return false
}
switch a1.AuthType() {
case "tls":
state1 := a1.(credentials.TLSInfo).State
state2 := a2.(credentials.TLSInfo).State
if state1.Version == state2.Version &&
state1.HandshakeComplete == state2.HandshakeComplete &&
state1.CipherSuite == state2.CipherSuite &&
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
return true
}
return false
default:
return false
}
}
func (s) TestAdvancedTLSOverrideServerName(t *testing.T) {
expectedServerName := "server.name"
cs := &testutils.CertStore{}
if err := cs.LoadCerts(); err != nil {
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
clientOptions := &ClientOptions{
RootOptions: RootCertificateOptions{
RootCACerts: cs.ClientTrust1,
},
ServerNameOverride: expectedServerName,
}
c, err := NewClientCreds(clientOptions)
if err != nil {
t.Fatalf("Client is unable to create credentials. Error: %v", err)
}
c.OverrideServerName(expectedServerName)
if c.Info().ServerName != expectedServerName {
t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}
}
func (s) TestGetCertificatesSNI(t *testing.T) {
cs := &testutils.CertStore{}
if err := cs.LoadCerts(); err != nil {
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
tests := []struct {
desc string
serverName string
wantCert tls.Certificate
}{
{
desc: "Select ServerCert1",
// "foo.bar.com" is the common name on server certificate server_cert_1.pem.
serverName: "foo.bar.com",
wantCert: cs.ServerCert1,
},
{
desc: "Select ServerCert2",
// "foo.bar.server2.com" is the common name on server certificate server_cert_2.pem.
serverName: "foo.bar.server2.com",
wantCert: cs.ServerCert2,
},
{
desc: "Select serverCert3",
// "google.com" is one of the DNS names on server certificate server_cert_3.pem.
serverName: "google.com",
wantCert: cs.ServerPeer3,
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
serverOptions := &ServerOptions{
IdentityOptions: IdentityCertificateOptions{
GetIdentityCertificatesForServer: func(info *tls.ClientHelloInfo) ([]*tls.Certificate, error) {
return []*tls.Certificate{&cs.ServerCert1, &cs.ServerCert2, &cs.ServerPeer3}, nil
},
},
}
serverConfig, err := serverOptions.config()
if err != nil {
t.Fatalf("serverOptions.config() failed: %v", err)
}
pointFormatUncompressed := uint8(0)
clientHello := &tls.ClientHelloInfo{
CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA},
ServerName: test.serverName,
SupportedCurves: []tls.CurveID{tls.CurveP256},
SupportedPoints: []uint8{pointFormatUncompressed},
SupportedVersions: []uint16{tls.VersionTLS10},
}
gotCertificate, err := serverConfig.GetCertificate(clientHello)
if err != nil {
t.Fatalf("serverConfig.GetCertificate(clientHello) failed: %v", err)
}
if !gotCertificate.Leaf.Equal(test.wantCert.Leaf) {
t.Errorf("GetCertificates() returned leaf certificate does not match expected (-want +got):\n%s", cmp.Diff(test.wantCert, *gotCertificate, cmp.AllowUnexported(big.Int{})))
}
})
}
}