blob: 70259101ed76230049cb444f976991768fc49dea [file] [log] [blame]
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package jwt
import (
"fmt"
"time"
"unicode/utf8"
spb "google.golang.org/protobuf/types/known/structpb"
)
const (
claimIssuer = "iss"
claimSubject = "sub"
claimAudience = "aud"
claimExpiration = "exp"
claimNotBefore = "nbf"
claimIssuedAt = "iat"
claimJWTID = "jti"
jwtTimestampMax = 253402300799
jwtTimestampMin = 0
)
// RawJWTOptions represent an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519.
//
// It contains all payload claims and a subset of the headers. It does not
// contain any headers that depend on the key, such as "alg" or "kid", because
// these headers are chosen when the token is signed and encoded, and should not
// be chosen by the user. This ensures that the key can be changed without any
// changes to the user code.
type RawJWTOptions struct {
Audiences []string
Audience *string
Subject *string
Issuer *string
JWTID *string
IssuedAt *time.Time
ExpiresAt *time.Time
NotBefore *time.Time
CustomClaims map[string]interface{}
TypeHeader *string
WithoutExpiration bool
}
// RawJWT is an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519.
type RawJWT struct {
jsonpb *spb.Struct
typeHeader *string
}
// NewRawJWT constructs a new RawJWT token based on the RawJwtOptions provided.
func NewRawJWT(opts *RawJWTOptions) (*RawJWT, error) {
if opts == nil {
return nil, fmt.Errorf("jwt options can't be nil")
}
payload, err := createPayload(opts)
if err != nil {
return nil, err
}
if err := validatePayload(payload); err != nil {
return nil, err
}
return &RawJWT{
jsonpb: payload,
typeHeader: opts.TypeHeader,
}, nil
}
// NewRawJWTFromJSON builds a RawJWT from a marshaled JSON.
// Users shouldn't call this function and instead use NewRawJWT.
func NewRawJWTFromJSON(typeHeader *string, jsonPayload []byte) (*RawJWT, error) {
payload := &spb.Struct{}
if err := payload.UnmarshalJSON(jsonPayload); err != nil {
return nil, err
}
if err := validatePayload(payload); err != nil {
return nil, err
}
return &RawJWT{
jsonpb: payload,
typeHeader: typeHeader,
}, nil
}
// JSONPayload marshals a RawJWT payload to JSON.
func (r *RawJWT) JSONPayload() ([]byte, error) {
return r.jsonpb.MarshalJSON()
}
// HasTypeHeader returns whether a RawJWT contains a type header.
func (r *RawJWT) HasTypeHeader() bool {
return r.typeHeader != nil
}
// TypeHeader returns the JWT type header.
func (r *RawJWT) TypeHeader() (string, error) {
if !r.HasTypeHeader() {
return "", fmt.Errorf("no type header present")
}
return *r.typeHeader, nil
}
// HasAudiences checks whether a JWT contains the audience claim ('aud').
func (r *RawJWT) HasAudiences() bool {
return r.hasField(claimAudience)
}
// Audiences returns a list of audiences from the 'aud' claim. If the 'aud' claim is a single string, it is converted into a list with a single entry.
func (r *RawJWT) Audiences() ([]string, error) {
aud, ok := r.field(claimAudience)
if !ok {
return nil, fmt.Errorf("no audience claim found")
}
if err := validateAudienceClaim(aud); err != nil {
return nil, err
}
if val, isString := aud.GetKind().(*spb.Value_StringValue); isString {
return []string{val.StringValue}, nil
}
s := []string{}
for _, a := range aud.GetListValue().GetValues() {
s = append(s, a.GetStringValue())
}
return s, nil
}
// HasSubject checks whether a JWT contains an issuer claim ('sub').
func (r *RawJWT) HasSubject() bool {
return r.hasField(claimSubject)
}
// Subject returns the subject claim ('sub') or an error if no claim is present.
func (r *RawJWT) Subject() (string, error) {
return r.stringClaim(claimSubject)
}
// HasIssuer checks whether a JWT contains an issuer claim ('iss').
func (r *RawJWT) HasIssuer() bool {
return r.hasField(claimIssuer)
}
// Issuer returns the issuer claim ('iss') or an error if no claim is present.
func (r *RawJWT) Issuer() (string, error) {
return r.stringClaim(claimIssuer)
}
// HasJWTID checks whether a JWT contains an JWT ID claim ('jti').
func (r *RawJWT) HasJWTID() bool {
return r.hasField(claimJWTID)
}
// JWTID returns the JWT ID claim ('jti') or an error if no claim is present.
func (r *RawJWT) JWTID() (string, error) {
return r.stringClaim(claimJWTID)
}
// HasIssuedAt checks whether a JWT contains an issued at claim ('iat').
func (r *RawJWT) HasIssuedAt() bool {
return r.hasField(claimIssuedAt)
}
// IssuedAt returns the issued at claim ('iat') or an error if no claim is present.
func (r *RawJWT) IssuedAt() (time.Time, error) {
return r.timeClaim(claimIssuedAt)
}
// HasExpiration checks whether a JWT contains an expiration time claim ('exp').
func (r *RawJWT) HasExpiration() bool {
return r.hasField(claimExpiration)
}
// ExpiresAt returns the expiration claim ('exp') or an error if no claim is present.
func (r *RawJWT) ExpiresAt() (time.Time, error) {
return r.timeClaim(claimExpiration)
}
// HasNotBefore checks whether a JWT contains a not before claim ('nbf').
func (r *RawJWT) HasNotBefore() bool {
return r.hasField(claimNotBefore)
}
// NotBefore returns the not before claim ('nbf') or an error if no claim is present.
func (r *RawJWT) NotBefore() (time.Time, error) {
return r.timeClaim(claimNotBefore)
}
// HasStringClaim checks whether a claim of type string is present.
func (r *RawJWT) HasStringClaim(name string) bool {
return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StringValue{}})
}
// StringClaim returns a custom string claim or an error if no claim is present.
func (r *RawJWT) StringClaim(name string) (string, error) {
if isRegisteredClaim(name) {
return "", fmt.Errorf("claim '%q' is a registered claim", name)
}
return r.stringClaim(name)
}
// HasNumberClaim checks whether a claim of type number is present.
func (r *RawJWT) HasNumberClaim(name string) bool {
return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NumberValue{}})
}
// NumberClaim returns a custom number claim or an error if no claim is present.
func (r *RawJWT) NumberClaim(name string) (float64, error) {
if isRegisteredClaim(name) {
return 0, fmt.Errorf("claim '%q' is a registered claim", name)
}
return r.numberClaim(name)
}
// HasBooleanClaim checks whether a claim of type boolean is present.
func (r *RawJWT) HasBooleanClaim(name string) bool {
return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_BoolValue{}})
}
// BooleanClaim returns a custom bool claim or an error if no claim is present.
func (r *RawJWT) BooleanClaim(name string) (bool, error) {
val, err := r.customClaim(name)
if err != nil {
return false, err
}
b, ok := val.Kind.(*spb.Value_BoolValue)
if !ok {
return false, fmt.Errorf("claim '%q' is not a boolean", name)
}
return b.BoolValue, nil
}
// HasNullClaim checks whether a claim of type null is present.
func (r *RawJWT) HasNullClaim(name string) bool {
return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NullValue{}})
}
// HasArrayClaim checks whether a claim of type list is present.
func (r *RawJWT) HasArrayClaim(name string) bool {
return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_ListValue{}})
}
// ArrayClaim returns a slice representing a JSON array for a claim or an error if the claim is empty.
func (r *RawJWT) ArrayClaim(name string) ([]interface{}, error) {
val, err := r.customClaim(name)
if err != nil {
return nil, err
}
if val.GetListValue() == nil {
return nil, fmt.Errorf("claim '%q' is not a list", name)
}
return val.GetListValue().AsSlice(), nil
}
// HasObjectClaim checks whether a claim of type JSON object is present.
func (r *RawJWT) HasObjectClaim(name string) bool {
return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StructValue{}})
}
// ObjectClaim returns a map representing a JSON object for a claim or an error if the claim is empty.
func (r *RawJWT) ObjectClaim(name string) (map[string]interface{}, error) {
val, err := r.customClaim(name)
if err != nil {
return nil, err
}
if val.GetStructValue() == nil {
return nil, fmt.Errorf("claim '%q' is not a JSON object", name)
}
return val.GetStructValue().AsMap(), err
}
// CustomClaimNames returns a list with the name of custom claims in a RawJWT.
func (r *RawJWT) CustomClaimNames() []string {
names := []string{}
for key := range r.jsonpb.GetFields() {
if !isRegisteredClaim(key) {
names = append(names, key)
}
}
return names
}
func (r *RawJWT) timeClaim(name string) (time.Time, error) {
n, err := r.numberClaim(name)
if err != nil {
return time.Time{}, err
}
return time.Unix(int64(n), 0), err
}
func (r *RawJWT) numberClaim(name string) (float64, error) {
val, ok := r.field(name)
if !ok {
return 0, fmt.Errorf("no '%q' claim found", name)
}
s, ok := val.Kind.(*spb.Value_NumberValue)
if !ok {
return 0, fmt.Errorf("claim '%q' is not a number", name)
}
return s.NumberValue, nil
}
func (r *RawJWT) stringClaim(name string) (string, error) {
val, ok := r.field(name)
if !ok {
return "", fmt.Errorf("no '%q' claim found", name)
}
s, ok := val.Kind.(*spb.Value_StringValue)
if !ok {
return "", fmt.Errorf("claim '%q' is not a string", name)
}
if !utf8.ValidString(s.StringValue) {
return "", fmt.Errorf("claim '%q' is not a valid utf-8 encoded string", name)
}
return s.StringValue, nil
}
func (r *RawJWT) hasClaimOfKind(name string, exp *spb.Value) bool {
val, exist := r.field(name)
if !exist || exp == nil {
return false
}
var isKind bool
switch exp.GetKind().(type) {
case *spb.Value_StructValue:
_, isKind = val.GetKind().(*spb.Value_StructValue)
case *spb.Value_NullValue:
_, isKind = val.GetKind().(*spb.Value_NullValue)
case *spb.Value_BoolValue:
_, isKind = val.GetKind().(*spb.Value_BoolValue)
case *spb.Value_ListValue:
_, isKind = val.GetKind().(*spb.Value_ListValue)
case *spb.Value_StringValue:
_, isKind = val.GetKind().(*spb.Value_StringValue)
case *spb.Value_NumberValue:
_, isKind = val.GetKind().(*spb.Value_NumberValue)
default:
isKind = false
}
return isKind
}
func (r *RawJWT) customClaim(name string) (*spb.Value, error) {
if isRegisteredClaim(name) {
return nil, fmt.Errorf("'%q' is a registered claim", name)
}
val, ok := r.field(name)
if !ok {
return nil, fmt.Errorf("claim '%q' not found", name)
}
return val, nil
}
func (r *RawJWT) hasField(name string) bool {
_, ok := r.field(name)
return ok
}
func (r *RawJWT) field(name string) (*spb.Value, bool) {
val, ok := r.jsonpb.GetFields()[name]
return val, ok
}
// createPayload creates a JSON payload from JWT options.
func createPayload(opts *RawJWTOptions) (*spb.Struct, error) {
if err := validateCustomClaims(opts.CustomClaims); err != nil {
return nil, err
}
if opts.ExpiresAt == nil && !opts.WithoutExpiration {
return nil, fmt.Errorf("jwt options must contain an expiration or must be marked WithoutExpiration")
}
if opts.ExpiresAt != nil && opts.WithoutExpiration {
return nil, fmt.Errorf("jwt options can't be marked WithoutExpiration when expiration is specified")
}
if opts.Audience != nil && opts.Audiences != nil {
return nil, fmt.Errorf("jwt options can either contain a single Audience or a list of Audiences but not both")
}
payload := &spb.Struct{
Fields: map[string]*spb.Value{},
}
setStringValue(payload, claimJWTID, opts.JWTID)
setStringValue(payload, claimIssuer, opts.Issuer)
setStringValue(payload, claimSubject, opts.Subject)
setStringValue(payload, claimAudience, opts.Audience)
setTimeValue(payload, claimIssuedAt, opts.IssuedAt)
setTimeValue(payload, claimNotBefore, opts.NotBefore)
setTimeValue(payload, claimExpiration, opts.ExpiresAt)
setAudiences(payload, claimAudience, opts.Audiences)
for k, v := range opts.CustomClaims {
val, err := spb.NewValue(v)
if err != nil {
return nil, err
}
setValue(payload, k, val)
}
return payload, nil
}
func validatePayload(payload *spb.Struct) error {
if payload.Fields == nil || len(payload.Fields) == 0 {
return nil
}
if err := validateAudienceClaim(payload.Fields[claimAudience]); err != nil {
return err
}
for claim, val := range payload.GetFields() {
if isRegisteredTimeClaim(claim) {
if err := validateTimeClaim(claim, val); err != nil {
return err
}
}
if isRegisteredStringClaim(claim) {
if err := validateStringClaim(claim, val); err != nil {
return err
}
}
}
return nil
}
func validateStringClaim(claim string, val *spb.Value) error {
v, ok := val.Kind.(*spb.Value_StringValue)
if !ok {
return fmt.Errorf("claim: '%q' MUST be a string", claim)
}
if !utf8.ValidString(v.StringValue) {
return fmt.Errorf("claim: '%q' isn't a valid UTF-8 string", claim)
}
return nil
}
func validateTimeClaim(claim string, val *spb.Value) error {
if _, ok := val.Kind.(*spb.Value_NumberValue); !ok {
return fmt.Errorf("claim %q MUST be a numeric value, ", claim)
}
t := int64(val.GetNumberValue())
if t > jwtTimestampMax || t < jwtTimestampMin {
return fmt.Errorf("invalid timestamp: '%d' for claim: %q", t, claim)
}
return nil
}
func validateAudienceClaim(val *spb.Value) error {
if val == nil {
return nil
}
_, isString := val.Kind.(*spb.Value_StringValue)
l, isList := val.Kind.(*spb.Value_ListValue)
if !isList && !isString {
return fmt.Errorf("audience claim MUST be a list with at least one string or a single string value")
}
if isString {
return validateStringClaim(claimAudience, val)
}
if l.ListValue != nil && len(l.ListValue.Values) == 0 {
return fmt.Errorf("there MUST be at least one value present in the audience claim")
}
for _, aud := range l.ListValue.Values {
v, ok := aud.Kind.(*spb.Value_StringValue)
if !ok {
return fmt.Errorf("audience value is not a string")
}
if !utf8.ValidString(v.StringValue) {
return fmt.Errorf("audience value is not a valid UTF-8 string")
}
}
return nil
}
func validateCustomClaims(cc map[string]interface{}) error {
if cc == nil {
return nil
}
for key := range cc {
if isRegisteredClaim(key) {
return fmt.Errorf("claim '%q' is a registered claim, it can't be declared as a custom claim", key)
}
}
return nil
}
func setTimeValue(p *spb.Struct, claim string, val *time.Time) {
if val == nil {
return
}
setValue(p, claim, spb.NewNumberValue(float64(val.Unix())))
}
func setStringValue(p *spb.Struct, claim string, val *string) {
if val == nil {
return
}
setValue(p, claim, spb.NewStringValue(*val))
}
func setAudiences(p *spb.Struct, claim string, vals []string) {
if vals == nil {
return
}
audList := &spb.ListValue{
Values: []*spb.Value{},
}
for _, aud := range vals {
audList.Values = append(audList.Values, spb.NewStringValue(aud))
}
setValue(p, claim, spb.NewListValue(audList))
}
func setValue(p *spb.Struct, claim string, val *spb.Value) {
if p.GetFields() == nil {
p.Fields = make(map[string]*spb.Value)
}
p.GetFields()[claim] = val
}
func isRegisteredClaim(c string) bool {
return isRegisteredStringClaim(c) || isRegisteredTimeClaim(c) || c == claimAudience
}
func isRegisteredStringClaim(c string) bool {
return c == claimIssuer || c == claimSubject || c == claimJWTID
}
func isRegisteredTimeClaim(c string) bool {
return c == claimExpiration || c == claimNotBefore || c == claimIssuedAt
}