blob: 609d52b14933c41334edc805ab88de2b0ea691e3 [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package jwt
import (
"encoding/base64"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"github.com/google/tink/go/core/registry"
"github.com/google/tink/go/subtle/random"
jwtmacpb "github.com/google/tink/go/proto/jwt_hmac_go_proto"
tinkpb "github.com/google/tink/go/proto/tink_go_proto"
)
type jwtKeyManagerTestCase struct {
tag string
keyFormat *jwtmacpb.JwtHmacKeyFormat
key *jwtmacpb.JwtHmacKey
}
const (
typeURL = "type.googleapis.com/google.crypto.tink.JwtHmacKey"
)
func generateKeyFormat(keySize uint32, algorithm jwtmacpb.JwtHmacAlgorithm) *jwtmacpb.JwtHmacKeyFormat {
return &jwtmacpb.JwtHmacKeyFormat{
KeySize: keySize,
Algorithm: algorithm,
}
}
func TestDoesSupport(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err)
}
if !km.DoesSupport(typeURL) {
t.Errorf("km.DoesSupport(%q) = false, want true", typeURL)
}
}
func TestTypeURL(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err)
}
if km.TypeURL() != typeURL {
t.Errorf("km.TypeURL() = %q, want %q", km.TypeURL(), typeURL)
}
}
var invalidKeyFormatTestCases = []jwtKeyManagerTestCase{
{
tag: "invalid hash algorithm",
keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN),
},
{
tag: "invalid HS256 key size",
keyFormat: generateKeyFormat(31, jwtmacpb.JwtHmacAlgorithm_HS256),
},
{
tag: "invalid HS384 key size",
keyFormat: generateKeyFormat(47, jwtmacpb.JwtHmacAlgorithm_HS384),
},
{
tag: "invalid HS512 key size",
keyFormat: generateKeyFormat(63, jwtmacpb.JwtHmacAlgorithm_HS512),
},
{
tag: "empty key format",
keyFormat: &jwtmacpb.JwtHmacKeyFormat{},
},
{
tag: "nil key format",
keyFormat: nil,
},
}
func TestNewKeyInvalidFormatFails(t *testing.T) {
for _, tc := range invalidKeyFormatTestCases {
t.Run(tc.tag, func(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
}
serializedKeyFormat, err := proto.Marshal(tc.keyFormat)
if err != nil {
t.Errorf("serializing key format: %v", err)
}
if _, err := km.NewKey(serializedKeyFormat); err == nil {
t.Errorf("km.NewKey() err = nil, want error")
}
})
}
}
func TestNewDataInvalidFormatFails(t *testing.T) {
for _, tc := range invalidKeyFormatTestCases {
t.Run(tc.tag, func(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
}
serializedKeyFormat, err := proto.Marshal(tc.keyFormat)
if err != nil {
t.Errorf("serializing key format: %v", err)
}
if _, err := km.NewKeyData(serializedKeyFormat); err == nil {
t.Errorf("km.NewKey() err = nil, want error")
}
})
}
}
var validKeyFormatTestCases = []jwtKeyManagerTestCase{
{
tag: "SHA256 hash algorithm",
keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256),
},
{
tag: "SHA384 hash algorithm",
keyFormat: generateKeyFormat(48, jwtmacpb.JwtHmacAlgorithm_HS384),
},
{
tag: "SHA512 hash algorithm",
keyFormat: generateKeyFormat(64, jwtmacpb.JwtHmacAlgorithm_HS512),
},
}
func TestNewKey(t *testing.T) {
for _, tc := range validKeyFormatTestCases {
t.Run(tc.tag, func(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
}
serializedKeyFormat, err := proto.Marshal(tc.keyFormat)
if err != nil {
t.Errorf("serializing key format: %v", err)
}
k, err := km.NewKey(serializedKeyFormat)
if err != nil {
t.Errorf("km.NewKey() err = %v, want nil", err)
}
key, ok := k.(*jwtmacpb.JwtHmacKey)
if !ok {
t.Errorf("key isn't of type JwtHmacKey")
}
if key.Algorithm != tc.keyFormat.Algorithm {
t.Errorf("k.Algorithm = %v, want %v", key.Algorithm, tc.keyFormat.Algorithm)
}
if len(key.KeyValue) != int(tc.keyFormat.KeySize) {
t.Errorf("len(key.KeyValue) = %d, want %d", len(key.KeyValue), tc.keyFormat.KeySize)
}
})
}
}
func TestNewKeyData(t *testing.T) {
for _, tc := range validKeyFormatTestCases {
t.Run(tc.tag, func(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
}
serializedKeyFormat, err := proto.Marshal(tc.keyFormat)
if err != nil {
t.Errorf("serializing key format: %v", err)
}
k, err := km.NewKeyData(serializedKeyFormat)
if err != nil {
t.Errorf("km.NewKeyData() err = %v, want nil", err)
}
if k.GetTypeUrl() != typeURL {
t.Errorf("k.GetTypeUrl() = %q, want %q", k.GetTypeUrl(), typeURL)
}
if k.GetKeyMaterialType() != tinkpb.KeyData_SYMMETRIC {
t.Errorf("k.GetKeyMaterialType() = %q, want %q", k.GetKeyMaterialType(), tinkpb.KeyData_SYMMETRIC)
}
})
}
}
func generateKey(keySize, version uint32, algorithm jwtmacpb.JwtHmacAlgorithm, kid *jwtmacpb.JwtHmacKey_CustomKid) *jwtmacpb.JwtHmacKey {
return &jwtmacpb.JwtHmacKey{
KeyValue: random.GetRandomBytes(keySize),
Algorithm: algorithm,
CustomKid: kid,
Version: version,
}
}
func TestGetPrimitiveWithValidKeys(t *testing.T) {
rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true, Audiences: []string{"tink-aud"}})
if err != nil {
t.Fatalf("NewRawJWT() err = %v, want nil", err)
}
validator, err := NewValidator(&ValidatorOpts{AllowMissingExpiration: true, ExpectedAudience: refString("tink-aud")})
if err != nil {
t.Fatalf("NewValidator() err = %v, want nil", err)
}
for _, tc := range []jwtKeyManagerTestCase{
{
tag: "SHA256 hash algorithm",
key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil),
},
{
tag: "SHA384 hash algorithm",
key: generateKey(48, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil),
},
{
tag: "SHA512 hash algorithm",
key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil),
},
{
tag: "with custom kid",
key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"}),
},
} {
t.Run(tc.tag, func(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
}
serializedKey, err := proto.Marshal(tc.key)
if err != nil {
t.Errorf("serializing key format: %v", err)
}
p, err := km.Primitive(serializedKey)
if err != nil {
t.Errorf("km.Primitive() err = %v, want nil", err)
}
primitive, ok := p.(*macWithKID)
if !ok {
t.Errorf("primitive isn't of type: macWithKID")
}
compact, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, nil)
if err != nil {
t.Errorf("ComputeMACAndEncodeWithKID() err = %v, want nil", err)
}
verifiedJWT, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil)
if err != nil {
t.Errorf("VerifyMACAndDecodeWithKID() err = %v, want nil", err)
}
audiences, err := verifiedJWT.Audiences()
if err != nil {
t.Errorf("verifiedJWT.Audiences() err = %v, want nil", err)
}
if !cmp.Equal(audiences, []string{"tink-aud"}) {
t.Errorf("verifiedJWT.Audiences() = %q, want ['tink-aud']", audiences)
}
})
}
}
func TestGetPrimitiveWithInvalidKeys(t *testing.T) {
for _, tc := range []jwtKeyManagerTestCase{
{
tag: "HS256",
key: generateKey(31, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil),
},
{
tag: "HS384",
key: generateKey(47, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil),
},
{
tag: "HS512",
key: generateKey(63, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil),
},
} {
t.Run(tc.tag, func(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Fatalf("registry.GetKeyManager(%q) err=%q, want nil", typeURL, err)
}
serializedKey, err := proto.Marshal(tc.key)
if err != nil {
t.Fatalf("proto.Marshal(tc.key) err =%q, want nil", err)
}
_, err = km.Primitive(serializedKey)
if err == nil {
t.Error("km.Primitive(serializedKey) err = nil, want error")
}
})
}
}
func TestSpecyfingCustomKIDAndTINKKIDFails(t *testing.T) {
// key and compact are examples from: https://datatracker.ietf.org/doc/html/rfc7515#appendix-A.1.1
compact := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
rawKey, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString("AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow")
if err != nil {
t.Fatalf("failed decoding test key: %v", err)
}
key := &jwtmacpb.JwtHmacKey{
KeyValue: rawKey,
Algorithm: jwtmacpb.JwtHmacAlgorithm_HS256,
CustomKid: &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"},
Version: 0,
}
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
}
serializedKey, err := proto.Marshal(key)
if err != nil {
t.Errorf("serializing key format: %v", err)
}
p, err := km.Primitive(serializedKey)
if err != nil {
t.Errorf("km.Primitive() err = %v, want nil", err)
}
primitive, ok := p.(*macWithKID)
if !ok {
t.Errorf("primitive isn't of type: macWithKID")
}
rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true})
if err != nil {
t.Errorf("creating new RawJWT: %v", err)
}
opts := &ValidatorOpts{
ExpectedTypeHeader: refString("JWT"),
ExpectedIssuer: refString("joe"),
FixedNow: time.Unix(12345, 0),
}
validator, err := NewValidator(opts)
if err != nil {
t.Errorf("creating new JWTValidator: %v", err)
}
if _, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, refString("4566")); err == nil {
t.Errorf("primitive.ComputeMACAndEncodeWithKID() err = nil, want error")
}
if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, refString("4566")); err == nil {
t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = 4566) err = nil, want error")
}
// Verify success without KID
if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil); err != nil {
t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = nil) err = %v, want nil", err)
}
}
func TestGetPrimitiveWithInvalidKeyFails(t *testing.T) {
for _, tc := range []jwtKeyManagerTestCase{
{
tag: "empty key",
key: &jwtmacpb.JwtHmacKey{},
},
{
tag: "nil key",
key: nil,
},
{
tag: "unsupported hash algorithm",
key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN, nil),
},
{
tag: "short key length",
key: generateKey(20, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil),
},
{
tag: "unsupported version",
key: generateKey(48, 1, jwtmacpb.JwtHmacAlgorithm_HS384, nil),
},
} {
t.Run(tc.tag, func(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
}
serializedKey, err := proto.Marshal(tc.key)
if err != nil {
t.Errorf("serializing key format: %v", err)
}
if _, err := km.Primitive(serializedKey); err == nil {
t.Errorf("km.Primitive() err = nil, want error")
}
})
}
}
func TestGeneratesDifferentKeys(t *testing.T) {
km, err := registry.GetKeyManager(typeURL)
if err != nil {
t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
}
serializedKeyFormat, err := proto.Marshal(generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256))
if err != nil {
t.Errorf("serializing key format: %v", err)
}
k1, err := km.NewKey(serializedKeyFormat)
if err != nil {
t.Errorf("km.NewKey() err = %v, want nil", err)
}
k2, err := km.NewKey(serializedKeyFormat)
if err != nil {
t.Errorf("km.NewKey() err = %v, want nil", err)
}
key1, ok := k1.(*jwtmacpb.JwtHmacKey)
if !ok {
t.Errorf("k1 isn't of type JwtHmacKey")
}
key2, ok := k2.(*jwtmacpb.JwtHmacKey)
if !ok {
t.Errorf("k2 isn't of type JwtHmacKey")
}
if cmp.Equal(key1.GetKeyValue(), key2.GetKeyValue()) {
t.Errorf("key material should differ")
}
}