blob: 557f03999e7df4141d918e6a3ce01b0ca149777a [file] [log] [blame]
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Tink Keyset with JWT keys from and to JWK sets."""
import io
import json
import random
from typing import Dict, List, Optional, Union
from tink.proto import jwt_ecdsa_pb2
from tink.proto import jwt_rsa_ssa_pkcs1_pb2
from tink.proto import jwt_rsa_ssa_pss_pb2
from tink.proto import tink_pb2
import tink
from tink import cleartext_keyset_handle
from tink.jwt import _jwt_format
_JWT_ECDSA_PUBLIC_KEY_TYPE = (
'type.googleapis.com/google.crypto.tink.JwtEcdsaPublicKey')
_JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE = (
'type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PublicKey')
_JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE = (
'type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPublicKey')
_ECDSA_PARAMS = {
jwt_ecdsa_pb2.ES256: ('ES256', 'P-256'),
jwt_ecdsa_pb2.ES384: ('ES384', 'P-384'),
jwt_ecdsa_pb2.ES512: ('ES512', 'P-521')
}
_ECDSA_NAME_TO_ALGORITHM = {
alg_name: algorithm for algorithm, (alg_name, _) in _ECDSA_PARAMS.items()
}
_RSA_SSA_PKCS1_PARAMS = {
jwt_rsa_ssa_pkcs1_pb2.RS256: 'RS256',
jwt_rsa_ssa_pkcs1_pb2.RS384: 'RS384',
jwt_rsa_ssa_pkcs1_pb2.RS512: 'RS512'
}
_RSA_SSA_PKCS1_NAME_TO_ALGORITHM = {
alg_name: algorithm
for algorithm, alg_name in _RSA_SSA_PKCS1_PARAMS.items()
}
_RSA_SSA_PSS_PARAMS = {
jwt_rsa_ssa_pss_pb2.PS256: 'PS256',
jwt_rsa_ssa_pss_pb2.PS384: 'PS384',
jwt_rsa_ssa_pss_pb2.PS512: 'PS512'
}
_RSA_SSA_PSS_NAME_TO_ALGORITHM = {
alg_name: algorithm
for algorithm, alg_name in _RSA_SSA_PSS_PARAMS.items()
}
def _base64_encode(data: bytes) -> str:
return _jwt_format.base64_encode(data).decode('utf8')
def _base64_decode(data: str) -> bytes:
return _jwt_format.base64_decode(data.encode('utf8'))
def from_public_keyset_handle(keyset_handle: tink.KeysetHandle) -> str:
"""Converts a Tink KeysetHandle with JWT keys into a Json Web Key (JWK) set.
JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.txt.
Disabled keys are skipped.
Keys with output prefix type "TINK" will include the encoded key ID as "kid"
value. Keys with output prefix type "RAW" will not have a "kid" value set.
Currently, public keys for algorithms ES256, ES384, ES512, RS256, RS384,
RS512, PS256, PS384 and PS512 supported.
Args:
keyset_handle: A Tink KeysetHandle that contains JWT Keys.
Returns:
A JWK set, which is a JSON encoded string.
Raises:
TinkError if the keys are not of the expected type, or if they have a
ouput prefix type that is not supported.
"""
output_stream = io.BytesIO()
writer = tink.BinaryKeysetWriter(output_stream)
keyset_handle.write_no_secret(writer)
keyset = tink_pb2.Keyset.FromString(output_stream.getvalue())
keys = []
for key in keyset.key:
if key.status != tink_pb2.ENABLED:
continue
if key.key_data.key_material_type != tink_pb2.KeyData.ASYMMETRIC_PUBLIC:
raise tink.TinkError('wrong key material type')
if key.output_prefix_type not in [tink_pb2.RAW, tink_pb2.TINK]:
raise tink.TinkError('unsupported output prefix type')
if key.key_data.type_url == _JWT_ECDSA_PUBLIC_KEY_TYPE:
keys.append(_convert_jwt_ecdsa_key(key))
elif key.key_data.type_url == _JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE:
keys.append(_convert_jwt_rsa_ssa_pkcs1_key(key))
elif key.key_data.type_url == _JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE:
keys.append(_convert_jwt_rsa_ssa_pss_key(key))
else:
raise tink.TinkError('unknown key type: %s' % key.key_data.type_url)
return json.dumps({'keys': keys}, separators=(',', ':'))
# Deprecated. Use from_public_keyset_handle instead.
def from_keyset_handle(keyset_handle: tink.KeysetHandle,
key_access: Optional[tink.KeyAccess] = None) -> str:
_ = key_access
return from_public_keyset_handle(keyset_handle)
def _convert_jwt_ecdsa_key(
key: tink_pb2.Keyset.Key) -> Dict[str, Union[str, List[str]]]:
"""Converts a JwtEcdsaPublicKey into a JWK."""
ecdsa_public_key = jwt_ecdsa_pb2.JwtEcdsaPublicKey.FromString(
key.key_data.value)
if ecdsa_public_key.algorithm not in _ECDSA_PARAMS:
raise tink.TinkError('unknown ecdsa algorithm')
alg, crv = _ECDSA_PARAMS[ecdsa_public_key.algorithm]
output = {
'kty': 'EC',
'crv': crv,
'x': _base64_encode(ecdsa_public_key.x),
'y': _base64_encode(ecdsa_public_key.y),
'use': 'sig',
'alg': alg,
'key_ops': ['verify'],
}
kid = _jwt_format.get_kid(key.key_id, key.output_prefix_type)
if kid:
output['kid'] = kid
elif ecdsa_public_key.HasField('custom_kid'):
output['kid'] = ecdsa_public_key.custom_kid.value
return output
def _convert_jwt_rsa_ssa_pkcs1_key(
key: tink_pb2.Keyset.Key) -> Dict[str, Union[str, List[str]]]:
"""Converts a JwtRsaSsaPkcs1PublicKey into a JWK."""
public_key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PublicKey.FromString(
key.key_data.value)
if public_key.algorithm not in _RSA_SSA_PKCS1_PARAMS:
raise tink.TinkError('unknown RSA SSA PKCS1 algorithm')
alg = _RSA_SSA_PKCS1_PARAMS[public_key.algorithm]
output = {
'kty': 'RSA',
'n': _base64_encode(public_key.n),
'e': _base64_encode(public_key.e),
'use': 'sig',
'alg': alg,
'key_ops': ['verify'],
}
kid = _jwt_format.get_kid(key.key_id, key.output_prefix_type)
if kid:
output['kid'] = kid
elif public_key.HasField('custom_kid'):
output['kid'] = public_key.custom_kid.value
return output
def _convert_jwt_rsa_ssa_pss_key(
key: tink_pb2.Keyset.Key) -> Dict[str, Union[str, List[str]]]:
"""Converts a JwtRsaSsaPssPublicKey into a JWK."""
public_key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPublicKey.FromString(
key.key_data.value)
if public_key.algorithm not in _RSA_SSA_PSS_PARAMS:
raise tink.TinkError('unknown RSA SSA PSS algorithm')
alg = _RSA_SSA_PSS_PARAMS[public_key.algorithm]
output = {
'kty': 'RSA',
'n': _base64_encode(public_key.n),
'e': _base64_encode(public_key.e),
'use': 'sig',
'alg': alg,
'key_ops': ['verify'],
}
kid = _jwt_format.get_kid(key.key_id, key.output_prefix_type)
if kid:
output['kid'] = kid
elif public_key.HasField('custom_kid'):
output['kid'] = public_key.custom_kid.value
return output
def _generate_unused_key_id(keyset: tink_pb2.Keyset) -> int:
while True:
key_id = random.randint(1, 2147483647)
if key_id not in {key.key_id for key in keyset.key}:
return key_id
def to_public_keyset_handle(jwk_set: str) -> tink.KeysetHandle:
"""Converts a Json Web Key (JWK) set into a Tink KeysetHandle with JWT keys.
JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.txt.
All keys are converted into Tink keys with output prefix type "RAW".
Currently, public keys for algorithms ES256, ES384, ES512, RS256, RS384,
RS512, PS256, PS384 and PS512 supported.
Args:
jwk_set: A JWK set, which is a JSON encoded string.
Returns:
A tink.KeysetHandle.
Raises:
TinkError if the key cannot be converted.
"""
try:
keys_dict = json.loads(jwk_set)
except json.decoder.JSONDecodeError as e:
raise tink.TinkError('error parsing JWK set: %s' % e.msg)
if 'keys' not in keys_dict:
raise tink.TinkError('invalid JWK set: keys not found')
proto_keyset = tink_pb2.Keyset()
for key in keys_dict['keys']:
if 'alg' not in key:
raise tink.TinkError('invalid JWK: alg not found')
alg = key['alg']
if alg.startswith('ES'):
proto_key = _convert_to_ecdsa_key(key)
elif alg.startswith('RS'):
proto_key = _convert_to_rsa_ssa_pkcs1_key(key)
elif alg.startswith('PS'):
proto_key = _convert_to_rsa_ssa_pss_key(key)
else:
raise tink.TinkError('unknown alg')
new_id = _generate_unused_key_id(proto_keyset)
proto_key.key_id = new_id
proto_keyset.key.append(proto_key)
# JWK sets do not really have a primary key (see RFC 7517, Section 5.1).
# To verify signature, it also does not matter which key is primary. We
# simply set it to the last key.
proto_keyset.primary_key_id = new_id
return cleartext_keyset_handle.from_keyset(proto_keyset)
# Deprecated. Use to_public_keyset_handle instead.
def to_keyset_handle(
jwk_set: str,
key_access: Optional[tink.KeyAccess] = None) -> tink.KeysetHandle:
_ = key_access
return to_public_keyset_handle(jwk_set)
def _validate_use_and_key_ops(key: Dict[str, Union[str, List[str]]]):
"""Checks that 'key_ops' and 'use' have the right values if present."""
if 'key_ops' in key:
key_ops = key['key_ops']
if len(key_ops) != 1 or key_ops[0] != 'verify':
raise tink.TinkError('invalid key_ops')
if 'use' in key and key['use'] != 'sig':
raise tink.TinkError('invalid use')
def _convert_to_ecdsa_key(
key: Dict[str, Union[str, List[str]]]) -> tink_pb2.Keyset.Key:
"""Converts a EC Json Web Key (JWK) into a tink_pb2.Keyset.Key."""
ecdsa_public_key = jwt_ecdsa_pb2.JwtEcdsaPublicKey()
algorithm = _ECDSA_NAME_TO_ALGORITHM.get(key['alg'], None)
if not algorithm:
raise tink.TinkError('unknown ECDSA algorithm')
if key.get('kty', None) != 'EC':
raise tink.TinkError('invalid kty')
_, crv = _ECDSA_PARAMS[algorithm]
if key.get('crv', None) != crv:
raise tink.TinkError('invalid crv')
_validate_use_and_key_ops(key)
if 'd' in key:
raise tink.TinkError('cannot convert private ECDSA key')
ecdsa_public_key.algorithm = algorithm
ecdsa_public_key.x = _base64_decode(key['x'])
ecdsa_public_key.y = _base64_decode(key['y'])
if 'kid' in key:
ecdsa_public_key.custom_kid.value = key['kid']
proto_key = tink_pb2.Keyset.Key()
proto_key.key_data.type_url = _JWT_ECDSA_PUBLIC_KEY_TYPE
proto_key.key_data.value = ecdsa_public_key.SerializeToString()
proto_key.key_data.key_material_type = tink_pb2.KeyData.ASYMMETRIC_PUBLIC
proto_key.output_prefix_type = tink_pb2.RAW
proto_key.status = tink_pb2.ENABLED
return proto_key
def _convert_to_rsa_ssa_pkcs1_key(
key: Dict[str, Union[str, List[str]]]) -> tink_pb2.Keyset.Key:
"""Converts a JWK into a JwtEcdsaPublicKey."""
public_key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PublicKey()
algorithm = _RSA_SSA_PKCS1_NAME_TO_ALGORITHM.get(key['alg'], None)
if not algorithm:
raise tink.TinkError('unknown RSA SSA PKCS1 algorithm')
if key.get('kty', None) != 'RSA':
raise tink.TinkError('invalid kty')
_validate_use_and_key_ops(key)
if ('p' in key or 'q' in key or 'dp' in key or 'dq' in key or 'd' in key or
'qi' in key):
raise tink.TinkError('importing RSA private keys is not implemented')
public_key.algorithm = algorithm
public_key.n = _base64_decode(key['n'])
public_key.e = _base64_decode(key['e'])
if 'kid' in key:
public_key.custom_kid.value = key['kid']
proto_key = tink_pb2.Keyset.Key()
proto_key.key_data.type_url = _JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE
proto_key.key_data.value = public_key.SerializeToString()
proto_key.key_data.key_material_type = tink_pb2.KeyData.ASYMMETRIC_PUBLIC
proto_key.output_prefix_type = tink_pb2.RAW
proto_key.status = tink_pb2.ENABLED
return proto_key
def _convert_to_rsa_ssa_pss_key(
key: Dict[str, Union[str, List[str]]]) -> tink_pb2.Keyset.Key:
"""Converts a JWK into a JwtEcdsaPublicKey."""
public_key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPublicKey()
algorithm = _RSA_SSA_PSS_NAME_TO_ALGORITHM.get(key['alg'], None)
if not algorithm:
raise tink.TinkError('unknown RSA SSA PSS algorithm')
if key.get('kty', None) != 'RSA':
raise tink.TinkError('invalid kty')
_validate_use_and_key_ops(key)
if ('p' in key or 'q' in key or 'dp' in key or 'dq' in key or 'd' in key or
'qi' in key):
raise tink.TinkError('importing RSA private keys is not implemented')
public_key.algorithm = algorithm
public_key.n = _base64_decode(key['n'])
public_key.e = _base64_decode(key['e'])
if 'kid' in key:
public_key.custom_kid.value = key['kid']
proto_key = tink_pb2.Keyset.Key()
proto_key.key_data.type_url = _JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE
proto_key.key_data.value = public_key.SerializeToString()
proto_key.key_data.key_material_type = tink_pb2.KeyData.ASYMMETRIC_PUBLIC
proto_key.output_prefix_type = tink_pb2.RAW
proto_key.status = tink_pb2.ENABLED
return proto_key