blob: 373a3dd2aa5143404a25faaa0996163c27450aca [file] [log] [blame]
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions that help to serialize and deserialize from/to the JWT format."""
import base64
import binascii
import struct
from typing import Any, Optional, Tuple
from tink.proto import tink_pb2
from tink.jwt import _json_util
from tink.jwt import _jwt_error
from tink.jwt import _raw_jwt
_VALID_ALGORITHMS = frozenset({
'HS256', 'HS384', 'HS512', 'ES256', 'ES384', 'ES512', 'RS256', 'RS384',
'RS384', 'RS512', 'PS256', 'PS384', 'PS512'
})
def base64_encode(data: bytes) -> bytes:
"""Does a URL-safe base64 encoding without padding."""
return base64.urlsafe_b64encode(data).rstrip(b'=')
def _is_valid_urlsafe_base64_char(c: int) -> bool:
if c >= ord('a') and c <= ord('z'):
return True
if c >= ord('A') and c <= ord('Z'):
return True
if c >= ord('0') and c <= ord('9'):
return True
if c == ord('-') or c == ord('_'):
return True
return False
def base64_decode(encoded_data: bytes) -> bytes:
"""Does a URL-safe base64 decoding without padding."""
# base64.urlsafe_b64decode ignores all non-base64 chars. We don't want that.
for c in encoded_data:
if not _is_valid_urlsafe_base64_char(c):
raise _jwt_error.JwtInvalidError('invalid base64 encoding')
# base64.urlsafe_b64decode requires padding, but does not mind too much
# padding. So we simply add the maximum amount of padding needed.
padded_encoded_data = encoded_data + b'==='
try:
return base64.urlsafe_b64decode(padded_encoded_data)
except binascii.Error:
# Throws when the length of encoded_data is (4*i + 1) for some i
raise _jwt_error.JwtInvalidError('invalid base64 encoding')
def _validate_algorithm(algorithm: str) -> None:
if algorithm not in _VALID_ALGORITHMS:
raise _jwt_error.JwtInvalidError('Invalid algorithm %s' % algorithm)
def encode_header(json_header: str) -> bytes:
try:
return base64_encode(json_header.encode('utf8'))
except UnicodeEncodeError:
raise _jwt_error.JwtInvalidError('invalid token')
def decode_header(encoded_header: bytes) -> str:
try:
return base64_decode(encoded_header).decode('utf8')
except UnicodeDecodeError:
raise _jwt_error.JwtInvalidError('invalid token')
def encode_payload(json_payload: str) -> bytes:
"""Encodes the payload into compact form."""
try:
return base64_encode(json_payload.encode('utf8'))
except UnicodeEncodeError:
raise _jwt_error.JwtInvalidError('invalid token')
def decode_payload(encoded_payload: bytes) -> str:
"""Decodes the payload from compact form."""
try:
return base64_decode(encoded_payload).decode('utf8')
except UnicodeDecodeError:
raise _jwt_error.JwtInvalidError('invalid token')
def encode_signature(signature: bytes) -> bytes:
"""Encodes the signature."""
return base64_encode(signature)
def decode_signature(encoded_signature: bytes) -> bytes:
"""Decodes the signature."""
return base64_decode(encoded_signature)
def create_header(algorithm: str, type_header: Optional[str],
kid: Optional[str]) -> bytes:
_validate_algorithm(algorithm)
header = {}
if kid:
header['kid'] = kid
header['alg'] = algorithm
if type_header:
header['typ'] = type_header
return encode_header(_json_util.json_dumps(header))
def get_kid(key_id: int, prefix: tink_pb2.OutputPrefixType) -> Optional[str]:
"""Returns the encoded key_id, or None."""
if prefix == tink_pb2.RAW:
return None
if prefix == tink_pb2.TINK:
if key_id < 0 or key_id > 2**32:
raise _jwt_error.JwtInvalidError('invalid key_id')
return base64_encode(struct.pack('>L', key_id)).decode('utf8')
raise _jwt_error.JwtInvalidError('unexpected output prefix type')
def split_signed_compact(signed_compact: str) -> Tuple[bytes, str, str, bytes]:
"""Splits a signed compact into its parts.
Args:
signed_compact: A signed compact JWT.
Returns:
A (unsigned_compact, json_header, json_payload, signature_or_mac) tuple.
Raises:
_jwt_error.JwtInvalidError if it fails.
"""
if not isinstance(signed_compact, str):
raise _jwt_error.JwtInvalidError('invalid token: not a str')
try:
encoded = signed_compact.encode('utf8')
except UnicodeEncodeError:
raise _jwt_error.JwtInvalidError('invalid token')
try:
unsigned_compact, encoded_signature = encoded.rsplit(b'.', 1)
except ValueError:
raise _jwt_error.JwtInvalidError('invalid token')
signature_or_mac = decode_signature(encoded_signature)
try:
encoded_header, encoded_payload = unsigned_compact.split(b'.')
except ValueError:
raise _jwt_error.JwtInvalidError('invalid token')
json_header = decode_header(encoded_header)
json_payload = decode_payload(encoded_payload)
return (unsigned_compact, json_header, json_payload, signature_or_mac)
def _validate_kid_header(header: Any, kid: str) -> None:
if header['kid'] != kid:
raise _jwt_error.JwtInvalidError('invalid kid header')
def validate_header(header: Any,
algorithm: str,
tink_kid: Optional[str] = None,
custom_kid: Optional[str] = None) -> None:
"""Parses the header and validates its values."""
_validate_algorithm(algorithm)
hdr_algorithm = header.get('alg', '')
if hdr_algorithm.upper() != algorithm:
raise _jwt_error.JwtInvalidError('Invalid algorithm; expected %s, got %s' %
(algorithm, hdr_algorithm))
if 'crit' in header:
raise _jwt_error.JwtInvalidError(
'all tokens with crit headers are rejected')
if tink_kid is not None and custom_kid is not None:
raise _jwt_error.JwtInvalidError('custom_kid can only be set for RAW keys')
if tink_kid is not None:
if 'kid' not in header:
# for output prefix type TINK, the kid header is required
raise _jwt_error.JwtInvalidError('missing kid in header')
_validate_kid_header(header, tink_kid)
if custom_kid is not None and 'kid' in header:
_validate_kid_header(header, custom_kid)
def get_type_header(header: Any) -> Optional[str]:
return header.get('typ', None)
def create_unsigned_compact(algorithm: str, kid: Optional[str],
raw_jwt: _raw_jwt.RawJwt) -> bytes:
if raw_jwt.has_type_header():
header = create_header(algorithm, raw_jwt.type_header(), kid)
else:
header = create_header(algorithm, None, kid)
return header + b'.' + encode_payload(raw_jwt.json_payload())
def create_signed_compact(unsigned_compact: bytes, signature: bytes) -> str:
return (unsigned_compact + b'.' + encode_signature(signature)).decode('utf8')