blob: 0ab1f7c59f07fa822cbf1de520f56203d608cfa6 [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package jwt
import (
"encoding/base64"
"encoding/binary"
"fmt"
"strings"
spb "google.golang.org/protobuf/types/known/structpb"
tpb "github.com/google/tink/go/proto/tink_go_proto"
)
// keyID returns the keyID in big endian format base64 encoded if the key output prefix is of type Tink or nil otherwise.
func keyID(keyID uint32, outPrefixType tpb.OutputPrefixType) *string {
if outPrefixType != tpb.OutputPrefixType_TINK {
return nil
}
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, keyID)
s := base64Encode(buf)
return &s
}
// createUnsigned creates an unsigned JWT by created the header/payload, encoding them to a websafe base64 encoded string and concatenating.
func createUnsigned(rawJWT *RawJWT, algo string, tinkKID *string, customKID *string) (string, error) {
if rawJWT == nil {
return "", fmt.Errorf("rawJWT is nil")
}
var typeHeader *string = nil
if rawJWT.HasTypeHeader() {
th, err := rawJWT.TypeHeader()
if err != nil {
return "", err
}
typeHeader = &th
}
if customKID != nil && tinkKID != nil {
return "", fmt.Errorf("TINK Keys are not allowed to have a kid value set")
}
if tinkKID != nil {
customKID = tinkKID
}
encodedHeader, err := createHeader(algo, typeHeader, customKID)
if err != nil {
return "", err
}
payload, err := rawJWT.JSONPayload()
if err != nil {
return "", err
}
return dotConcat(encodedHeader, base64Encode(payload)), nil
}
// combineUnsignedAndSignature combines the token with the raw signature to provide a signed token.
func combineUnsignedAndSignature(unsigned string, signature []byte) string {
return dotConcat(unsigned, base64Encode(signature))
}
// splitSignedCompact extracts the witness and usigned JWT.
func splitSignedCompact(compact string) ([]byte, string, error) {
i := strings.LastIndex(compact, ".")
if i < 0 {
return nil, "", fmt.Errorf("invalid token")
}
witness, err := base64Decode(compact[i+1:])
if err != nil {
return nil, "", fmt.Errorf("%q: %v", compact[i+1:], err)
}
if len(witness) == 0 {
return nil, "", fmt.Errorf("empty signature")
}
unsigned := compact[0:i]
if len(unsigned) == 0 {
return nil, "", fmt.Errorf("empty content")
}
if strings.Count(unsigned, ".") != 1 {
return nil, "", fmt.Errorf("only tokens in JWS compact serialization formats are supported")
}
return witness, unsigned, nil
}
// decodeUnsignedTokenAndValidateHeader verifies the header on an unsigned JWT and decodes the payload into a RawJWT.
// Expects the token to be in compact serialization format. The signature should be verified before calling this function.
func decodeUnsignedTokenAndValidateHeader(unsigned, algorithm string, tinkKID, customKID *string) (*RawJWT, error) {
parts := strings.Split(unsigned, ".")
if len(parts) != 2 {
return nil, fmt.Errorf("only tokens in JWS compact serialization formats are supported")
}
jsonHeader, err := base64Decode(parts[0])
if err != nil {
return nil, err
}
header, err := jsonToStruct(jsonHeader)
if err != nil {
return nil, err
}
if err := validateHeader(header, algorithm, tinkKID, customKID); err != nil {
return nil, err
}
typeHeader, err := extractTypeHeader(header)
if err != nil {
return nil, err
}
jsonPayload, err := base64Decode(parts[1])
if err != nil {
return nil, err
}
return NewRawJWTFromJSON(typeHeader, jsonPayload)
}
// base64Encode encodes a byte array into a base64 URL safe string with no padding.
func base64Encode(content []byte) string {
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(content)
}
// base64Decode decodes a URL safe base64 encoded string into a byte array ignoring padding.
func base64Decode(content string) ([]byte, error) {
for _, c := range content {
if !isValidURLsafeBase64Char(c) {
return nil, fmt.Errorf("invalid encoding")
}
}
return base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(content)
}
func isValidURLsafeBase64Char(c rune) bool {
return (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')) ||
((c >= '0') && (c <= '9')) || ((c == '-') || (c == '_')))
}
func dotConcat(a, b string) string {
return fmt.Sprintf("%s.%s", a, b)
}
func jsonToStruct(jsonPayload []byte) (*spb.Struct, error) {
payload := &spb.Struct{}
if err := payload.UnmarshalJSON(jsonPayload); err != nil {
return nil, err
}
return payload, nil
}
func extractTypeHeader(header *spb.Struct) (*string, error) {
fields := header.GetFields()
if fields == nil {
return nil, fmt.Errorf("header contains no fields")
}
val, ok := fields["typ"]
if !ok {
return nil, nil
}
str, ok := val.Kind.(*spb.Value_StringValue)
if !ok {
return nil, fmt.Errorf("type header isn't a string")
}
return &str.StringValue, nil
}
func createHeader(algorithm string, typeHeader, kid *string) (string, error) {
header := &spb.Struct{
Fields: map[string]*spb.Value{
"alg": spb.NewStringValue(algorithm),
},
}
if typeHeader != nil {
header.Fields["typ"] = spb.NewStringValue(*typeHeader)
}
if kid != nil {
header.Fields["kid"] = spb.NewStringValue(*kid)
}
jsonHeader, err := header.MarshalJSON()
if err != nil {
return "", err
}
return base64Encode(jsonHeader), nil
}
func validateHeader(header *spb.Struct, algorithm string, tinkKID, customKID *string) error {
fields := header.GetFields()
if fields == nil {
return fmt.Errorf("header contains no fields")
}
alg, err := headerStringField(fields, "alg")
if err != nil {
return err
}
if alg != algorithm {
return fmt.Errorf("invalid alg")
}
if _, ok := fields["crit"]; ok {
return fmt.Errorf("all tokens with crit headers are rejected")
}
if tinkKID != nil && customKID != nil {
return fmt.Errorf("custom_kid can only be set for RAW keys")
}
_, hasKID := fields["kid"]
if tinkKID != nil && !hasKID {
return fmt.Errorf("missing kid in header")
}
if tinkKID != nil {
return validateKIDInHeader(fields, tinkKID)
}
if hasKID && customKID != nil {
return validateKIDInHeader(fields, customKID)
}
return nil
}
func validateKIDInHeader(fields map[string]*spb.Value, kid *string) error {
headerKID, err := headerStringField(fields, "kid")
if err != nil {
return err
}
if headerKID != *kid {
return fmt.Errorf("invalid kid header")
}
return nil
}
func headerStringField(fields map[string]*spb.Value, name string) (string, error) {
val, ok := fields[name]
if !ok {
return "", fmt.Errorf("header is missing %q", name)
}
str, ok := val.Kind.(*spb.Value_StringValue)
if !ok {
return "", fmt.Errorf("%q header isn't a string", name)
}
return str.StringValue, nil
}