blob: 921d01e73651ae8929164cca370affb6f343dc26 [file] [log] [blame]
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
"""The raw JSON Web Token (JWT)."""
import copy
import datetime
import json
from typing import cast, Mapping, Set, List, Dict, Optional, Union, Any
from tink import core
from tink.jwt import _json_util
from tink.jwt import _jwt_error
_REGISTERED_NAMES = frozenset({'iss', 'sub', 'jti', 'aud', 'exp', 'nbf', 'iat'})
_MAX_TIMESTAMP_VALUE = 253402300799 # 31 Dec 9999, 23:59:59 GMT
Claim = Union[None, bool, int, float, str, List[Any], Dict[str, Any]]
def _from_datetime(t: datetime.datetime) -> int:
if not t.tzinfo:
raise _jwt_error.JwtInvalidError('datetime must have tzinfo')
return int(t.timestamp())
def _to_datetime(timestamp: float) -> datetime.datetime:
return datetime.datetime.fromtimestamp(timestamp, datetime.timezone.utc)
def _validate_custom_claim_name(name: str) -> None:
if name in _REGISTERED_NAMES:
raise _jwt_error.JwtInvalidError(
'registered name %s cannot be custom claim name' % name)
class RawJwt:
"""An unencoded and unsigned JSON Web Token (JWT).
It contains all payload claims and a subset of the headers. It does not
contain any headers that depend on the key, such as "alg" or "kid", because
these headers are chosen when the token is signed and encoded, and should not
be chosen by the user. This ensures that the key can be changed without any
changes to the user code.
"""
def __new__(cls):
raise core.TinkError('RawJwt cannot be instantiated directly.')
def __init__(self, type_header: Optional[str], payload: Dict[str,
Any]) -> None:
# No need to copy payload, because only create and from_json_payload
# call this method.
if not isinstance(payload, Dict):
raise _jwt_error.JwtInvalidError('payload must be a dict')
self._type_header = type_header
self._payload = payload
self._validate_string_claim('iss')
self._validate_string_claim('sub')
self._validate_string_claim('jti')
self._validate_timestamp_claim('exp')
self._validate_timestamp_claim('nbf')
self._validate_timestamp_claim('iat')
self._validate_audience_claim()
def _validate_string_claim(self, name: str):
if name in self._payload:
if not isinstance(self._payload[name], str):
raise _jwt_error.JwtInvalidError('claim %s must be a String' % name)
def _validate_timestamp_claim(self, name: str):
if name in self._payload:
timestamp = self._payload[name]
if not isinstance(timestamp, (int, float)):
raise _jwt_error.JwtInvalidError('claim %s must be a Number' % name)
if timestamp > _MAX_TIMESTAMP_VALUE or timestamp < 0:
raise _jwt_error.JwtInvalidError(
'timestamp of claim %s is out of range' % name)
def _validate_audience_claim(self):
"""The 'aud' claim must either be a string or a list of strings."""
if 'aud' in self._payload:
audiences = self._payload['aud']
if isinstance(audiences, str):
return
if not isinstance(audiences, list) or not audiences:
raise _jwt_error.JwtInvalidError('audiences cannot be an empty list')
if not all(isinstance(value, str) for value in audiences):
raise _jwt_error.JwtInvalidError('audiences must only contain strings')
# TODO(juerg): Consider adding a raw_ prefix to all access methods
def has_type_header(self) -> bool:
return self._type_header is not None
def type_header(self) -> str:
if not self.has_type_header():
raise KeyError('type header is not set')
return self._type_header
def has_issuer(self) -> bool:
return 'iss' in self._payload
def issuer(self) -> str:
return cast(str, self._payload['iss'])
def has_subject(self) -> bool:
return 'sub' in self._payload
def subject(self) -> str:
return cast(str, self._payload['sub'])
def has_audiences(self) -> bool:
return 'aud' in self._payload
def audiences(self) -> List[str]:
aud = self._payload['aud']
if isinstance(aud, str):
return [aud]
return list(aud)
def has_jwt_id(self) -> bool:
return 'jti' in self._payload
def jwt_id(self) -> str:
return cast(str, self._payload['jti'])
def has_expiration(self) -> bool:
return 'exp' in self._payload
def expiration(self) -> datetime.datetime:
return _to_datetime(self._payload['exp'])
def has_not_before(self) -> bool:
return 'nbf' in self._payload
def not_before(self) -> datetime.datetime:
return _to_datetime(self._payload['nbf'])
def has_issued_at(self) -> bool:
return 'iat' in self._payload
def issued_at(self) -> datetime.datetime:
return _to_datetime(self._payload['iat'])
def custom_claim_names(self) -> Set[str]:
return {n for n in self._payload.keys() if n not in _REGISTERED_NAMES}
def custom_claim(self, name: str) -> Claim:
_validate_custom_claim_name(name)
value = self._payload[name]
if isinstance(value, (list, dict)):
return copy.deepcopy(value)
else:
return value
def json_payload(self) -> str:
"""Returns the payload encoded as JSON string."""
return _json_util.json_dumps(self._payload)
@classmethod
def create(cls,
*,
type_header: Optional[str] = None,
issuer: Optional[str] = None,
subject: Optional[str] = None,
audience: Optional[str] = None,
audiences: Optional[List[str]] = None,
jwt_id: Optional[str] = None,
expiration: Optional[datetime.datetime] = None,
without_expiration: Optional[bool] = None,
not_before: Optional[datetime.datetime] = None,
issued_at: Optional[datetime.datetime] = None,
custom_claims: Optional[Mapping[str, Claim]] = None) -> 'RawJwt':
"""Create a new RawJwt instance."""
if not expiration and not without_expiration:
raise ValueError('either expiration or without_expiration must be set')
if expiration and without_expiration:
raise ValueError(
'expiration and without_expiration cannot be set at the same time')
if audience is not None and audiences is not None:
raise _jwt_error.JwtInvalidError(
'audience and audiences cannot be set at the same time')
payload = {}
if issuer:
payload['iss'] = issuer
if subject:
payload['sub'] = subject
if jwt_id is not None:
payload['jti'] = jwt_id
if audience is not None:
payload['aud'] = audience
if audiences is not None:
payload['aud'] = copy.copy(audiences)
if expiration:
payload['exp'] = _from_datetime(expiration)
if not_before:
payload['nbf'] = _from_datetime(not_before)
if issued_at:
payload['iat'] = _from_datetime(issued_at)
if custom_claims:
for name, value in custom_claims.items():
_validate_custom_claim_name(name)
if not isinstance(name, str):
raise _jwt_error.JwtInvalidError('claim name must be Text')
if (value is None or isinstance(value, (bool, int, float, str))):
payload[name] = value
elif isinstance(value, list):
payload[name] = json.loads(json.dumps(value))
elif isinstance(value, dict):
payload[name] = json.loads(json.dumps(value))
else:
raise _jwt_error.JwtInvalidError('claim %s has unknown type' % name)
raw_jwt = object.__new__(cls)
raw_jwt.__init__(type_header, payload)
return raw_jwt
@classmethod
def _from_json(cls, type_header: Optional[str], payload: str) -> 'RawJwt':
"""Creates a RawJwt from payload encoded as JSON string."""
raw_jwt = object.__new__(cls)
raw_jwt.__init__(type_header, _json_util.json_loads(payload))
return raw_jwt
def new_raw_jwt(*,
type_header: Optional[str] = None,
issuer: Optional[str] = None,
subject: Optional[str] = None,
audience: Optional[str] = None,
audiences: Optional[List[str]] = None,
jwt_id: Optional[str] = None,
expiration: Optional[datetime.datetime] = None,
without_expiration: bool = False,
not_before: Optional[datetime.datetime] = None,
issued_at: Optional[datetime.datetime] = None,
custom_claims: Optional[Mapping[str, Claim]] = None) -> RawJwt:
"""Creates a new RawJwt."""
return RawJwt.create(
type_header=type_header,
issuer=issuer,
subject=subject,
audience=audience,
audiences=audiences,
jwt_id=jwt_id,
expiration=expiration,
without_expiration=without_expiration,
not_before=not_before,
issued_at=issued_at,
custom_claims=custom_claims)
def raw_jwt_from_json(type_header: Optional[str], payload: str) -> RawJwt:
"""Internal function used to verify JWT token."""
return RawJwt._from_json(type_header, payload) # pylint: disable=protected-access