| # Lint as: python3 |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Tests for tink.python.streaming_aead.encrypting_stream.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import io |
| import sys |
| |
| from absl.testing import absltest |
| # TODO(b/141106504) Replace this with unittest.mock |
| import mock |
| |
| from tink.python.streaming_aead import encrypting_stream |
| |
| |
| def fake_get_output_stream_adapter(self, cc_primitive, aad, destination): |
| del cc_primitive, aad, self # unused |
| return destination |
| |
| |
| class TestBytesObject(io.BytesIO): |
| """A BytesIO object that does not close.""" |
| |
| def close(self): |
| pass |
| |
| |
| def get_encrypting_stream(ciphertext_destination, aad): |
| return encrypting_stream.EncryptingStream(None, ciphertext_destination, aad) |
| |
| |
| class EncryptingStreamTest(absltest.TestCase): |
| |
| def setUp(self): |
| super(EncryptingStreamTest, self).setUp() |
| # Replace the EncryptingStream's staticmethod with a custom function to |
| # avoid the need for a Streaming AEAD primitive. |
| self.addCleanup(mock.patch.stopall) |
| mock.patch.object( |
| encrypting_stream.EncryptingStream, |
| '_get_output_stream_adapter', |
| new=fake_get_output_stream_adapter).start() |
| |
| def test_non_writable_object(self): |
| f = mock.Mock() |
| f.writable = mock.Mock(return_value=False) |
| with self.assertRaisesRegex(ValueError, 'writable'): |
| get_encrypting_stream(f, b'aad') |
| |
| def test_write(self): |
| f = TestBytesObject() |
| with get_encrypting_stream(f, b'aad') as es: |
| es.write(b'Hello world!') |
| |
| self.assertEqual(b'Hello world!', f.getvalue()) |
| |
| @absltest.skipIf(sys.version_info[0] == 2, 'Python 2 strings are bytes') |
| def test_write_non_bytes(self): |
| with io.BytesIO() as f, get_encrypting_stream(f, b'aad') as es: |
| with self.assertRaisesRegex(TypeError, 'bytes-like object is required'): |
| es.write('This is a string, not a bytes object') |
| |
| def test_textiowrapper_compatibility(self): |
| """A test that checks the TextIOWrapper works as expected. |
| |
| It encrypts the same plaintext twice - once directly from bytes, and once |
| through TextIOWrapper's encoding. The two ciphertexts should have the same |
| length. |
| """ |
| file_1 = TestBytesObject() |
| file_2 = TestBytesObject() |
| |
| with get_encrypting_stream(file_1, b'aad') as es: |
| with io.TextIOWrapper(es) as wrapper: |
| # Need to specify this is a unicode string for Python 2. |
| wrapper.write(u'some data') |
| |
| with get_encrypting_stream(file_2, b'aad') as es: |
| es.write(b'some data') |
| |
| self.assertEqual(len(file_1.getvalue()), len(file_2.getvalue())) |
| |
| def test_flush(self): |
| with io.BytesIO() as f, get_encrypting_stream(f, b'assoc') as es: |
| es.write(b'Hello world!') |
| es.flush() |
| |
| def test_closed(self): |
| f = io.BytesIO() |
| es = get_encrypting_stream(f, b'assoc') |
| es.write(b'Hello world!') |
| es.close() |
| |
| self.assertTrue(es.closed) |
| self.assertTrue(f.closed) |
| |
| def test_closed_methods_raise(self): |
| f = io.BytesIO() |
| es = get_encrypting_stream(f, b'assoc') |
| es.write(b'Hello world!') |
| es.close() |
| |
| with self.assertRaisesRegex(ValueError, 'closed'): |
| es.write(b'Goodbye world.') |
| with self.assertRaisesRegex(ValueError, 'closed'): |
| with es: |
| pass |
| with self.assertRaisesRegex(ValueError, 'closed'): |
| es.flush() |
| |
| def test_unsupported_operation(self): |
| with io.BytesIO() as f, get_encrypting_stream(f, b'assoc') as es: |
| with self.assertRaisesRegex(io.UnsupportedOperation, 'seek'): |
| es.seek(0, 2) |
| with self.assertRaisesRegex(io.UnsupportedOperation, 'truncate'): |
| es.truncate(0) |
| with self.assertRaisesRegex(io.UnsupportedOperation, 'read'): |
| es.read(-1) |
| |
| def test_inquiries(self): |
| with io.BytesIO() as f, get_encrypting_stream(f, b'assoc') as es: |
| self.assertTrue(es.writable()) |
| self.assertFalse(es.readable()) |
| self.assertFalse(es.seekable()) |
| |
| def test_position(self): |
| with io.BytesIO() as f: |
| with get_encrypting_stream(f, b'assoc') as es: |
| es.write(b'Hello world!') |
| self.assertEqual(es.position(), 12) |
| |
| def test_position_works_closed(self): |
| with io.BytesIO() as f: |
| es = get_encrypting_stream(f, b'assoc') |
| |
| es.write(b'Hello world!') |
| es.close() |
| |
| self.assertTrue(es.closed) |
| self.assertEqual(es.position(), 12) |
| |
| def test_blocking_io(self): |
| |
| class OnlyWritesFirstFiveBytes(io.BytesIO): |
| |
| def write(self, buffer): |
| buffer = buffer[:5] |
| n = super(OnlyWritesFirstFiveBytes, self).write(buffer) |
| return n |
| |
| with OnlyWritesFirstFiveBytes() as f: |
| with get_encrypting_stream(f, b'assoc') as es: |
| with self.assertRaisesRegex(io.BlockingIOError, 'could not complete'): |
| es.write(b'Hello world!') |
| |
| def test_context_manager_exception_close(self): |
| """Tests that exceptional exits do not trigger normal file closure. |
| |
| Instead, the file will be closed without a proper final ciphertext block, |
| and will result in an invalid ciphertext. The ciphertext_destination file |
| object itself should in most cases still be closed when garbage collected. |
| """ |
| f = io.BytesIO() |
| with self.assertRaisesRegex(ValueError, 'raised inside'): |
| with get_encrypting_stream(f, b'assoc') as es: |
| es.write(b'some message') |
| raise ValueError('Error raised inside context manager') |
| |
| self.assertFalse(f.closed) |
| |
| |
| if __name__ == '__main__': |
| absltest.main() |