blob: 08852cfe93fc505304733782141c87771fe0480c [file] [log] [blame]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tink.python.streaming_aead.decrypting_stream."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
from absl.testing import absltest
# TODO(b/141106504) Replace this with unittest.mock
import mock
from tink.python.core import tink_error
from tink.python.streaming_aead import decrypting_stream
from tink.util import error as clif_error
class FakeInputStreamAdapter(object):
def __init__(self, file_object_adapter):
self._adapter = file_object_adapter
@tink_error.use_tink_errors
def read(self, size=-1):
try:
if size < 0:
size = 100
return self._adapter.read(size)
except EOFError:
raise clif_error.StatusNotOk(11, 'EOF')
def read1(self, size=-1):
del size # unused
return self.read(4)
def fake_get_input_stream_adapter(self, cc_primitive, aad, source):
del cc_primitive, aad, self # unused
return FakeInputStreamAdapter(source)
def get_decrypting_stream(ciphertext_source, aad):
return decrypting_stream.DecryptingStream(None, ciphertext_source, aad)
class DecryptingStreamTest(absltest.TestCase):
def setUp(self):
super(DecryptingStreamTest, self).setUp()
# Replace the DecryptingStream's staticmethod with a custom function to
# avoid the need for a Streaming AEAD primitive.
self.addCleanup(mock.patch.stopall)
mock.patch.object(
decrypting_stream.DecryptingStream,
'_get_input_stream_adapter',
new=fake_get_input_stream_adapter).start()
def test_non_readable_object(self):
f = mock.Mock()
f.readable = mock.Mock(return_value=False)
with self.assertRaisesRegex(ValueError, 'readable'):
get_decrypting_stream(f, b'aad')
def test_read(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
self.assertEqual(ds.read(9), b'something')
def test_read1(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
self.assertEqual(ds.read1(9), b'some')
def test_readinto(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
b = bytearray(9)
self.assertEqual(ds.readinto(b), 9)
self.assertEqual(bytes(b), b'something')
def test_readinto1(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
b = bytearray(9)
self.assertEqual(ds.readinto1(b), 4)
self.assertEqual(bytes(b[:4]), b'some')
def test_read_until_eof(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
self.assertEqual(ds.read(), b'something')
def test_read_eof_reached(self):
f = io.BytesIO()
ds = get_decrypting_stream(f, b'aad')
self.assertEqual(ds.read(), b'')
def test_read_no_data_available(self):
f = mock.Mock()
f.read = mock.Mock(return_value=None)
f.readable = mock.Mock(return_value=True)
ds = get_decrypting_stream(f, b'aad')
self.assertRaises(io.BlockingIOError, ds.read, 5)
def test_unsupported_operation(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
with self.assertRaises(io.UnsupportedOperation):
ds.seek(0, 0)
self.assertRaises(io.UnsupportedOperation, ds.tell)
self.assertRaises(io.UnsupportedOperation, ds.truncate)
with self.assertRaises(io.UnsupportedOperation):
ds.write(b'data')
with self.assertRaises(io.UnsupportedOperation):
ds.writelines([b'data'])
self.assertRaises(io.UnsupportedOperation, ds.fileno)
self.assertRaises(io.UnsupportedOperation, ds.detach)
def test_closed(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
self.assertFalse(ds.closed)
self.assertFalse(f.closed)
ds.close()
self.assertTrue(ds.closed)
self.assertTrue(f.closed)
ds.close()
def test_closed_methods_raise(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
ds.close()
self.assertRaisesRegex(ValueError, 'closed', ds.read)
self.assertRaisesRegex(ValueError, 'closed', ds.flush)
self.assertRaisesRegex(ValueError, 'closed', ds.__enter__)
self.assertRaisesRegex(ValueError, 'closed', ds.__iter__)
self.assertRaisesRegex(ValueError, 'closed', ds.isatty)
def test_position(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
self.assertEqual(ds.position(), 0)
ds.read(4)
self.assertEqual(ds.position(), 4)
ds.read(4)
self.assertEqual(ds.position(), 8)
ds.close()
self.assertEqual(ds.position(), 8)
def test_inquiries(self):
f = io.BytesIO(b'something')
ds = get_decrypting_stream(f, b'aad')
self.assertTrue(ds.readable())
self.assertFalse(ds.writable())
self.assertFalse(ds.seekable())
self.assertFalse(ds.isatty())
def test_context_manager(self):
f = io.BytesIO(b'something')
with get_decrypting_stream(f, b'aad') as ds:
self.assertEqual(ds.read(), b'something')
self.assertTrue(ds.closed)
def test_readline(self):
f = io.BytesIO(b'hello\nworld\n')
ds = get_decrypting_stream(f, b'aad')
self.assertEqual(ds.readline(), b'hello\n')
self.assertEqual(ds.readline(), b'world\n')
def test_readline_with_size(self):
f = io.BytesIO(b'hello\nworld\n')
ds = get_decrypting_stream(f, b'aad')
self.assertEqual(ds.readline(4), b'hell')
self.assertEqual(ds.readline(4), b'o\n')
def test_readlines(self):
f = io.BytesIO(b'hello\nworld\n')
ds = get_decrypting_stream(f, b'aad')
self.assertEqual(ds.readlines(), [b'hello\n', b'world\n'])
def test_readlines_with_hint(self):
f = io.BytesIO(b'hello\nworld\n!!!\n')
ds = get_decrypting_stream(f, b'aad')
self.assertEqual(ds.readlines(10), [b'hello\n', b'world\n'])
def test_iterator(self):
f = io.BytesIO(b'hello\nworld\n')
result = []
for line in get_decrypting_stream(f, b'aad'):
result.append(line)
self.assertEqual(result, [b'hello\n', b'world\n'])
def test_textiowrapper_compatibility(self):
"""A test that checks the TextIOWrapper works as expected.
It decrypts the same ciphertext twice - once directly from bytes, and once
through TextIOWrapper's encoding. The two plaintexts should have the same
length.
"""
file_1 = io.BytesIO(b'something')
file_2 = io.BytesIO(b'something')
with get_decrypting_stream(file_1, b'aad') as ds:
with io.TextIOWrapper(ds) as wrapper:
data_1 = wrapper.read()
with get_decrypting_stream(file_2, b'aad') as ds:
data_2 = ds.read()
self.assertEqual(len(data_1), len(data_2))
if __name__ == '__main__':
absltest.main()