blob: d518d5774317f0e5aa24297cc060d1a8a09b572c [file] [log] [blame]
# Copyright 2019 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This class implements helper functions for testing."""
from __future__ import absolute_import
from __future__ import division
# Placeholder for import for type annotations
from __future__ import print_function
import os
from typing import Text, Mapping, Optional
from tink.proto import tink_pb2
from tink import aead
from tink import core
from tink import daead
from tink import hybrid
from tink import mac
from tink import prf
from tink import signature as pk_signature
from google.protobuf import text_format
def tink_root_path() -> Text:
"""Returns the path to the Tink root directory used for the test enviroment.
The path can be set in the TINK_SRC_PATH enviroment variable. If Bazel
is used the path is derived from the Bazel enviroment variables. If that
does not work, it generates the root path relative to the __file__ path.
"""
root_paths = []
if 'TINK_SRC_PATH' in os.environ:
root_paths.append(os.environ['TINK_SRC_PATH'])
if 'TEST_SRCDIR' in os.environ:
# Bazel enviroment
root_paths.append(os.path.join(os.environ['TEST_SRCDIR'], 'tink_base'))
root_paths.append(os.path.join(os.environ['TEST_SRCDIR'],
'google3/third_party/tink'))
for root_path in root_paths:
# return the first root path that exists.
if os.path.exists(root_path):
return root_path
raise ValueError('Could not find path to Tink root directory. Make sure that '
'TINK_SRC_PATH is set.')
def template_from_testdata(
template_name: Text,
dir_name: Optional[Text] = None) -> tink_pb2.KeyTemplate:
"""Reads a template from the testdata."""
if dir_name:
path = os.path.join(tink_root_path(), 'testdata/templates', dir_name,
template_name)
else:
path = os.path.join(tink_root_path(), 'testdata/templates', template_name)
with open(path, mode='rt') as f:
data = f.read()
return text_format.Parse(data, tink_pb2.KeyTemplate())
def fake_key(value: bytes = b'fakevalue',
type_url: Text = 'fakeurl',
key_material_type: tink_pb2.KeyData.KeyMaterialType = tink_pb2
.KeyData.SYMMETRIC,
key_id: int = 1234,
status: tink_pb2.KeyStatusType = tink_pb2.ENABLED,
output_prefix_type: tink_pb2.OutputPrefixType = tink_pb2.TINK
) -> tink_pb2.Keyset.Key:
"""Returns a fake but valid key."""
key = tink_pb2.Keyset.Key(
key_id=key_id,
status=status,
output_prefix_type=output_prefix_type)
key.key_data.type_url = type_url
key.key_data.value = value
key.key_data.key_material_type = key_material_type
return key
class FakeMac(mac.Mac):
"""A fake MAC implementation."""
def __init__(self, name: Text = 'FakeMac'):
self._name = name
def compute_mac(self, data: bytes) -> bytes:
return data + b'|' + self._name.encode()
def verify_mac(self, mac_value: bytes, data: bytes) -> None:
if mac_value != data + b'|' + self._name.encode():
raise core.TinkError('invalid mac ' + mac_value.decode())
class FakeAead(aead.Aead):
"""A fake AEAD implementation."""
def __init__(self, name: Text = 'FakeAead'):
self._name = name
def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
return plaintext + b'|' + associated_data + b'|' + self._name.encode()
def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
data = ciphertext.split(b'|')
if (len(data) < 3 or data[1] != associated_data or
data[2] != self._name.encode()):
raise core.TinkError('failed to decrypt ciphertext ' +
ciphertext.decode())
return data[0]
class FakeDeterministicAead(daead.DeterministicAead):
"""A fake Deterministic AEAD implementation."""
def __init__(self, name: Text = 'FakeDeterministicAead'):
self._name = name
def encrypt_deterministically(self, plaintext: bytes,
associated_data: bytes) -> bytes:
return plaintext + b'|' + associated_data + b'|' + self._name.encode()
def decrypt_deterministically(self, ciphertext: bytes,
associated_data: bytes) -> bytes:
data = ciphertext.split(b'|')
if (len(data) < 3 or
data[1] != associated_data or
data[2] != self._name.encode()):
raise core.TinkError('failed to decrypt ciphertext ' +
ciphertext.decode())
return data[0]
class FakeHybridDecrypt(hybrid.HybridDecrypt):
"""A fake HybridEncrypt implementation."""
def __init__(self, name: Text = 'Hybrid'):
self._name = name
def decrypt(self, ciphertext: bytes, context_info: bytes) -> bytes:
data = ciphertext.split(b'|')
if (len(data) < 3 or
data[1] != context_info or
data[2] != self._name.encode()):
raise core.TinkError('failed to decrypt ciphertext ' +
ciphertext.decode())
return data[0]
class FakeHybridEncrypt(hybrid.HybridEncrypt):
"""A fake HybridEncrypt implementation."""
def __init__(self, name: Text = 'Hybrid'):
self._name = name
def encrypt(self, plaintext: bytes, context_info: bytes) -> bytes:
return plaintext + b'|' + context_info + b'|' + self._name.encode()
class FakePublicKeySign(pk_signature.PublicKeySign):
"""A fake PublicKeySign implementation."""
def __init__(self, name: Text = 'FakePublicKeySign'):
self._name = name
def sign(self, data: bytes) -> bytes:
return data + b'|' + self._name.encode()
class FakePublicKeyVerify(pk_signature.PublicKeyVerify):
"""A fake PublicKeyVerify implementation."""
def __init__(self, name: Text = 'FakePublicKeyVerify'):
self._name = name
def verify(self, signature: bytes, data: bytes):
if signature != data + b'|' + self._name.encode():
raise core.TinkError('invalid signature ' + signature.decode())
class FakePrf(prf.Prf):
"""A fake Prf implementation."""
def __init__(self, name: Text = 'FakePrf'):
self._name = name
def compute(self, input_data: bytes, output_length: int) -> bytes:
if output_length > 32:
raise core.TinkError('invalid output_length')
output = (
input_data + b'|' + self._name.encode() + b'|' +
b''.join([b'*' for _ in range(output_length)]))
return output[:output_length]
class FakePrfSet(prf.PrfSet):
"""A fake PrfSet implementation that contains exactly one Prf."""
def __init__(self, name: Text = 'FakePrf'):
self._prf = FakePrf(name)
def primary_id(self) -> int:
return 0
def all(self) -> Mapping[int, prf.Prf]:
return {0: self._prf}
def primary(self) -> prf.Prf:
return self._prf