Add DecryptingStream.
Added DecryptingStream, which wraps a C++ decrypting stream and is returned by streaming_aead.new_decrypting_stream(...).
PiperOrigin-RevId: 270236129
diff --git a/python/streaming_aead/decrypting_stream.py b/python/streaming_aead/decrypting_stream.py
new file mode 100644
index 0000000..215249d
--- /dev/null
+++ b/python/streaming_aead/decrypting_stream.py
@@ -0,0 +1,244 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A file-like object that decrypts the data it reads.
+
+It reads the ciphertext from a given other file-like object, and decrypts it.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import google_type_annotations
+from __future__ import print_function
+
+import errno
+import io
+from typing import BinaryIO
+
+from tink.python.cc.clif import cc_streaming_aead_wrappers
+from tink.python.core import tink_error
+from tink.python.util import file_object_adapter
+
+_OUT_OF_RANGE_ERROR_CODE = 11
+
+
+class DecryptingStream(io.BufferedIOBase):
+ """A file-like object which decrypts reads from an underlying object.
+
+ It reads the ciphertext from the wrapped file-like object, and decrypts it.
+
+ The additional method position() returns the number of read plaintext bytes.
+
+ Closing this wrapper also closes the underlying object.
+ """
+
+ def __init__(self, stream_aead, ciphertext_source: BinaryIO,
+ associated_data: bytes):
+ """Create a new DecryptingStream.
+
+ Args:
+ stream_aead: C++ StreamingAead primitive from which a C++ DecryptingStream
+ will be obtained.
+ ciphertext_source: A readable file-like object from which ciphertext bytes
+ will be read.
+ associated_data: The associated data to use for decryption.
+ """
+ super(DecryptingStream, self).__init__()
+ self._closed = False
+ self._bytes_read = 0
+ self._ciphertext_source = ciphertext_source
+
+ # Create FileObjectAdapter
+ if not ciphertext_source.readable():
+ raise ValueError('ciphertext_source must be readable')
+ cc_ciphertext_source = file_object_adapter.FileObjectAdapter(
+ ciphertext_source)
+ # Get InputStreamAdapter of C++ DecryptingStream
+ self._input_stream_adapter = self._get_input_stream_adapter(
+ stream_aead, associated_data, cc_ciphertext_source)
+
+ @staticmethod
+ @tink_error.use_tink_errors
+ def _get_input_stream_adapter(cc_primitive, aad, source):
+ """Implemented as a separate method to ensure correct error transform."""
+ return cc_streaming_aead_wrappers.new_cc_decrypting_stream(
+ cc_primitive, aad, source)
+
+ ### Reading ###
+
+ def read(self, size: int = -1) -> bytes:
+ """Read and return up to size bytes.
+
+ Multiple reads may be issued to the underlying object.
+
+ Args:
+ size: Maximum number of bytes to read. If the argument is omitted, None,
+ or negative, data is read and returned until EOF or if the read call
+ would block in non-blocking mode.
+
+ Returns:
+ Bytes read. An empty bytes object is returned if the stream is already at
+ EOF.
+
+ Raises:
+ BlockingIOError if no data is available at the moment.
+ TinkError if there was a permanent error.
+ """
+ return self._read(size, read1=False)
+
+ def read1(self, size: int = -1) -> bytes:
+ """Read and return up to size bytes.
+
+ At most one read will be issued to the underlying object.
+
+ Args:
+ size: Maximum number of bytes to read. If the argument is omitted, None,
+ or negative, an arbitrary number of bytes are returned.
+
+ Returns:
+ Bytes read. An empty bytes object is returned if the stream is already at
+ EOF.
+
+ Raises:
+ BlockingIOError if no data is available at the moment.
+ TinkError if there was a permanent error.
+ """
+ return self._read(size, read1=True)
+
+ def readinto(self, b: bytearray) -> int:
+ """Read bytes into a pre-allocated bytes-like object b.
+
+ Multiple reads may be issued to the underlying object.
+
+ Args:
+ b: Bytes-like object to which data will be read.
+
+ Returns:
+ Number of bytes read. If 0 is returned it means EOF is reached.
+
+ Raises:
+ BlockingIOError if no data is available at the moment.
+ TinkError if there was a permanent error.
+ """
+ return self._readinto(b, read1=False)
+
+ def readinto1(self, b: bytearray) -> int:
+ """Read bytes into a pre-allocated bytes-like object b.
+
+ At most one read will be issued to the underlying object.
+
+ Args:
+ b: Bytes-like object to which data will be read.
+
+ Returns:
+ Number of bytes read. If 0 is returned it means EOF is reached.
+
+ Raises:
+ BlockingIOError if no data is available at the moment.
+ TinkError if there was a permanent error.
+ """
+ return self._readinto(b, read1=True)
+
+ def _read(self, size: int, read1: bool) -> bytes:
+ self._check_not_closed()
+
+ if size is None:
+ size = -1
+
+ try:
+ if read1:
+ data = self._read1_with_tink_error(size)
+ else:
+ data = self._read_with_tink_error(size)
+
+ if not data:
+ raise io.BlockingIOError(errno.EAGAIN,
+ 'No data available at the moment.')
+ else:
+ self._bytes_read += len(data)
+ return data
+ except tink_error.TinkError as e:
+ # We are checking if the exception was raised because of C++
+ # OUT_OF_RANGE status, which signals EOF.
+ if e.args[0].code == _OUT_OF_RANGE_ERROR_CODE:
+ return b''
+ else:
+ raise e
+
+ # TODO(b/141344377) use the implementation in parent class
+ def _readinto(self, b: bytearray, read1: bool) -> int:
+ data = self._read(len(b), read1)
+ n = len(data)
+ b[:n] = data
+ return n
+
+ @tink_error.use_tink_errors
+ def _read_with_tink_error(self, size: int) -> bytes:
+ """Implemented as a separate method to ensure correct error transform."""
+ return self._input_stream_adapter.read(size)
+
+ @tink_error.use_tink_errors
+ def _read1_with_tink_error(self, size: int) -> bytes:
+ """Implemented as a separate method to ensure correct error transform."""
+ return self._input_stream_adapter.read1(size)
+
+ ### Internal ###
+
+ # TODO(b/141344377) use parent class _checkClosed() instead
+ def _check_not_closed(self, msg=None):
+ """Internal: raise a ValueError if file is closed."""
+ if self.closed:
+ raise ValueError('I/O operation on closed file.' if msg is None else msg)
+
+ ### Positioning ###
+
+ def position(self) -> int:
+ """Returns total number of read plaintext bytes."""
+ return self._bytes_read
+
+ ### Flush and close ###
+
+ def flush(self) -> None:
+ """This has no effect because the stream is read-only."""
+ self._check_not_closed()
+
+ def close(self) -> None:
+ """Close the stream.
+
+ This has no effect on a closed stream.
+ """
+ if self.closed:
+ return
+ self._ciphertext_source.close()
+ self._closed = True
+
+ ### Inquiries ###
+
+ def readable(self) -> bool:
+ """Indicates whether object was opened for reading.
+
+ Returns:
+ Whether object was opened for reading.
+
+ If False, read() will raise UnsupportedOperation.
+ """
+ return True
+
+ @property
+ def closed(self) -> bool:
+ """Indicates if the file has been closed.
+
+ Returns:
+ True if and only if the file has been closed.
+
+ For backwards compatibility, this is a property, not a predicate.
+ """
+ return self._closed
diff --git a/python/streaming_aead/decrypting_stream_test.py b/python/streaming_aead/decrypting_stream_test.py
new file mode 100644
index 0000000..08852cf
--- /dev/null
+++ b/python/streaming_aead/decrypting_stream_test.py
@@ -0,0 +1,245 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for tink.python.streaming_aead.decrypting_stream."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import io
+
+from absl.testing import absltest
+# TODO(b/141106504) Replace this with unittest.mock
+import mock
+
+from tink.python.core import tink_error
+from tink.python.streaming_aead import decrypting_stream
+from tink.util import error as clif_error
+
+
+class FakeInputStreamAdapter(object):
+
+ def __init__(self, file_object_adapter):
+ self._adapter = file_object_adapter
+
+ @tink_error.use_tink_errors
+ def read(self, size=-1):
+ try:
+ if size < 0:
+ size = 100
+ return self._adapter.read(size)
+ except EOFError:
+ raise clif_error.StatusNotOk(11, 'EOF')
+
+ def read1(self, size=-1):
+ del size # unused
+ return self.read(4)
+
+
+def fake_get_input_stream_adapter(self, cc_primitive, aad, source):
+ del cc_primitive, aad, self # unused
+ return FakeInputStreamAdapter(source)
+
+
+def get_decrypting_stream(ciphertext_source, aad):
+ return decrypting_stream.DecryptingStream(None, ciphertext_source, aad)
+
+
+class DecryptingStreamTest(absltest.TestCase):
+
+ def setUp(self):
+ super(DecryptingStreamTest, self).setUp()
+ # Replace the DecryptingStream's staticmethod with a custom function to
+ # avoid the need for a Streaming AEAD primitive.
+ self.addCleanup(mock.patch.stopall)
+ mock.patch.object(
+ decrypting_stream.DecryptingStream,
+ '_get_input_stream_adapter',
+ new=fake_get_input_stream_adapter).start()
+
+ def test_non_readable_object(self):
+ f = mock.Mock()
+ f.readable = mock.Mock(return_value=False)
+
+ with self.assertRaisesRegex(ValueError, 'readable'):
+ get_decrypting_stream(f, b'aad')
+
+ def test_read(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertEqual(ds.read(9), b'something')
+
+ def test_read1(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertEqual(ds.read1(9), b'some')
+
+ def test_readinto(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ b = bytearray(9)
+ self.assertEqual(ds.readinto(b), 9)
+ self.assertEqual(bytes(b), b'something')
+
+ def test_readinto1(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ b = bytearray(9)
+ self.assertEqual(ds.readinto1(b), 4)
+ self.assertEqual(bytes(b[:4]), b'some')
+
+ def test_read_until_eof(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertEqual(ds.read(), b'something')
+
+ def test_read_eof_reached(self):
+ f = io.BytesIO()
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertEqual(ds.read(), b'')
+
+ def test_read_no_data_available(self):
+ f = mock.Mock()
+ f.read = mock.Mock(return_value=None)
+ f.readable = mock.Mock(return_value=True)
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertRaises(io.BlockingIOError, ds.read, 5)
+
+ def test_unsupported_operation(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ with self.assertRaises(io.UnsupportedOperation):
+ ds.seek(0, 0)
+ self.assertRaises(io.UnsupportedOperation, ds.tell)
+ self.assertRaises(io.UnsupportedOperation, ds.truncate)
+ with self.assertRaises(io.UnsupportedOperation):
+ ds.write(b'data')
+ with self.assertRaises(io.UnsupportedOperation):
+ ds.writelines([b'data'])
+ self.assertRaises(io.UnsupportedOperation, ds.fileno)
+ self.assertRaises(io.UnsupportedOperation, ds.detach)
+
+ def test_closed(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertFalse(ds.closed)
+ self.assertFalse(f.closed)
+ ds.close()
+ self.assertTrue(ds.closed)
+ self.assertTrue(f.closed)
+ ds.close()
+
+ def test_closed_methods_raise(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ ds.close()
+ self.assertRaisesRegex(ValueError, 'closed', ds.read)
+ self.assertRaisesRegex(ValueError, 'closed', ds.flush)
+ self.assertRaisesRegex(ValueError, 'closed', ds.__enter__)
+ self.assertRaisesRegex(ValueError, 'closed', ds.__iter__)
+ self.assertRaisesRegex(ValueError, 'closed', ds.isatty)
+
+ def test_position(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertEqual(ds.position(), 0)
+ ds.read(4)
+ self.assertEqual(ds.position(), 4)
+ ds.read(4)
+ self.assertEqual(ds.position(), 8)
+ ds.close()
+ self.assertEqual(ds.position(), 8)
+
+ def test_inquiries(self):
+ f = io.BytesIO(b'something')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertTrue(ds.readable())
+ self.assertFalse(ds.writable())
+ self.assertFalse(ds.seekable())
+ self.assertFalse(ds.isatty())
+
+ def test_context_manager(self):
+ f = io.BytesIO(b'something')
+
+ with get_decrypting_stream(f, b'aad') as ds:
+ self.assertEqual(ds.read(), b'something')
+ self.assertTrue(ds.closed)
+
+ def test_readline(self):
+ f = io.BytesIO(b'hello\nworld\n')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertEqual(ds.readline(), b'hello\n')
+ self.assertEqual(ds.readline(), b'world\n')
+
+ def test_readline_with_size(self):
+ f = io.BytesIO(b'hello\nworld\n')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertEqual(ds.readline(4), b'hell')
+ self.assertEqual(ds.readline(4), b'o\n')
+
+ def test_readlines(self):
+ f = io.BytesIO(b'hello\nworld\n')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertEqual(ds.readlines(), [b'hello\n', b'world\n'])
+
+ def test_readlines_with_hint(self):
+ f = io.BytesIO(b'hello\nworld\n!!!\n')
+ ds = get_decrypting_stream(f, b'aad')
+
+ self.assertEqual(ds.readlines(10), [b'hello\n', b'world\n'])
+
+ def test_iterator(self):
+ f = io.BytesIO(b'hello\nworld\n')
+
+ result = []
+ for line in get_decrypting_stream(f, b'aad'):
+ result.append(line)
+
+ self.assertEqual(result, [b'hello\n', b'world\n'])
+
+ def test_textiowrapper_compatibility(self):
+ """A test that checks the TextIOWrapper works as expected.
+
+ It decrypts the same ciphertext twice - once directly from bytes, and once
+ through TextIOWrapper's encoding. The two plaintexts should have the same
+ length.
+ """
+ file_1 = io.BytesIO(b'something')
+ file_2 = io.BytesIO(b'something')
+
+ with get_decrypting_stream(file_1, b'aad') as ds:
+ with io.TextIOWrapper(ds) as wrapper:
+ data_1 = wrapper.read()
+
+ with get_decrypting_stream(file_2, b'aad') as ds:
+ data_2 = ds.read()
+
+ self.assertEqual(len(data_1), len(data_2))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/python/streaming_aead/streaming_aead.py b/python/streaming_aead/streaming_aead.py
index 31d9ada..f6b9f14 100644
--- a/python/streaming_aead/streaming_aead.py
+++ b/python/streaming_aead/streaming_aead.py
@@ -114,10 +114,8 @@
readinto1()
readline()
readlines()
- readall()
close()
closed
- tell()
isatty()
flush() (no-op)
readable()
@@ -129,6 +127,9 @@
io.UnsupportedOperation.
Closing the wrapper also closes the ciphertext_source.
+ The wrapper also supports the position() method, which returns the number
+ of plaintext bytes read.
+
Raises:
tink.TinkError if the creation fails.