blob: 60bb220d3cdc07224c084fa0214509ea9ef22343 [file] [log] [blame]
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for _hcvault_kms_aead.py."""
import base64
import http.server
import json
import threading
from absl.testing import absltest
from absl.testing import parameterized
import hvac
import tink
from tink.integration import hcvault
_TOKEN = ''
_KEY_PATH = '/transit/keys/key-1'
_PORT = 8205
_VAULT_URI = f'http://localhost:{_PORT}'
_GCP_KEY_URI = (
'gcp-kms://projects/tink-test-infrastructure/locations/global/'
'keyRings/unit-and-integration-testing/cryptoKeys/aead-key'
)
class MockHandler(http.server.BaseHTTPRequestHandler):
def do_POST(self): # pylint: disable=invalid-name
if 'encrypt' in self.path:
post_body = self.rfile.read(int(self.headers.get('Content-Length')))
req = json.loads(post_body)
ret = self.encrypt(req)
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(ret.encode('utf-8'))
return
if 'decrypt' in self.path:
try:
post_body = self.rfile.read(int(self.headers.get('Content-Length')))
req = json.loads(post_body)
ret = self.decrypt(req)
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(ret.encode('utf-8'))
except (OSError, BlockingIOError):
self.send_error(400)
return
self.send_error(404)
def do_GET(self): # pylint: disable=invalid-name
self.send_error(404)
def do_PUT(self): # pylint: disable=invalid-name
self.send_error(404)
def encrypt(self, data):
pt64 = data['plaintext']
context64 = data['context']
pt = base64.b64decode(pt64)
context = base64.b64decode(context64)
resp = {'ciphertext': self._encrypt(pt, context)}
return json.dumps({'data': resp})
def decrypt(self, data):
ct64 = data['ciphertext']
context64 = data['context']
context = base64.b64decode(context64)
resp = {
'plaintext': base64.b64encode(self._decrypt(ct64, context)).decode()
}
return json.dumps({'data': resp})
def _encrypt(self, plaintext: str, context: str) -> str:
s = 'enc:%s:%s' % (
base64.b64encode(context).decode(),
base64.b64encode(plaintext).decode(),
)
return s
def _decrypt(self, ciphertext: str, context: str) -> bytes:
parts = ciphertext.split(':')
if len(parts) != 3 or parts[0] != 'enc':
raise tink.TinkError('malformed ciphertext')
context2 = base64.b64decode(parts[1])
if context != context2:
raise tink.TinkError("context doesn't match")
plaintext = base64.b64decode(parts[2])
return plaintext
class HcVaultKmsAeadTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.server = http.server.HTTPServer(('localhost', _PORT), MockHandler)
threading.Thread(target=self.server.serve_forever).start()
def tearDown(self):
super().tearDown()
if self.server:
self.server.shutdown()
self.server.server_close()
def test_encrypt_decrypt(self):
client = hvac.Client(url=_VAULT_URI, token=_TOKEN, verify=False)
vaultaead = hcvault.new_aead(_KEY_PATH, client)
plaintext = b'hello'
associated_data = b'world'
ciphertext = vaultaead.encrypt(plaintext, associated_data)
self.assertEqual(plaintext, vaultaead.decrypt(ciphertext, associated_data))
plaintext = b'hello'
ciphertext = vaultaead.encrypt(plaintext, b'')
self.assertEqual(plaintext, vaultaead.decrypt(ciphertext, b''))
def test_invalid_context(self):
client = hvac.Client(url=_VAULT_URI, token=_TOKEN, verify=False)
vaultaead = hcvault.new_aead(_KEY_PATH, client)
plaintext = b'helloworld'
ciphertext = vaultaead.encrypt(plaintext, b'')
self.assertEqual(plaintext, vaultaead.decrypt(ciphertext, b''))
with self.assertRaises(tink.TinkError):
vaultaead.decrypt(ciphertext, b'a')
def test_encrypt_with_bad_uri(self):
client = hvac.Client(url=_VAULT_URI, token=_TOKEN, verify=False)
with self.assertRaises(tink.TinkError):
hcvault.new_aead(_GCP_KEY_URI, client)
@parameterized.named_parameters([
('simple', '/transit/keys/key-1', 'transit', 'key-1'),
(
'escaped',
'/transit/keys/this%2Band+that',
'transit',
'this%252Band%2Bthat',
),
(
'sub_path',
'/teams/billing/something/transit/keys/pci-key',
'teams/billing/something/transit',
'pci-key',
),
(
'transit_twice',
'/transit/keys/something/transit/keys/my-key',
'transit/keys/something/transit',
'my-key',
),
('mount_not_named_transit', '/cipher/keys/hi', 'cipher', 'hi'),
])
def test_valid_get_endpoint_paths(
self, path, expected_mount, expected_key_name
):
mount, path = hcvault._hcvault_kms_aead._get_endpoint_paths(path)
self.assertEqual(mount, expected_mount)
self.assertEqual(path, expected_key_name)
@parameterized.named_parameters([
('empty', ''),
('slash_only', '/'),
('no_mount', '/keys/foo'),
('traling_slash', 'mount/keys/foo/'),
('no_leading_slash', 'mount/keys/foo'),
('invalid_mount', '////keys/foo'),
('invalid_mount_and_trailing_slash', '////keys/foo/'),
('invalid_mount_with_empty_component_in_between', '/foo//bar/keys/baz'),
('invalid_mount_with_empty_trailing_components', '/foo/bar///keys/baz'),
('invalid_key_name', '/transit/keys/bar/baz'),
])
def test_invalid_get_endpoint_paths(self, path):
with self.assertRaises(tink.TinkError):
_, _ = hcvault._hcvault_kms_aead._get_endpoint_paths(path)
if __name__ == '__main__':
absltest.main()