Implement new_decrypting_stream in the key manager and add a round-trip test.
PiperOrigin-RevId: 270236806
diff --git a/python/streaming_aead/streaming_aead_key_manager.py b/python/streaming_aead/streaming_aead_key_manager.py
index c40cdea..0379f29 100644
--- a/python/streaming_aead/streaming_aead_key_manager.py
+++ b/python/streaming_aead/streaming_aead_key_manager.py
@@ -22,6 +22,7 @@
from tink.python.cc.clif import cc_key_manager
from tink.python.core import key_manager
from tink.python.core import tink_error
+from tink.python.streaming_aead import decrypting_stream
from tink.python.streaming_aead import encrypting_stream
from tink.python.streaming_aead import streaming_aead
@@ -43,8 +44,10 @@
@tink_error.use_tink_errors
def new_decrypting_stream(self, ciphertext_source: BinaryIO,
associated_data: bytes) -> BinaryIO:
- # TODO(tink-dev) implement DecryptingStream
- return typing.cast(BinaryIO, None)
+ stream = decrypting_stream.DecryptingStream(self._streaming_aead,
+ ciphertext_source,
+ associated_data)
+ return typing.cast(BinaryIO, stream)
def from_cc_registry(
diff --git a/python/streaming_aead/streaming_aead_key_manager_test.py b/python/streaming_aead/streaming_aead_key_manager_test.py
index a2a5121..ab054bd 100644
--- a/python/streaming_aead/streaming_aead_key_manager_test.py
+++ b/python/streaming_aead/streaming_aead_key_manager_test.py
@@ -15,6 +15,8 @@
from __future__ import division
from __future__ import print_function
+import io
+
from absl.testing import absltest
from tink.proto import aes_ctr_hmac_streaming_pb2
from tink.proto import aes_gcm_hkdf_streaming_pb2
@@ -27,6 +29,13 @@
from tink.python.streaming_aead import streaming_aead_key_templates
+class TestBytesObject(io.BytesIO):
+ """A BytesIO object that does not close."""
+
+ def close(self):
+ pass
+
+
def setUpModule():
tink_config.register()
@@ -99,9 +108,41 @@
self.key_manager_ctr.new_key_data(key_template)
def test_encrypt_decrypt(self):
- pass
- # TODO(tanujdhir) Consider putting round-trip encryption test here once
- # implemented
+ saead_primitive = self.key_manager_ctr.primitive(
+ self.key_manager_ctr.new_key_data(
+ streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB))
+ plaintext = b'plaintext'
+ aad = b'associated_data'
+
+ # Encrypt
+ ct_destination = TestBytesObject()
+ with saead_primitive.new_encrypting_stream(ct_destination, aad) as es:
+ self.assertLen(plaintext, es.write(plaintext))
+ self.assertNotEqual(ct_destination.getvalue(), plaintext)
+
+ # Decrypt
+ ct_source = TestBytesObject(ct_destination.getvalue())
+ with saead_primitive.new_decrypting_stream(ct_source, aad) as ds:
+ self.assertEqual(ds.read(), plaintext)
+
+ def test_encrypt_decrypt_wrong_aad(self):
+ saead_primitive = self.key_manager_ctr.primitive(
+ self.key_manager_ctr.new_key_data(
+ streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB))
+ plaintext = b'plaintext'
+ aad = b'associated_data'
+
+ # Encrypt
+ ct_destination = TestBytesObject()
+ with saead_primitive.new_encrypting_stream(ct_destination, aad) as es:
+ self.assertLen(plaintext, es.write(plaintext))
+ self.assertNotEqual(ct_destination.getvalue(), plaintext)
+
+ # Decrypt
+ ct_source = TestBytesObject(ct_destination.getvalue())
+ with saead_primitive.new_decrypting_stream(ct_source, b'bad ' + aad) as ds:
+ with self.assertRaises(tink_error.TinkError):
+ ds.read()
if __name__ == '__main__':