Implement new_decrypting_stream in the key manager and add a round-trip test.

PiperOrigin-RevId: 270236806
diff --git a/python/streaming_aead/streaming_aead_key_manager.py b/python/streaming_aead/streaming_aead_key_manager.py
index c40cdea..0379f29 100644
--- a/python/streaming_aead/streaming_aead_key_manager.py
+++ b/python/streaming_aead/streaming_aead_key_manager.py
@@ -22,6 +22,7 @@
 from tink.python.cc.clif import cc_key_manager
 from tink.python.core import key_manager
 from tink.python.core import tink_error
+from tink.python.streaming_aead import decrypting_stream
 from tink.python.streaming_aead import encrypting_stream
 from tink.python.streaming_aead import streaming_aead
 
@@ -43,8 +44,10 @@
   @tink_error.use_tink_errors
   def new_decrypting_stream(self, ciphertext_source: BinaryIO,
                             associated_data: bytes) -> BinaryIO:
-    # TODO(tink-dev) implement DecryptingStream
-    return typing.cast(BinaryIO, None)
+    stream = decrypting_stream.DecryptingStream(self._streaming_aead,
+                                                ciphertext_source,
+                                                associated_data)
+    return typing.cast(BinaryIO, stream)
 
 
 def from_cc_registry(
diff --git a/python/streaming_aead/streaming_aead_key_manager_test.py b/python/streaming_aead/streaming_aead_key_manager_test.py
index a2a5121..ab054bd 100644
--- a/python/streaming_aead/streaming_aead_key_manager_test.py
+++ b/python/streaming_aead/streaming_aead_key_manager_test.py
@@ -15,6 +15,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import io
+
 from absl.testing import absltest
 from tink.proto import aes_ctr_hmac_streaming_pb2
 from tink.proto import aes_gcm_hkdf_streaming_pb2
@@ -27,6 +29,13 @@
 from tink.python.streaming_aead import streaming_aead_key_templates
 
 
+class TestBytesObject(io.BytesIO):
+  """A BytesIO object that does not close."""
+
+  def close(self):
+    pass
+
+
 def setUpModule():
   tink_config.register()
 
@@ -99,9 +108,41 @@
       self.key_manager_ctr.new_key_data(key_template)
 
   def test_encrypt_decrypt(self):
-    pass
-    # TODO(tanujdhir) Consider putting round-trip encryption test here once
-    # implemented
+    saead_primitive = self.key_manager_ctr.primitive(
+        self.key_manager_ctr.new_key_data(
+            streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB))
+    plaintext = b'plaintext'
+    aad = b'associated_data'
+
+    # Encrypt
+    ct_destination = TestBytesObject()
+    with saead_primitive.new_encrypting_stream(ct_destination, aad) as es:
+      self.assertLen(plaintext, es.write(plaintext))
+    self.assertNotEqual(ct_destination.getvalue(), plaintext)
+
+    # Decrypt
+    ct_source = TestBytesObject(ct_destination.getvalue())
+    with saead_primitive.new_decrypting_stream(ct_source, aad) as ds:
+      self.assertEqual(ds.read(), plaintext)
+
+  def test_encrypt_decrypt_wrong_aad(self):
+    saead_primitive = self.key_manager_ctr.primitive(
+        self.key_manager_ctr.new_key_data(
+            streaming_aead_key_templates.AES128_CTR_HMAC_SHA256_4KB))
+    plaintext = b'plaintext'
+    aad = b'associated_data'
+
+    # Encrypt
+    ct_destination = TestBytesObject()
+    with saead_primitive.new_encrypting_stream(ct_destination, aad) as es:
+      self.assertLen(plaintext, es.write(plaintext))
+    self.assertNotEqual(ct_destination.getvalue(), plaintext)
+
+    # Decrypt
+    ct_source = TestBytesObject(ct_destination.getvalue())
+    with saead_primitive.new_decrypting_stream(ct_source, b'bad ' + aad) as ds:
+      with self.assertRaises(tink_error.TinkError):
+        ds.read()
 
 
 if __name__ == '__main__':