blob: 48e0bd2f1c3bf497c939421923946cf6ee301c1e [file] [log] [blame]
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package advancedtls
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"math/big"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/security/advancedtls/internal/testutils"
"google.golang.org/grpc/security/advancedtls/testdata"
)
func (s) TestNewPEMFileProvider(t *testing.T) {
tests := []struct {
desc string
options PEMFileProviderOptions
certFile string
keyFile string
trustFile string
wantError bool
}{
{
desc: "Expect error if no credential files specified",
options: PEMFileProviderOptions{},
wantError: true,
},
{
desc: "Expect error if only certFile is specified",
options: PEMFileProviderOptions{
CertFile: testdata.Path("client_cert_1.pem"),
},
wantError: true,
},
{
desc: "Should be good if only identity key cert pairs are specified",
options: PEMFileProviderOptions{
KeyFile: testdata.Path("client_key_1.pem"),
CertFile: testdata.Path("client_cert_1.pem"),
},
wantError: false,
},
{
desc: "Should be good if only root certs are specified",
options: PEMFileProviderOptions{
TrustFile: testdata.Path("client_trust_cert_1.pem"),
},
wantError: false,
},
{
desc: "Should be good if both identity pairs and root certs are specified",
options: PEMFileProviderOptions{
KeyFile: testdata.Path("client_key_1.pem"),
CertFile: testdata.Path("client_cert_1.pem"),
TrustFile: testdata.Path("client_trust_cert_1.pem"),
},
wantError: false,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
provider, err := NewPEMFileProvider(test.options)
if (err != nil) != test.wantError {
t.Fatalf("NewPEMFileProvider(%v) = %v, want %v", test.options, err, test.wantError)
}
if err != nil {
return
}
provider.Close()
})
}
}
// This test overwrites the credential reading function used by the watching
// goroutine. It is tested under different stages:
// At stage 0, we force reading function to load ClientCert1 and ServerTrust1,
// and see if the credentials are picked up by the watching go routine.
// At stage 1, we force reading function to cause an error. The watching go
// routine should log the error while leaving the credentials unchanged.
// At stage 2, we force reading function to load ClientCert2 and ServerTrust2,
// and see if the new credentials are picked up.
func (s) TestWatchingRoutineUpdates(t *testing.T) {
// Load certificates.
cs := &testutils.CertStore{}
if err := cs.LoadCerts(); err != nil {
t.Fatalf("cs.LoadCerts() failed, err: %v", err)
}
tests := []struct {
desc string
options PEMFileProviderOptions
wantKmStage0 certprovider.KeyMaterial
wantKmStage1 certprovider.KeyMaterial
wantKmStage2 certprovider.KeyMaterial
}{
{
desc: "use identity certs and root certs",
options: PEMFileProviderOptions{
CertFile: "not_empty_cert_file",
KeyFile: "not_empty_key_file",
TrustFile: "not_empty_trust_file",
},
wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}, Roots: cs.ServerTrust1},
wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}, Roots: cs.ServerTrust2},
},
{
desc: "use identity certs only",
options: PEMFileProviderOptions{
CertFile: "not_empty_cert_file",
KeyFile: "not_empty_key_file",
},
wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert1}},
wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.ClientCert2}},
},
{
desc: "use trust certs only",
options: PEMFileProviderOptions{
TrustFile: "not_empty_trust_file",
},
wantKmStage0: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
wantKmStage1: certprovider.KeyMaterial{Roots: cs.ServerTrust1},
wantKmStage2: certprovider.KeyMaterial{Roots: cs.ServerTrust2},
},
}
for _, test := range tests {
testInterval := 200 * time.Millisecond
test.options.IdentityInterval = testInterval
test.options.RootInterval = testInterval
t.Run(test.desc, func(t *testing.T) {
stage := &stageInfo{}
oldReadKeyCertPairFunc := readKeyCertPairFunc
readKeyCertPairFunc = func(certFile, keyFile string) (tls.Certificate, error) {
switch stage.read() {
case 0:
return cs.ClientCert1, nil
case 1:
return tls.Certificate{}, fmt.Errorf("error occurred while reloading")
case 2:
return cs.ClientCert2, nil
default:
return tls.Certificate{}, fmt.Errorf("test stage not supported")
}
}
defer func() {
readKeyCertPairFunc = oldReadKeyCertPairFunc
}()
oldReadTrustCertFunc := readTrustCertFunc
readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) {
switch stage.read() {
case 0:
return cs.ServerTrust1, nil
case 1:
return nil, fmt.Errorf("error occurred while reloading")
case 2:
return cs.ServerTrust2, nil
default:
return nil, fmt.Errorf("test stage not supported")
}
}
defer func() {
readTrustCertFunc = oldReadTrustCertFunc
}()
provider, err := NewPEMFileProvider(test.options)
if err != nil {
t.Fatalf("NewPEMFileProvider failed: %v", err)
}
defer provider.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
//// ------------------------Stage 0------------------------------------
// Wait for the refreshing go-routine to pick up the changes.
time.Sleep(1 * time.Second)
gotKM, err := provider.KeyMaterial(ctx)
if !cmp.Equal(*gotKM, test.wantKmStage0, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage0)
}
// ------------------------Stage 1------------------------------------
stage.increase()
// Wait for the refreshing go-routine to pick up the changes.
time.Sleep(1 * time.Second)
gotKM, err = provider.KeyMaterial(ctx)
if !cmp.Equal(*gotKM, test.wantKmStage1, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage1)
}
//// ------------------------Stage 2------------------------------------
// Wait for the refreshing go-routine to pick up the changes.
stage.increase()
time.Sleep(1 * time.Second)
gotKM, err = provider.KeyMaterial(ctx)
if !cmp.Equal(*gotKM, test.wantKmStage2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage2)
}
stage.reset()
})
}
}