Add DecryptingStream.

Added DecryptingStream, which wraps a C++ decrypting stream and is returned by streaming_aead.new_decrypting_stream(...).

PiperOrigin-RevId: 270236129
diff --git a/python/streaming_aead/decrypting_stream.py b/python/streaming_aead/decrypting_stream.py
new file mode 100644
index 0000000..215249d
--- /dev/null
+++ b/python/streaming_aead/decrypting_stream.py
@@ -0,0 +1,244 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A file-like object that decrypts the data it reads.
+
+It reads the ciphertext from a given other file-like object, and decrypts it.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import google_type_annotations
+from __future__ import print_function
+
+import errno
+import io
+from typing import BinaryIO
+
+from tink.python.cc.clif import cc_streaming_aead_wrappers
+from tink.python.core import tink_error
+from tink.python.util import file_object_adapter
+
+_OUT_OF_RANGE_ERROR_CODE = 11
+
+
+class DecryptingStream(io.BufferedIOBase):
+  """A file-like object which decrypts reads from an underlying object.
+
+  It reads the ciphertext from the wrapped file-like object, and decrypts it.
+
+  The additional method position() returns the number of read plaintext bytes.
+
+  Closing this wrapper also closes the underlying object.
+  """
+
+  def __init__(self, stream_aead, ciphertext_source: BinaryIO,
+               associated_data: bytes):
+    """Create a new DecryptingStream.
+
+    Args:
+      stream_aead: C++ StreamingAead primitive from which a C++ DecryptingStream
+        will be obtained.
+      ciphertext_source: A readable file-like object from which ciphertext bytes
+        will be read.
+      associated_data: The associated data to use for decryption.
+    """
+    super(DecryptingStream, self).__init__()
+    self._closed = False
+    self._bytes_read = 0
+    self._ciphertext_source = ciphertext_source
+
+    # Create FileObjectAdapter
+    if not ciphertext_source.readable():
+      raise ValueError('ciphertext_source must be readable')
+    cc_ciphertext_source = file_object_adapter.FileObjectAdapter(
+        ciphertext_source)
+    # Get InputStreamAdapter of C++ DecryptingStream
+    self._input_stream_adapter = self._get_input_stream_adapter(
+        stream_aead, associated_data, cc_ciphertext_source)
+
+  @staticmethod
+  @tink_error.use_tink_errors
+  def _get_input_stream_adapter(cc_primitive, aad, source):
+    """Implemented as a separate method to ensure correct error transform."""
+    return cc_streaming_aead_wrappers.new_cc_decrypting_stream(
+        cc_primitive, aad, source)
+
+  ### Reading ###
+
+  def read(self, size: int = -1) -> bytes:
+    """Read and return up to size bytes.
+
+    Multiple reads may be issued to the underlying object.
+
+    Args:
+      size: Maximum number of bytes to read. If the argument is omitted, None,
+        or negative, data is read and returned until EOF or if the read call
+        would block in non-blocking mode.
+
+    Returns:
+      Bytes read. An empty bytes object is returned if the stream is already at
+      EOF.
+
+    Raises:
+      BlockingIOError if no data is available at the moment.
+      TinkError if there was a permanent error.
+    """
+    return self._read(size, read1=False)
+
+  def read1(self, size: int = -1) -> bytes:
+    """Read and return up to size bytes.
+
+    At most one read will be issued to the underlying object.
+
+    Args:
+      size: Maximum number of bytes to read. If the argument is omitted, None,
+        or negative, an arbitrary number of bytes are returned.
+
+    Returns:
+      Bytes read. An empty bytes object is returned if the stream is already at
+      EOF.
+
+    Raises:
+      BlockingIOError if no data is available at the moment.
+      TinkError if there was a permanent error.
+    """
+    return self._read(size, read1=True)
+
+  def readinto(self, b: bytearray) -> int:
+    """Read bytes into a pre-allocated bytes-like object b.
+
+    Multiple reads may be issued to the underlying object.
+
+    Args:
+      b: Bytes-like object to which data will be read.
+
+    Returns:
+      Number of bytes read. If 0 is returned it means EOF is reached.
+
+    Raises:
+      BlockingIOError if no data is available at the moment.
+      TinkError if there was a permanent error.
+    """
+    return self._readinto(b, read1=False)
+
+  def readinto1(self, b: bytearray) -> int:
+    """Read bytes into a pre-allocated bytes-like object b.
+
+    At most one read will be issued to the underlying object.
+
+    Args:
+      b: Bytes-like object to which data will be read.
+
+    Returns:
+      Number of bytes read. If 0 is returned it means EOF is reached.
+
+    Raises:
+      BlockingIOError if no data is available at the moment.
+      TinkError if there was a permanent error.
+    """
+    return self._readinto(b, read1=True)
+
+  def _read(self, size: int, read1: bool) -> bytes:
+    self._check_not_closed()
+
+    if size is None:
+      size = -1
+
+    try:
+      if read1:
+        data = self._read1_with_tink_error(size)
+      else:
+        data = self._read_with_tink_error(size)
+
+      if not data:
+        raise io.BlockingIOError(errno.EAGAIN,
+                                 'No data available at the moment.')
+      else:
+        self._bytes_read += len(data)
+        return data
+    except tink_error.TinkError as e:
+      # We are checking if the exception was raised because of C++
+      # OUT_OF_RANGE status, which signals EOF.
+      if e.args[0].code == _OUT_OF_RANGE_ERROR_CODE:
+        return b''
+      else:
+        raise e
+
+  # TODO(b/141344377) use the implementation in parent class
+  def _readinto(self, b: bytearray, read1: bool) -> int:
+    data = self._read(len(b), read1)
+    n = len(data)
+    b[:n] = data
+    return n
+
+  @tink_error.use_tink_errors
+  def _read_with_tink_error(self, size: int) -> bytes:
+    """Implemented as a separate method to ensure correct error transform."""
+    return self._input_stream_adapter.read(size)
+
+  @tink_error.use_tink_errors
+  def _read1_with_tink_error(self, size: int) -> bytes:
+    """Implemented as a separate method to ensure correct error transform."""
+    return self._input_stream_adapter.read1(size)
+
+  ### Internal ###
+
+  # TODO(b/141344377) use parent class _checkClosed() instead
+  def _check_not_closed(self, msg=None):
+    """Internal: raise a ValueError if file is closed."""
+    if self.closed:
+      raise ValueError('I/O operation on closed file.' if msg is None else msg)
+
+  ### Positioning ###
+
+  def position(self) -> int:
+    """Returns total number of read plaintext bytes."""
+    return self._bytes_read
+
+  ### Flush and close ###
+
+  def flush(self) -> None:
+    """This has no effect because the stream is read-only."""
+    self._check_not_closed()
+
+  def close(self) -> None:
+    """Close the stream.
+
+    This has no effect on a closed stream.
+    """
+    if self.closed:
+      return
+    self._ciphertext_source.close()
+    self._closed = True
+
+  ### Inquiries ###
+
+  def readable(self) -> bool:
+    """Indicates whether object was opened for reading.
+
+    Returns:
+      Whether object was opened for reading.
+
+    If False, read() will raise UnsupportedOperation.
+    """
+    return True
+
+  @property
+  def closed(self) -> bool:
+    """Indicates if the file has been closed.
+
+    Returns:
+      True if and only if the file has been closed.
+
+    For backwards compatibility, this is a property, not a predicate.
+    """
+    return self._closed
diff --git a/python/streaming_aead/decrypting_stream_test.py b/python/streaming_aead/decrypting_stream_test.py
new file mode 100644
index 0000000..08852cf
--- /dev/null
+++ b/python/streaming_aead/decrypting_stream_test.py
@@ -0,0 +1,245 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for tink.python.streaming_aead.decrypting_stream."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import io
+
+from absl.testing import absltest
+# TODO(b/141106504) Replace this with unittest.mock
+import mock
+
+from tink.python.core import tink_error
+from tink.python.streaming_aead import decrypting_stream
+from tink.util import error as clif_error
+
+
+class FakeInputStreamAdapter(object):
+
+  def __init__(self, file_object_adapter):
+    self._adapter = file_object_adapter
+
+  @tink_error.use_tink_errors
+  def read(self, size=-1):
+    try:
+      if size < 0:
+        size = 100
+      return self._adapter.read(size)
+    except EOFError:
+      raise clif_error.StatusNotOk(11, 'EOF')
+
+  def read1(self, size=-1):
+    del size  # unused
+    return self.read(4)
+
+
+def fake_get_input_stream_adapter(self, cc_primitive, aad, source):
+  del cc_primitive, aad, self  # unused
+  return FakeInputStreamAdapter(source)
+
+
+def get_decrypting_stream(ciphertext_source, aad):
+  return decrypting_stream.DecryptingStream(None, ciphertext_source, aad)
+
+
+class DecryptingStreamTest(absltest.TestCase):
+
+  def setUp(self):
+    super(DecryptingStreamTest, self).setUp()
+    # Replace the DecryptingStream's staticmethod with a custom function to
+    # avoid the need for a Streaming AEAD primitive.
+    self.addCleanup(mock.patch.stopall)
+    mock.patch.object(
+        decrypting_stream.DecryptingStream,
+        '_get_input_stream_adapter',
+        new=fake_get_input_stream_adapter).start()
+
+  def test_non_readable_object(self):
+    f = mock.Mock()
+    f.readable = mock.Mock(return_value=False)
+
+    with self.assertRaisesRegex(ValueError, 'readable'):
+      get_decrypting_stream(f, b'aad')
+
+  def test_read(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertEqual(ds.read(9), b'something')
+
+  def test_read1(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertEqual(ds.read1(9), b'some')
+
+  def test_readinto(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    b = bytearray(9)
+    self.assertEqual(ds.readinto(b), 9)
+    self.assertEqual(bytes(b), b'something')
+
+  def test_readinto1(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    b = bytearray(9)
+    self.assertEqual(ds.readinto1(b), 4)
+    self.assertEqual(bytes(b[:4]), b'some')
+
+  def test_read_until_eof(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertEqual(ds.read(), b'something')
+
+  def test_read_eof_reached(self):
+    f = io.BytesIO()
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertEqual(ds.read(), b'')
+
+  def test_read_no_data_available(self):
+    f = mock.Mock()
+    f.read = mock.Mock(return_value=None)
+    f.readable = mock.Mock(return_value=True)
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertRaises(io.BlockingIOError, ds.read, 5)
+
+  def test_unsupported_operation(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    with self.assertRaises(io.UnsupportedOperation):
+      ds.seek(0, 0)
+    self.assertRaises(io.UnsupportedOperation, ds.tell)
+    self.assertRaises(io.UnsupportedOperation, ds.truncate)
+    with self.assertRaises(io.UnsupportedOperation):
+      ds.write(b'data')
+    with self.assertRaises(io.UnsupportedOperation):
+      ds.writelines([b'data'])
+    self.assertRaises(io.UnsupportedOperation, ds.fileno)
+    self.assertRaises(io.UnsupportedOperation, ds.detach)
+
+  def test_closed(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertFalse(ds.closed)
+    self.assertFalse(f.closed)
+    ds.close()
+    self.assertTrue(ds.closed)
+    self.assertTrue(f.closed)
+    ds.close()
+
+  def test_closed_methods_raise(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    ds.close()
+    self.assertRaisesRegex(ValueError, 'closed', ds.read)
+    self.assertRaisesRegex(ValueError, 'closed', ds.flush)
+    self.assertRaisesRegex(ValueError, 'closed', ds.__enter__)
+    self.assertRaisesRegex(ValueError, 'closed', ds.__iter__)
+    self.assertRaisesRegex(ValueError, 'closed', ds.isatty)
+
+  def test_position(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertEqual(ds.position(), 0)
+    ds.read(4)
+    self.assertEqual(ds.position(), 4)
+    ds.read(4)
+    self.assertEqual(ds.position(), 8)
+    ds.close()
+    self.assertEqual(ds.position(), 8)
+
+  def test_inquiries(self):
+    f = io.BytesIO(b'something')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertTrue(ds.readable())
+    self.assertFalse(ds.writable())
+    self.assertFalse(ds.seekable())
+    self.assertFalse(ds.isatty())
+
+  def test_context_manager(self):
+    f = io.BytesIO(b'something')
+
+    with get_decrypting_stream(f, b'aad') as ds:
+      self.assertEqual(ds.read(), b'something')
+    self.assertTrue(ds.closed)
+
+  def test_readline(self):
+    f = io.BytesIO(b'hello\nworld\n')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertEqual(ds.readline(), b'hello\n')
+    self.assertEqual(ds.readline(), b'world\n')
+
+  def test_readline_with_size(self):
+    f = io.BytesIO(b'hello\nworld\n')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertEqual(ds.readline(4), b'hell')
+    self.assertEqual(ds.readline(4), b'o\n')
+
+  def test_readlines(self):
+    f = io.BytesIO(b'hello\nworld\n')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertEqual(ds.readlines(), [b'hello\n', b'world\n'])
+
+  def test_readlines_with_hint(self):
+    f = io.BytesIO(b'hello\nworld\n!!!\n')
+    ds = get_decrypting_stream(f, b'aad')
+
+    self.assertEqual(ds.readlines(10), [b'hello\n', b'world\n'])
+
+  def test_iterator(self):
+    f = io.BytesIO(b'hello\nworld\n')
+
+    result = []
+    for line in get_decrypting_stream(f, b'aad'):
+      result.append(line)
+
+    self.assertEqual(result, [b'hello\n', b'world\n'])
+
+  def test_textiowrapper_compatibility(self):
+    """A test that checks the TextIOWrapper works as expected.
+
+    It decrypts the same ciphertext twice - once directly from bytes, and once
+    through TextIOWrapper's encoding. The two plaintexts should have the same
+    length.
+    """
+    file_1 = io.BytesIO(b'something')
+    file_2 = io.BytesIO(b'something')
+
+    with get_decrypting_stream(file_1, b'aad') as ds:
+      with io.TextIOWrapper(ds) as wrapper:
+        data_1 = wrapper.read()
+
+    with get_decrypting_stream(file_2, b'aad') as ds:
+      data_2 = ds.read()
+
+    self.assertEqual(len(data_1), len(data_2))
+
+
+if __name__ == '__main__':
+  absltest.main()
diff --git a/python/streaming_aead/streaming_aead.py b/python/streaming_aead/streaming_aead.py
index 31d9ada..f6b9f14 100644
--- a/python/streaming_aead/streaming_aead.py
+++ b/python/streaming_aead/streaming_aead.py
@@ -114,10 +114,8 @@
         readinto1()
         readline()
         readlines()
-        readall()
         close()
         closed
-        tell()
         isatty()
         flush() (no-op)
         readable()
@@ -129,6 +127,9 @@
       io.UnsupportedOperation.
       Closing the wrapper also closes the ciphertext_source.
 
+      The wrapper also supports the position() method, which returns the number
+      of plaintext bytes read.
+
     Raises:
       tink.TinkError if the creation fails.