blob: d2bc81b46171b81b8c2350344761f1318436bbb9 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
"""Test decoder."""
import io
import unittest
from google.protobuf import message
from google.protobuf.internal import api_implementation
from google.protobuf.internal import decoder
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import wire_format
from absl.testing import parameterized
_INPUT_BYTES = b'\x84r\x12'
_EXPECTED = (14596, 18)
@testing_refleaks.TestCase
class DecoderTest(parameterized.TestCase):
def test_decode_varint_bytes(self):
(size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 0)
self.assertEqual(size, _EXPECTED[0])
self.assertEqual(pos, 2)
(size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 2)
self.assertEqual(size, _EXPECTED[1])
self.assertEqual(pos, 3)
def test_decode_varint_bytes_empty(self):
with self.assertRaises(IndexError) as context:
decoder._DecodeVarint(b'', 0)
self.assertIn('index out of range', str(context.exception))
def test_decode_varint_bytesio(self):
index = 0
input_io = io.BytesIO(_INPUT_BYTES)
while True:
size = decoder._DecodeVarint(input_io)
if size is None:
break
self.assertEqual(size, _EXPECTED[index])
index += 1
self.assertEqual(index, len(_EXPECTED))
def test_decode_varint_bytesio_empty(self):
input_io = io.BytesIO(b'')
size = decoder._DecodeVarint(input_io)
self.assertIsNone(size)
def test_decode_unknown_group_field(self):
data = memoryview(b'\013\020\003\014\040\005')
parsed, pos = decoder._DecodeUnknownField(
data, 1, len(data), 1, wire_format.WIRETYPE_START_GROUP
)
self.assertEqual(pos, 4)
self.assertEqual(len(parsed), 1)
self.assertEqual(parsed[0].field_number, 2)
self.assertEqual(parsed[0].data, 3)
def test_decode_unknown_group_field_nested(self):
data = memoryview(b'\013\023\013\030\004\014\024\014\050\006')
parsed, pos = decoder._DecodeUnknownField(
data, 1, len(data), 1, wire_format.WIRETYPE_START_GROUP
)
self.assertEqual(pos, 8)
self.assertEqual(len(parsed), 1)
self.assertEqual(parsed[0].field_number, 2)
self.assertEqual(len(parsed[0].data), 1)
self.assertEqual(parsed[0].data[0].field_number, 1)
self.assertEqual(len(parsed[0].data[0].data), 1)
self.assertEqual(parsed[0].data[0].data[0].field_number, 3)
self.assertEqual(parsed[0].data[0].data[0].data, 4)
def test_decode_unknown_group_field_too_many_levels(self):
data = memoryview(b'\023' * 5_000_000)
self.assertRaisesRegex(
message.DecodeError,
'Error parsing message',
decoder._DecodeUnknownField,
data,
1,
len(data),
1,
wire_format.WIRETYPE_START_GROUP,
)
def test_decode_unknown_mismatched_end_group(self):
self.assertRaisesRegex(
message.DecodeError,
'Missing group end tag.*',
decoder._DecodeUnknownField,
memoryview(b'\013\024'),
1,
2,
1,
wire_format.WIRETYPE_START_GROUP,
)
def test_decode_unknown_mismatched_end_group_nested(self):
self.assertRaisesRegex(
message.DecodeError,
'Missing group end tag.*',
decoder._DecodeUnknownField,
memoryview(b'\013\023\034\024\014'),
1,
5,
1,
wire_format.WIRETYPE_START_GROUP,
)
def test_decode_message_set_unknown_mismatched_end_group(self):
proto = message_set_extensions_pb2.TestMessageSet()
self.assertRaisesRegex(
message.DecodeError,
'Unexpected end-group tag.'
if api_implementation.Type() == 'python'
else '.*',
proto.ParseFromString,
b'\013\054\014',
)
def test_unknown_message_set_decoder_mismatched_end_group(self):
# This behavior isn't actually reachable in practice, but it's good to
# test anyway.
decode = decoder.UnknownMessageSetItemDecoder()
self.assertRaisesRegex(
message.DecodeError,
'Unexpected end-group tag.',
decode,
memoryview(b'\054\014'),
)
@parameterized.parameters(int(0), float(0.0), False, '')
def test_default_scalar(self, value):
self.assertTrue(decoder.IsDefaultScalarValue(value))
@parameterized.parameters(int(1), float(-0.0), float(1.0), True, 'a')
def test_not_default_scalar(self, value):
self.assertFalse(decoder.IsDefaultScalarValue(value))
if __name__ == '__main__':
unittest.main()