Add decryption tests to streaming_aead_test.
PiperOrigin-RevId: 270264435
diff --git a/python/streaming_aead/streaming_aead_test.py b/python/streaming_aead/streaming_aead_test.py
index 8b7a7ea..747e0a6 100644
--- a/python/streaming_aead/streaming_aead_test.py
+++ b/python/streaming_aead/streaming_aead_test.py
@@ -37,7 +37,7 @@
class StreamingAeadTest(absltest.TestCase):
- """End-to-end test of Streaming AEAD Encrypting Streams."""
+ """End-to-end test of Streaming AEAD Encrypting/Decrypting Streams."""
@staticmethod
def get_primitive():
@@ -51,7 +51,6 @@
return key_manager.primitive(key_data)
def test_get_encrypting_stream(self):
- # Get the primitive.
primitive = self.get_primitive()
# Use the primitive to get an encrypting stream.
@@ -64,7 +63,6 @@
def test_get_two_encrypting_streams(self):
"""Test that multiple EncryptingStreams can be obtained from a primitive."""
- # Get the primitive.
primitive = self.get_primitive()
f1 = TestBytesObject()
@@ -79,14 +77,13 @@
self.assertNotEmpty(f1.getvalue())
self.assertNotEmpty(f2.getvalue())
- def test_textiowrapper_compatibility(self):
+ def test_encrypting_textiowrapper(self):
"""A test that checks the TextIOWrapper works as expected.
It encrypts the same plaintext twice - once directly from bytes, and once
through TextIOWrapper's encoding. The two ciphertexts should have the same
length.
"""
- # Get the primitive.
primitive = self.get_primitive()
file_1 = TestBytesObject()
@@ -94,7 +91,7 @@
with primitive.new_encrypting_stream(file_1, b'aad') as es:
with io.TextIOWrapper(es) as wrapper:
- # Need to specify this is a unicode string for Python 2.
+ # Need to specify this is a unicode string for Python 2 (b/141106504).
wrapper.write(u'some data')
with primitive.new_encrypting_stream(file_2, b'aad') as es:
@@ -102,6 +99,98 @@
self.assertEqual(len(file_1.getvalue()), len(file_2.getvalue()))
+ def test_round_trip(self):
+ primitive = self.get_primitive()
+
+ f = TestBytesObject()
+
+ original_plaintext = b'some data'
+
+ with primitive.new_encrypting_stream(f, b'test aad') as es:
+ es.write(original_plaintext)
+
+ f.seek(0)
+
+ with primitive.new_decrypting_stream(f, b'test aad') as ds:
+ read_plaintext = ds.read()
+
+ self.assertEqual(read_plaintext, original_plaintext)
+
+ def test_round_trip_textiowrapper_single_line(self):
+ """Read and write a single line through a TextIOWrapper."""
+ primitive = self.get_primitive()
+ f = TestBytesObject()
+
+ # Mark this as unicode for Python 2 (b/141106504)
+ original_plaintext = u'One-line string.'
+ with primitive.new_encrypting_stream(f, b'test aad') as es:
+ with io.TextIOWrapper(es) as wrapper:
+ wrapper.write(original_plaintext)
+
+ f.seek(0)
+
+ with primitive.new_decrypting_stream(f, b'test aad') as ds:
+ with io.TextIOWrapper(ds) as wrapper:
+ read_plaintext = wrapper.read()
+
+ self.assertEqual(original_plaintext, read_plaintext)
+
+ def test_round_trip_decrypt_textiowrapper(self):
+ """Write bytes to EncryptingStream, then decrypt through TextIOWrapper."""
+ primitive = self.get_primitive()
+ f = TestBytesObject()
+ original_plaintext = '''some
+ data
+ on multiple lines.'''
+
+ with primitive.new_encrypting_stream(f, b'test aad') as es:
+ es.write(original_plaintext.encode('utf-8'))
+
+ f.seek(0)
+ with primitive.new_decrypting_stream(f, b'test aad') as ds:
+ with io.TextIOWrapper(ds) as wrapper:
+ data = wrapper.read()
+
+ self.assertEqual(data, original_plaintext)
+
+ def test_round_trip_encrypt_textiowrapper(self):
+ """Encrypt with TextIOWrapper, then decrypt direct bytes."""
+ primitive = self.get_primitive()
+ f = TestBytesObject()
+ # Mark this as unicode for Python 2 (b/141106504)
+ original_plaintext = u'''some
+ data
+ on multiple lines.'''
+
+ with primitive.new_encrypting_stream(f, b'test aad') as es:
+ with io.TextIOWrapper(es) as wrapper:
+ wrapper.write(original_plaintext)
+
+ f.seek(0)
+ with primitive.new_decrypting_stream(f, b'test aad') as ds:
+ data = ds.read().decode('utf-8')
+
+ self.assertEqual(data, original_plaintext)
+
+ def test_round_trip_encrypt_decrypt_textiowrapper(self):
+ """Use TextIOWrapper for both encryption and decryption."""
+ primitive = self.get_primitive()
+ f = TestBytesObject()
+ # Mark this as unicode for Python 2 (b/141106504)
+ original_plaintext = u'''some
+ data
+ on multiple lines.'''
+
+ with primitive.new_encrypting_stream(f, b'test aad') as es:
+ with io.TextIOWrapper(es) as wrapper:
+ wrapper.write(original_plaintext)
+
+ f.seek(0)
+ with primitive.new_decrypting_stream(f, b'test aad') as ds:
+ with io.TextIOWrapper(ds) as wrapper:
+ data = wrapper.read()
+
+ self.assertEqual(data, original_plaintext)
if __name__ == '__main__':
absltest.main()