blob: 85c4ed2637168cf647c0dfad9a44f30527a18fc3 [file] [log] [blame]
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements tink primitives from gRPC testing_api stubs."""
from __future__ import absolute_import
from __future__ import division
# Placeholder for import for type annotations
from __future__ import print_function
import datetime
import io
import json
from typing import BinaryIO, Mapping, Text, Tuple
import tink
from tink import aead
from tink import daead
from tink import hybrid
from tink import mac
from tink import prf
from tink import signature as tink_signature
from tink import streaming_aead
from tink.proto import tink_pb2
from proto.testing import testing_api_pb2
from proto.testing import testing_api_pb2_grpc
from tink import jwt
def new_keyset(stub: testing_api_pb2_grpc.KeysetStub,
key_template: tink_pb2.KeyTemplate) -> bytes:
gen_request = testing_api_pb2.KeysetGenerateRequest(
template=key_template.SerializeToString())
gen_response = stub.Generate(gen_request)
if gen_response.err:
raise tink.TinkError(gen_response.err)
return gen_response.keyset
def public_keyset(stub: testing_api_pb2_grpc.KeysetStub,
private_keyset: bytes) -> bytes:
request = testing_api_pb2.KeysetPublicRequest(private_keyset=private_keyset)
response = stub.Public(request)
if response.err:
raise tink.TinkError(response.err)
return response.public_keyset
def keyset_to_json(
stub: testing_api_pb2_grpc.KeysetStub,
keyset: bytes) -> Text:
request = testing_api_pb2.KeysetToJsonRequest(keyset=keyset)
response = stub.ToJson(request)
if response.err:
raise tink.TinkError(response.err)
return response.json_keyset
def keyset_from_json(
stub: testing_api_pb2_grpc.KeysetStub,
json_keyset: Text) -> bytes:
request = testing_api_pb2.KeysetFromJsonRequest(json_keyset=json_keyset)
response = stub.FromJson(request)
if response.err:
raise tink.TinkError(response.err)
return response.keyset
class Aead(aead.Aead):
"""Wraps AEAD service stub into an Aead primitive."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.AeadStub,
keyset: bytes) -> None:
self.lang = lang
self._stub = stub
self._keyset = keyset
def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
enc_request = testing_api_pb2.AeadEncryptRequest(
keyset=self._keyset,
plaintext=plaintext,
associated_data=associated_data)
enc_response = self._stub.Encrypt(enc_request)
if enc_response.err:
raise tink.TinkError(enc_response.err)
return enc_response.ciphertext
def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
dec_request = testing_api_pb2.AeadDecryptRequest(
keyset=self._keyset,
ciphertext=ciphertext,
associated_data=associated_data)
dec_response = self._stub.Decrypt(dec_request)
if dec_response.err:
raise tink.TinkError(dec_response.err)
return dec_response.plaintext
class DeterministicAead(daead.DeterministicAead):
"""Wraps DAEAD services stub into an DeterministicAead primitive."""
def __init__(self, lang: Text,
stub: testing_api_pb2_grpc.DeterministicAeadStub,
keyset: bytes) -> None:
self.lang = lang
self._stub = stub
self._keyset = keyset
def encrypt_deterministically(self, plaintext: bytes,
associated_data: bytes) -> bytes:
"""Encrypts."""
enc_request = testing_api_pb2.DeterministicAeadEncryptRequest(
keyset=self._keyset,
plaintext=plaintext,
associated_data=associated_data)
enc_response = self._stub.EncryptDeterministically(enc_request)
if enc_response.err:
raise tink.TinkError(enc_response.err)
return enc_response.ciphertext
def decrypt_deterministically(self, ciphertext: bytes,
associated_data: bytes) -> bytes:
"""Decrypts."""
dec_request = testing_api_pb2.DeterministicAeadDecryptRequest(
keyset=self._keyset,
ciphertext=ciphertext,
associated_data=associated_data)
dec_response = self._stub.DecryptDeterministically(dec_request)
if dec_response.err:
raise tink.TinkError(dec_response.err)
return dec_response.plaintext
class StreamingAead(streaming_aead.StreamingAead):
"""Wraps Streaming AEAD service stub into a StreamingAead primitive."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.StreamingAeadStub,
keyset: bytes) -> None:
self.lang = lang
self._stub = stub
self._keyset = keyset
def new_encrypting_stream(self, plaintext: BinaryIO,
associated_data: bytes) -> BinaryIO:
enc_request = testing_api_pb2.StreamingAeadEncryptRequest(
keyset=self._keyset,
plaintext=plaintext.read(),
associated_data=associated_data)
enc_response = self._stub.Encrypt(enc_request)
if enc_response.err:
raise tink.TinkError(enc_response.err)
return io.BytesIO(enc_response.ciphertext)
def new_decrypting_stream(self, ciphertext: BinaryIO,
associated_data: bytes) -> BinaryIO:
dec_request = testing_api_pb2.StreamingAeadDecryptRequest(
keyset=self._keyset,
ciphertext=ciphertext.read(),
associated_data=associated_data)
dec_response = self._stub.Decrypt(dec_request)
if dec_response.err:
raise tink.TinkError(dec_response.err)
return io.BytesIO(dec_response.plaintext)
class Mac(mac.Mac):
"""Wraps MAC service stub into an Mac primitive."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.MacStub,
keyset: bytes) -> None:
self.lang = lang
self._stub = stub
self._keyset = keyset
def compute_mac(self, data: bytes) -> bytes:
request = testing_api_pb2.ComputeMacRequest(keyset=self._keyset, data=data)
response = self._stub.ComputeMac(request)
if response.err:
raise tink.TinkError(response.err)
return response.mac_value
def verify_mac(self, mac_value: bytes, data: bytes) -> None:
request = testing_api_pb2.VerifyMacRequest(
keyset=self._keyset, mac_value=mac_value, data=data)
response = self._stub.VerifyMac(request)
if response.err:
raise tink.TinkError(response.err)
class HybridEncrypt(hybrid.HybridEncrypt):
"""Implements the HybridEncrypt primitive using a hybrid service stub."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.HybridStub,
public_handle: bytes) -> None:
self.lang = lang
self._stub = stub
self._public_handle = public_handle
def encrypt(self, plaintext: bytes, context_info: bytes) -> bytes:
enc_request = testing_api_pb2.HybridEncryptRequest(
public_keyset=self._public_handle,
plaintext=plaintext,
context_info=context_info)
enc_response = self._stub.Encrypt(enc_request)
if enc_response.err:
raise tink.TinkError(enc_response.err)
return enc_response.ciphertext
class HybridDecrypt(hybrid.HybridDecrypt):
"""Implements the HybridDecrypt primitive using a hybrid service stub."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.HybridStub,
private_handle: bytes) -> None:
self.lang = lang
self._stub = stub
self._private_handle = private_handle
def decrypt(self, ciphertext: bytes, context_info: bytes) -> bytes:
dec_request = testing_api_pb2.HybridDecryptRequest(
private_keyset=self._private_handle,
ciphertext=ciphertext,
context_info=context_info)
dec_response = self._stub.Decrypt(dec_request)
if dec_response.err:
raise tink.TinkError(dec_response.err)
return dec_response.plaintext
class PublicKeySign(tink_signature.PublicKeySign):
"""Implements the PublicKeySign primitive using a signature service stub."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.SignatureStub,
private_handle: bytes) -> None:
self.lang = lang
self._stub = stub
self._private_handle = private_handle
def sign(self, data: bytes) -> bytes:
request = testing_api_pb2.SignatureSignRequest(
private_keyset=self._private_handle, data=data)
response = self._stub.Sign(request)
if response.err:
raise tink.TinkError(response.err)
return response.signature
class PublicKeyVerify(tink_signature.PublicKeyVerify):
"""Implements the PublicKeyVerify primitive using a signature service stub."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.SignatureStub,
public_handle: bytes) -> None:
self.lang = lang
self._stub = stub
self._public_handle = public_handle
def verify(self, signature: bytes, data: bytes) -> None:
request = testing_api_pb2.SignatureVerifyRequest(
public_keyset=self._public_handle, signature=signature, data=data)
response = self._stub.Verify(request)
if response.err:
raise tink.TinkError(response.err)
class _Prf(prf.Prf):
"""Implements a Prf from a PrfSet service stub."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.PrfSetStub,
keyset: bytes, key_id: int) -> None:
self.lang = lang
self._stub = stub
self._keyset = keyset
self._key_id = key_id
def compute(self, input_data: bytes, output_length: int) -> bytes:
request = testing_api_pb2.PrfSetComputeRequest(
keyset=self._keyset,
key_id=self._key_id,
input_data=input_data,
output_length=output_length)
response = self._stub.Compute(request)
if response.err:
raise tink.TinkError(response.err)
return response.output
class PrfSet(prf.PrfSet):
"""Implements a PrfSet from a PrfSet service stub."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.PrfSetStub,
keyset: bytes) -> None:
self.lang = lang
self._stub = stub
self._keyset = keyset
self._key_ids_initialized = False
self._primary_key_id = None
self._prfs = None
def _initialize_key_ids(self) -> None:
if not self._key_ids_initialized:
request = testing_api_pb2.PrfSetKeyIdsRequest(keyset=self._keyset)
response = self._stub.KeyIds(request)
if response.err:
raise tink.TinkError(response.err)
self._primary_key_id = response.output.primary_key_id
self._prfs = {}
for key_id in response.output.key_id:
self._prfs[key_id] = _Prf(self.lang, self._stub, self._keyset, key_id)
self._key_ids_initialized = True
def primary_id(self) -> int:
self._initialize_key_ids()
return self._primary_key_id
def all(self) -> Mapping[int, prf.Prf]:
self._initialize_key_ids()
return self._prfs.copy()
def primary(self) -> prf.Prf:
self._initialize_key_ids()
return self._prfs[self._primary_key_id]
def split_datetime(dt: datetime.datetime) -> Tuple[int, int]:
t = dt.timestamp()
seconds = int(t)
nanos = int((t - seconds) * 1e9)
return (seconds, nanos)
def to_datetime(seconds: int, nanos: int) -> datetime.datetime:
t = seconds + (nanos / 1e9)
return datetime.datetime.fromtimestamp(t, datetime.timezone.utc)
def raw_jwt_to_proto(raw_jwt: jwt.RawJwt) -> testing_api_pb2.JwtToken:
"""Converts a jwt.RawJwt into a proto."""
raw_token = testing_api_pb2.JwtToken()
if raw_jwt.has_issuer():
raw_token.issuer.value = raw_jwt.issuer()
if raw_jwt.has_subject():
raw_token.subject.value = raw_jwt.subject()
if raw_jwt.has_audiences():
raw_token.audiences.extend(raw_jwt.audiences())
if raw_jwt.has_jwt_id():
raw_token.jwt_id.value = raw_jwt.jwt_id()
if raw_jwt.has_expiration():
seconds, nanos = split_datetime(raw_jwt.expiration())
raw_token.expiration.seconds = seconds
raw_token.expiration.nanos = nanos
if raw_jwt.has_not_before():
seconds, nanos = split_datetime(raw_jwt.not_before())
raw_token.not_before.seconds = seconds
raw_token.not_before.nanos = nanos
if raw_jwt.has_issued_at():
seconds, nanos = split_datetime(raw_jwt.issued_at())
raw_token.issued_at.seconds = seconds
raw_token.issued_at.nanos = nanos
for name in raw_jwt.custom_claim_names():
value = raw_jwt.custom_claim(name)
if value is None:
raw_token.custom_claims[name].null_value = testing_api_pb2.NULL_VALUE
if isinstance(value, (int, float)):
raw_token.custom_claims[name].number_value = value
if isinstance(value, Text):
raw_token.custom_claims[name].string_value = value
if isinstance(value, bool):
raw_token.custom_claims[name].bool_value = value
if isinstance(value, dict):
raw_token.custom_claims[name].json_object_value = json.dumps(value)
if isinstance(value, list):
raw_token.custom_claims[name].json_array_value = json.dumps(value)
return raw_token
def proto_to_verified_jwt(
token: testing_api_pb2.JwtToken) -> jwt.VerifiedJwt:
"""Converts a proto JwtToken into a jwt.VerifiedJwt."""
issuer = None
if token.HasField('issuer'):
issuer = token.issuer.value
subject = None
if token.HasField('subject'):
subject = token.subject.value
jwt_id = None
if token.HasField('jwt_id'):
jwt_id = token.jwt_id.value
audiences = None
if token.audiences:
audiences = list(token.audiences)
expiration = None
if token.HasField('expiration'):
expiration = to_datetime(token.expiration.seconds, token.expiration.nanos)
not_before = None
if token.HasField('not_before'):
not_before = to_datetime(token.not_before.seconds, token.not_before.nanos)
issued_at = None
if token.HasField('issued_at'):
issued_at = to_datetime(token.issued_at.seconds, token.issued_at.nanos)
custom_claims = {}
for name in token.custom_claims:
value = token.custom_claims[name]
if value.HasField('null_value'):
custom_claims[name] = None
if value.HasField('number_value'):
custom_claims[name] = value.number_value
if value.HasField('string_value'):
custom_claims[name] = value.string_value
if value.HasField('bool_value'):
custom_claims[name] = value.bool_value
if value.HasField('json_object_value'):
custom_claims[name] = json.loads(value.json_object_value)
if value.HasField('json_array_value'):
custom_claims[name] = json.loads(value.json_array_value)
raw_jwt = jwt.new_raw_jwt(issuer, subject, audiences, jwt_id, expiration,
not_before, issued_at, custom_claims)
return jwt.VerifiedJwt._create(raw_jwt) # pylint: disable=protected-access
def jwt_validator_to_proto(
validator: jwt.JwtValidator) -> testing_api_pb2.JwtValidator:
"""Converts a jwt.JwtValidator into a proto JwtValidator."""
proto_validator = testing_api_pb2.JwtValidator()
if validator.has_issuer():
proto_validator.issuer.value = validator.issuer()
if validator.has_subject():
proto_validator.subject.value = validator.subject()
if validator.has_audience():
proto_validator.audience.value = validator.audience()
proto_validator.clock_skew.seconds = validator.clock_skew().seconds
if validator.has_fixed_now():
seconds, nanos = split_datetime(validator.fixed_now())
proto_validator.now.seconds = seconds
proto_validator.now.nanos = nanos
return proto_validator
class JwtMac():
"""Implements a JwtMac from a Jwt service stub."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.JwtStub,
keyset: bytes) -> None:
self.lang = lang
self._stub = stub
self._keyset = keyset
def compute_mac_and_encode(self, raw_jwt: jwt.RawJwt) -> Text:
request = testing_api_pb2.JwtSignRequest(
keyset=self._keyset, raw_jwt=raw_jwt_to_proto(raw_jwt))
response = self._stub.ComputeMacAndEncode(request)
if response.err:
raise tink.TinkError(response.err)
return response.signed_compact_jwt
def verify_mac_and_decode(self, signed_compact_jwt: Text,
validator: jwt.JwtValidator) -> jwt.VerifiedJwt:
request = testing_api_pb2.JwtVerifyRequest(
keyset=self._keyset,
validator=jwt_validator_to_proto(validator),
signed_compact_jwt=signed_compact_jwt)
response = self._stub.VerifyMacAndDecode(request)
if response.err:
raise tink.TinkError(response.err)
return proto_to_verified_jwt(response.verified_jwt)
class JwtPublicKeySign():
"""Implements a JwtPublicKeySign from a Jwt service stub."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.JwtStub,
keyset: bytes) -> None:
self.lang = lang
self._stub = stub
self._keyset = keyset
def sign_and_encode(self, raw_jwt: jwt.RawJwt) -> Text:
request = testing_api_pb2.JwtSignRequest(
keyset=self._keyset, raw_jwt=raw_jwt_to_proto(raw_jwt))
response = self._stub.PublicKeySignAndEncode(request)
if response.err:
raise tink.TinkError(response.err)
return response.signed_compact_jwt
class JwtPublicKeyVerify():
"""Implements a JwtPublicKeyVerify from a Jwt service stub."""
def __init__(self, lang: Text, stub: testing_api_pb2_grpc.JwtStub,
keyset: bytes) -> None:
self.lang = lang
self._stub = stub
self._keyset = keyset
def verify_and_decode(self, signed_compact_jwt: Text,
validator: jwt.JwtValidator) -> jwt.VerifiedJwt:
request = testing_api_pb2.JwtVerifyRequest(
keyset=self._keyset,
validator=jwt_validator_to_proto(validator),
signed_compact_jwt=signed_compact_jwt)
response = self._stub.PublicKeyVerifyAndDecode(request)
if response.err:
raise tink.TinkError(response.err)
return proto_to_verified_jwt(response.verified_jwt)