blob: 509befa0d8fe7ac0b041b786a5687427f31e5627 [file] [log] [blame]
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package advancedtls
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"net"
"reflect"
"syscall"
"testing"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/security/advancedtls/testdata"
)
func TestClientServerHandshake(t *testing.T) {
// ------------------Load Client Trust Cert and Peer Cert-------------------
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
if err != nil {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return &GetRootCAsResults{TrustCerts: clientTrustPool}, nil
}
verifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) {
return &VerificationResults{}, nil
}
verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) {
return nil, fmt.Errorf("custom verification function failed")
}
clientPeerCert, err := tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"),
testdata.Path("client_key_1.pem"))
if err != nil {
t.Fatalf("Client is unable to parse peer certificates. Error: %v", err)
}
// ------------------Load Server Trust Cert and Peer Cert-------------------
serverTrustPool, err := readTrustCert(testdata.Path("server_trust_cert_1.pem"))
if err != nil {
t.Fatalf("Server is unable to load trust certs. Error: %v", err)
}
getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return &GetRootCAsResults{TrustCerts: serverTrustPool}, nil
}
serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"),
testdata.Path("server_key_1.pem"))
if err != nil {
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
}
getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
return nil, fmt.Errorf("bad root certificate reloading")
}
for _, test := range []struct {
desc string
clientCert []tls.Certificate
clientGetClientCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
clientRoot *x509.CertPool
clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
clientVerifyFunc CustomVerificationFunc
clientExpectCreateError bool
clientExpectHandshakeError bool
serverMutualTLS bool
serverCert []tls.Certificate
serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)
serverRoot *x509.CertPool
serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error)
serverExpectError bool
}{
// Client: nil setting
// Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure
// Reason: if either clientCert or clientGetClientCert is not set and
// verifyFunc is not set, we will fail directly
{
"Client_no_trust_cert_Server_peer_cert",
nil,
nil,
nil,
nil,
nil,
true,
false,
false,
[]tls.Certificate{serverPeerCert},
nil,
nil,
nil,
true,
},
// Client: nil setting except verifyFuncGood
// Server: only set serverCert with mutual TLS off
// Expected Behavior: success
// Reason: we will use verifyFuncGood to verify the server,
// if either clientCert or clientGetClientCert is not set
{
"Client_no_trust_cert_verifyFuncGood_Server_peer_cert",
nil,
nil,
nil,
nil,
verifyFuncGood,
false,
false,
false,
[]tls.Certificate{serverPeerCert},
nil,
nil,
nil,
false,
},
// Client: only set clientRoot
// Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure and client handshake failure
// Reason: not setting advanced TLS features will fall back to normal check, and will hence fail
// on default host name check. All the default hostname checks will fail in this test suites.
{
"Client_root_cert_Server_peer_cert",
nil,
nil,
clientTrustPool,
nil,
nil,
false,
true,
false,
[]tls.Certificate{serverPeerCert},
nil,
nil,
nil,
true,
},
// Client: only set clientGetRoot
// Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure and client handshake failure
// Reason: setting root reloading function without custom verifyFunc will also fail,
// since it will also fall back to default host name check
{
"Client_reload_root_Server_peer_cert",
nil,
nil,
nil,
getRootCAsForClient,
nil,
false,
true,
false,
[]tls.Certificate{serverPeerCert},
nil,
nil,
nil,
true,
},
// Client: set clientGetRoot and clientVerifyFunc
// Server: only set serverCert with mutual TLS off
// Expected Behavior: success
{
"Client_reload_root_verifyFuncGood_Server_peer_cert",
nil,
nil,
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
false,
[]tls.Certificate{serverPeerCert},
nil,
nil,
nil,
false,
},
// Client: set clientGetRoot and bad clientVerifyFunc function
// Server: only set serverCert with mutual TLS off
// Expected Behavior: server side failure and client handshake failure
// Reason: custom verification function is bad
{
"Client_reload_root_verifyFuncBad_Server_peer_cert",
nil,
nil,
nil,
getRootCAsForClient,
verifyFuncBad,
false,
true,
false,
[]tls.Certificate{serverPeerCert},
nil,
nil,
nil,
true,
},
// Client: set clientGetRoot and clientVerifyFunc
// Server: nil setting
// Expected Behavior: server side failure
// Reason: server side must either set serverCert or serverGetCert
{
"Client_reload_root_verifyFuncGood_Server_nil",
nil,
nil,
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
false,
nil,
nil,
nil,
nil,
true,
},
// Client: set clientGetRoot and clientVerifyFunc
// Server: only set serverCert with mutual TLS on
// Expected Behavior: server side failure
// Reason: server side must either set serverRoot or serverGetRoot when using mutual TLS
{
"Client_reload_root_verifyFuncGood_Server_peer_cert_no_root_cert_mutualTLS",
nil,
nil,
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
true,
[]tls.Certificate{serverPeerCert},
nil,
nil,
nil,
true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverRoot and serverCert with mutual TLS on
// Expected Behavior: success
{
"Client_peer_cert_reload_root_verifyFuncGood_Server_peer_cert_root_cert_mutualTLS",
[]tls.Certificate{clientPeerCert},
nil,
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
true,
[]tls.Certificate{serverPeerCert},
nil,
serverTrustPool,
nil,
false,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverGetRoot and serverCert with mutual TLS on
// Expected Behavior: success
{
"Client_peer_cert_reload_root_verifyFuncGood_Server_peer_cert_reload_root_mutualTLS",
[]tls.Certificate{clientPeerCert},
nil,
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
true,
[]tls.Certificate{serverPeerCert},
nil,
nil,
getRootCAsForServer,
false,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set serverGetRoot returning error and serverCert with mutual TLS on
// Expected Behavior: server side failure
// Reason: server side reloading returns failure
{
"Client_peer_cert_reload_root_verifyFuncGood_Server_peer_cert_bad_reload_root_mutualTLS",
[]tls.Certificate{clientPeerCert},
nil,
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
true,
[]tls.Certificate{serverPeerCert},
nil,
nil,
getRootCAsForServerBad,
true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientGetClientCert
// Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: success
{
"Client_reload_both_certs_verifyFuncGood_Server_reload_both_certs_mutualTLS",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForServer,
false,
},
// Client: set everything but with the wrong peer cert not trusted by server
// Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: server side returns failure because of
// certificate mismatch
{
"Client_wrong_peer_cert_Server_reload_both_certs_mutualTLS",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForServer,
true,
},
// Client: set everything but with the wrong trust cert not trusting server
// Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure
{
"Client_wrong_trust_cert_Server_reload_both_certs_mutualTLS",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForServer,
verifyFuncGood,
false,
true,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForServer,
true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set everything but with the wrong peer cert not trusted by client
// Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure
{
"Client_reload_both_certs_verifyFuncGood_Server_wrong_peer_cert",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForClient,
verifyFuncGood,
false,
false,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForServer,
true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set everything but with the wrong trust cert not trusting client
// Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure
{
"Client_reload_both_certs_verifyFuncGood_Server_wrong_trust_cert",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForClient,
verifyFuncGood,
false,
true,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForClient,
true,
},
} {
test := test
t.Run(test.desc, func(t *testing.T) {
done := make(chan credentials.AuthInfo, 1)
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
// Start a server using ServerOptions in another goroutine.
serverOptions := &ServerOptions{
Certificates: test.serverCert,
GetCertificate: test.serverGetCert,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCAs: test.serverGetRoot,
},
RequireClientCert: test.serverMutualTLS,
}
go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) {
serverRawConn, err := lis.Accept()
if err != nil {
close(done)
return
}
serverTLS, err := NewServerCreds(serverOptions)
if err != nil {
serverRawConn.Close()
close(done)
return
}
_, serverAuthInfo, err := serverTLS.ServerHandshake(serverRawConn)
if err != nil {
serverRawConn.Close()
close(done)
return
}
done <- serverAuthInfo
}(done, lis, serverOptions)
defer lis.Close()
// Start a client using ClientOptions and connects to the server.
lisAddr := lis.Addr().String()
conn, err := net.Dial("tcp", lisAddr)
if err != nil {
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
}
defer conn.Close()
clientOptions := &ClientOptions{
Certificates: test.clientCert,
GetClientCertificate: test.clientGetClientCert,
VerifyPeer: test.clientVerifyFunc,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
GetRootCAs: test.clientGetRoot,
},
}
clientTLS, newClientErr := NewClientCreds(clientOptions)
if newClientErr != nil && test.clientExpectCreateError {
return
}
if newClientErr != nil && !test.clientExpectCreateError ||
newClientErr == nil && test.clientExpectCreateError {
t.Fatalf("Expect error: %v, but err is %v",
test.clientExpectCreateError, newClientErr)
}
_, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(),
lisAddr, conn)
// wait until server sends serverAuthInfo or fails.
serverAuthInfo, ok := <-done
if !ok && test.serverExpectError {
return
}
if ok && test.serverExpectError || !ok && !test.serverExpectError {
t.Fatalf("Server side error mismatch, got %v, want %v", !ok, test.serverExpectError)
}
if handshakeErr != nil && test.clientExpectHandshakeError {
return
}
if handshakeErr != nil && !test.clientExpectHandshakeError ||
handshakeErr == nil && test.clientExpectHandshakeError {
t.Fatalf("Expect error: %v, but err is %v",
test.clientExpectHandshakeError, handshakeErr)
}
if !compare(clientAuthInfo, serverAuthInfo) {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr,
clientAuthInfo, serverAuthInfo)
}
})
}
}
func readTrustCert(fileName string) (*x509.CertPool, error) {
trustData, err := ioutil.ReadFile(fileName)
if err != nil {
return nil, err
}
trustBlock, _ := pem.Decode(trustData)
if trustBlock == nil {
return nil, err
}
trustCert, err := x509.ParseCertificate(trustBlock.Bytes)
if err != nil {
return nil, err
}
trustPool := x509.NewCertPool()
trustPool.AddCert(trustCert)
return trustPool, nil
}
func compare(a1, a2 credentials.AuthInfo) bool {
if a1.AuthType() != a2.AuthType() {
return false
}
switch a1.AuthType() {
case "tls":
state1 := a1.(credentials.TLSInfo).State
state2 := a2.(credentials.TLSInfo).State
if state1.Version == state2.Version &&
state1.HandshakeComplete == state2.HandshakeComplete &&
state1.CipherSuite == state2.CipherSuite &&
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
return true
}
return false
default:
return false
}
}
func TestAdvancedTLSOverrideServerName(t *testing.T) {
expectedServerName := "server.name"
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
if err != nil {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
clientOptions := &ClientOptions{
RootCertificateOptions: RootCertificateOptions{
RootCACerts: clientTrustPool,
},
ServerNameOverride: expectedServerName,
}
c, err := NewClientCreds(clientOptions)
if err != nil {
t.Fatalf("Client is unable to create credentials. Error: %v", err)
}
c.OverrideServerName(expectedServerName)
if c.Info().ServerName != expectedServerName {
t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}
}
func TestTLSClone(t *testing.T) {
expectedServerName := "server.name"
clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem"))
if err != nil {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
clientOptions := &ClientOptions{
RootCertificateOptions: RootCertificateOptions{
RootCACerts: clientTrustPool,
},
ServerNameOverride: expectedServerName,
}
c, err := NewClientCreds(clientOptions)
if err != nil {
t.Fatalf("Failed to create new client: %v", err)
}
cc := c.Clone()
if cc.Info().ServerName != expectedServerName {
t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
}
cc.OverrideServerName("")
if c.Info().ServerName != expectedServerName {
t.Fatalf("Change in clone should not affect the original, "+
"c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}
}
func TestAppendH2ToNextProtos(t *testing.T) {
tests := []struct {
name string
ps []string
want []string
}{
{
name: "empty",
ps: nil,
want: []string{"h2"},
},
{
name: "only h2",
ps: []string{"h2"},
want: []string{"h2"},
},
{
name: "with h2",
ps: []string{"alpn", "h2"},
want: []string{"alpn", "h2"},
},
{
name: "no h2",
ps: []string{"alpn"},
want: []string{"alpn", "h2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := appendH2ToNextProtos(tt.ps); !reflect.DeepEqual(got, tt.want) {
t.Errorf("appendH2ToNextProtos() = %v, want %v", got, tt.want)
}
})
}
}
type nonSyscallConn struct {
net.Conn
}
func TestWrapSyscallConn(t *testing.T) {
sc := &syscallConn{}
nsc := &nonSyscallConn{}
wrapConn := WrapSyscallConn(sc, nsc)
if _, ok := wrapConn.(syscall.Conn); !ok {
t.Errorf("returned conn (type %T) doesn't implement syscall.Conn, want implement",
wrapConn)
}
}