Add decryption tests to streaming_aead_test.

PiperOrigin-RevId: 270264435
diff --git a/python/streaming_aead/streaming_aead_test.py b/python/streaming_aead/streaming_aead_test.py
index 8b7a7ea..747e0a6 100644
--- a/python/streaming_aead/streaming_aead_test.py
+++ b/python/streaming_aead/streaming_aead_test.py
@@ -37,7 +37,7 @@
 
 
 class StreamingAeadTest(absltest.TestCase):
-  """End-to-end test of Streaming AEAD Encrypting Streams."""
+  """End-to-end test of Streaming AEAD Encrypting/Decrypting Streams."""
 
   @staticmethod
   def get_primitive():
@@ -51,7 +51,6 @@
     return key_manager.primitive(key_data)
 
   def test_get_encrypting_stream(self):
-    # Get the primitive.
     primitive = self.get_primitive()
 
     # Use the primitive to get an encrypting stream.
@@ -64,7 +63,6 @@
 
   def test_get_two_encrypting_streams(self):
     """Test that multiple EncryptingStreams can be obtained from a primitive."""
-    # Get the primitive.
     primitive = self.get_primitive()
 
     f1 = TestBytesObject()
@@ -79,14 +77,13 @@
     self.assertNotEmpty(f1.getvalue())
     self.assertNotEmpty(f2.getvalue())
 
-  def test_textiowrapper_compatibility(self):
+  def test_encrypting_textiowrapper(self):
     """A test that checks the TextIOWrapper works as expected.
 
     It encrypts the same plaintext twice - once directly from bytes, and once
     through TextIOWrapper's encoding. The two ciphertexts should have the same
     length.
     """
-    # Get the primitive.
     primitive = self.get_primitive()
 
     file_1 = TestBytesObject()
@@ -94,7 +91,7 @@
 
     with primitive.new_encrypting_stream(file_1, b'aad') as es:
       with io.TextIOWrapper(es) as wrapper:
-        # Need to specify this is a unicode string for Python 2.
+        # Need to specify this is a unicode string for Python 2 (b/141106504).
         wrapper.write(u'some data')
 
     with primitive.new_encrypting_stream(file_2, b'aad') as es:
@@ -102,6 +99,98 @@
 
     self.assertEqual(len(file_1.getvalue()), len(file_2.getvalue()))
 
+  def test_round_trip(self):
+    primitive = self.get_primitive()
+
+    f = TestBytesObject()
+
+    original_plaintext = b'some data'
+
+    with primitive.new_encrypting_stream(f, b'test aad') as es:
+      es.write(original_plaintext)
+
+    f.seek(0)
+
+    with primitive.new_decrypting_stream(f, b'test aad') as ds:
+      read_plaintext = ds.read()
+
+    self.assertEqual(read_plaintext, original_plaintext)
+
+  def test_round_trip_textiowrapper_single_line(self):
+    """Read and write a single line through a TextIOWrapper."""
+    primitive = self.get_primitive()
+    f = TestBytesObject()
+
+    # Mark this as unicode for Python 2 (b/141106504)
+    original_plaintext = u'One-line string.'
+    with primitive.new_encrypting_stream(f, b'test aad') as es:
+      with io.TextIOWrapper(es) as wrapper:
+        wrapper.write(original_plaintext)
+
+    f.seek(0)
+
+    with primitive.new_decrypting_stream(f, b'test aad') as ds:
+      with io.TextIOWrapper(ds) as wrapper:
+        read_plaintext = wrapper.read()
+
+    self.assertEqual(original_plaintext, read_plaintext)
+
+  def test_round_trip_decrypt_textiowrapper(self):
+    """Write bytes to EncryptingStream, then decrypt through TextIOWrapper."""
+    primitive = self.get_primitive()
+    f = TestBytesObject()
+    original_plaintext = '''some
+    data
+    on multiple lines.'''
+
+    with primitive.new_encrypting_stream(f, b'test aad') as es:
+      es.write(original_plaintext.encode('utf-8'))
+
+    f.seek(0)
+    with primitive.new_decrypting_stream(f, b'test aad') as ds:
+      with io.TextIOWrapper(ds) as wrapper:
+        data = wrapper.read()
+
+    self.assertEqual(data, original_plaintext)
+
+  def test_round_trip_encrypt_textiowrapper(self):
+    """Encrypt with TextIOWrapper, then decrypt direct bytes."""
+    primitive = self.get_primitive()
+    f = TestBytesObject()
+    # Mark this as unicode for Python 2 (b/141106504)
+    original_plaintext = u'''some
+    data
+    on multiple lines.'''
+
+    with primitive.new_encrypting_stream(f, b'test aad') as es:
+      with io.TextIOWrapper(es) as wrapper:
+        wrapper.write(original_plaintext)
+
+    f.seek(0)
+    with primitive.new_decrypting_stream(f, b'test aad') as ds:
+      data = ds.read().decode('utf-8')
+
+    self.assertEqual(data, original_plaintext)
+
+  def test_round_trip_encrypt_decrypt_textiowrapper(self):
+    """Use TextIOWrapper for both encryption and decryption."""
+    primitive = self.get_primitive()
+    f = TestBytesObject()
+    # Mark this as unicode for Python 2 (b/141106504)
+    original_plaintext = u'''some
+    data
+    on multiple lines.'''
+
+    with primitive.new_encrypting_stream(f, b'test aad') as es:
+      with io.TextIOWrapper(es) as wrapper:
+        wrapper.write(original_plaintext)
+
+    f.seek(0)
+    with primitive.new_decrypting_stream(f, b'test aad') as ds:
+      with io.TextIOWrapper(ds) as wrapper:
+        data = wrapper.read()
+
+    self.assertEqual(data, original_plaintext)
 
 if __name__ == '__main__':
   absltest.main()