blob: cc8b08e5d933dc1d9340cb53c24a0c99233b42f6 [file] [log] [blame]
// +build go1.13
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package sts
import (
"bytes"
"context"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
)
const (
requestedTokenType = "urn:ietf:params:oauth:token-type:access-token"
actorTokenPath = "/var/run/secrets/token.jwt"
actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
actorTokenContents = "actorToken.jwt.contents"
accessTokenContents = "access_token"
subjectTokenPath = "/var/run/secrets/token.jwt"
subjectTokenType = "urn:ietf:params:oauth:token-type:id_token"
subjectTokenContents = "subjectToken.jwt.contents"
serviceURI = "http://localhost"
exampleResource = "https://backend.example.com/api"
exampleAudience = "example-backend-service"
testScope = "https://www.googleapis.com/auth/monitoring"
defaultTestTimeout = 1 * time.Second
)
var (
goodOptions = Options{
TokenExchangeServiceURI: serviceURI,
Audience: exampleAudience,
RequestedTokenType: requestedTokenType,
SubjectTokenPath: subjectTokenPath,
SubjectTokenType: subjectTokenType,
}
goodRequestParams = &requestParameters{
GrantType: tokenExchangeGrantType,
Audience: exampleAudience,
Scope: defaultCloudPlatformScope,
RequestedTokenType: requestedTokenType,
SubjectToken: subjectTokenContents,
SubjectTokenType: subjectTokenType,
}
goodMetadata = map[string]string{
"Authorization": fmt.Sprintf("Bearer %s", accessTokenContents),
}
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// A struct that implements AuthInfo interface and added to the context passed
// to GetRequestMetadata from tests.
type testAuthInfo struct {
credentials.CommonAuthInfo
}
func (ta testAuthInfo) AuthType() string {
return "testAuthInfo"
}
func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context {
auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}}
ri := credentials.RequestInfo{
Method: "testInfo",
AuthInfo: auth,
}
return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
}
// errReader implements the io.Reader interface and returns an error from the
// Read method.
type errReader struct{}
func (r errReader) Read(b []byte) (n int, err error) {
return 0, errors.New("read error")
}
// We need a function to construct the response instead of simply declaring it
// as a variable since the the response body will be consumed by the
// credentials, and therefore we will need a new one everytime.
func makeGoodResponse() *http.Response {
respJSON, _ := json.Marshal(responseParameters{
AccessToken: accessTokenContents,
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
return &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: respBody,
}
}
// fakeHTTPDoer helps mock out the http.Client.Do calls made by the credentials
// code under test. It makes the http.Request made by the credentials available
// through a channel, and makes it possible to inject various responses.
type fakeHTTPDoer struct {
reqCh *testutils.Channel
respCh *testutils.Channel
err error
}
func (fc *fakeHTTPDoer) Do(req *http.Request) (*http.Response, error) {
fc.reqCh.Send(req)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := fc.respCh.Receive(ctx)
if err != nil {
return nil, err
}
return val.(*http.Response), fc.err
}
// Overrides the http.Client with a fakeClient which sends a good response.
func overrideHTTPClientGood() (*fakeHTTPDoer, func()) {
fc := &fakeHTTPDoer{
reqCh: testutils.NewChannel(),
respCh: testutils.NewChannel(),
}
fc.respCh.Send(makeGoodResponse())
origMakeHTTPDoer := makeHTTPDoer
makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
return fc, func() { makeHTTPDoer = origMakeHTTPDoer }
}
// Overrides the http.Client with the provided fakeClient.
func overrideHTTPClient(fc *fakeHTTPDoer) func() {
origMakeHTTPDoer := makeHTTPDoer
makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
return func() { makeHTTPDoer = origMakeHTTPDoer }
}
// Overrides the subject token read to return a const which we can compare in
// our tests.
func overrideSubjectTokenGood() func() {
origReadSubjectTokenFrom := readSubjectTokenFrom
readSubjectTokenFrom = func(path string) ([]byte, error) {
return []byte(subjectTokenContents), nil
}
return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
}
// Overrides the subject token read to always return an error.
func overrideSubjectTokenError() func() {
origReadSubjectTokenFrom := readSubjectTokenFrom
readSubjectTokenFrom = func(path string) ([]byte, error) {
return nil, errors.New("error reading subject token")
}
return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
}
// Overrides the actor token read to return a const which we can compare in
// our tests.
func overrideActorTokenGood() func() {
origReadActorTokenFrom := readActorTokenFrom
readActorTokenFrom = func(path string) ([]byte, error) {
return []byte(actorTokenContents), nil
}
return func() { readActorTokenFrom = origReadActorTokenFrom }
}
// Overrides the actor token read to always return an error.
func overrideActorTokenError() func() {
origReadActorTokenFrom := readActorTokenFrom
readActorTokenFrom = func(path string) ([]byte, error) {
return nil, errors.New("error reading actor token")
}
return func() { readActorTokenFrom = origReadActorTokenFrom }
}
// compareRequest compares the http.Request received in the test with the
// expected requestParameters specified in wantReqParams.
func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error {
jsonBody, err := json.Marshal(wantReqParams)
if err != nil {
return err
}
wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody))
if err != nil {
return fmt.Errorf("failed to create http request: %v", err)
}
wantReq.Header.Set("Content-Type", "application/json")
wantR, err := httputil.DumpRequestOut(wantReq, true)
if err != nil {
return err
}
gotR, err := httputil.DumpRequestOut(gotRequest, true)
if err != nil {
return err
}
if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" {
return fmt.Errorf("sts request diff (-want +got):\n%s", diff)
}
return nil
}
// receiveAndCompareRequest waits for a request to be sent out by the
// credentials implementation using the fakeHTTPClient and compares it to an
// expected goodRequest. This is expected to be called in a separate goroutine
// by the tests. So, any errors encountered are pushed to an error channel
// which is monitored by the test.
func receiveAndCompareRequest(reqCh *testutils.Channel, errCh chan error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := reqCh.Receive(ctx)
if err != nil {
errCh <- err
return
}
req := val.(*http.Request)
if err := compareRequest(req, goodRequestParams); err != nil {
errCh <- err
return
}
errCh <- nil
}
// TestGetRequestMetadataSuccess verifies the successful case of sending an
// token exchange request and processing the response.
func (s) TestGetRequestMetadataSuccess(t *testing.T) {
defer overrideSubjectTokenGood()()
fc, cancel := overrideHTTPClientGood()
defer cancel()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
errCh := make(chan error, 1)
go receiveAndCompareRequest(fc.reqCh, errCh)
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
if err != nil {
t.Fatalf("creds.GetRequestMetadata() = %v", err)
}
if !cmp.Equal(gotMetadata, goodMetadata) {
t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
}
if err := <-errCh; err != nil {
t.Fatal(err)
}
// Make another call to get request metadata and this should return contents
// from the cache. This will fail if the credentials tries to send a fresh
// request here since we have not configured our fakeClient to return any
// response on retries.
gotMetadata, err = creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
if err != nil {
t.Fatalf("creds.GetRequestMetadata() = %v", err)
}
if !cmp.Equal(gotMetadata, goodMetadata) {
t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
}
}
// TestGetRequestMetadataBadSecurityLevel verifies the case where the
// securityLevel specified in the context passed to GetRequestMetadata is not
// sufficient.
func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) {
defer overrideSubjectTokenGood()()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.IntegrityOnly), "")
if err == nil {
t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata)
}
}
// TestGetRequestMetadataCacheExpiry verifies the case where the cached access
// token has expired, and the credentials implementation will have to send a
// fresh token exchange request.
func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) {
const expiresInSecs = 1
defer overrideSubjectTokenGood()()
fc := &fakeHTTPDoer{
reqCh: testutils.NewChannel(),
respCh: testutils.NewChannel(),
}
defer overrideHTTPClient(fc)()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
// The fakeClient is configured to return an access_token with a one second
// expiry. So, in the second iteration, the credentials will find the cache
// entry, but that would have expired, and therefore we expect it to send
// out a fresh request.
for i := 0; i < 2; i++ {
errCh := make(chan error, 1)
go receiveAndCompareRequest(fc.reqCh, errCh)
respJSON, _ := json.Marshal(responseParameters{
AccessToken: accessTokenContents,
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: expiresInSecs,
})
respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
resp := &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: respBody,
}
fc.respCh.Send(resp)
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
if err != nil {
t.Fatalf("creds.GetRequestMetadata() = %v", err)
}
if !cmp.Equal(gotMetadata, goodMetadata) {
t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
}
if err := <-errCh; err != nil {
t.Fatal(err)
}
time.Sleep(expiresInSecs * time.Second)
}
}
// TestGetRequestMetadataBadResponses verifies the scenario where the token
// exchange server returns bad responses.
func (s) TestGetRequestMetadataBadResponses(t *testing.T) {
tests := []struct {
name string
response *http.Response
}{
{
name: "bad JSON",
response: &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader("not JSON")),
},
},
{
name: "no access token",
response: &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader("{}")),
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
defer overrideSubjectTokenGood()()
fc := &fakeHTTPDoer{
reqCh: testutils.NewChannel(),
respCh: testutils.NewChannel(),
}
defer overrideHTTPClient(fc)()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
errCh := make(chan error, 1)
go receiveAndCompareRequest(fc.reqCh, errCh)
fc.respCh.Send(test.response)
if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil {
t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
}
if err := <-errCh; err != nil {
t.Fatal(err)
}
})
}
}
// TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the
// attempt to read the subjectToken fails.
func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) {
defer overrideSubjectTokenError()()
fc, cancel := overrideHTTPClientGood()
defer cancel()
creds, err := NewCredentials(goodOptions)
if err != nil {
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
}
errCh := make(chan error, 1)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := fc.reqCh.Receive(ctx); err != context.DeadlineExceeded {
errCh <- err
return
}
errCh <- nil
}()
if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil {
t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
}
if err := <-errCh; err != nil {
t.Fatal(err)
}
}
func (s) TestNewCredentials(t *testing.T) {
tests := []struct {
name string
opts Options
errSystemRoots bool
wantErr bool
}{
{
name: "invalid options - empty subjectTokenPath",
opts: Options{
TokenExchangeServiceURI: serviceURI,
},
wantErr: true,
},
{
name: "invalid system root certs",
opts: goodOptions,
errSystemRoots: true,
wantErr: true,
},
{
name: "good case",
opts: goodOptions,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.errSystemRoots {
oldSystemRoots := loadSystemCertPool
loadSystemCertPool = func() (*x509.CertPool, error) {
return nil, errors.New("failed to load system cert pool")
}
defer func() {
loadSystemCertPool = oldSystemRoots
}()
}
creds, err := NewCredentials(test.opts)
if (err != nil) != test.wantErr {
t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr)
}
if err == nil {
if !creds.RequireTransportSecurity() {
t.Errorf("creds.RequireTransportSecurity() returned false")
}
}
})
}
}
func (s) TestValidateOptions(t *testing.T) {
tests := []struct {
name string
opts Options
wantErrPrefix string
}{
{
name: "empty token exchange service URI",
opts: Options{},
wantErrPrefix: "empty token_exchange_service_uri in options",
},
{
name: "invalid URI",
opts: Options{
TokenExchangeServiceURI: "\tI'm a bad URI\n",
},
wantErrPrefix: "invalid control character in URL",
},
{
name: "unsupported scheme",
opts: Options{
TokenExchangeServiceURI: "unix:///path/to/socket",
},
wantErrPrefix: "scheme is not supported",
},
{
name: "empty subjectTokenPath",
opts: Options{
TokenExchangeServiceURI: serviceURI,
},
wantErrPrefix: "required field SubjectTokenPath is not specified",
},
{
name: "empty subjectTokenType",
opts: Options{
TokenExchangeServiceURI: serviceURI,
SubjectTokenPath: subjectTokenPath,
},
wantErrPrefix: "required field SubjectTokenType is not specified",
},
{
name: "good options",
opts: goodOptions,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := validateOptions(test.opts)
if (err != nil) != (test.wantErrPrefix != "") {
t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
}
if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) {
t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
}
})
}
}
func (s) TestConstructRequest(t *testing.T) {
tests := []struct {
name string
opts Options
subjectTokenReadErr bool
actorTokenReadErr bool
wantReqParams *requestParameters
wantErr bool
}{
{
name: "subject token read failure",
subjectTokenReadErr: true,
opts: goodOptions,
wantErr: true,
},
{
name: "actor token read failure",
actorTokenReadErr: true,
opts: Options{
TokenExchangeServiceURI: serviceURI,
Audience: exampleAudience,
RequestedTokenType: requestedTokenType,
SubjectTokenPath: subjectTokenPath,
SubjectTokenType: subjectTokenType,
ActorTokenPath: actorTokenPath,
ActorTokenType: actorTokenType,
},
wantErr: true,
},
{
name: "default cloud platform scope",
opts: goodOptions,
wantReqParams: goodRequestParams,
},
{
name: "all good",
opts: Options{
TokenExchangeServiceURI: serviceURI,
Resource: exampleResource,
Audience: exampleAudience,
Scope: testScope,
RequestedTokenType: requestedTokenType,
SubjectTokenPath: subjectTokenPath,
SubjectTokenType: subjectTokenType,
ActorTokenPath: actorTokenPath,
ActorTokenType: actorTokenType,
},
wantReqParams: &requestParameters{
GrantType: tokenExchangeGrantType,
Resource: exampleResource,
Audience: exampleAudience,
Scope: testScope,
RequestedTokenType: requestedTokenType,
SubjectToken: subjectTokenContents,
SubjectTokenType: subjectTokenType,
ActorToken: actorTokenContents,
ActorTokenType: actorTokenType,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.subjectTokenReadErr {
defer overrideSubjectTokenError()()
} else {
defer overrideSubjectTokenGood()()
}
if test.actorTokenReadErr {
defer overrideActorTokenError()()
} else {
defer overrideActorTokenGood()()
}
gotRequest, err := constructRequest(context.Background(), test.opts)
if (err != nil) != test.wantErr {
t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr)
}
if test.wantErr {
return
}
if err := compareRequest(gotRequest, test.wantReqParams); err != nil {
t.Fatal(err)
}
})
}
}
func (s) TestSendRequest(t *testing.T) {
defer overrideSubjectTokenGood()()
req, err := constructRequest(context.Background(), goodOptions)
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
resp *http.Response
respErr error
wantErr bool
}{
{
name: "client error",
respErr: errors.New("http.Client.Do failed"),
wantErr: true,
},
{
name: "bad response body",
resp: &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(errReader{}),
},
wantErr: true,
},
{
name: "nonOK status code",
resp: &http.Response{
Status: "400 BadRequest",
StatusCode: http.StatusBadRequest,
Body: ioutil.NopCloser(strings.NewReader("")),
},
wantErr: true,
},
{
name: "good case",
resp: makeGoodResponse(),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
client := &fakeHTTPDoer{
reqCh: testutils.NewChannel(),
respCh: testutils.NewChannel(),
err: test.respErr,
}
client.respCh.Send(test.resp)
_, err := sendRequest(client, req)
if (err != nil) != test.wantErr {
t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr)
}
})
}
}
func (s) TestTokenInfoFromResponse(t *testing.T) {
noAccessToken, _ := json.Marshal(responseParameters{
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
goodResponse, _ := json.Marshal(responseParameters{
IssuedTokenType: requestedTokenType,
AccessToken: accessTokenContents,
TokenType: "Bearer",
ExpiresIn: 3600,
})
tests := []struct {
name string
respBody []byte
wantTokenInfo *tokenInfo
wantErr bool
}{
{
name: "bad JSON",
respBody: []byte("not JSON"),
wantErr: true,
},
{
name: "empty response",
respBody: []byte(""),
wantErr: true,
},
{
name: "non-empty response with no access token",
respBody: noAccessToken,
wantErr: true,
},
{
name: "good response",
respBody: goodResponse,
wantTokenInfo: &tokenInfo{
tokenType: "Bearer",
token: accessTokenContents,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
gotTokenInfo, err := tokenInfoFromResponse(test.respBody)
if (err != nil) != test.wantErr {
t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr)
}
if test.wantErr {
return
}
// Can't do a cmp.Equal on the whole struct since the expiryField
// is populated based on time.Now().
if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token {
t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo)
}
})
}
}