blob: 1f63e3e6650bea3cfeaf2d9b318de8a38be9e008 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This class implements helper functions for testing."""
import os
from typing import Mapping
from tink.proto import tink_pb2
from tink import aead
from tink import core
from tink import daead
from tink import hybrid
from tink import mac
from tink import prf
from tink import signature as pk_signature
_RELATIVE_TESTDATA_PATH = 'tink_py/testdata'
def tink_py_testdata_path() -> str:
"""Returns the path to the test data directory to be used for testing."""
# List of pairs <Env. variable, Path>.
testdata_paths = []
if 'TINK_PYTHON_ROOT_PATH' in os.environ:
testdata_paths.append(('TINK_PYTHON_ROOT_PATH',
os.path.join(os.environ['TINK_PYTHON_ROOT_PATH'],
'testdata')))
if 'TEST_SRCDIR' in os.environ:
testdata_paths.append(('TEST_SRCDIR',
os.path.join(os.environ['TEST_SRCDIR'],
_RELATIVE_TESTDATA_PATH)))
for env_variable, testdata_path in testdata_paths:
# Return the first path that is encountered.
if not os.path.exists(testdata_path):
raise FileNotFoundError(f'Variable {env_variable} is set but has an ' +
f'invalid path {testdata_path}')
return testdata_path
raise ValueError('No path environment variable set among ' +
'TINK_PYTHON_ROOT_PATH, TEST_SRCDIR')
def fake_key(
value: bytes = b'fakevalue',
type_url: str = 'fakeurl',
key_material_type: tink_pb2.KeyData.KeyMaterialType = tink_pb2.KeyData
.SYMMETRIC,
key_id: int = 1234,
status: tink_pb2.KeyStatusType = tink_pb2.ENABLED,
output_prefix_type: tink_pb2.OutputPrefixType = tink_pb2.TINK
) -> tink_pb2.Keyset.Key:
"""Returns a fake but valid key."""
key = tink_pb2.Keyset.Key(
key_id=key_id,
status=status,
output_prefix_type=output_prefix_type)
key.key_data.type_url = type_url
key.key_data.value = value
key.key_data.key_material_type = key_material_type
return key
class FakeMac(mac.Mac):
"""A fake MAC implementation."""
def __init__(self, name: str = 'FakeMac'):
self._name = name
def compute_mac(self, data: bytes) -> bytes:
return data + b'|' + self._name.encode()
def verify_mac(self, mac_value: bytes, data: bytes) -> None:
if mac_value != data + b'|' + self._name.encode():
raise core.TinkError('invalid mac ' + mac_value.decode())
class FakeAead(aead.Aead):
"""A fake AEAD implementation."""
def __init__(self, name: str = 'FakeAead'):
self._name = name
def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
return plaintext + b'|' + associated_data + b'|' + self._name.encode()
def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
data = ciphertext.split(b'|')
if (len(data) < 3 or data[1] != associated_data or
data[2] != self._name.encode()):
raise core.TinkError('failed to decrypt ciphertext ' +
ciphertext.decode())
return data[0]
class FakeDeterministicAead(daead.DeterministicAead):
"""A fake Deterministic AEAD implementation."""
def __init__(self, name: str = 'FakeDeterministicAead'):
self._name = name
def encrypt_deterministically(self, plaintext: bytes,
associated_data: bytes) -> bytes:
return plaintext + b'|' + associated_data + b'|' + self._name.encode()
def decrypt_deterministically(self, ciphertext: bytes,
associated_data: bytes) -> bytes:
data = ciphertext.split(b'|')
if (len(data) < 3 or
data[1] != associated_data or
data[2] != self._name.encode()):
raise core.TinkError('failed to decrypt ciphertext ' +
ciphertext.decode())
return data[0]
class FakeHybridDecrypt(hybrid.HybridDecrypt):
"""A fake HybridEncrypt implementation."""
def __init__(self, name: str = 'Hybrid'):
self._name = name
def decrypt(self, ciphertext: bytes, context_info: bytes) -> bytes:
data = ciphertext.split(b'|')
if (len(data) < 3 or
data[1] != context_info or
data[2] != self._name.encode()):
raise core.TinkError('failed to decrypt ciphertext ' +
ciphertext.decode())
return data[0]
class FakeHybridEncrypt(hybrid.HybridEncrypt):
"""A fake HybridEncrypt implementation."""
def __init__(self, name: str = 'Hybrid'):
self._name = name
def encrypt(self, plaintext: bytes, context_info: bytes) -> bytes:
return plaintext + b'|' + context_info + b'|' + self._name.encode()
class FakePublicKeySign(pk_signature.PublicKeySign):
"""A fake PublicKeySign implementation."""
def __init__(self, name: str = 'FakePublicKeySign'):
self._name = name
def sign(self, data: bytes) -> bytes:
return data + b'|' + self._name.encode()
class FakePublicKeyVerify(pk_signature.PublicKeyVerify):
"""A fake PublicKeyVerify implementation."""
def __init__(self, name: str = 'FakePublicKeyVerify'):
self._name = name
def verify(self, signature: bytes, data: bytes):
if signature != data + b'|' + self._name.encode():
raise core.TinkError('invalid signature ' + signature.decode())
class FakePrf(prf.Prf):
"""A fake Prf implementation."""
def __init__(self, name: str = 'FakePrf'):
self._name = name
def compute(self, input_data: bytes, output_length: int) -> bytes:
if output_length > 32:
raise core.TinkError('invalid output_length')
output = (
input_data + b'|' + self._name.encode() + b'|' +
b''.join([b'*' for _ in range(output_length)]))
return output[:output_length]
class FakePrfSet(prf.PrfSet):
"""A fake PrfSet implementation that contains exactly one Prf."""
def __init__(self, name: str = 'FakePrf'):
self._prf = FakePrf(name)
def primary_id(self) -> int:
return 0
def all(self) -> Mapping[int, prf.Prf]:
return {0: self._prf}
def primary(self) -> prf.Prf:
return self._prf