blob: 861962de50738d4e1c0f8a08453d991e3b455666 [file] [log] [blame]
import sys
import pickle
import typing
from unittest import TestCase, main, skipUnless
from mypy_extensions import TypedDict
class BaseTestCase(TestCase):
def assertIsSubclass(self, cls, class_or_tuple, msg=None):
if not issubclass(cls, class_or_tuple):
message = '%r is not a subclass of %r' % (cls, class_or_tuple)
if msg is not None:
message += ' : %s' % msg
raise self.failureException(message)
def assertNotIsSubclass(self, cls, class_or_tuple, msg=None):
if issubclass(cls, class_or_tuple):
message = '%r is a subclass of %r' % (cls, class_or_tuple)
if msg is not None:
message += ' : %s' % msg
raise self.failureException(message)
PY36 = sys.version_info[:2] >= (3, 6)
PY36_TESTS = """
Label = TypedDict('Label', [('label', str)])
class Point2D(TypedDict):
x: int
y: int
class LabelPoint2D(Point2D, Label): ...
class Options(TypedDict, total=False):
log_level: int
log_path: str
"""
if PY36:
exec(PY36_TESTS)
class TypedDictTests(BaseTestCase):
def test_basics_iterable_syntax(self):
Emp = TypedDict('Emp', {'name': str, 'id': int})
self.assertIsSubclass(Emp, dict)
self.assertIsSubclass(Emp, typing.MutableMapping)
if sys.version_info[0] >= 3:
import collections.abc
self.assertNotIsSubclass(Emp, collections.abc.Sequence)
jim = Emp(name='Jim', id=1)
self.assertIs(type(jim), dict)
self.assertEqual(jim['name'], 'Jim')
self.assertEqual(jim['id'], 1)
self.assertEqual(Emp.__name__, 'Emp')
self.assertEqual(Emp.__module__, __name__)
self.assertEqual(Emp.__bases__, (dict,))
self.assertEqual(Emp.__annotations__, {'name': str, 'id': int})
self.assertEqual(Emp.__total__, True)
def test_basics_keywords_syntax(self):
Emp = TypedDict('Emp', name=str, id=int)
self.assertIsSubclass(Emp, dict)
self.assertIsSubclass(Emp, typing.MutableMapping)
if sys.version_info[0] >= 3:
import collections.abc
self.assertNotIsSubclass(Emp, collections.abc.Sequence)
jim = Emp(name='Jim', id=1) # type: ignore # mypy doesn't support keyword syntax yet
self.assertIs(type(jim), dict)
self.assertEqual(jim['name'], 'Jim')
self.assertEqual(jim['id'], 1)
self.assertEqual(Emp.__name__, 'Emp')
self.assertEqual(Emp.__module__, __name__)
self.assertEqual(Emp.__bases__, (dict,))
self.assertEqual(Emp.__annotations__, {'name': str, 'id': int})
self.assertEqual(Emp.__total__, True)
def test_typeddict_errors(self):
Emp = TypedDict('Emp', {'name': str, 'id': int})
self.assertEqual(TypedDict.__module__, 'mypy_extensions')
jim = Emp(name='Jim', id=1)
with self.assertRaises(TypeError):
isinstance({}, Emp) # type: ignore
with self.assertRaises(TypeError):
isinstance(jim, Emp) # type: ignore
with self.assertRaises(TypeError):
issubclass(dict, Emp) # type: ignore
with self.assertRaises(TypeError):
TypedDict('Hi', x=())
with self.assertRaises(TypeError):
TypedDict('Hi', [('x', int), ('y', ())])
with self.assertRaises(TypeError):
TypedDict('Hi', [('x', int)], y=int)
@skipUnless(PY36, 'Python 3.6 required')
def test_py36_class_syntax_usage(self):
self.assertEqual(LabelPoint2D.__name__, 'LabelPoint2D') # noqa
self.assertEqual(LabelPoint2D.__module__, __name__) # noqa
self.assertEqual(LabelPoint2D.__annotations__, {'x': int, 'y': int, 'label': str}) # noqa
self.assertEqual(LabelPoint2D.__bases__, (dict,)) # noqa
self.assertEqual(LabelPoint2D.__total__, True) # noqa
self.assertNotIsSubclass(LabelPoint2D, typing.Sequence) # noqa
not_origin = Point2D(x=0, y=1) # noqa
self.assertEqual(not_origin['x'], 0)
self.assertEqual(not_origin['y'], 1)
other = LabelPoint2D(x=0, y=1, label='hi') # noqa
self.assertEqual(other['label'], 'hi')
def test_pickle(self):
global EmpD # pickle wants to reference the class by name
EmpD = TypedDict('EmpD', name=str, id=int)
jane = EmpD({'name': 'jane', 'id': 37})
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
z = pickle.dumps(jane, proto)
jane2 = pickle.loads(z)
self.assertEqual(jane2, jane)
self.assertEqual(jane2, {'name': 'jane', 'id': 37})
ZZ = pickle.dumps(EmpD, proto)
EmpDnew = pickle.loads(ZZ)
self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane)
def test_optional(self):
EmpD = TypedDict('EmpD', name=str, id=int)
self.assertEqual(typing.Optional[EmpD], typing.Union[None, EmpD])
self.assertNotEqual(typing.List[EmpD], typing.Tuple[EmpD])
def test_total(self):
D = TypedDict('D', {'x': int}, total=False)
self.assertEqual(D(), {})
self.assertEqual(D(x=1), {'x': 1})
self.assertEqual(D.__total__, False)
if PY36:
self.assertEqual(Options(), {}) # noqa
self.assertEqual(Options(log_level=2), {'log_level': 2}) # noqa
self.assertEqual(Options.__total__, False) # noqa
if __name__ == '__main__':
main()