blob: 2b8f0d667274c5083ceab62ee5fa1635731c7cb8 [file] [log] [blame]
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tink.python.tink.jwt._jwt_signature_key_manager."""
import datetime
from typing import cast
from absl.testing import absltest
from absl.testing import parameterized
from tink.proto import jwt_ecdsa_pb2
from tink.proto import tink_pb2
import tink
from tink import jwt
from tink.jwt import _jwt_format
from tink.jwt import _jwt_signature_key_manager
from tink.jwt import _jwt_signature_wrappers
DATETIME_1970 = datetime.datetime.fromtimestamp(12345, datetime.timezone.utc)
DATETIME_2011 = datetime.datetime.fromtimestamp(1300819380,
datetime.timezone.utc)
DATETIME_2020 = datetime.datetime.fromtimestamp(1582230020,
datetime.timezone.utc)
def setUpModule():
jwt.register_jwt_signature()
def gen_compact(json_header: str, json_payload: str, raw_sign) -> str:
unsigned_compact = (
_jwt_format.encode_header(json_header) + b'.' +
_jwt_format.encode_payload(json_payload))
signature = raw_sign.sign(unsigned_compact)
return _jwt_format.create_signed_compact(unsigned_compact, signature)
class JwtSignatureKeyManagerTest(parameterized.TestCase):
def test_create_sign_verify(self):
handle = tink.new_keyset_handle(jwt.jwt_es256_template())
sign = handle.primitive(jwt.JwtPublicKeySign)
verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
raw_jwt = jwt.new_raw_jwt(
issuer='joe',
expiration=DATETIME_2011,
custom_claims={'http://example.com/is_root': True})
signed_compact = sign.sign_and_encode(raw_jwt)
validator = jwt.new_validator(
expected_issuer='joe', fixed_now=DATETIME_1970)
verified_jwt = verify.verify_and_decode(signed_compact, validator)
self.assertEqual(verified_jwt.issuer(), 'joe')
self.assertEqual(verified_jwt.expiration().year, 2011)
self.assertCountEqual(verified_jwt.custom_claim_names(),
['http://example.com/is_root'])
self.assertTrue(verified_jwt.custom_claim('http://example.com/is_root'))
# fails because it is expired
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(
signed_compact,
jwt.new_validator(expected_issuer='joe', fixed_now=DATETIME_2020))
# wrong issuer
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(
signed_compact,
jwt.new_validator(expected_issuer='jane', fixed_now=DATETIME_1970))
# invalid format
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(signed_compact + '.123', validator)
# invalid character
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(signed_compact + '?', validator)
# modified signature
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(signed_compact + 'a', validator)
# modified header
with self.assertRaises(tink.TinkError):
verify.verify_and_decode('a' + signed_compact, validator)
def test_create_sign_verify_with_type_header(self):
handle = tink.new_keyset_handle(jwt.jwt_es256_template())
sign = handle.primitive(jwt.JwtPublicKeySign)
verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
raw_jwt = jwt.new_raw_jwt(
type_header='typeHeader', issuer='joe', without_expiration=True)
signed_compact = sign.sign_and_encode(raw_jwt)
validator = jwt.new_validator(
expected_type_header='typeHeader',
expected_issuer='joe',
allow_missing_expiration=True)
verified_jwt = verify.verify_and_decode(signed_compact, validator)
self.assertEqual(verified_jwt.type_header(), 'typeHeader')
def test_verify_with_other_key_fails(self):
handle = tink.new_keyset_handle(jwt.jwt_es256_template())
sign = handle.primitive(jwt.JwtPublicKeySign)
raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True)
compact = sign.sign_and_encode(raw_jwt)
other_handle = tink.new_keyset_handle(jwt.jwt_es256_template())
other_verify = other_handle.public_keyset_handle().primitive(
jwt.JwtPublicKeyVerify)
with self.assertRaises(tink.TinkError):
other_verify.verify_and_decode(
compact,
jwt.new_validator(
expected_issuer='issuer', allow_missing_expiration=True))
def test_weird_tokens_with_valid_signatures(self):
handle = tink.new_keyset_handle(jwt.raw_jwt_es256_template())
sign = handle.primitive(jwt.JwtPublicKeySign)
# Get the internal PublicKeySign primitive to create valid signatures.
wrapped = cast(_jwt_signature_wrappers._WrappedJwtPublicKeySign, sign)
raw_sign = cast(_jwt_signature_key_manager._JwtPublicKeySign,
wrapped._primitive_set.primary().primitive)._public_key_sign
verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
validator = jwt.new_validator(
expected_issuer='issuer', allow_missing_expiration=True)
# Normal token.
valid = gen_compact('{"alg":"ES256"}', '{"iss":"issuer"}', raw_sign)
verified_jwt = verify.verify_and_decode(valid, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
# Token with unknown header is valid.
unknown_header = gen_compact('{"alg":"ES256","unknown_header":"abc"} \n ',
'{"iss":"issuer" }', raw_sign)
verified_jwt = verify.verify_and_decode(unknown_header, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
# Token with unknown kid is valid, since primitives with output prefix type
# RAW ignore kid headers.
unknown_header = gen_compact('{"alg":"ES256","kid":"unknown"} \n ',
'{"iss":"issuer" }', raw_sign)
verified_jwt = verify.verify_and_decode(unknown_header, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
# Token with invalid alg header
alg_invalid = gen_compact('{"alg":"ES384"}', '{"iss":"issuer"}', raw_sign)
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(alg_invalid, validator)
# Token with empty header
empty_header = gen_compact('{}', '{"iss":"issuer"}', raw_sign)
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(empty_header, validator)
# Token header is not valid JSON
header_invalid = gen_compact('{"alg":"ES256"', '{"iss":"issuer"}', raw_sign)
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(header_invalid, validator)
# Token payload is not valid JSON
payload_invalid = gen_compact('{"alg":"ES256"}', '{"iss":"issuer"',
raw_sign)
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(payload_invalid, validator)
# Token with whitespace in header JSON string is valid.
whitespace_in_header = gen_compact(' {"alg": \n "ES256"} \n ',
'{"iss":"issuer" }', raw_sign)
verified_jwt = verify.verify_and_decode(whitespace_in_header, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
# Token with whitespace in payload JSON string is valid.
whitespace_in_payload = gen_compact('{"alg":"ES256"}',
' {"iss": \n"issuer" } \n', raw_sign)
verified_jwt = verify.verify_and_decode(whitespace_in_payload, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
# Token with whitespace in base64-encoded header is invalid.
with_whitespace = (
_jwt_format.encode_header('{"alg":"ES256"}') + b' .' +
_jwt_format.encode_payload('{"iss":"issuer"}'))
token_with_whitespace = _jwt_format.create_signed_compact(
with_whitespace, raw_sign.sign(with_whitespace))
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(token_with_whitespace, validator)
# Token with invalid character is invalid.
with_invalid_char = (
_jwt_format.encode_header('{"alg":"ES256"}') + b'.?' +
_jwt_format.encode_payload('{"iss":"issuer"}'))
token_with_invalid_char = _jwt_format.create_signed_compact(
with_invalid_char, raw_sign.sign(with_invalid_char))
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(token_with_invalid_char, validator)
# Token with additional '.' is invalid.
with_dot = (
_jwt_format.encode_header('{"alg":"ES256"}') + b'.' +
_jwt_format.encode_payload('{"iss":"issuer"}') + b'.')
token_with_dot = _jwt_format.create_signed_compact(
with_dot, raw_sign.sign(with_dot))
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(token_with_dot, validator)
# num_recursions has been chosen such that parsing of this token fails
# in all languages. We want to make sure that the algorithm does not
# hang or crash in this case, but only returns a parsing error.
num_recursions = 10000
rec_payload = ('{"a":' * num_recursions) + '""' + ('}' * num_recursions)
rec_token = gen_compact('{"alg":"ES256"}', rec_payload, raw_sign)
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(
rec_token, validator=jwt.new_validator(allow_missing_expiration=True))
# test wrong types
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(cast(str, None), validator)
with self.assertRaises(tink.TinkError):
verify.verify_and_decode(cast(str, 123), validator)
with self.assertRaises(tink.TinkError):
valid_bytes = valid.encode('utf8')
verify.verify_and_decode(cast(str, valid_bytes), validator)
def test_create_ecdsa_handle_with_invalid_algorithm_fails(self):
key_format = jwt_ecdsa_pb2.JwtEcdsaKeyFormat(
algorithm=jwt_ecdsa_pb2.ES_UNKNOWN)
template = tink_pb2.KeyTemplate(
type_url='type.googleapis.com/google.crypto.tink.JwtEcdsaPrivateKey',
value=key_format.SerializeToString(),
output_prefix_type=tink_pb2.RAW)
with self.assertRaises(tink.TinkError):
tink.new_keyset_handle(template)
def test_create_sign_primitive_with_invalid_algorithm_fails(self):
handle = tink.new_keyset_handle(jwt.jwt_es256_template())
key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString(
handle._keyset.key[0].key_data.value)
key.public_key.algorithm = jwt_ecdsa_pb2.ES_UNKNOWN
handle._keyset.key[0].key_data.value = key.SerializeToString()
with self.assertRaises(tink.TinkError):
handle.primitive(jwt.JwtPublicKeySign)
def test_create_verify_primitive_with_invalid_algorithm_fails(self):
private_handle = tink.new_keyset_handle(jwt.jwt_es256_template())
handle = private_handle.public_keyset_handle()
key = jwt_ecdsa_pb2.JwtEcdsaPublicKey.FromString(
handle._keyset.key[0].key_data.value)
key.algorithm = jwt_ecdsa_pb2.ES_UNKNOWN
handle._keyset.key[0].key_data.value = key.SerializeToString()
with self.assertRaises(tink.TinkError):
handle.primitive(jwt.JwtPublicKeyVerify)
if __name__ == '__main__':
absltest.main()