| # Test cases related to the functools.singledispatch decorator |
| # Most of these tests are marked as xfails because mypyc doesn't support singledispatch yet |
| # (These tests will be re-enabled when mypyc supports singledispatch) |
| |
| [case testSpecializedImplementationUsed] |
| from functools import singledispatch |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| return False |
| |
| @fun.register |
| def fun_specialized(arg: str) -> bool: |
| return True |
| |
| def test_specialize() -> None: |
| assert fun('a') |
| assert not fun(3) |
| |
| [case testSubclassesOfExpectedTypeUseSpecialized] |
| from functools import singledispatch |
| class A: pass |
| class B(A): pass |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| return False |
| |
| @fun.register |
| def fun_specialized(arg: A) -> bool: |
| return True |
| |
| def test_specialize() -> None: |
| assert fun(B()) |
| assert fun(A()) |
| |
| [case testSuperclassImplementationNotUsedWhenSubclassHasImplementation] |
| from functools import singledispatch |
| class A: pass |
| class B(A): pass |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| # shouldn't be using this |
| assert False |
| |
| @fun.register |
| def fun_specialized(arg: A) -> bool: |
| return False |
| |
| @fun.register |
| def fun_specialized2(arg: B) -> bool: |
| return True |
| |
| def test_specialize() -> None: |
| assert fun(B()) |
| assert not fun(A()) |
| |
| [case testMultipleUnderscoreFunctionsIsntError] |
| from functools import singledispatch |
| |
| @singledispatch |
| def fun(arg) -> str: |
| return 'default' |
| |
| @fun.register |
| def _(arg: str) -> str: |
| return 'str' |
| |
| @fun.register |
| def _(arg: int) -> str: |
| return 'int' |
| |
| # extra function to make sure all 3 underscore functions aren't treated as one OverloadedFuncDef |
| def a(b): pass |
| |
| @fun.register |
| def _(arg: list) -> str: |
| return 'list' |
| |
| def test_singledispatch() -> None: |
| assert fun(0) == 'int' |
| assert fun('a') == 'str' |
| assert fun([1, 2]) == 'list' |
| assert fun({'a': 'b'}) == 'default' |
| |
| [case testCanRegisterCompiledClasses] |
| from functools import singledispatch |
| class A: pass |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| return False |
| @fun.register |
| def fun_specialized(arg: A) -> bool: |
| return True |
| |
| def test_singledispatch() -> None: |
| assert fun(A()) |
| assert not fun(1) |
| |
| [case testTypeUsedAsArgumentToRegister] |
| from functools import singledispatch |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| return False |
| |
| @fun.register(int) |
| def fun_specialized(arg) -> bool: |
| return True |
| |
| def test_singledispatch() -> None: |
| assert fun(1) |
| assert not fun('a') |
| |
| [case testUseRegisterAsAFunction] |
| from functools import singledispatch |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| return False |
| |
| def fun_specialized_impl(arg) -> bool: |
| return True |
| |
| fun.register(int, fun_specialized_impl) |
| |
| def test_singledispatch() -> None: |
| assert fun(0) |
| assert not fun('a') |
| |
| [case testRegisterDoesntChangeFunction] |
| from functools import singledispatch |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| return False |
| |
| @fun.register(int) |
| def fun_specialized(arg) -> bool: |
| return True |
| |
| def test_singledispatch() -> None: |
| assert fun_specialized('a') |
| |
| # TODO: turn this into a mypy error |
| [case testNoneIsntATypeWhenUsedAsArgumentToRegister] |
| from functools import singledispatch |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| return False |
| |
| try: |
| @fun.register |
| def fun_specialized(arg: None) -> bool: |
| return True |
| except TypeError: |
| pass |
| |
| [case testRegisteringTheSameFunctionSeveralTimes] |
| from functools import singledispatch |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| return False |
| |
| @fun.register(int) |
| @fun.register(str) |
| def fun_specialized(arg) -> bool: |
| return True |
| |
| def test_singledispatch() -> None: |
| assert fun(0) |
| assert fun('a') |
| assert not fun([1, 2]) |
| |
| [case testTypeIsAnABC] |
| from functools import singledispatch |
| from collections.abc import Mapping |
| |
| @singledispatch |
| def fun(arg) -> bool: |
| return False |
| |
| @fun.register |
| def fun_specialized(arg: Mapping) -> bool: |
| return True |
| |
| def test_singledispatch() -> None: |
| assert not fun(1) |
| assert fun({'a': 'b'}) |
| |
| [case testSingleDispatchMethod-xfail] |
| from functools import singledispatchmethod |
| class A: |
| @singledispatchmethod |
| def fun(self, arg) -> str: |
| return 'default' |
| |
| @fun.register |
| def fun_int(self, arg: int) -> str: |
| return 'int' |
| |
| @fun.register |
| def fun_str(self, arg: str) -> str: |
| return 'str' |
| |
| def test_singledispatchmethod() -> None: |
| x = A() |
| assert x.fun(5) == 'int' |
| assert x.fun('a') == 'str' |
| assert x.fun([1, 2]) == 'default' |
| |
| [case testSingleDispatchMethodWithOtherDecorator-xfail] |
| from functools import singledispatchmethod |
| class A: |
| @singledispatchmethod |
| @staticmethod |
| def fun(arg) -> str: |
| return 'default' |
| |
| @fun.register |
| @staticmethod |
| def fun_int(arg: int) -> str: |
| return 'int' |
| |
| @fun.register |
| @staticmethod |
| def fun_str(arg: str) -> str: |
| return 'str' |
| |
| def test_singledispatchmethod() -> None: |
| x = A() |
| assert x.fun(5) == 'int' |
| assert x.fun('a') == 'str' |
| assert x.fun([1, 2]) == 'default' |
| |
| [case testSingledispatchTreeSumAndEqual] |
| from functools import singledispatch |
| |
| class Tree: |
| pass |
| class Leaf(Tree): |
| pass |
| class Node(Tree): |
| def __init__(self, value: int, left: Tree, right: Tree) -> None: |
| self.value = value |
| self.left = left |
| self.right = right |
| |
| @singledispatch |
| def calc_sum(x: Tree) -> int: |
| raise TypeError('invalid type for x') |
| |
| @calc_sum.register |
| def _(x: Leaf) -> int: |
| return 0 |
| |
| @calc_sum.register |
| def _(x: Node) -> int: |
| return x.value + calc_sum(x.left) + calc_sum(x.right) |
| |
| @singledispatch |
| def equal(to_compare: Tree, known: Tree) -> bool: |
| raise TypeError('invalid type for x') |
| |
| @equal.register |
| def _(to_compare: Leaf, known: Tree) -> bool: |
| return isinstance(known, Leaf) |
| |
| @equal.register |
| def _(to_compare: Node, known: Tree) -> bool: |
| if isinstance(known, Node): |
| if to_compare.value != known.value: |
| return False |
| else: |
| return equal(to_compare.left, known.left) and equal(to_compare.right, known.right) |
| return False |
| |
| def build(n: int) -> Tree: |
| if n == 0: |
| return Leaf() |
| return Node(n, build(n - 1), build(n - 1)) |
| |
| def test_sum_and_equal(): |
| tree = build(5) |
| tree2 = build(5) |
| tree2.right.right.right.value = 10 |
| assert calc_sum(tree) == 57 |
| assert calc_sum(tree2) == 65 |
| assert equal(tree, tree) |
| assert not equal(tree, tree2) |
| tree3 = build(4) |
| assert not equal(tree, tree3) |
| |
| [case testSimulateMypySingledispatch] |
| from functools import singledispatch |
| from mypy_extensions import trait |
| from typing import Iterator, Union, TypeVar, Any, List, Type |
| # based on use of singledispatch in stubtest.py |
| class Error: |
| def __init__(self, msg: str) -> None: |
| self.msg = msg |
| |
| @trait |
| class Node: pass |
| |
| class MypyFile(Node): pass |
| class TypeInfo(Node): pass |
| |
| |
| @trait |
| class SymbolNode(Node): pass |
| @trait |
| class Expression(Node): pass |
| class TypeVarLikeExpr(SymbolNode, Expression): pass |
| class TypeVarExpr(TypeVarLikeExpr): pass |
| class TypeAlias(SymbolNode): pass |
| |
| class Missing: pass |
| MISSING = Missing() |
| |
| T = TypeVar("T") |
| |
| MaybeMissing = Union[T, Missing] |
| |
| @singledispatch |
| def verify(stub: Node, a: MaybeMissing[Any], b: List[str]) -> Iterator[Error]: |
| yield Error('unknown node type') |
| |
| @verify.register(MypyFile) |
| def verify_mypyfile(stub: MypyFile, a: MaybeMissing[int], b: List[str]) -> Iterator[Error]: |
| if isinstance(a, Missing): |
| yield Error("shouldn't be missing") |
| return |
| if not isinstance(a, int): |
| # this check should be unnecessary because of the type signature and the previous check, |
| # but stubtest.py has this check |
| yield Error("should be an int") |
| return |
| yield from verify(TypeInfo(), str, ['abc', 'def']) |
| |
| @verify.register(TypeInfo) |
| def verify_typeinfo(stub: TypeInfo, a: MaybeMissing[Type[Any]], b: List[str]) -> Iterator[Error]: |
| yield Error('in TypeInfo') |
| yield Error('hello') |
| |
| @verify.register(TypeVarExpr) |
| def verify_typevarexpr(stub: TypeVarExpr, a: MaybeMissing[Any], b: List[str]) -> Iterator[Error]: |
| if False: |
| yield None |
| |
| def verify_list(stub, a, b) -> List[str]: |
| """Helper function that converts iterator of errors to list of messages""" |
| return list(err.msg for err in verify(stub, a, b)) |
| |
| def test_verify() -> None: |
| assert verify_list(TypeAlias(), 'a', ['a', 'b']) == ['unknown node type'] |
| assert verify_list(MypyFile(), MISSING, ['a', 'b']) == ["shouldn't be missing"] |
| assert verify_list(MypyFile(), 5, ['a', 'b']) == ['in TypeInfo', 'hello'] |
| assert verify_list(TypeInfo(), str, ['a', 'b']) == ['in TypeInfo', 'hello'] |
| assert verify_list(TypeVarExpr(), 'a', ['x', 'y']) == [] |
| |
| |
| [case testArgsInRegisteredImplNamedDifferentlyFromMainFunction] |
| from functools import singledispatch |
| |
| @singledispatch |
| def f(a) -> bool: |
| return False |
| |
| @f.register |
| def g(b: int) -> bool: |
| return True |
| |
| def test_singledispatch(): |
| assert f(5) |
| assert not f('a') |
| |
| [case testKeywordArguments] |
| from functools import singledispatch |
| |
| @singledispatch |
| def f(arg, *, kwarg: int = 0) -> int: |
| return kwarg + 10 |
| |
| @f.register |
| def g(arg: int, *, kwarg: int = 5) -> int: |
| return kwarg - 10 |
| |
| def test_keywords(): |
| assert f('a') == 10 |
| assert f('a', kwarg=3) == 13 |
| assert f('a', kwarg=7) == 17 |
| |
| assert f(1) == -5 |
| assert f(1, kwarg=4) == -6 |
| assert f(1, kwarg=6) == -4 |
| |
| [case testGeneratorAndMultipleTypesOfIterable] |
| from functools import singledispatch |
| from typing import * |
| |
| @singledispatch |
| def f(arg: Any) -> Iterable[int]: |
| yield 1 |
| |
| @f.register |
| def g(arg: str) -> Iterable[int]: |
| return [0] |
| |
| def test_iterables(): |
| assert f(1) != [1] |
| assert list(f(1)) == [1] |
| assert f('a') == [0] |
| |
| [case testRegisterUsedAtSameTimeAsOtherDecorators] |
| from functools import singledispatch |
| from typing import TypeVar |
| |
| class A: pass |
| class B: pass |
| |
| T = TypeVar('T') |
| |
| def decorator(f: T) -> T: |
| return f |
| |
| @singledispatch |
| def f(arg) -> int: |
| return 0 |
| |
| @f.register |
| @decorator |
| def h(arg: str) -> int: |
| return 2 |
| |
| def test_singledispatch(): |
| assert f(1) == 0 |
| assert f('a') == 2 |
| |
| [case testDecoratorModifiesFunction] |
| from functools import singledispatch |
| from typing import Callable, Any |
| |
| class A: pass |
| |
| def decorator(f: Callable[[Any], int]) -> Callable[[Any], int]: |
| def wrapper(x) -> int: |
| return f(x) * 7 |
| return wrapper |
| |
| @singledispatch |
| def f(arg) -> int: |
| return 10 |
| |
| @f.register |
| @decorator |
| def h(arg: str) -> int: |
| return 5 |
| |
| |
| def test_singledispatch(): |
| assert f('a') == 35 |
| assert f(A()) == 10 |
| |
| [case testMoreSpecificTypeBeforeLessSpecificType] |
| from functools import singledispatch |
| class A: pass |
| class B(A): pass |
| |
| @singledispatch |
| def f(arg) -> str: |
| return 'default' |
| |
| @f.register |
| def g(arg: B) -> str: |
| return 'b' |
| |
| @f.register |
| def h(arg: A) -> str: |
| return 'a' |
| |
| def test_singledispatch(): |
| assert f(B()) == 'b' |
| assert f(A()) == 'a' |
| assert f(5) == 'default' |
| |
| [case testMultipleRelatedClassesBeingRegistered] |
| from functools import singledispatch |
| |
| class A: pass |
| class B(A): pass |
| class C(B): pass |
| |
| @singledispatch |
| def f(arg) -> str: return 'default' |
| |
| @f.register |
| def _(arg: A) -> str: return 'a' |
| |
| @f.register |
| def _(arg: C) -> str: return 'c' |
| |
| @f.register |
| def _(arg: B) -> str: return 'b' |
| |
| def test_singledispatch(): |
| assert f(A()) == 'a' |
| assert f(B()) == 'b' |
| assert f(C()) == 'c' |
| assert f(1) == 'default' |
| |
| [case testRegisteredImplementationsInDifferentFiles] |
| from other_a import f, A, B, C |
| @f.register |
| def a(arg: A) -> int: |
| return 2 |
| |
| @f.register |
| def _(arg: C) -> int: |
| return 3 |
| |
| def test_singledispatch(): |
| assert f(B()) == 1 |
| assert f(A()) == 2 |
| assert f(C()) == 3 |
| assert f(1) == 0 |
| |
| [file other_a.py] |
| from functools import singledispatch |
| |
| class A: pass |
| class B(A): pass |
| class C(B): pass |
| |
| @singledispatch |
| def f(arg) -> int: |
| return 0 |
| |
| @f.register |
| def g(arg: B) -> int: |
| return 1 |
| |
| [case testOrderCanOnlyBeDeterminedFromMRONotIsinstanceChecks] |
| from mypy_extensions import trait |
| from functools import singledispatch |
| |
| @trait |
| class A: pass |
| @trait |
| class B: pass |
| class AB(A, B): pass |
| class BA(B, A): pass |
| |
| @singledispatch |
| def f(arg) -> str: |
| return "default" |
| pass |
| |
| @f.register |
| def fa(arg: A) -> str: |
| return "a" |
| |
| @f.register |
| def fb(arg: B) -> str: |
| return "b" |
| |
| def test_singledispatch(): |
| assert f(AB()) == "a" |
| assert f(BA()) == "b" |
| |
| [case testCallingFunctionBeforeAllImplementationsRegistered] |
| from functools import singledispatch |
| |
| class A: pass |
| class B(A): pass |
| |
| @singledispatch |
| def f(arg) -> str: |
| return 'default' |
| |
| assert f(A()) == 'default' |
| assert f(B()) == 'default' |
| assert f(1) == 'default' |
| |
| @f.register |
| def g(arg: A) -> str: |
| return 'a' |
| |
| assert f(A()) == 'a' |
| assert f(B()) == 'a' |
| assert f(1) == 'default' |
| |
| @f.register |
| def _(arg: B) -> str: |
| return 'b' |
| |
| assert f(A()) == 'a' |
| assert f(B()) == 'b' |
| assert f(1) == 'default' |
| |
| |
| [case testDynamicallyRegisteringFunctionFromInterpretedCode] |
| from functools import singledispatch |
| |
| class A: pass |
| class B(A): pass |
| class C(B): pass |
| class D(C): pass |
| |
| @singledispatch |
| def f(arg) -> str: |
| return "default" |
| |
| @f.register |
| def _(arg: B) -> str: |
| return 'b' |
| |
| [file register_impl.py] |
| from native import f, A, B, C |
| |
| @f.register(A) |
| def a(arg) -> str: |
| return 'a' |
| |
| @f.register |
| def c(arg: C) -> str: |
| return 'c' |
| |
| [file driver.py] |
| from native import f, A, B, C |
| from register_impl import a, c |
| # We need a custom driver here because register_impl has to be run before we test this (so that the |
| # additional implementations are registered) |
| assert f(C()) == 'c' |
| assert f(A()) == 'a' |
| assert f(B()) == 'b' |
| assert a(C()) == 'a' |
| assert c(A()) == 'c' |
| |
| [case testMalformedDynamicRegisterCall] |
| from functools import singledispatch |
| |
| @singledispatch |
| def f(arg) -> None: |
| pass |
| [file register.py] |
| from native import f |
| from testutil import assertRaises |
| |
| with assertRaises(TypeError, 'Invalid first argument to `register()`'): |
| @f.register |
| def _(): |
| pass |
| |
| [file driver.py] |
| import register |
| |
| [case testCacheClearedWhenNewFunctionRegistered] |
| from functools import singledispatch |
| |
| @singledispatch |
| def f(arg) -> str: |
| return 'default' |
| |
| [file register.py] |
| from native import f |
| class A: pass |
| class B: pass |
| class C: pass |
| |
| # annotated function |
| assert f(A()) == 'default' |
| @f.register |
| def _(arg: A) -> str: |
| return 'a' |
| assert f(A()) == 'a' |
| |
| # type passed as argument |
| assert f(B()) == 'default' |
| @f.register(B) |
| def _(arg: B) -> str: |
| return 'b' |
| assert f(B()) == 'b' |
| |
| # 2 argument form |
| assert f(C()) == 'default' |
| def c(arg) -> str: |
| return 'c' |
| f.register(C, c) |
| assert f(C()) == 'c' |
| |
| |
| [file driver.py] |
| import register |