blob: 24a7942045d5c8ec09b6146c5a4222194a7d516e [file] [log] [blame]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cross-language consistency tests for AEAD."""
import itertools
from absl.testing import absltest
from absl.testing import parameterized
import tink
from tink import aead
from tink.proto import aes_ctr_hmac_aead_pb2
from tink.proto import aes_eax_pb2
from tink.proto import aes_gcm_pb2
from tink.proto import common_pb2
from tink.proto import tink_pb2
from util import supported_key_types
from util import testing_servers
HASH_TYPES = [
common_pb2.UNKNOWN_HASH,
common_pb2.SHA1,
common_pb2.SHA256,
common_pb2.SHA384,
common_pb2.SHA512
]
# Test cases that succeed in a language but should fail
SUCCEEDS_BUT_SHOULD_FAIL = [
# TODO(b/159989251)
# HMAC with SHA384 is accepted in go, but not in other langs.
('AesCtrHmacAeadKey(16,16,16,16,SHA384,0,0,0)', 'go'),
]
# Test cases that fail in a language but should succeed
FAILS_BUT_SHOULD_SUCCEED = [
]
def setUpModule():
aead.register()
testing_servers.start()
def tearDownModule():
testing_servers.stop()
def _gen_keyset(
type_url: str, value: bytes,
key_material_type: tink_pb2.KeyData.KeyMaterialType) -> tink_pb2.Keyset:
"""Generates a new Keyset."""
keyset = tink_pb2.Keyset()
key = keyset.key.add()
key.key_data.type_url = type_url
key.key_data.value = value
key.key_data.key_material_type = key_material_type
key.status = tink_pb2.ENABLED
key.key_id = 42
key.output_prefix_type = tink_pb2.TINK
keyset.primary_key_id = 42
return keyset
def _gen_key_value(size: int) -> bytes:
"""Returns a fixed key_value of a given size."""
return bytes(i for i in range(size))
def aes_eax_key_test_cases():
def _test_case(key_size=16, iv_size=16, key_version=0):
key = aes_eax_pb2.AesEaxKey()
key.version = key_version
key.key_value = _gen_key_value(key_size)
key.params.iv_size = iv_size
keyset = _gen_keyset(
'type.googleapis.com/google.crypto.tink.AesEaxKey',
key.SerializeToString(),
tink_pb2.KeyData.SYMMETRIC)
return ('AesEaxKey(%d,%d,%d)' % (key_size, iv_size, key_version), keyset)
for key_size in [15, 16, 24, 32, 64, 96]:
for iv_size in [11, 12, 16, 17, 24, 32]:
yield _test_case(key_size=key_size, iv_size=iv_size)
yield _test_case(key_version=1)
def aes_gcm_key_test_cases():
def _test_case(key_size=16, key_version=0):
key = aes_gcm_pb2.AesGcmKey()
key.version = key_version
key.key_value = _gen_key_value(key_size)
keyset = _gen_keyset(
'type.googleapis.com/google.crypto.tink.AesGcmKey',
key.SerializeToString(),
tink_pb2.KeyData.SYMMETRIC)
return ('AesGcmKey(%d,%d)' % (key_size, key_version), keyset)
for key_size in [15, 16, 24, 32, 64, 96]:
yield _test_case(key_size=key_size)
yield _test_case(key_version=1)
def aes_ctr_hmac_aead_key_test_cases():
def _test_case(aes_key_size=16,
iv_size=16,
hmac_key_size=16,
hmac_tag_size=16,
hash_type=common_pb2.SHA256,
key_version=0,
aes_ctr_version=0,
hmac_version=0):
key = aes_ctr_hmac_aead_pb2.AesCtrHmacAeadKey()
key.version = key_version
key.aes_ctr_key.version = aes_ctr_version
key.aes_ctr_key.params.iv_size = iv_size
key.aes_ctr_key.key_value = _gen_key_value(aes_key_size)
key.hmac_key.version = hmac_version
key.hmac_key.params.tag_size = hmac_tag_size
key.hmac_key.params.hash = hash_type
key.hmac_key.key_value = _gen_key_value(hmac_key_size)
keyset = _gen_keyset(
'type.googleapis.com/google.crypto.tink.AesCtrHmacAeadKey',
key.SerializeToString(),
tink_pb2.KeyData.SYMMETRIC)
return ('AesCtrHmacAeadKey(%d,%d,%d,%d,%s,%d,%d,%d)' %
(aes_key_size, iv_size, hmac_key_size, hmac_tag_size,
common_pb2.HashType.Name(hash_type),
key_version, aes_ctr_version, hmac_version), keyset)
yield _test_case()
for aes_key_size in [15, 16, 24, 32, 64, 96]:
for iv_size in [11, 12, 16, 17, 24, 32]:
yield _test_case(aes_key_size=aes_key_size, iv_size=iv_size)
for hmac_key_size in [15, 16, 24, 32, 64, 96]:
for hmac_tag_size in [9, 10, 16, 20, 21, 24, 32, 33, 64, 65]:
yield _test_case(hmac_key_size=hmac_key_size,
hmac_tag_size=hmac_tag_size)
for hash_type in HASH_TYPES:
yield _test_case(hash_type=hash_type)
yield _test_case(key_version=1)
yield _test_case(aes_ctr_version=1)
yield _test_case(hmac_version=1)
class AeadKeyConsistencyTest(parameterized.TestCase):
@parameterized.parameters(
itertools.chain(aes_eax_key_test_cases(),
aes_gcm_key_test_cases(),
aes_ctr_hmac_aead_key_test_cases()))
def test_keyset_validation_consistency(self, name, keyset):
supported_langs = supported_key_types.SUPPORTED_LANGUAGES_PER_TYPE_URL[
keyset.key[0].key_data.type_url]
supported_aeads = [
testing_servers.aead(lang, keyset.SerializeToString())
for lang in supported_langs
]
plaintext = b'plaintext'
associated_data = b'associated_data'
failures = 0
ciphertexts = {}
results = {}
for p in supported_aeads:
try:
ciphertexts[p.lang] = p.encrypt(plaintext, associated_data)
if (name, p.lang) in SUCCEEDS_BUT_SHOULD_FAIL:
failures += 1
del ciphertexts[p.lang]
if (name, p.lang) in FAILS_BUT_SHOULD_SUCCEED:
self.fail('(%s, %s) succeeded, but is in FAILS_BUT_SHOULD_SUCCEED' %
(name, p.lang))
results[p.lang] = 'success'
except tink.TinkError as e:
if (name, p.lang) not in FAILS_BUT_SHOULD_SUCCEED:
failures += 1
if (name, p.lang) in SUCCEEDS_BUT_SHOULD_FAIL:
self.fail(
'(%s, %s) is in SUCCEEDS_BUT_SHOULD_FAIL, but failed with %s' %
(name, p.lang, e))
results[p.lang] = e
# Test that either all supported langs accept the key, or all reject.
if failures not in [0, len(supported_langs)]:
self.fail('encryption for key %s is inconsistent: %s' %
(name, results))
# Test all generated ciphertexts can be decypted.
for enc_lang, ciphertext in ciphertexts.items():
dec_aead = supported_aeads[0]
output = dec_aead.decrypt(ciphertext, associated_data)
if output != plaintext:
self.fail('ciphertext encrypted with key %s in lang %s could not be'
'decrypted in lang %s.' % (name, enc_lang, dec_aead.lang))
if __name__ == '__main__':
absltest.main()