blob: 05f324039c985839274c1a55487755a4dc74093d [file] [log] [blame]
"""Mypy type checker."""
from __future__ import annotations
import itertools
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from typing import (
AbstractSet,
Callable,
Dict,
Final,
Generic,
Iterable,
Iterator,
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias as _TypeAlias
import mypy.checkexpr
from mypy import errorcodes as codes, message_registry, nodes, operators
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
from mypy.checkmember import (
MemberContext,
analyze_decorator_or_funcbase_access,
analyze_descriptor_access,
analyze_member_access,
type_object_type,
)
from mypy.checkpattern import PatternChecker
from mypy.constraints import SUPERTYPE_OF
from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values
from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode
from mypy.errors import Errors, ErrorWatcher, report_internal_error
from mypy.expandtype import expand_self_type, expand_type, expand_type_by_instance
from mypy.join import join_types
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import is_overlapping_erased_types, is_overlapping_types
from mypy.message_registry import ErrorMessage
from mypy.messages import (
SUGGESTED_TEST_FIXTURES,
MessageBuilder,
append_invariance_notes,
format_type,
format_type_bare,
format_type_distinctly,
make_inferred_type_note,
pretty_seq,
)
from mypy.mro import MroError, calculate_mro
from mypy.nodes import (
ARG_NAMED,
ARG_POS,
ARG_STAR,
CONTRAVARIANT,
COVARIANT,
FUNC_NO_INFO,
GDEF,
IMPLICITLY_ABSTRACT,
INVARIANT,
IS_ABSTRACT,
LDEF,
LITERAL_TYPE,
MDEF,
NOT_ABSTRACT,
AssertStmt,
AssignmentExpr,
AssignmentStmt,
Block,
BreakStmt,
BytesExpr,
CallExpr,
ClassDef,
ComparisonExpr,
Context,
ContinueStmt,
Decorator,
DelStmt,
EllipsisExpr,
Expression,
ExpressionStmt,
FloatExpr,
ForStmt,
FuncBase,
FuncDef,
FuncItem,
IfStmt,
Import,
ImportAll,
ImportBase,
ImportFrom,
IndexExpr,
IntExpr,
LambdaExpr,
ListExpr,
Lvalue,
MatchStmt,
MemberExpr,
MypyFile,
NameExpr,
Node,
OperatorAssignmentStmt,
OpExpr,
OverloadedFuncDef,
PassStmt,
PromoteExpr,
RaiseStmt,
RefExpr,
ReturnStmt,
StarExpr,
Statement,
StrExpr,
SymbolNode,
SymbolTable,
SymbolTableNode,
TempNode,
TryStmt,
TupleExpr,
TypeAlias,
TypeInfo,
TypeVarExpr,
UnaryExpr,
Var,
WhileStmt,
WithStmt,
YieldExpr,
is_final_node,
)
from mypy.options import Options
from mypy.patterns import AsPattern, StarredPattern
from mypy.plugin import CheckerPluginInterface, Plugin
from mypy.plugins import dataclasses as dataclasses_plugin
from mypy.scope import Scope
from mypy.semanal import is_trivial_body, refers_to_fullname, set_callable_name
from mypy.semanal_enum import ENUM_BASES, ENUM_SPECIAL_PROPS
from mypy.sharedparse import BINARY_MAGIC_METHODS
from mypy.state import state
from mypy.subtypes import (
find_member,
is_callable_compatible,
is_equivalent,
is_more_precise,
is_proper_subtype,
is_same_type,
is_subtype,
restrict_subtype_away,
unify_generic_callable,
)
from mypy.traverser import TraverserVisitor, all_return_statements, has_return_statement
from mypy.treetransform import TransformVisitor
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type
from mypy.typeops import (
bind_self,
coerce_to_literal,
custom_special_method,
erase_def_to_union_or_bound,
erase_to_bound,
erase_to_union_or_bound,
false_only,
fixup_partial_type,
function_type,
get_type_vars,
is_literal_type_like,
is_singleton_type,
make_simplified_union,
map_type_from_supertype,
true_only,
try_expanding_sum_type_to_union,
try_getting_int_literals_from_type,
try_getting_str_literals,
try_getting_str_literals_from_type,
tuple_fallback,
)
from mypy.types import (
ANY_STRATEGY,
MYPYC_NATIVE_INT_NAMES,
OVERLOAD_NAMES,
AnyType,
BoolTypeQuery,
CallableType,
DeletedType,
ErasedType,
FunctionLike,
Instance,
LiteralType,
NoneType,
Overloaded,
PartialType,
ProperType,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeGuardedType,
TypeOfAny,
TypeTranslator,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarType,
UnboundType,
UninhabitedType,
UnionType,
flatten_nested_unions,
get_proper_type,
get_proper_types,
is_literal_type,
is_named_instance,
)
from mypy.types_utils import is_overlapping_none, remove_optional, store_argument_type, strip_type
from mypy.typetraverser import TypeTraverserVisitor
from mypy.typevars import fill_typevars, fill_typevars_with_any, has_no_typevars
from mypy.util import is_dunder, is_sunder, is_typeshed_file
from mypy.visitor import NodeVisitor
T = TypeVar("T")
DEFAULT_LAST_PASS: Final = 1 # Pass numbers start at 0
DeferredNodeType: _TypeAlias = Union[FuncDef, LambdaExpr, OverloadedFuncDef, Decorator]
FineGrainedDeferredNodeType: _TypeAlias = Union[FuncDef, MypyFile, OverloadedFuncDef]
# A node which is postponed to be processed during the next pass.
# In normal mode one can defer functions and methods (also decorated and/or overloaded)
# and lambda expressions. Nested functions can't be deferred -- only top-level functions
# and methods of classes not defined within a function can be deferred.
class DeferredNode(NamedTuple):
node: DeferredNodeType
# And its TypeInfo (for semantic analysis self type handling
active_typeinfo: TypeInfo | None
# Same as above, but for fine-grained mode targets. Only top-level functions/methods
# and module top levels are allowed as such.
class FineGrainedDeferredNode(NamedTuple):
node: FineGrainedDeferredNodeType
active_typeinfo: TypeInfo | None
# Data structure returned by find_isinstance_check representing
# information learned from the truth or falsehood of a condition. The
# dict maps nodes representing expressions like 'a[0].x' to their
# refined types under the assumption that the condition has a
# particular truth value. A value of None means that the condition can
# never have that truth value.
# NB: The keys of this dict are nodes in the original source program,
# which are compared by reference equality--effectively, being *the
# same* expression of the program, not just two identical expressions
# (such as two references to the same variable). TODO: it would
# probably be better to have the dict keyed by the nodes' literal_hash
# field instead.
TypeMap: _TypeAlias = Optional[Dict[Expression, Type]]
# An object that represents either a precise type or a type with an upper bound;
# it is important for correct type inference with isinstance.
class TypeRange(NamedTuple):
item: Type
is_upper_bound: bool # False => precise type
# Keeps track of partial types in a single scope. In fine-grained incremental
# mode partial types initially defined at the top level cannot be completed in
# a function, and we use the 'is_function' attribute to enforce this.
class PartialTypeScope(NamedTuple):
map: dict[Var, Context]
is_function: bool
is_local: bool
class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
"""Mypy type checker.
Type check mypy source files that have been semantically analyzed.
You must create a separate instance for each source file.
"""
# Are we type checking a stub?
is_stub = False
# Error message reporter
errors: Errors
# Utility for generating messages
msg: MessageBuilder
# Types of type checked nodes. The first item is the "master" type
# map that will store the final, exported types. Additional items
# are temporary type maps used during type inference, and these
# will be eventually popped and either discarded or merged into
# the master type map.
#
# Avoid accessing this directly, but prefer the lookup_type(),
# has_type() etc. helpers instead.
_type_maps: list[dict[Expression, Type]]
# Helper for managing conditional types
binder: ConditionalTypeBinder
# Helper for type checking expressions
expr_checker: mypy.checkexpr.ExpressionChecker
pattern_checker: PatternChecker
tscope: Scope
scope: CheckerScope
# Stack of function return types
return_types: list[Type]
# Flags; true for dynamically typed functions
dynamic_funcs: list[bool]
# Stack of collections of variables with partial types
partial_types: list[PartialTypeScope]
# Vars for which partial type errors are already reported
# (to avoid logically duplicate errors with different error context).
partial_reported: set[Var]
globals: SymbolTable
modules: dict[str, MypyFile]
# Nodes that couldn't be checked because some types weren't available. We'll run
# another pass and try these again.
deferred_nodes: list[DeferredNode]
# Type checking pass number (0 = first pass)
pass_num = 0
# Last pass number to take
last_pass = DEFAULT_LAST_PASS
# Have we deferred the current function? If yes, don't infer additional
# types during this pass within the function.
current_node_deferred = False
# Is this file a typeshed stub?
is_typeshed_stub = False
options: Options
# Used for collecting inferred attribute types so that they can be checked
# for consistency.
inferred_attribute_types: dict[Var, Type] | None = None
# Don't infer partial None types if we are processing assignment from Union
no_partial_types: bool = False
# The set of all dependencies (suppressed or not) that this module accesses, either
# directly or indirectly.
module_refs: set[str]
# A map from variable nodes to a snapshot of the frame ids of the
# frames that were active when the variable was declared. This can
# be used to determine nearest common ancestor frame of a variable's
# declaration and the current frame, which lets us determine if it
# was declared in a different branch of the same `if` statement
# (if that frame is a conditional_frame).
var_decl_frames: dict[Var, set[int]]
# Plugin that provides special type checking rules for specific library
# functions such as open(), etc.
plugin: Plugin
def __init__(
self,
errors: Errors,
modules: dict[str, MypyFile],
options: Options,
tree: MypyFile,
path: str,
plugin: Plugin,
per_line_checking_time_ns: dict[int, int],
) -> None:
"""Construct a type checker.
Use errors to report type check errors.
"""
self.errors = errors
self.modules = modules
self.options = options
self.tree = tree
self.path = path
self.msg = MessageBuilder(errors, modules)
self.plugin = plugin
self.tscope = Scope()
self.scope = CheckerScope(tree)
self.binder = ConditionalTypeBinder()
self.globals = tree.names
self.return_types = []
self.dynamic_funcs = []
self.partial_types = []
self.partial_reported = set()
self.var_decl_frames = {}
self.deferred_nodes = []
self._type_maps = [{}]
self.module_refs = set()
self.pass_num = 0
self.current_node_deferred = False
self.is_stub = tree.is_stub
self.is_typeshed_stub = is_typeshed_file(options.abs_custom_typeshed_dir, path)
self.inferred_attribute_types = None
# If True, process function definitions. If False, don't. This is used
# for processing module top levels in fine-grained incremental mode.
self.recurse_into_functions = True
# This internal flag is used to track whether we a currently type-checking
# a final declaration (assignment), so that some errors should be suppressed.
# Should not be set manually, use get_final_context/enter_final_context instead.
# NOTE: we use the context manager to avoid "threading" an additional `is_final_def`
# argument through various `checker` and `checkmember` functions.
self._is_final_def = False
# This flag is set when we run type-check or attribute access check for the purpose
# of giving a note on possibly missing "await". It is used to avoid infinite recursion.
self.checking_missing_await = False
# While this is True, allow passing an abstract class where Type[T] is expected.
# although this is technically unsafe, this is desirable in some context, for
# example when type-checking class decorators.
self.allow_abstract_call = False
# Child checker objects for specific AST node types
self.expr_checker = mypy.checkexpr.ExpressionChecker(
self, self.msg, self.plugin, per_line_checking_time_ns
)
self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options)
@property
def type_context(self) -> list[Type | None]:
return self.expr_checker.type_context
def reset(self) -> None:
"""Cleanup stale state that might be left over from a typechecking run.
This allows us to reuse TypeChecker objects in fine-grained
incremental mode.
"""
# TODO: verify this is still actually worth it over creating new checkers
self.partial_reported.clear()
self.module_refs.clear()
self.binder = ConditionalTypeBinder()
self._type_maps[1:] = []
self._type_maps[0].clear()
self.temp_type_map = None
self.expr_checker.reset()
assert self.inferred_attribute_types is None
assert self.partial_types == []
assert self.deferred_nodes == []
assert len(self.scope.stack) == 1
assert self.partial_types == []
def check_first_pass(self) -> None:
"""Type check the entire file, but defer functions with unresolved references.
Unresolved references are forward references to variables
whose types haven't been inferred yet. They may occur later
in the same file or in a different file that's being processed
later (usually due to an import cycle).
Deferred functions will be processed by check_second_pass().
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
self.errors.set_file(
self.path, self.tree.fullname, scope=self.tscope, options=self.options
)
with self.tscope.module_scope(self.tree.fullname):
with self.enter_partial_types(), self.binder.top_frame_context():
for d in self.tree.defs:
if self.binder.is_unreachable():
if not self.should_report_unreachable_issues():
break
if not self.is_noop_for_reachability(d):
self.msg.unreachable_statement(d)
break
else:
self.accept(d)
assert not self.current_node_deferred
all_ = self.globals.get("__all__")
if all_ is not None and all_.type is not None:
all_node = all_.node
assert all_node is not None
seq_str = self.named_generic_type(
"typing.Sequence", [self.named_type("builtins.str")]
)
if not is_subtype(all_.type, seq_str):
str_seq_s, all_s = format_type_distinctly(
seq_str, all_.type, options=self.options
)
self.fail(
message_registry.ALL_MUST_BE_SEQ_STR.format(str_seq_s, all_s), all_node
)
def check_second_pass(
self, todo: Sequence[DeferredNode | FineGrainedDeferredNode] | None = None
) -> bool:
"""Run second or following pass of type checking.
This goes through deferred nodes, returning True if there were any.
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional):
if not todo and not self.deferred_nodes:
return False
self.errors.set_file(
self.path, self.tree.fullname, scope=self.tscope, options=self.options
)
with self.tscope.module_scope(self.tree.fullname):
self.pass_num += 1
if not todo:
todo = self.deferred_nodes
else:
assert not self.deferred_nodes
self.deferred_nodes = []
done: set[DeferredNodeType | FineGrainedDeferredNodeType] = set()
for node, active_typeinfo in todo:
if node in done:
continue
# This is useful for debugging:
# print("XXX in pass %d, class %s, function %s" %
# (self.pass_num, type_name, node.fullname or node.name))
done.add(node)
with self.tscope.class_scope(
active_typeinfo
) if active_typeinfo else nullcontext():
with self.scope.push_class(
active_typeinfo
) if active_typeinfo else nullcontext():
self.check_partial(node)
return True
def check_partial(self, node: DeferredNodeType | FineGrainedDeferredNodeType) -> None:
if isinstance(node, MypyFile):
self.check_top_level(node)
else:
self.recurse_into_functions = True
if isinstance(node, LambdaExpr):
self.expr_checker.accept(node)
else:
self.accept(node)
def check_top_level(self, node: MypyFile) -> None:
"""Check only the top-level of a module, skipping function definitions."""
self.recurse_into_functions = False
with self.enter_partial_types():
with self.binder.top_frame_context():
for d in node.defs:
d.accept(self)
assert not self.current_node_deferred
# TODO: Handle __all__
def defer_node(self, node: DeferredNodeType, enclosing_class: TypeInfo | None) -> None:
"""Defer a node for processing during next type-checking pass.
Args:
node: function/method being deferred
enclosing_class: for methods, the class where the method is defined
NOTE: this can't handle nested functions/methods.
"""
# We don't freeze the entire scope since only top-level functions and methods
# can be deferred. Only module/class level scope information is needed.
# Module-level scope information is preserved in the TypeChecker instance.
self.deferred_nodes.append(DeferredNode(node, enclosing_class))
def handle_cannot_determine_type(self, name: str, context: Context) -> None:
node = self.scope.top_non_lambda_function()
if self.pass_num < self.last_pass and isinstance(node, FuncDef):
# Don't report an error yet. Just defer. Note that we don't defer
# lambdas because they are coupled to the surrounding function
# through the binder and the inferred type of the lambda, so it
# would get messy.
enclosing_class = self.scope.enclosing_class()
self.defer_node(node, enclosing_class)
# Set a marker so that we won't infer additional types in this
# function. Any inferred types could be bogus, because there's at
# least one type that we don't know.
self.current_node_deferred = True
else:
self.msg.cannot_determine_type(name, context)
def accept(self, stmt: Statement) -> None:
"""Type check a node in the given type context."""
try:
stmt.accept(self)
except Exception as err:
report_internal_error(err, self.errors.file, stmt.line, self.errors, self.options)
def accept_loop(
self,
body: Statement,
else_body: Statement | None = None,
*,
exit_condition: Expression | None = None,
) -> None:
"""Repeatedly type check a loop body until the frame doesn't change.
If exit_condition is set, assume it must be False on exit from the loop.
Then check the else_body.
"""
# The outer frame accumulates the results of all iterations
with self.binder.frame_context(can_skip=False, conditional_frame=True):
while True:
with self.binder.frame_context(can_skip=True, break_frame=2, continue_frame=1):
self.accept(body)
if not self.binder.last_pop_changed:
break
if exit_condition:
_, else_map = self.find_isinstance_check(exit_condition)
self.push_type_map(else_map)
if else_body:
self.accept(else_body)
#
# Definitions
#
def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
if not self.recurse_into_functions:
return
with self.tscope.function_scope(defn):
self._visit_overloaded_func_def(defn)
def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
num_abstract = 0
if not defn.items:
# In this case we have already complained about none of these being
# valid overloads.
return None
if len(defn.items) == 1:
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, defn)
if defn.is_property:
# HACK: Infer the type of the property.
assert isinstance(defn.items[0], Decorator)
self.visit_decorator(defn.items[0])
for fdef in defn.items:
assert isinstance(fdef, Decorator)
if defn.is_property:
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
else:
# Perform full check for real overloads to infer type of all decorated
# overload variants.
self.visit_decorator_inner(fdef, allow_empty=True)
if fdef.func.abstract_status in (IS_ABSTRACT, IMPLICITLY_ABSTRACT):
num_abstract += 1
if num_abstract not in (0, len(defn.items)):
self.fail(message_registry.INCONSISTENT_ABSTRACT_OVERLOAD, defn)
if defn.impl:
defn.impl.accept(self)
if not defn.is_property:
self.check_overlapping_overloads(defn)
if defn.type is None:
item_types = []
for item in defn.items:
assert isinstance(item, Decorator)
item_type = self.extract_callable_type(item.var.type, item)
if item_type is not None:
item_types.append(item_type)
if item_types:
defn.type = Overloaded(item_types)
# Check override validity after we analyzed current definition.
if defn.info:
found_method_base_classes = self.check_method_override(defn)
if (
defn.is_explicit_override
and not found_method_base_classes
and found_method_base_classes is not None
):
self.msg.no_overridable_method(defn.name, defn)
self.check_explicit_override_decorator(defn, found_method_base_classes, defn.impl)
self.check_inplace_operator_method(defn)
return None
def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> CallableType | None:
"""Get type as seen by an overload item caller."""
inner_type = get_proper_type(inner_type)
outer_type: CallableType | None = None
if inner_type is not None and not isinstance(inner_type, AnyType):
if isinstance(inner_type, CallableType):
outer_type = inner_type
elif isinstance(inner_type, Instance):
inner_call = get_proper_type(
analyze_member_access(
name="__call__",
typ=inner_type,
context=ctx,
is_lvalue=False,
is_super=False,
is_operator=True,
msg=self.msg,
original_type=inner_type,
chk=self,
)
)
if isinstance(inner_call, CallableType):
outer_type = inner_call
if outer_type is None:
self.msg.not_callable(inner_type, ctx)
return outer_type
def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# At this point we should have set the impl already, and all remaining
# items are decorators
if self.msg.errors.file in self.msg.errors.ignored_files:
# This is a little hacky, however, the quadratic check here is really expensive, this
# method has no side effects, so we should skip it if we aren't going to report
# anything. In some other places we swallow errors in stubs, but this error is very
# useful for stubs!
return
# Compute some info about the implementation (if it exists) for use below
impl_type: CallableType | None = None
if defn.impl:
if isinstance(defn.impl, FuncDef):
inner_type: Type | None = defn.impl.type
elif isinstance(defn.impl, Decorator):
inner_type = defn.impl.var.type
else:
assert False, "Impl isn't the right type"
# This can happen if we've got an overload with a different
# decorator or if the implementation is untyped -- we gave up on the types.
impl_type = self.extract_callable_type(inner_type, defn.impl)
is_descriptor_get = defn.info and defn.name == "__get__"
for i, item in enumerate(defn.items):
assert isinstance(item, Decorator)
sig1 = self.extract_callable_type(item.var.type, item)
if sig1 is None:
continue
for j, item2 in enumerate(defn.items[i + 1 :]):
assert isinstance(item2, Decorator)
sig2 = self.extract_callable_type(item2.var.type, item2)
if sig2 is None:
continue
if not are_argument_counts_overlapping(sig1, sig2):
continue
if overload_can_never_match(sig1, sig2):
self.msg.overloaded_signature_will_never_match(i + 1, i + j + 2, item2.func)
elif not is_descriptor_get:
# Note: we force mypy to check overload signatures in strict-optional mode
# so we don't incorrectly report errors when a user tries typing an overload
# that happens to have a 'if the argument is None' fallback.
#
# For example, the following is fine in strict-optional mode but would throw
# the unsafe overlap error when strict-optional is disabled:
#
# @overload
# def foo(x: None) -> int: ...
# @overload
# def foo(x: str) -> str: ...
#
# See Python 2's map function for a concrete example of this kind of overload.
current_class = self.scope.active_class()
type_vars = current_class.defn.type_vars if current_class else []
with state.strict_optional_set(True):
if is_unsafe_overlapping_overload_signatures(sig1, sig2, type_vars):
self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func)
if impl_type is not None:
assert defn.impl is not None
# We perform a unification step that's very similar to what
# 'is_callable_compatible' would have done if we had set
# 'unify_generics' to True -- the only difference is that
# we check and see if the impl_type's return value is a
# *supertype* of the overload alternative, not a *subtype*.
#
# This is to match the direction the implementation's return
# needs to be compatible in.
if impl_type.variables:
impl: CallableType | None = unify_generic_callable(
# Normalize both before unifying
impl_type.with_unpacked_kwargs(),
sig1.with_unpacked_kwargs(),
ignore_return=False,
return_constraint_direction=SUPERTYPE_OF,
)
if impl is None:
self.msg.overloaded_signatures_typevar_specific(i + 1, defn.impl)
continue
else:
impl = impl_type
# Prevent extra noise from inconsistent use of @classmethod by copying
# the first arg from the method being checked against.
if sig1.arg_types and defn.info:
impl = impl.copy_modified(arg_types=[sig1.arg_types[0]] + impl.arg_types[1:])
# Is the overload alternative's arguments subtypes of the implementation's?
if not is_callable_compatible(
impl, sig1, is_compat=is_subtype, ignore_return=True
):
self.msg.overloaded_signatures_arg_specific(i + 1, defn.impl)
# Is the overload alternative's return type a subtype of the implementation's?
if not (
is_subtype(sig1.ret_type, impl.ret_type)
or is_subtype(impl.ret_type, sig1.ret_type)
):
self.msg.overloaded_signatures_ret_specific(i + 1, defn.impl)
# Here's the scoop about generators and coroutines.
#
# There are two kinds of generators: classic generators (functions
# with `yield` or `yield from` in the body) and coroutines
# (functions declared with `async def`). The latter are specified
# in PEP 492 and only available in Python >= 3.5.
#
# Classic generators can be parameterized with three types:
# - ty is the Yield type (the type of y in `yield y`)
# - tc is the type reCeived by yield (the type of c in `c = yield`).
# - tr is the Return type (the type of r in `return r`)
#
# A classic generator must define a return type that's either
# `Generator[ty, tc, tr]`, Iterator[ty], or Iterable[ty] (or
# object or Any). If tc/tr are not given, both are None.
#
# A coroutine must define a return type corresponding to tr; the
# other two are unconstrained. The "external" return type (seen
# by the caller) is Awaitable[tr].
#
# In addition, there's the synthetic type AwaitableGenerator: it
# inherits from both Awaitable and Generator and can be used both
# in `yield from` and in `await`. This type is set automatically
# for functions decorated with `@types.coroutine` or
# `@asyncio.coroutine`. Its single parameter corresponds to tr.
#
# PEP 525 adds a new type, the asynchronous generator, which was
# first released in Python 3.6. Async generators are `async def`
# functions that can also `yield` values. They can be parameterized
# with two types, ty and tc, because they cannot return a value.
#
# There are several useful methods, each taking a type t and a
# flag c indicating whether it's for a generator or coroutine:
#
# - is_generator_return_type(t, c) returns whether t is a Generator,
# Iterator, Iterable (if not c), or Awaitable (if c), or
# AwaitableGenerator (regardless of c).
# - is_async_generator_return_type(t) returns whether t is an
# AsyncGenerator.
# - get_generator_yield_type(t, c) returns ty.
# - get_generator_receive_type(t, c) returns tc.
# - get_generator_return_type(t, c) returns tr.
def is_generator_return_type(self, typ: Type, is_coroutine: bool) -> bool:
"""Is `typ` a valid type for a generator/coroutine?
True if `typ` is a *supertype* of Generator or Awaitable.
Also true it it's *exactly* AwaitableGenerator (modulo type parameters).
"""
typ = get_proper_type(typ)
if is_coroutine:
# This means we're in Python 3.5 or later.
at = self.named_generic_type("typing.Awaitable", [AnyType(TypeOfAny.special_form)])
if is_subtype(at, typ):
return True
else:
any_type = AnyType(TypeOfAny.special_form)
gt = self.named_generic_type("typing.Generator", [any_type, any_type, any_type])
if is_subtype(gt, typ):
return True
return isinstance(typ, Instance) and typ.type.fullname == "typing.AwaitableGenerator"
def is_async_generator_return_type(self, typ: Type) -> bool:
"""Is `typ` a valid type for an async generator?
True if `typ` is a supertype of AsyncGenerator.
"""
try:
any_type = AnyType(TypeOfAny.special_form)
agt = self.named_generic_type("typing.AsyncGenerator", [any_type, any_type])
except KeyError:
# we're running on a version of typing that doesn't have AsyncGenerator yet
return False
return is_subtype(agt, typ)
def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Type:
"""Given the declared return type of a generator (t), return the type it yields (ty)."""
return_type = get_proper_type(return_type)
if isinstance(return_type, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=return_type)
elif isinstance(return_type, UnionType):
return make_simplified_union(
[self.get_generator_yield_type(item, is_coroutine) for item in return_type.items]
)
elif not self.is_generator_return_type(
return_type, is_coroutine
) and not self.is_async_generator_return_type(return_type):
# If the function doesn't have a proper Generator (or
# Awaitable) return type, anything is permissible.
return AnyType(TypeOfAny.from_error)
elif not isinstance(return_type, Instance):
# Same as above, but written as a separate branch so the typechecker can understand.
return AnyType(TypeOfAny.from_error)
elif return_type.type.fullname == "typing.Awaitable":
# Awaitable: ty is Any.
return AnyType(TypeOfAny.special_form)
elif return_type.args:
# AwaitableGenerator, Generator, AsyncGenerator, Iterator, or Iterable; ty is args[0].
ret_type = return_type.args[0]
# TODO not best fix, better have dedicated yield token
return ret_type
else:
# If the function's declared supertype of Generator has no type
# parameters (i.e. is `object`), then the yielded values can't
# be accessed so any type is acceptable. IOW, ty is Any.
# (However, see https://github.com/python/mypy/issues/1933)
return AnyType(TypeOfAny.special_form)
def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> Type:
"""Given a declared generator return type (t), return the type its yield receives (tc)."""
return_type = get_proper_type(return_type)
if isinstance(return_type, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=return_type)
elif isinstance(return_type, UnionType):
return make_simplified_union(
[self.get_generator_receive_type(item, is_coroutine) for item in return_type.items]
)
elif not self.is_generator_return_type(
return_type, is_coroutine
) and not self.is_async_generator_return_type(return_type):
# If the function doesn't have a proper Generator (or
# Awaitable) return type, anything is permissible.
return AnyType(TypeOfAny.from_error)
elif not isinstance(return_type, Instance):
# Same as above, but written as a separate branch so the typechecker can understand.
return AnyType(TypeOfAny.from_error)
elif return_type.type.fullname == "typing.Awaitable":
# Awaitable, AwaitableGenerator: tc is Any.
return AnyType(TypeOfAny.special_form)
elif (
return_type.type.fullname in ("typing.Generator", "typing.AwaitableGenerator")
and len(return_type.args) >= 3
):
# Generator: tc is args[1].
return return_type.args[1]
elif return_type.type.fullname == "typing.AsyncGenerator" and len(return_type.args) >= 2:
return return_type.args[1]
else:
# `return_type` is a supertype of Generator, so callers won't be able to send it
# values. IOW, tc is None.
return NoneType()
def get_coroutine_return_type(self, return_type: Type) -> Type:
return_type = get_proper_type(return_type)
if isinstance(return_type, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=return_type)
assert isinstance(return_type, Instance), "Should only be called on coroutine functions."
# Note: return type is the 3rd type parameter of Coroutine.
return return_type.args[2]
def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Type:
"""Given the declared return type of a generator (t), return the type it returns (tr)."""
return_type = get_proper_type(return_type)
if isinstance(return_type, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=return_type)
elif isinstance(return_type, UnionType):
return make_simplified_union(
[self.get_generator_return_type(item, is_coroutine) for item in return_type.items]
)
elif not self.is_generator_return_type(return_type, is_coroutine):
# If the function doesn't have a proper Generator (or
# Awaitable) return type, anything is permissible.
return AnyType(TypeOfAny.from_error)
elif not isinstance(return_type, Instance):
# Same as above, but written as a separate branch so the typechecker can understand.
return AnyType(TypeOfAny.from_error)
elif return_type.type.fullname == "typing.Awaitable" and len(return_type.args) == 1:
# Awaitable: tr is args[0].
return return_type.args[0]
elif (
return_type.type.fullname in ("typing.Generator", "typing.AwaitableGenerator")
and len(return_type.args) >= 3
):
# AwaitableGenerator, Generator: tr is args[2].
return return_type.args[2]
else:
# Supertype of Generator (Iterator, Iterable, object): tr is any.
return AnyType(TypeOfAny.special_form)
def visit_func_def(self, defn: FuncDef) -> None:
if not self.recurse_into_functions:
return
with self.tscope.function_scope(defn):
self._visit_func_def(defn)
def _visit_func_def(self, defn: FuncDef) -> None:
"""Type check a function definition."""
self.check_func_item(defn, name=defn.name)
if defn.info:
if not defn.is_dynamic() and not defn.is_overload and not defn.is_decorated:
# If the definition is the implementation for an
# overload, the legality of the override has already
# been typechecked, and decorated methods will be
# checked when the decorator is.
found_method_base_classes = self.check_method_override(defn)
self.check_explicit_override_decorator(defn, found_method_base_classes)
self.check_inplace_operator_method(defn)
if defn.original_def:
# Override previous definition.
new_type = self.function_type(defn)
if isinstance(defn.original_def, FuncDef):
# Function definition overrides function definition.
old_type = self.function_type(defn.original_def)
if not is_same_type(new_type, old_type):
self.msg.incompatible_conditional_function_def(defn, old_type, new_type)
else:
# Function definition overrides a variable initialized via assignment or a
# decorated function.
orig_type = defn.original_def.type
if orig_type is None:
# If other branch is unreachable, we don't type check it and so we might
# not have a type for the original definition
return
if isinstance(orig_type, PartialType):
if orig_type.type is None:
# Ah this is a partial type. Give it the type of the function.
orig_def = defn.original_def
if isinstance(orig_def, Decorator):
var = orig_def.var
else:
var = orig_def
partial_types = self.find_partial_types(var)
if partial_types is not None:
var.type = new_type
del partial_types[var]
else:
# Trying to redefine something like partial empty list as function.
self.fail(message_registry.INCOMPATIBLE_REDEFINITION, defn)
else:
name_expr = NameExpr(defn.name)
name_expr.node = defn.original_def
self.binder.assign_type(name_expr, new_type, orig_type)
self.check_subtype(
new_type,
orig_type,
defn,
message_registry.INCOMPATIBLE_REDEFINITION,
"redefinition with type",
"original type",
)
def check_func_item(
self,
defn: FuncItem,
type_override: CallableType | None = None,
name: str | None = None,
allow_empty: bool = False,
) -> None:
"""Type check a function.
If type_override is provided, use it as the function type.
"""
self.dynamic_funcs.append(defn.is_dynamic() and not type_override)
with self.enter_partial_types(is_function=True):
typ = self.function_type(defn)
if type_override:
typ = type_override.copy_modified(line=typ.line, column=typ.column)
if isinstance(typ, CallableType):
with self.enter_attribute_inference_context():
self.check_func_def(defn, typ, name, allow_empty)
else:
raise RuntimeError("Not supported")
self.dynamic_funcs.pop()
self.current_node_deferred = False
if name == "__exit__":
self.check__exit__return_type(defn)
if name == "__post_init__":
if dataclasses_plugin.is_processed_dataclass(defn.info):
dataclasses_plugin.check_post_init(self, defn, defn.info)
@contextmanager
def enter_attribute_inference_context(self) -> Iterator[None]:
old_types = self.inferred_attribute_types
self.inferred_attribute_types = {}
yield None
self.inferred_attribute_types = old_types
def check_func_def(
self, defn: FuncItem, typ: CallableType, name: str | None, allow_empty: bool = False
) -> None:
"""Type check a function definition."""
# Expand type variables with value restrictions to ordinary types.
expanded = self.expand_typevars(defn, typ)
original_typ = typ
for item, typ in expanded:
old_binder = self.binder
self.binder = ConditionalTypeBinder()
with self.binder.top_frame_context():
defn.expanded.append(item)
# We may be checking a function definition or an anonymous
# function. In the first case, set up another reference with the
# precise type.
if isinstance(item, FuncDef):
fdef = item
# Check if __init__ has an invalid return type.
if (
fdef.info
and fdef.name in ("__init__", "__init_subclass__")
and not isinstance(
get_proper_type(typ.ret_type), (NoneType, UninhabitedType)
)
and not self.dynamic_funcs[-1]
):
self.fail(
message_registry.MUST_HAVE_NONE_RETURN_TYPE.format(fdef.name), item
)
# Check validity of __new__ signature
if fdef.info and fdef.name == "__new__":
self.check___new___signature(fdef, typ)
self.check_for_missing_annotations(fdef)
if self.options.disallow_any_unimported:
if fdef.type and isinstance(fdef.type, CallableType):
ret_type = fdef.type.ret_type
if has_any_from_unimported_type(ret_type):
self.msg.unimported_type_becomes_any("Return type", ret_type, fdef)
for idx, arg_type in enumerate(fdef.type.arg_types):
if has_any_from_unimported_type(arg_type):
prefix = f'Argument {idx + 1} to "{fdef.name}"'
self.msg.unimported_type_becomes_any(prefix, arg_type, fdef)
check_for_explicit_any(
fdef.type, self.options, self.is_typeshed_stub, self.msg, context=fdef
)
if name: # Special method names
if defn.info and self.is_reverse_op_method(name):
self.check_reverse_op_method(item, typ, name, defn)
elif name in ("__getattr__", "__getattribute__"):
self.check_getattr_method(typ, defn, name)
elif name == "__setattr__":
self.check_setattr_method(typ, defn)
# Refuse contravariant return type variable
if isinstance(typ.ret_type, TypeVarType):
if typ.ret_type.variance == CONTRAVARIANT:
self.fail(
message_registry.RETURN_TYPE_CANNOT_BE_CONTRAVARIANT, typ.ret_type
)
self.check_unbound_return_typevar(typ)
elif (
isinstance(original_typ.ret_type, TypeVarType) and original_typ.ret_type.values
):
# Since type vars with values are expanded, the return type is changed
# to a raw value. This is a hack to get it back.
self.check_unbound_return_typevar(original_typ)
# Check that Generator functions have the appropriate return type.
if defn.is_generator:
if defn.is_async_generator:
if not self.is_async_generator_return_type(typ.ret_type):
self.fail(
message_registry.INVALID_RETURN_TYPE_FOR_ASYNC_GENERATOR, typ
)
else:
if not self.is_generator_return_type(typ.ret_type, defn.is_coroutine):
self.fail(message_registry.INVALID_RETURN_TYPE_FOR_GENERATOR, typ)
# Fix the type if decorated with `@types.coroutine` or `@asyncio.coroutine`.
if defn.is_awaitable_coroutine:
# Update the return type to AwaitableGenerator.
# (This doesn't exist in typing.py, only in typing.pyi.)
t = typ.ret_type
c = defn.is_coroutine
ty = self.get_generator_yield_type(t, c)
tc = self.get_generator_receive_type(t, c)
if c:
tr = self.get_coroutine_return_type(t)
else:
tr = self.get_generator_return_type(t, c)
ret_type = self.named_generic_type(
"typing.AwaitableGenerator", [ty, tc, tr, t]
)
typ = typ.copy_modified(ret_type=ret_type)
defn.type = typ
# Push return type.
self.return_types.append(typ.ret_type)
# Store argument types.
for i in range(len(typ.arg_types)):
arg_type = typ.arg_types[i]
with self.scope.push_function(defn):
# We temporary push the definition to get the self type as
# visible from *inside* of this function/method.
ref_type: Type | None = self.scope.active_self_type()
if (
isinstance(defn, FuncDef)
and ref_type is not None
and i == 0
and (not defn.is_static or defn.name == "__new__")
and typ.arg_kinds[0] not in [nodes.ARG_STAR, nodes.ARG_STAR2]
):
if defn.is_class or defn.name == "__new__":
ref_type = mypy.types.TypeType.make_normalized(ref_type)
erased = get_proper_type(erase_to_bound(arg_type))
if not is_subtype(ref_type, erased, ignore_type_params=True):
if (
isinstance(erased, Instance)
and erased.type.is_protocol
or isinstance(erased, TypeType)
and isinstance(erased.item, Instance)
and erased.item.type.is_protocol
):
# We allow the explicit self-type to be not a supertype of
# the current class if it is a protocol. For such cases
# the consistency check will be performed at call sites.
msg = None
elif typ.arg_names[i] in {"self", "cls"}:
msg = message_registry.ERASED_SELF_TYPE_NOT_SUPERTYPE.format(
erased.str_with_options(self.options),
ref_type.str_with_options(self.options),
)
else:
msg = message_registry.MISSING_OR_INVALID_SELF_TYPE
if msg:
self.fail(msg, defn)
elif isinstance(arg_type, TypeVarType):
# Refuse covariant parameter type variables
# TODO: check recursively for inner type variables
if (
arg_type.variance == COVARIANT
and defn.name not in ("__init__", "__new__", "__post_init__")
and not is_private(defn.name) # private methods are not inherited
):
ctx: Context = arg_type
if ctx.line < 0:
ctx = typ
self.fail(message_registry.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, ctx)
# Need to store arguments again for the expanded item.
store_argument_type(item, i, typ, self.named_generic_type)
# Type check initialization expressions.
body_is_trivial = is_trivial_body(defn.body)
self.check_default_args(item, body_is_trivial)
# Type check body in a new scope.
with self.binder.top_frame_context():
# Copy some type narrowings from an outer function when it seems safe enough
# (i.e. we can't find an assignment that might change the type of the
# variable afterwards).
new_frame: Frame | None = None
for frame in old_binder.frames:
for key, narrowed_type in frame.types.items():
key_var = extract_var_from_literal_hash(key)
if key_var is not None and not self.is_var_redefined_in_outer_context(
key_var, defn.line
):
# It seems safe to propagate the type narrowing to a nested scope.
if new_frame is None:
new_frame = self.binder.push_frame()
new_frame.types[key] = narrowed_type
self.binder.declarations[key] = old_binder.declarations[key]
with self.scope.push_function(defn):
# We suppress reachability warnings for empty generator functions
# (return; yield) which have a "yield" that's unreachable by definition
# since it's only there to promote the function into a generator function.
#
# We also suppress reachability warnings when we use TypeVars with value
# restrictions: we only want to report a warning if a certain statement is
# marked as being suppressed in *all* of the expansions, but we currently
# have no good way of doing this.
#
# TODO: Find a way of working around this limitation
if _is_empty_generator_function(item) or len(expanded) >= 2:
self.binder.suppress_unreachable_warnings()
self.accept(item.body)
unreachable = self.binder.is_unreachable()
if new_frame is not None:
self.binder.pop_frame(True, 0)
if not unreachable:
if defn.is_generator or is_named_instance(
self.return_types[-1], "typing.AwaitableGenerator"
):
return_type = self.get_generator_return_type(
self.return_types[-1], defn.is_coroutine
)
elif defn.is_coroutine:
return_type = self.get_coroutine_return_type(self.return_types[-1])
else:
return_type = self.return_types[-1]
return_type = get_proper_type(return_type)
allow_empty = allow_empty or self.options.allow_empty_bodies
show_error = (
not body_is_trivial
or
# Allow empty bodies for abstract methods, overloads, in tests and stubs.
(
not allow_empty
and not (
isinstance(defn, FuncDef) and defn.abstract_status != NOT_ABSTRACT
)
and not self.is_stub
)
)
# Ignore plugin generated methods, these usually don't need any bodies.
if defn.info is not FUNC_NO_INFO and (
defn.name not in defn.info.names or defn.info.names[defn.name].plugin_generated
):
show_error = False
# Ignore also definitions that appear in `if TYPE_CHECKING: ...` blocks.
# These can't be called at runtime anyway (similar to plugin-generated).
if isinstance(defn, FuncDef) and defn.is_mypy_only:
show_error = False
# We want to minimize the fallout from checking empty bodies
# that was absent in many mypy versions.
if body_is_trivial and is_subtype(NoneType(), return_type):
show_error = False
may_be_abstract = (
body_is_trivial
and defn.info is not FUNC_NO_INFO
and defn.info.metaclass_type is not None
and defn.info.metaclass_type.type.has_base("abc.ABCMeta")
)
if self.options.warn_no_return:
if (
not self.current_node_deferred
and not isinstance(return_type, (NoneType, AnyType))
and show_error
):
# Control flow fell off the end of a function that was
# declared to return a non-None type.
if isinstance(return_type, UninhabitedType):
# This is a NoReturn function
msg = message_registry.INVALID_IMPLICIT_RETURN
else:
msg = message_registry.MISSING_RETURN_STATEMENT
if body_is_trivial:
msg = msg._replace(code=codes.EMPTY_BODY)
self.fail(msg, defn)
if may_be_abstract:
self.note(message_registry.EMPTY_BODY_ABSTRACT, defn)
elif show_error:
msg = message_registry.INCOMPATIBLE_RETURN_VALUE_TYPE
if body_is_trivial:
msg = msg._replace(code=codes.EMPTY_BODY)
# similar to code in check_return_stmt
if (
not self.check_subtype(
subtype_label="implicitly returns",
subtype=NoneType(),
supertype_label="expected",
supertype=return_type,
context=defn,
msg=msg,
)
and may_be_abstract
):
self.note(message_registry.EMPTY_BODY_ABSTRACT, defn)
self.return_types.pop()
self.binder = old_binder
def is_var_redefined_in_outer_context(self, v: Var, after_line: int) -> bool:
"""Can the variable be assigned to at module top level or outer function?
Note that this doesn't do a full CFG analysis but uses a line number based
heuristic that isn't correct in some (rare) cases.
"""
outers = self.tscope.outer_functions()
if not outers:
# Top-level function -- outer context is top level, and we can't reason about
# globals
return True
for outer in outers:
if isinstance(outer, FuncDef):
if find_last_var_assignment_line(outer.body, v) >= after_line:
return True
return False
def check_unbound_return_typevar(self, typ: CallableType) -> None:
"""Fails when the return typevar is not defined in arguments."""
if isinstance(typ.ret_type, TypeVarType) and typ.ret_type in typ.variables:
arg_type_visitor = CollectArgTypeVarTypes()
for argtype in typ.arg_types:
argtype.accept(arg_type_visitor)
if typ.ret_type not in arg_type_visitor.arg_types:
self.fail(message_registry.UNBOUND_TYPEVAR, typ.ret_type, code=TYPE_VAR)
upper_bound = get_proper_type(typ.ret_type.upper_bound)
if not (
isinstance(upper_bound, Instance)
and upper_bound.type.fullname == "builtins.object"
):
self.note(
"Consider using the upper bound "
f"{format_type(typ.ret_type.upper_bound, self.options)} instead",
context=typ.ret_type,
)
def check_default_args(self, item: FuncItem, body_is_trivial: bool) -> None:
for arg in item.arguments:
if arg.initializer is None:
continue
if body_is_trivial and isinstance(arg.initializer, EllipsisExpr):
continue
name = arg.variable.name
msg = "Incompatible default for "
if name.startswith("__tuple_arg_"):
msg += f"tuple argument {name[12:]}"
else:
msg += f'argument "{name}"'
if (
not self.options.implicit_optional
and isinstance(arg.initializer, NameExpr)
and arg.initializer.fullname == "builtins.None"
):
notes = [
"PEP 484 prohibits implicit Optional. "
"Accordingly, mypy has changed its default to no_implicit_optional=True",
"Use https://github.com/hauntsaninja/no_implicit_optional to automatically "
"upgrade your codebase",
]
else:
notes = None
self.check_simple_assignment(
arg.variable.type,
arg.initializer,
context=arg.initializer,
msg=ErrorMessage(msg, code=codes.ASSIGNMENT),
lvalue_name="argument",
rvalue_name="default",
notes=notes,
)
def is_forward_op_method(self, method_name: str) -> bool:
return method_name in operators.reverse_op_methods
def is_reverse_op_method(self, method_name: str) -> bool:
return method_name in operators.reverse_op_method_set
def check_for_missing_annotations(self, fdef: FuncItem) -> None:
# Check for functions with unspecified/not fully specified types.
def is_unannotated_any(t: Type) -> bool:
if not isinstance(t, ProperType):
return False
return isinstance(t, AnyType) and t.type_of_any == TypeOfAny.unannotated
has_explicit_annotation = isinstance(fdef.type, CallableType) and any(
not is_unannotated_any(t) for t in fdef.type.arg_types + [fdef.type.ret_type]
)
show_untyped = not self.is_typeshed_stub or self.options.warn_incomplete_stub
check_incomplete_defs = self.options.disallow_incomplete_defs and has_explicit_annotation
if show_untyped and (self.options.disallow_untyped_defs or check_incomplete_defs):
if fdef.type is None and self.options.disallow_untyped_defs:
if not fdef.arguments or (
len(fdef.arguments) == 1
and (fdef.arg_names[0] == "self" or fdef.arg_names[0] == "cls")
):
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
if not has_return_statement(fdef) and not fdef.is_generator:
self.note(
'Use "-> None" if function does not return a value',
fdef,
code=codes.NO_UNTYPED_DEF,
)
else:
self.fail(message_registry.FUNCTION_TYPE_EXPECTED, fdef)
elif isinstance(fdef.type, CallableType):
ret_type = get_proper_type(fdef.type.ret_type)
if is_unannotated_any(ret_type):
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
elif fdef.is_generator:
if is_unannotated_any(
self.get_generator_return_type(ret_type, fdef.is_coroutine)
):
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
elif fdef.is_coroutine and isinstance(ret_type, Instance):
if is_unannotated_any(self.get_coroutine_return_type(ret_type)):
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
if any(is_unannotated_any(t) for t in fdef.type.arg_types):
self.fail(message_registry.ARGUMENT_TYPE_EXPECTED, fdef)
def check___new___signature(self, fdef: FuncDef, typ: CallableType) -> None:
self_type = fill_typevars_with_any(fdef.info)
bound_type = bind_self(typ, self_type, is_classmethod=True)
# Check that __new__ (after binding cls) returns an instance
# type (or any).
if fdef.info.is_metaclass():
# This is a metaclass, so it must return a new unrelated type.
self.check_subtype(
bound_type.ret_type,
self.type_type(),
fdef,
message_registry.INVALID_NEW_TYPE,
"returns",
"but must return a subtype of",
)
elif not isinstance(
get_proper_type(bound_type.ret_type), (AnyType, Instance, TupleType, UninhabitedType)
):
self.fail(
message_registry.NON_INSTANCE_NEW_TYPE.format(
format_type(bound_type.ret_type, self.options)
),
fdef,
)
else:
# And that it returns a subtype of the class
self.check_subtype(
bound_type.ret_type,
self_type,
fdef,
message_registry.INVALID_NEW_TYPE,
"returns",
"but must return a subtype of",
)
def check_reverse_op_method(
self, defn: FuncItem, reverse_type: CallableType, reverse_name: str, context: Context
) -> None:
"""Check a reverse operator method such as __radd__."""
# Decides whether it's worth calling check_overlapping_op_methods().
# This used to check for some very obscure scenario. It now
# just decides whether it's worth calling
# check_overlapping_op_methods().
assert defn.info
# First check for a valid signature
method_type = CallableType(
[AnyType(TypeOfAny.special_form), AnyType(TypeOfAny.special_form)],
[nodes.ARG_POS, nodes.ARG_POS],
[None, None],
AnyType(TypeOfAny.special_form),
self.named_type("builtins.function"),
)
if not is_subtype(reverse_type, method_type):
self.msg.invalid_signature(reverse_type, context)
return
if reverse_name in ("__eq__", "__ne__"):
# These are defined for all objects => can't cause trouble.
return
# With 'Any' or 'object' return type we are happy, since any possible
# return value is valid.
ret_type = get_proper_type(reverse_type.ret_type)
if isinstance(ret_type, AnyType):
return
if isinstance(ret_type, Instance):
if ret_type.type.fullname == "builtins.object":
return
if reverse_type.arg_kinds[0] == ARG_STAR:
reverse_type = reverse_type.copy_modified(
arg_types=[reverse_type.arg_types[0]] * 2,
arg_kinds=[ARG_POS] * 2,
arg_names=[reverse_type.arg_names[0], "_"],
)
assert len(reverse_type.arg_types) >= 2
forward_name = operators.normal_from_reverse_op[reverse_name]
forward_inst = get_proper_type(reverse_type.arg_types[1])
if isinstance(forward_inst, TypeVarType):
forward_inst = get_proper_type(forward_inst.upper_bound)
elif isinstance(forward_inst, TupleType):
forward_inst = tuple_fallback(forward_inst)
elif isinstance(forward_inst, (FunctionLike, TypedDictType, LiteralType)):
forward_inst = forward_inst.fallback
if isinstance(forward_inst, TypeType):
item = forward_inst.item
if isinstance(item, Instance):
opt_meta = item.type.metaclass_type
if opt_meta is not None:
forward_inst = opt_meta
def has_readable_member(typ: UnionType | Instance, name: str) -> bool:
# TODO: Deal with attributes of TupleType etc.
if isinstance(typ, Instance):
return typ.type.has_readable_member(name)
return all(
(isinstance(x, UnionType) and has_readable_member(x, name))
or (isinstance(x, Instance) and x.type.has_readable_member(name))
for x in get_proper_types(typ.relevant_items())
)
if not (
isinstance(forward_inst, (Instance, UnionType))
and has_readable_member(forward_inst, forward_name)
):
return
forward_base = reverse_type.arg_types[1]
forward_type = self.expr_checker.analyze_external_member_access(
forward_name, forward_base, context=defn
)
self.check_overlapping_op_methods(
reverse_type,
reverse_name,
defn.info,
forward_type,
forward_name,
forward_base,
context=defn,
)
def check_overlapping_op_methods(
self,
reverse_type: CallableType,
reverse_name: str,
reverse_class: TypeInfo,
forward_type: Type,
forward_name: str,
forward_base: Type,
context: Context,
) -> None:
"""Check for overlapping method and reverse method signatures.
This function assumes that:
- The reverse method has valid argument count and kinds.
- If the reverse operator method accepts some argument of type
X, the forward operator method also belong to class X.
For example, if we have the reverse operator `A.__radd__(B)`, then the
corresponding forward operator must have the type `B.__add__(...)`.
"""
# Note: Suppose we have two operator methods "A.__rOP__(B) -> R1" and
# "B.__OP__(C) -> R2". We check if these two methods are unsafely overlapping
# by using the following algorithm:
#
# 1. Rewrite "B.__OP__(C) -> R1" to "temp1(B, C) -> R1"
#
# 2. Rewrite "A.__rOP__(B) -> R2" to "temp2(B, A) -> R2"
#
# 3. Treat temp1 and temp2 as if they were both variants in the same
# overloaded function. (This mirrors how the Python runtime calls
# operator methods: we first try __OP__, then __rOP__.)
#
# If the first signature is unsafely overlapping with the second,
# report an error.
#
# 4. However, if temp1 shadows temp2 (e.g. the __rOP__ method can never
# be called), do NOT report an error.
#
# This behavior deviates from how we handle overloads -- many of the
# modules in typeshed seem to define __OP__ methods that shadow the
# corresponding __rOP__ method.
#
# Note: we do not attempt to handle unsafe overlaps related to multiple
# inheritance. (This is consistent with how we handle overloads: we also
# do not try checking unsafe overlaps due to multiple inheritance there.)
for forward_item in flatten_nested_unions([forward_type]):
forward_item = get_proper_type(forward_item)
if isinstance(forward_item, CallableType):
if self.is_unsafe_overlapping_op(forward_item, forward_base, reverse_type):
self.msg.operator_method_signatures_overlap(
reverse_class, reverse_name, forward_base, forward_name, context
)
elif isinstance(forward_item, Overloaded):
for item in forward_item.items:
if self.is_unsafe_overlapping_op(item, forward_base, reverse_type):
self.msg.operator_method_signatures_overlap(
reverse_class, reverse_name, forward_base, forward_name, context
)
elif not isinstance(forward_item, AnyType):
self.msg.forward_operator_not_callable(forward_name, context)
def is_unsafe_overlapping_op(
self, forward_item: CallableType, forward_base: Type, reverse_type: CallableType
) -> bool:
# TODO: check argument kinds?
if len(forward_item.arg_types) < 1:
# Not a valid operator method -- can't succeed anyway.
return False
# Erase the type if necessary to make sure we don't have a single
# TypeVar in forward_tweaked. (Having a function signature containing
# just a single TypeVar can lead to unpredictable behavior.)
forward_base_erased = forward_base
if isinstance(forward_base, TypeVarType):
forward_base_erased = erase_to_bound(forward_base)
# Construct normalized function signatures corresponding to the
# operator methods. The first argument is the left operand and the
# second operand is the right argument -- we switch the order of
# the arguments of the reverse method.
forward_tweaked = forward_item.copy_modified(
arg_types=[forward_base_erased, forward_item.arg_types[0]],
arg_kinds=[nodes.ARG_POS] * 2,
arg_names=[None] * 2,
)
reverse_tweaked = reverse_type.copy_modified(
arg_types=[reverse_type.arg_types[1], reverse_type.arg_types[0]],
arg_kinds=[nodes.ARG_POS] * 2,
arg_names=[None] * 2,
)
reverse_base_erased = reverse_type.arg_types[0]
if isinstance(reverse_base_erased, TypeVarType):
reverse_base_erased = erase_to_bound(reverse_base_erased)
if is_same_type(reverse_base_erased, forward_base_erased):
return False
elif is_subtype(reverse_base_erased, forward_base_erased):
first = reverse_tweaked
second = forward_tweaked
else:
first = forward_tweaked
second = reverse_tweaked
current_class = self.scope.active_class()
type_vars = current_class.defn.type_vars if current_class else []
return is_unsafe_overlapping_overload_signatures(first, second, type_vars)
def check_inplace_operator_method(self, defn: FuncBase) -> None:
"""Check an inplace operator method such as __iadd__.
They cannot arbitrarily overlap with __add__.
"""
method = defn.name
if method not in operators.inplace_operator_methods:
return
typ = bind_self(self.function_type(defn))
cls = defn.info
other_method = "__" + method[3:]
if cls.has_readable_member(other_method):
instance = fill_typevars(cls)
typ2 = get_proper_type(
self.expr_checker.analyze_external_member_access(other_method, instance, defn)
)
fail = False
if isinstance(typ2, FunctionLike):
if not is_more_general_arg_prefix(typ, typ2):
fail = True
else:
# TODO overloads
fail = True
if fail:
self.msg.signatures_incompatible(method, other_method, defn)
def check_getattr_method(self, typ: Type, context: Context, name: str) -> None:
if len(self.scope.stack) == 1:
# module scope
if name == "__getattribute__":
self.fail(message_registry.MODULE_LEVEL_GETATTRIBUTE, context)
return
# __getattr__ is fine at the module level as of Python 3.7 (PEP 562). We could
# show an error for Python < 3.7, but that would be annoying in code that supports
# both 3.7 and older versions.
method_type = CallableType(
[self.named_type("builtins.str")],
[nodes.ARG_POS],
[None],
AnyType(TypeOfAny.special_form),
self.named_type("builtins.function"),
)
elif self.scope.active_class():
method_type = CallableType(
[AnyType(TypeOfAny.special_form), self.named_type("builtins.str")],
[nodes.ARG_POS, nodes.ARG_POS],
[None, None],
AnyType(TypeOfAny.special_form),
self.named_type("builtins.function"),
)
else:
return
if not is_subtype(typ, method_type):
self.msg.invalid_signature_for_special_method(typ, context, name)
def check_setattr_method(self, typ: Type, context: Context) -> None:
if not self.scope.active_class():
return
method_type = CallableType(
[
AnyType(TypeOfAny.special_form),
self.named_type("builtins.str"),
AnyType(TypeOfAny.special_form),
],
[nodes.ARG_POS, nodes.ARG_POS, nodes.ARG_POS],
[None, None, None],
NoneType(),
self.named_type("builtins.function"),
)
if not is_subtype(typ, method_type):
self.msg.invalid_signature_for_special_method(typ, context, "__setattr__")
def check_slots_definition(self, typ: Type, context: Context) -> None:
"""Check the type of __slots__."""
str_type = self.named_type("builtins.str")
expected_type = UnionType(
[str_type, self.named_generic_type("typing.Iterable", [str_type])]
)
self.check_subtype(
typ,
expected_type,
context,
message_registry.INVALID_TYPE_FOR_SLOTS,
"actual type",
"expected type",
code=codes.ASSIGNMENT,
)
def check_match_args(self, var: Var, typ: Type, context: Context) -> None:
"""Check that __match_args__ contains literal strings"""
if not self.scope.active_class():
return
typ = get_proper_type(typ)
if not isinstance(typ, TupleType) or not all(
[is_string_literal(item) for item in typ.items]
):
self.msg.note(
"__match_args__ must be a tuple containing string literals for checking "
"of match statements to work",
context,
code=codes.LITERAL_REQ,
)
def expand_typevars(
self, defn: FuncItem, typ: CallableType
) -> list[tuple[FuncItem, CallableType]]:
# TODO use generator
subst: list[list[tuple[TypeVarId, Type]]] = []
tvars = list(typ.variables) or []
if defn.info:
# Class type variables
tvars += defn.info.defn.type_vars or []
# TODO(PEP612): audit for paramspec
for tvar in tvars:
if isinstance(tvar, TypeVarType) and tvar.values:
subst.append([(tvar.id, value) for value in tvar.values])
# Make a copy of the function to check for each combination of
# value restricted type variables. (Except when running mypyc,
# where we need one canonical version of the function.)
if subst and not (self.options.mypyc or self.options.inspections):
result: list[tuple[FuncItem, CallableType]] = []
for substitutions in itertools.product(*subst):
mapping = dict(substitutions)
result.append((expand_func(defn, mapping), expand_type(typ, mapping)))
return result
else:
return [(defn, typ)]
def check_explicit_override_decorator(
self,
defn: FuncDef | OverloadedFuncDef,
found_method_base_classes: list[TypeInfo] | None,
context: Context | None = None,
) -> None:
if (
found_method_base_classes
and not defn.is_explicit_override
and defn.name not in ("__init__", "__new__")
):
self.msg.explicit_override_decorator_missing(
defn.name, found_method_base_classes[0].fullname, context or defn
)
def check_method_override(
self, defn: FuncDef | OverloadedFuncDef | Decorator
) -> list[TypeInfo] | None:
"""Check if function definition is compatible with base classes.
This may defer the method if a signature is not available in at least one base class.
Return ``None`` if that happens.
Return a list of base classes which contain an attribute with the method name.
"""
# Check against definitions in base classes.
found_method_base_classes: list[TypeInfo] = []
for base in defn.info.mro[1:]:
result = self.check_method_or_accessor_override_for_base(defn, base)
if result is None:
# Node was deferred, we will have another attempt later.
return None
if result:
found_method_base_classes.append(base)
return found_method_base_classes
def check_method_or_accessor_override_for_base(
self, defn: FuncDef | OverloadedFuncDef | Decorator, base: TypeInfo
) -> bool | None:
"""Check if method definition is compatible with a base class.
Return ``None`` if the node was deferred because one of the corresponding
superclass nodes is not ready.
Return ``True`` if an attribute with the method name was found in the base class.
"""
found_base_method = False
if base:
name = defn.name
base_attr = base.names.get(name)
if base_attr:
# First, check if we override a final (always an error, even with Any types).
if is_final_node(base_attr.node):
self.msg.cant_override_final(name, base.name, defn)
# Second, final can't override anything writeable independently of types.
if defn.is_final:
self.check_if_final_var_override_writable(name, base_attr.node, defn)
found_base_method = True
# Check the type of override.
if name not in ("__init__", "__new__", "__init_subclass__", "__post_init__"):
# Check method override
# (__init__, __new__, __init_subclass__ are special).
if self.check_method_override_for_base_with_name(defn, name, base):
return None
if name in operators.inplace_operator_methods:
# Figure out the name of the corresponding operator method.
method = "__" + name[3:]
# An inplace operator method such as __iadd__ might not be
# always introduced safely if a base class defined __add__.
# TODO can't come up with an example where this is
# necessary; now it's "just in case"
if self.check_method_override_for_base_with_name(defn, method, base):
return None
return found_base_method
def check_method_override_for_base_with_name(
self, defn: FuncDef | OverloadedFuncDef | Decorator, name: str, base: TypeInfo
) -> bool:
"""Check if overriding an attribute `name` of `base` with `defn` is valid.
Return True if the supertype node was not analysed yet, and `defn` was deferred.
"""
base_attr = base.names.get(name)
if base_attr:
# The name of the method is defined in the base class.
# Point errors at the 'def' line (important for backward compatibility
# of type ignores).
if not isinstance(defn, Decorator):
context = defn
else:
context = defn.func
# Construct the type of the overriding method.
# TODO: this logic is much less complete than similar one in checkmember.py
if isinstance(defn, (FuncDef, OverloadedFuncDef)):
typ: Type = self.function_type(defn)
override_class_or_static = defn.is_class or defn.is_static
override_class = defn.is_class
else:
assert defn.var.is_ready
assert defn.var.type is not None
typ = defn.var.type
override_class_or_static = defn.func.is_class or defn.func.is_static
override_class = defn.func.is_class
typ = get_proper_type(typ)
if isinstance(typ, FunctionLike) and not is_static(context):
typ = bind_self(typ, self.scope.active_self_type(), is_classmethod=override_class)
# Map the overridden method type to subtype context so that
# it can be checked for compatibility.
original_type = get_proper_type(base_attr.type)
original_node = base_attr.node
# `original_type` can be partial if (e.g.) it is originally an
# instance variable from an `__init__` block that becomes deferred.
if original_type is None or isinstance(original_type, PartialType):
if self.pass_num < self.last_pass:
# If there are passes left, defer this node until next pass,
# otherwise try reconstructing the method type from available information.
self.defer_node(defn, defn.info)
return True
elif isinstance(original_node, (FuncDef, OverloadedFuncDef)):
original_type = self.function_type(original_node)
elif isinstance(original_node, Decorator):
original_type = self.function_type(original_node.func)
elif isinstance(original_node, Var):
# Super type can define method as an attribute.
# See https://github.com/python/mypy/issues/10134
# We also check that sometimes `original_node.type` is None.
# This is the case when we use something like `__hash__ = None`.
if original_node.type is not None:
original_type = get_proper_type(original_node.type)
else:
original_type = NoneType()
else:
# Will always fail to typecheck below, since we know the node is a method
original_type = NoneType()
if isinstance(original_node, (FuncDef, OverloadedFuncDef)):
original_class_or_static = original_node.is_class or original_node.is_static
elif isinstance(original_node, Decorator):
fdef = original_node.func
original_class_or_static = fdef.is_class or fdef.is_static
else:
original_class_or_static = False # a variable can't be class or static
if isinstance(original_type, FunctionLike):
original_type = self.bind_and_map_method(base_attr, original_type, defn.info, base)
if original_node and is_property(original_node):
original_type = get_property_type(original_type)
if isinstance(typ, FunctionLike) and is_property(defn):
typ = get_property_type(typ)
if (
isinstance(original_node, Var)
and not original_node.is_final
and (not original_node.is_property or original_node.is_settable_property)
and isinstance(defn, Decorator)
):
# We only give an error where no other similar errors will be given.
if not isinstance(original_type, AnyType):
self.msg.fail(
"Cannot override writeable attribute with read-only property",
# Give an error on function line to match old behaviour.
defn.func,
code=codes.OVERRIDE,
)
if isinstance(original_type, AnyType) or isinstance(typ, AnyType):
pass
elif isinstance(original_type, FunctionLike) and isinstance(typ, FunctionLike):
# Check that the types are compatible.
# TODO overloaded signatures
self.check_override(
typ,
original_type,
defn.name,
name,
base.name,
original_class_or_static,
override_class_or_static,
context,
)
elif is_equivalent(original_type, typ):
# Assume invariance for a non-callable attribute here. Note
# that this doesn't affect read-only properties which can have
# covariant overrides.
#
pass
elif (
original_node
and not self.is_writable_attribute(original_node)
and is_subtype(typ, original_type)
):
# If the attribute is read-only, allow covariance
pass
else:
self.msg.signature_incompatible_with_supertype(
defn.name, name, base.name, context, original=original_type, override=typ
)
return False
def bind_and_map_method(
self, sym: SymbolTableNode, typ: FunctionLike, sub_info: TypeInfo, super_info: TypeInfo
) -> FunctionLike:
"""Bind self-type and map type variables for a method.
Arguments:
sym: a symbol that points to method definition
typ: method type on the definition
sub_info: class where the method is used
super_info: class where the method was defined
"""
if isinstance(sym.node, (FuncDef, OverloadedFuncDef, Decorator)) and not is_static(
sym.node
):
if isinstance(sym.node, Decorator):
is_class_method = sym.node.func.is_class
else:
is_class_method = sym.node.is_class
mapped_typ = cast(FunctionLike, map_type_from_supertype(typ, sub_info, super_info))
active_self_type = self.scope.active_self_type()
if isinstance(mapped_typ, Overloaded) and active_self_type:
# If we have an overload, filter to overloads that match the self type.
# This avoids false positives for concrete subclasses of generic classes,
# see testSelfTypeOverrideCompatibility for an example.
filtered_items = []
for item in mapped_typ.items:
if not item.arg_types:
filtered_items.append(item)
item_arg = item.arg_types[0]
if isinstance(item_arg, TypeVarType):
item_arg = item_arg.upper_bound
if is_subtype(active_self_type, item_arg):
filtered_items.append(item)
# If we don't have any filtered_items, maybe it's always a valid override
# of the superclass? However if you get to that point you're in murky type
# territory anyway, so we just preserve the type and have the behaviour match
# that of older versions of mypy.
if filtered_items:
mapped_typ = Overloaded(filtered_items)
return bind_self(mapped_typ, active_self_type, is_class_method)
else:
return cast(FunctionLike, map_type_from_supertype(typ, sub_info, super_info))
def get_op_other_domain(self, tp: FunctionLike) -> Type | None:
if isinstance(tp, CallableType):
if tp.arg_kinds and tp.arg_kinds[0] == ARG_POS:
return tp.arg_types[0]
return None
elif isinstance(tp, Overloaded):
raw_items = [self.get_op_other_domain(it) for it in tp.items]
items = [it for it in raw_items if it]
if items:
return make_simplified_union(items)
return None
else:
assert False, "Need to check all FunctionLike subtypes here"
def check_override(
self,
override: FunctionLike,
original: FunctionLike,
name: str,
name_in_super: str,
supertype: str,
original_class_or_static: bool,
override_class_or_static: bool,
node: Context,
) -> None:
"""Check a method override with given signatures.
Arguments:
override: The signature of the overriding method.
original: The signature of the original supertype method.
name: The name of the subtype. This and the next argument are
only used for generating error messages.
supertype: The name of the supertype.
"""
# Use boolean variable to clarify code.
fail = False
op_method_wider_note = False
if not is_subtype(override, original, ignore_pos_arg_names=True):
fail = True
elif isinstance(override, Overloaded) and self.is_forward_op_method(name):
# Operator method overrides cannot extend the domain, as
# this could be unsafe with reverse operator methods.
original_domain = self.get_op_other_domain(original)
override_domain = self.get_op_other_domain(override)
if (
original_domain
and override_domain
and not is_subtype(override_domain, original_domain)
):
fail = True
op_method_wider_note = True
if isinstance(override, FunctionLike):
if original_class_or_static and not override_class_or_static:
fail = True
elif isinstance(original, CallableType) and isinstance(override, CallableType):
if original.type_guard is not None and override.type_guard is None:
fail = True
if is_private(name):
fail = False
if fail:
emitted_msg = False
# Normalize signatures, so we get better diagnostics.
if isinstance(override, (CallableType, Overloaded)):
override = override.with_unpacked_kwargs()
if isinstance(original, (CallableType, Overloaded)):
original = original.with_unpacked_kwargs()
if (
isinstance(override, CallableType)
and isinstance(original, CallableType)
and len(override.arg_types) == len(original.arg_types)
and override.min_args == original.min_args
):
# Give more detailed messages for the common case of both
# signatures having the same number of arguments and no
# overloads.
# override might have its own generic function type
# variables. If an argument or return type of override
# does not have the correct subtyping relationship
# with the original type even after these variables
# are erased, then it is definitely an incompatibility.
override_ids = override.type_var_ids()
type_name = None
if isinstance(override.definition, FuncDef):
type_name = override.definition.info.name
def erase_override(t: Type) -> Type:
return erase_typevars(t, ids_to_erase=override_ids)
for i in range(len(override.arg_types)):
if not is_subtype(
original.arg_types[i], erase_override(override.arg_types[i])
):
arg_type_in_super = original.arg_types[i]
if isinstance(node, FuncDef):
context: Context = node.arguments[i + len(override.bound_args)]
else:
context = node
self.msg.argument_incompatible_with_supertype(
i + 1,
name,
type_name,
name_in_super,
arg_type_in_super,
supertype,
context,
secondary_context=node,
)
emitted_msg = True
if not is_subtype(erase_override(override.ret_type), original.ret_type):
self.msg.return_type_incompatible_with_supertype(
name, name_in_super, supertype, original.ret_type, override.ret_type, node
)
emitted_msg = True
elif isinstance(override, Overloaded) and isinstance(original, Overloaded):
# Give a more detailed message in the case where the user is trying to
# override an overload, and the subclass's overload is plausible, except
# that the order of the variants are wrong.
#
# For example, if the parent defines the overload f(int) -> int and f(str) -> str
# (in that order), and if the child swaps the two and does f(str) -> str and
# f(int) -> int
order = []
for child_variant in override.items:
for i, parent_variant in enumerate(original.items):
if is_subtype(child_variant, parent_variant):
order.append(i)
break
if len(order) == len(original.items) and order != sorted(order):
self.msg.overload_signature_incompatible_with_supertype(
name, name_in_super, supertype, node
)
emitted_msg = True
if not emitted_msg:
# Fall back to generic incompatibility message.
self.msg.signature_incompatible_with_supertype(
name, name_in_super, supertype, node, original=original, override=override
)
if op_method_wider_note:
self.note(
"Overloaded operator methods can't have wider argument types in overrides",
node,
code=codes.OVERRIDE,
)
def check__exit__return_type(self, defn: FuncItem) -> None:
"""Generate error if the return type of __exit__ is problematic.
If __exit__ always returns False but the return type is declared
as bool, mypy thinks that a with statement may "swallow"
exceptions even though this is not the case, resulting in
invalid reachability inference.
"""
if not defn.type or not isinstance(defn.type, CallableType):
return
ret_type = get_proper_type(defn.type.ret_type)
if not has_bool_item(ret_type):
return
returns = all_return_statements(defn)
if not returns:
return
if all(
isinstance(ret.expr, NameExpr) and ret.expr.fullname == "builtins.False"
for ret in returns
):
self.msg.incorrect__exit__return(defn)
def visit_class_def(self, defn: ClassDef) -> None:
"""Type check a class definition."""
typ = defn.info
for base in typ.mro[1:]:
if base.is_final:
self.fail(message_registry.CANNOT_INHERIT_FROM_FINAL.format(base.name), defn)
with self.tscope.class_scope(defn.info), self.enter_partial_types(is_class=True):
old_binder = self.binder
self.binder = ConditionalTypeBinder()
with self.binder.top_frame_context():
with self.scope.push_class(defn.info):
self.accept(defn.defs)
self.binder = old_binder
if not (defn.info.typeddict_type or defn.info.tuple_type or defn.info.is_enum):
# If it is not a normal class (not a special form) check class keywords.
self.check_init_subclass(defn)
if not defn.has_incompatible_baseclass:
# Otherwise we've already found errors; more errors are not useful
self.check_multiple_inheritance(typ)
self.check_metaclass_compatibility(typ)
self.check_final_deletable(typ)
if defn.decorators:
sig: Type = type_object_type(defn.info, self.named_type)
# Decorators are applied in reverse order.
for decorator in reversed(defn.decorators):
if isinstance(decorator, CallExpr) and isinstance(
decorator.analyzed, PromoteExpr
):
# _promote is a special type checking related construct.
continue
dec = self.expr_checker.accept(decorator)
temp = self.temp_node(sig, context=decorator)
fullname = None
if isinstance(decorator, RefExpr):
fullname = decorator.fullname or None
# TODO: Figure out how to have clearer error messages.
# (e.g. "class decorator must be a function that accepts a type."
old_allow_abstract_call = self.allow_abstract_call
self.allow_abstract_call = True
sig, _ = self.expr_checker.check_call(
dec, [temp], [nodes.ARG_POS], defn, callable_name=fullname
)
self.allow_abstract_call = old_allow_abstract_call
# TODO: Apply the sig to the actual TypeInfo so we can handle decorators
# that completely swap out the type. (e.g. Callable[[Type[A]], Type[B]])
if typ.defn.type_vars:
for base_inst in typ.bases:
for base_tvar, base_decl_tvar in zip(
base_inst.args, base_inst.type.defn.type_vars
):
if (
isinstance(base_tvar, TypeVarType)
and base_tvar.variance != INVARIANT
and isinstance(base_decl_tvar, TypeVarType)
and base_decl_tvar.variance != base_tvar.variance
):
self.fail(
f'Variance of TypeVar "{base_tvar.name}" incompatible '
"with variance in parent type",
context=defn,
code=codes.TYPE_VAR,
)
if typ.is_protocol and typ.defn.type_vars:
self.check_protocol_variance(defn)
if not defn.has_incompatible_baseclass and defn.info.is_enum:
self.check_enum(defn)
def check_final_deletable(self, typ: TypeInfo) -> None:
# These checks are only for mypyc. Only perform some checks that are easier
# to implement here than in mypyc.
for attr in typ.deletable_attributes:
node = typ.names.get(attr)
if node and isinstance(node.node, Var) and node.node.is_final:
self.fail(message_registry.CANNOT_MAKE_DELETABLE_FINAL, node.node)
def check_init_subclass(self, defn: ClassDef) -> None:
"""Check that keywords in a class definition are valid arguments for __init_subclass__().
In this example:
1 class Base:
2 def __init_subclass__(cls, thing: int):
3 pass
4 class Child(Base, thing=5):
5 def __init_subclass__(cls):
6 pass
7 Child()
Base.__init_subclass__(thing=5) is called at line 4. This is what we simulate here.
Child.__init_subclass__ is never called.
"""
if defn.info.metaclass_type and defn.info.metaclass_type.type.fullname not in (
"builtins.type",
"abc.ABCMeta",
):
# We can't safely check situations when both __init_subclass__ and a custom
# metaclass are present.
return
# At runtime, only Base.__init_subclass__ will be called, so
# we skip the current class itself.
for base in defn.info.mro[1:]:
if "__init_subclass__" not in base.names:
continue
name_expr = NameExpr(defn.name)
name_expr.node = base
callee = MemberExpr(name_expr, "__init_subclass__")
args = list(defn.keywords.values())
arg_names: list[str | None] = list(defn.keywords.keys())
# 'metaclass' keyword is consumed by the rest of the type machinery,
# and is never passed to __init_subclass__ implementations
if "metaclass" in arg_names:
idx = arg_names.index("metaclass")
arg_names.pop(idx)
args.pop(idx)
arg_kinds = [ARG_NAMED] * len(args)
call_expr = CallExpr(callee, args, arg_kinds, arg_names)
call_expr.line = defn.line
call_expr.column = defn.column
call_expr.end_line = defn.end_line
self.expr_checker.accept(call_expr, allow_none_return=True, always_allow_any=True)
# We are only interested in the first Base having __init_subclass__,
# all other bases have already been checked.
break
def check_enum(self, defn: ClassDef) -> None:
assert defn.info.is_enum
if defn.info.fullname not in ENUM_BASES:
for sym in defn.info.names.values():
if (
isinstance(sym.node, Var)
and sym.node.has_explicit_value
and sym.node.name == "__members__"
):
# `__members__` will always be overwritten by `Enum` and is considered
# read-only so we disallow assigning a value to it
self.fail(message_registry.ENUM_MEMBERS_ATTR_WILL_BE_OVERRIDEN, sym.node)
for base in defn.info.mro[1:-1]: # we don't need self and `object`
if base.is_enum and base.fullname not in ENUM_BASES:
self.check_final_enum(defn, base)
self.check_enum_bases(defn)
self.check_enum_new(defn)
def check_final_enum(self, defn: ClassDef, base: TypeInfo) -> None:
for sym in base.names.values():
if self.is_final_enum_value(sym):
self.fail(f'Cannot extend enum with existing members: "{base.name}"', defn)
break
def is_final_enum_value(self, sym: SymbolTableNode) -> bool:
if isinstance(sym.node, (FuncBase, Decorator)):
return False # A method is fine
if not isinstance(sym.node, Var):
return True # Can be a class or anything else
# Now, only `Var` is left, we need to check:
# 1. Private name like in `__prop = 1`
# 2. Dunder name like `__hash__ = some_hasher`
# 3. Sunder name like `_order_ = 'a, b, c'`
# 4. If it is a method / descriptor like in `method = classmethod(func)`
if (
is_private(sym.node.name)
or is_dunder(sym.node.name)
or is_sunder(sym.node.name)
# TODO: make sure that `x = @class/staticmethod(func)`
# and `x = property(prop)` both work correctly.
# Now they are incorrectly counted as enum members.
or isinstance(get_proper_type(sym.node.type), FunctionLike)
):
return False
return self.is_stub or sym.node.has_explicit_value
def check_enum_bases(self, defn: ClassDef) -> None:
"""
Non-enum mixins cannot appear after enum bases; this is disallowed at runtime:
class Foo: ...
class Bar(enum.Enum, Foo): ...
But any number of enum mixins can appear in a class definition
(even if multiple enum bases define __new__). So this is fine:
class Foo(enum.Enum):
def __new__(cls, val): ...
class Bar(enum.Enum):
def __new__(cls, val): ...
class Baz(int, Foo, Bar, enum.Flag): ...
"""
enum_base: Instance | None = None
for base in defn.info.bases:
if enum_base is None and base.type.is_enum:
enum_base = base
continue
elif enum_base is not None and not base.type.is_enum:
self.fail(
f'No non-enum mixin classes are allowed after "{enum_base.str_with_options(self.options)}"',
defn,
)
break
def check_enum_new(self, defn: ClassDef) -> None:
def has_new_method(info: TypeInfo) -> bool:
new_method = info.get("__new__")
return bool(
new_method
and new_method.node
and new_method.node.fullname != "builtins.object.__new__"
)
has_new = False
for base in defn.info.bases:
candidate = False
if base.type.is_enum:
# If we have an `Enum`, then we need to check all its bases.
candidate = any(not b.is_enum and has_new_method(b) for b in base.type.mro[1:-1])
else:
candidate = has_new_method(base.type)
if candidate and has_new:
self.fail(
"Only a single data type mixin is allowed for Enum subtypes, "
'found extra "{}"'.format(base.str_with_options(self.options)),
defn,
)
elif candidate:
has_new = True
def check_protocol_variance(self, defn: ClassDef) -> None:
"""Check that protocol definition is compatible with declared
variances of type variables.
Note that we also prohibit declaring protocol classes as invariant
if they are actually covariant/contravariant, since this may break
transitivity of subtyping, see PEP 544.
"""
info = defn.info
object_type = Instance(info.mro[-1], [])
tvars = info.defn.type_vars
for i, tvar in enumerate(tvars):
up_args: list[Type] = [
object_type if i == j else AnyType(TypeOfAny.special_form)
for j, _ in enumerate(tvars)
]
down_args: list[Type] = [
UninhabitedType() if i == j else AnyType(TypeOfAny.special_form)
for j, _ in enumerate(tvars)
]
up, down = Instance(info, up_args), Instance(info, down_args)
# TODO: add advanced variance checks for recursive protocols
if is_subtype(down, up, ignore_declared_variance=True):
expected = COVARIANT
elif is_subtype(up, down, ignore_declared_variance=True):
expected = CONTRAVARIANT
else:
expected = INVARIANT
if isinstance(tvar, TypeVarType) and expected != tvar.variance:
self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn)
def check_multiple_inheritance(self, typ: TypeInfo) -> None:
"""Check for multiple inheritance related errors."""
if len(typ.bases) <= 1:
# No multiple inheritance.
return
# Verify that inherited attributes are compatible.
mro = typ.mro[1:]
for i, base in enumerate(mro):
# Attributes defined in both the type and base are skipped.
# Normal checks for attribute compatibility should catch any problems elsewhere.
non_overridden_attrs = base.names.keys() - typ.names.keys()
for name in non_overridden_attrs:
if is_private(name):
continue
for base2 in mro[i + 1 :]:
# We only need to check compatibility of attributes from classes not
# in a subclass relationship. For subclasses, normal (single inheritance)
# checks suffice (these are implemented elsewhere).
if name in base2.names and base2 not in base.mro:
self.check_compatibility(name, base, base2, typ)
def determine_type_of_member(self, sym: SymbolTableNode) -> Type | None:
if sym.type is not None:
return sym.type
if isinstance(sym.node, FuncBase):
return self.function_type(sym.node)
if isinstance(sym.node, TypeInfo):
if sym.node.typeddict_type:
# We special-case TypedDict, because they don't define any constructor.
return self.expr_checker.typeddict_callable(sym.node)
else:
return type_object_type(sym.node, self.named_type)
if isinstance(sym.node, TypeVarExpr):
# Use of TypeVars is rejected in an expression/runtime context, so
# we don't need to check supertype compatibility for them.
return AnyType(TypeOfAny.special_form)
if isinstance(sym.node, TypeAlias):
with self.msg.filter_errors():
# Suppress any errors, they will be given when analyzing the corresponding node.
# Here we may have incorrect options and location context.
return self.expr_checker.alias_type_in_runtime_context(sym.node, ctx=sym.node)
# TODO: handle more node kinds here.
return None
def check_compatibility(
self, name: str, base1: TypeInfo, base2: TypeInfo, ctx: TypeInfo
) -> None:
"""Check if attribute name in base1 is compatible with base2 in multiple inheritance.
Assume base1 comes before base2 in the MRO, and that base1 and base2 don't have
a direct subclass relationship (i.e., the compatibility requirement only derives from
multiple inheritance).
This check verifies that a definition taken from base1 (and mapped to the current
class ctx), is type compatible with the definition taken from base2 (also mapped), so
that unsafe subclassing like this can be detected:
class A(Generic[T]):
def foo(self, x: T) -> None: ...
class B:
def foo(self, x: str) -> None: ...
class C(B, A[int]): ... # this is unsafe because...
x: A[int] = C()
x.foo # ...runtime type is (str) -> None, while static type is (int) -> None
"""
if name in ("__init__", "__new__", "__init_subclass__"):
# __init__ and friends can be incompatible -- it's a special case.
return
first = base1.names[name]
second = base2.names[name]
first_type = get_proper_type(self.determine_type_of_member(first))
second_type = get_proper_type(self.determine_type_of_member(second))
# start with the special case that Instance can be a subtype of FunctionLike
call = None
if isinstance(first_type, Instance):
call = find_member("__call__", first_type, first_type, is_operator=True)
if call and isinstance(second_type, FunctionLike):
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
ok = is_subtype(call, second_sig, ignore_pos_arg_names=True)
elif isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike):
if first_type.is_type_obj() and second_type.is_type_obj():
# For class objects only check the subtype relationship of the classes,
# since we allow incompatible overrides of '__init__'/'__new__'
ok = is_subtype(
left=fill_typevars_with_any(first_type.type_object()),
right=fill_typevars_with_any(second_type.type_object()),
)
else:
# First bind/map method types when necessary.
first_sig = self.bind_and_map_method(first, first_type, ctx, base1)
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
ok = is_subtype(first_sig, second_sig, ignore_pos_arg_names=True)
elif first_type and second_type:
if isinstance(first.node, Var):
first_type = expand_self_type(first.node, first_type, fill_typevars(ctx))
if isinstance(second.node, Var):
second_type = expand_self_type(second.node, second_type, fill_typevars(ctx))
ok = is_equivalent(first_type, second_type)
if not ok:
second_node = base2[name].node
if (
isinstance(second_type, FunctionLike)
and second_node is not None
and is_property(second_node)
):
second_type = get_property_type(second_type)
ok = is_subtype(first_type, second_type)
else:
if first_type is None:
self.msg.cannot_determine_type_in_base(name, base1.name, ctx)
if second_type is None:
self.msg.cannot_determine_type_in_base(name, base2.name, ctx)
ok = True
# Final attributes can never be overridden, but can override
# non-final read-only attributes.
if is_final_node(second.node):
self.msg.cant_override_final(name, base2.name, ctx)
if is_final_node(first.node):
self.check_if_final_var_override_writable(name, second.node, ctx)
# Some attributes like __slots__ and __deletable__ are special, and the type can
# vary across class hierarchy.
if isinstance(second.node, Var) and second.node.allow_incompatible_override:
ok = True
if not ok:
self.msg.base_class_definitions_incompatible(name, base1, base2, ctx)
def check_metaclass_compatibility(self, typ: TypeInfo) -> None:
"""Ensures that metaclasses of all parent types are compatible."""
if (
typ.is_metaclass()
or typ.is_protocol
or typ.is_named_tuple
or typ.is_enum
or typ.typeddict_type is not None
):
return # Reasonable exceptions from this check
metaclasses = [
entry.metaclass_type
for entry in typ.mro[1:-1]
if entry.metaclass_type
and not is_named_instance(entry.metaclass_type, "builtins.type")
]
if not metaclasses:
return
if typ.metaclass_type is not None and all(
is_subtype(typ.metaclass_type, meta) for meta in metaclasses
):
return
self.fail(
"Metaclass conflict: the metaclass of a derived class must be "
"a (non-strict) subclass of the metaclasses of all its bases",
typ,
)
def visit_import_from(self, node: ImportFrom) -> None:
self.check_import(node)
def visit_import_all(self, node: ImportAll) -> None:
self.check_import(node)
def visit_import(self, node: Import) -> None:
self.check_import(node)
def check_import(self, node: ImportBase) -> None:
for assign in node.assignments:
lvalue = assign.lvalues[0]
lvalue_type, _, __ = self.check_lvalue(lvalue)
if lvalue_type is None:
# TODO: This is broken.
lvalue_type = AnyType(TypeOfAny.special_form)
assert isinstance(assign.rvalue, NameExpr)
message = message_registry.INCOMPATIBLE_IMPORT_OF.format(assign.rvalue.name)
self.check_simple_assignment(
lvalue_type,
assign.rvalue,
node,
msg=message,
lvalue_name="local name",
rvalue_name="imported name",
)
#
# Statements
#
def visit_block(self, b: Block) -> None:
if b.is_unreachable:
# This block was marked as being unreachable during semantic analysis.
# It turns out any blocks marked in this way are *intentionally* marked
# as unreachable -- so we don't display an error.
self.binder.unreachable()
return
for s in b.body:
if self.binder.is_unreachable():
if not self.should_report_unreachable_issues():
break
if not self.is_noop_for_reachability(s):
self.msg.unreachable_statement(s)
break
else:
self.accept(s)
def should_report_unreachable_issues(self) -> bool:
return (
self.in_checked_function()
and self.options.warn_unreachable
and not self.current_node_deferred
and not self.binder.is_unreachable_warning_suppressed()
)
def is_noop_for_reachability(self, s: Statement) -> bool:
"""Returns 'true' if the given statement either throws an error of some kind
or is a no-op.
We use this function while handling the '--warn-unreachable' flag. When
that flag is present, we normally report an error on any unreachable statement.
But if that statement is just something like a 'pass' or a just-in-case 'assert False',
reporting an error would be annoying.
"""
if isinstance(s, AssertStmt) and is_false_literal(s.expr):
return True
elif isinstance(s, (RaiseStmt, PassStmt)):
return True
elif isinstance(s, ExpressionStmt):
if isinstance(s.expr, EllipsisExpr):
return True
elif isinstance(s.expr, CallExpr):
with self.expr_checker.msg.filter_errors():
typ = get_proper_type(
self.expr_checker.accept(
s.expr, allow_none_return=True, always_allow_any=True
)
)
if isinstance(typ, UninhabitedType):
return True
return False
def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
"""Type check an assignment statement.
Handle all kinds of assignment statements (simple, indexed, multiple).
"""
# Avoid type checking type aliases in stubs to avoid false
# positives about modern type syntax available in stubs such
# as X | Y.
if not (s.is_alias_def and self.is_stub):
with self.enter_final_context(s.is_final_def):
self.check_assignment(s.lvalues[-1], s.rvalue, s.type is None, s.new_syntax)
if s.is_alias_def:
self.check_type_alias_rvalue(s)
if (
s.type is not None
and self.options.disallow_any_unimported
and has_any_from_unimported_type(s.type)
):
if isinstance(s.lvalues[-1], TupleExpr):
# This is a multiple assignment. Instead of figuring out which type is problematic,
# give a generic error message.
self.msg.unimported_type_becomes_any(
"A type on this line", AnyType(TypeOfAny.special_form), s
)
else:
self.msg.unimported_type_becomes_any("Type of variable", s.type, s)
check_for_explicit_any(s.type, self.options, self.is_typeshed_stub, self.msg, context=s)
if len(s.lvalues) > 1:
# Chained assignment (e.g. x = y = ...).
# Make sure that rvalue type will not be reinferred.
if not self.has_type(s.rvalue):
self.expr_checker.accept(s.rvalue)
rvalue = self.temp_node(self.lookup_type(s.rvalue), s)
for lv in s.lvalues[:-1]:
with self.enter_final_context(s.is_final_def):
self.check_assignment(lv, rvalue, s.type is None)
self.check_final(s)
if (
s.is_final_def
and s.type
and not has_no_typevars(s.type)
and self.scope.active_class() is not None
):
self.fail(message_registry.DEPENDENT_FINAL_IN_CLASS_BODY, s)
if s.unanalyzed_type and not self.in_checked_function():
self.msg.annotation_in_unchecked_function(context=s)
def check_type_alias_rvalue(self, s: AssignmentStmt) -> None:
alias_type = self.expr_checker.accept(s.rvalue)
self.store_type(s.lvalues[-1], alias_type)
def check_assignment(
self,
lvalue: Lvalue,
rvalue: Expression,
infer_lvalue_type: bool = True,
new_syntax: bool = False,
) -> None:
"""Type check a single assignment: lvalue = rvalue."""
if isinstance(lvalue, (TupleExpr, ListExpr)):
self.check_assignment_to_multiple_lvalues(
lvalue.items, rvalue, rvalue, infer_lvalue_type
)
else:
self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=")
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue)
# If we're assigning to __getattr__ or similar methods, check that the signature is
# valid.
if isinstance(lvalue, NameExpr) and lvalue.node:
name = lvalue.node.name
if name in ("__setattr__", "__getattribute__", "__getattr__"):
# If an explicit type is given, use that.
if lvalue_type:
signature = lvalue_type
else:
signature = self.expr_checker.accept(rvalue)
if signature:
if name == "__setattr__":
self.check_setattr_method(signature, lvalue)
else:
self.check_getattr_method(signature, lvalue, name)
if name == "__slots__":
typ = lvalue_type or self.expr_checker.accept(rvalue)
self.check_slots_definition(typ, lvalue)
if name == "__match_args__" and inferred is not None:
typ = self.expr_checker.accept(rvalue)
self.check_match_args(inferred, typ, lvalue)
if name == "__post_init__":
if dataclasses_plugin.is_processed_dataclass(self.scope.active_class()):
self.fail(message_registry.DATACLASS_POST_INIT_MUST_BE_A_FUNCTION, rvalue)
# Defer PartialType's super type checking.
if (
isinstance(lvalue, RefExpr)
and not (isinstance(lvalue_type, PartialType) and lvalue_type.type is None)
and not (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
):
if self.check_compatibility_all_supers(lvalue, lvalue_type, rvalue):
# We hit an error on this line; don't check for any others
return
if isinstance(lvalue, MemberExpr) and lvalue.name == "__match_args__":
self.fail(message_registry.CANNOT_MODIFY_MATCH_ARGS, lvalue)
if lvalue_type:
if isinstance(lvalue_type, PartialType) and lvalue_type.type is None:
# Try to infer a proper type for a variable with a partial None type.
rvalue_type = self.expr_checker.accept(rvalue)
if isinstance(get_proper_type(rvalue_type), NoneType):
# This doesn't actually provide any additional information -- multiple
# None initializers preserve the partial None type.
return
var = lvalue_type.var
if is_valid_inferred_type(rvalue_type, is_lvalue_final=var.is_final):
partial_types = self.find_partial_types(var)
if partial_types is not None:
if not self.current_node_deferred:
# Partial type can't be final, so strip any literal values.
rvalue_type = remove_instance_last_known_values(rvalue_type)
inferred_type = make_simplified_union([rvalue_type, NoneType()])
self.set_inferred_type(var, lvalue, inferred_type)
else:
var.type = None
del partial_types[var]
lvalue_type = var.type
else:
# Try to infer a partial type. No need to check the return value, as
# an error will be reported elsewhere.
self.infer_partial_type(lvalue_type.var, lvalue, rvalue_type)
# Handle None PartialType's super type checking here, after it's resolved.
if isinstance(lvalue, RefExpr) and self.check_compatibility_all_supers(
lvalue, lvalue_type, rvalue
):
# We hit an error on this line; don't check for any others
return
elif (
is_literal_none(rvalue)
and isinstance(lvalue, NameExpr)
and isinstance(lvalue.node, Var)
and lvalue.node.is_initialized_in_class
and not new_syntax
):
# Allow None's to be assigned to class variables with non-Optional types.
rvalue_type = lvalue_type
elif (
isinstance(lvalue, MemberExpr) and lvalue.kind is None
): # Ignore member access to modules
instance_type = self.expr_checker.accept(lvalue.expr)
rvalue_type, lvalue_type, infer_lvalue_type = self.check_member_assignment(
instance_type, lvalue_type, rvalue, context=rvalue
)
else:
# Hacky special case for assigning a literal None
# to a variable defined in a previous if
# branch. When we detect this, we'll go back and
# make the type optional. This is somewhat
# unpleasant, and a generalization of this would
# be an improvement!
if (
is_literal_none(rvalue)
and isinstance(lvalue, NameExpr)
and lvalue.kind == LDEF
and isinstance(lvalue.node, Var)
and lvalue.node.type
and lvalue.node in self.var_decl_frames
and not isinstance(get_proper_type(lvalue_type), AnyType)
):
decl_frame_map = self.var_decl_frames[lvalue.node]
# Check if the nearest common ancestor frame for the definition site
# and the current site is the enclosing frame of an if/elif/else block.
has_if_ancestor = False
for frame in reversed(self.binder.frames):
if frame.id in decl_frame_map:
has_if_ancestor = frame.conditional_frame
break
if has_if_ancestor:
lvalue_type = make_optional_type(lvalue_type)
self.set_inferred_type(lvalue.node, lvalue, lvalue_type)
rvalue_type = self.check_simple_assignment(lvalue_type, rvalue, context=rvalue)
# Special case: only non-abstract non-protocol classes can be assigned to
# variables with explicit type Type[A], where A is protocol or abstract.
p_rvalue_type = get_proper_type(rvalue_type)
p_lvalue_type = get_proper_type(lvalue_type)
if (
isinstance(p_rvalue_type, CallableType)
and p_rvalue_type.is_type_obj()
and (
p_rvalue_type.type_object().is_abstract
or p_rvalue_type.type_object().is_protocol
)
and isinstance(p_lvalue_type, TypeType)
and isinstance(p_lvalue_type.item, Instance)
and (
p_lvalue_type.item.type.is_abstract or p_lvalue_type.item.type.is_protocol
)
):
self.msg.concrete_only_assign(p_lvalue_type, rvalue)
return
if rvalue_type and infer_lvalue_type and not isinstance(lvalue_type, PartialType):
# Don't use type binder for definitions of special forms, like named tuples.
if not (isinstance(lvalue, NameExpr) and lvalue.is_special_form):
self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False)
elif index_lvalue:
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
if inferred:
type_context = self.get_variable_type_context(inferred)
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
if not (
inferred.is_final
or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
):
rvalue_type = remove_instance_last_known_values(rvalue_type)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
self.check_assignment_to_slots(lvalue)
# (type, operator) tuples for augmented assignments supported with partial types
partial_type_augmented_ops: Final = {("builtins.list", "+"), ("builtins.set", "|")}
def get_variable_type_context(self, inferred: Var) -> Type | None:
type_contexts = []
if inferred.info:
for base in inferred.info.mro[1:]:
base_type, base_node = self.lvalue_type_from_base(inferred, base)
if (
base_type
and not (isinstance(base_node, Var) and base_node.invalid_partial_type)
and not isinstance(base_type, PartialType)
):
type_contexts.append(base_type)
# Use most derived supertype as type context if available.
if not type_contexts:
return None
candidate = type_contexts[0]
for other in type_contexts:
if is_proper_subtype(other, candidate):
candidate = other
elif not is_subtype(candidate, other):
# Multiple incompatible candidates, cannot use any of them as context.
return None
return candidate
def try_infer_partial_generic_type_from_assignment(
self, lvalue: Lvalue, rvalue: Expression, op: str
) -> None:
"""Try to infer a precise type for partial generic type from assignment.
'op' is '=' for normal assignment and a binary operator ('+', ...) for
augmented assignment.
Example where this happens:
x = []
if foo():
x = [1] # Infer List[int] as type of 'x'
"""
var = None
if (
isinstance(lvalue, NameExpr)
and isinstance(lvalue.node, Var)
and isinstance(lvalue.node.type, PartialType)
):
var = lvalue.node
elif isinstance(lvalue, MemberExpr):
var = self.expr_checker.get_partial_self_var(lvalue)
if var is not None:
typ = var.type
assert isinstance(typ, PartialType)
if typ.type is None:
return
# Return if this is an unsupported augmented assignment.
if op != "=" and (typ.type.fullname, op) not in self.partial_type_augmented_ops:
return
# TODO: some logic here duplicates the None partial type counterpart
# inlined in check_assignment(), see #8043.
partial_types = self.find_partial_types(var)
if partial_types is None:
return
rvalue_type = self.expr_checker.accept(rvalue)
rvalue_type = get_proper_type(rvalue_type)
if isinstance(rvalue_type, Instance):
if rvalue_type.type == typ.type and is_valid_inferred_type(rvalue_type):
var.type = rvalue_type
del partial_types[var]
elif isinstance(rvalue_type, AnyType):
var.type = fill_typevars_with_any(typ.type)
del partial_types[var]
def check_compatibility_all_supers(
self, lvalue: RefExpr, lvalue_type: Type | None, rvalue: Expression
) -> bool:
lvalue_node = lvalue.node
# Check if we are a class variable with at least one base class
if (
isinstance(lvalue_node, Var)
and lvalue.kind in (MDEF, None)
and len(lvalue_node.info.bases) > 0 # None for Vars defined via self
):
for base in lvalue_node.info.mro[1:]:
tnode = base.names.get(lvalue_node.name)
if tnode is not None:
if not self.check_compatibility_classvar_super(lvalue_node, base, tnode.node):
# Show only one error per variable
break
if not self.check_compatibility_final_super(lvalue_node, base, tnode.node):
# Show only one error per variable
break
direct_bases = lvalue_node.info.direct_base_classes()
last_immediate_base = direct_bases[-1] if direct_bases else None
for base in lvalue_node.info.mro[1:]:
# The type of "__slots__" and some other attributes usually doesn't need to
# be compatible with a base class. We'll still check the type of "__slots__"
# against "object" as an exception.
if lvalue_node.allow_incompatible_override and not (
lvalue_node.name == "__slots__" and base.fullname == "builtins.object"
):
continue
if is_private(lvalue_node.name):
continue
base_type, base_node = self.lvalue_type_from_base(lvalue_node, base)
if isinstance(base_type, PartialType):
base_type = None
if base_type:
assert base_node is not None
if not self.check_compatibility_super(
lvalue, lvalue_type, rvalue, base, base_type, base_node
):
# Only show one error per variable; even if other
# base classes are also incompatible
return True
if base is last_immediate_base:
# At this point, the attribute was found to be compatible with all
# immediate parents.
break
return False
def check_compatibility_super(
self,
lvalue: RefExpr,
lvalue_type: Type | None,
rvalue: Expression,
base: TypeInfo,
base_type: Type,
base_node: Node,
) -> bool:
lvalue_node = lvalue.node
assert isinstance(lvalue_node, Var)
# Do not check whether the rvalue is compatible if the
# lvalue had a type defined; this is handled by other
# parts, and all we have to worry about in that case is
# that lvalue is compatible with the base class.
compare_node = None
if lvalue_type:
compare_type = lvalue_type
compare_node = lvalue.node
else:
compare_type = self.expr_checker.accept(rvalue, base_type)
if isinstance(rvalue, NameExpr):
compare_node = rvalue.node
if isinstance(compare_node, Decorator):
compare_node = compare_node.func
base_type = get_proper_type(base_type)
compare_type = get_proper_type(compare_type)
if compare_type:
if isinstance(base_type, CallableType) and isinstance(compare_type, CallableType):
base_static = is_node_static(base_node)
compare_static = is_node_static(compare_node)
# In case compare_static is unknown, also check
# if 'definition' is set. The most common case for
# this is with TempNode(), where we lose all
# information about the real rvalue node (but only get
# the rvalue type)
if compare_static is None and compare_type.definition:
compare_static = is_node_static(compare_type.definition)
# Compare against False, as is_node_static can return None
if base_static is False and compare_static is False:
# Class-level function objects and classmethods become bound
# methods: the former to the instance, the latter to the
# class
base_type = bind_self(base_type, self.scope.active_self_type())
compare_type = bind_self(compare_type, self.scope.active_self_type())
# If we are a static method, ensure to also tell the
# lvalue it now contains a static method
if base_static and compare_static:
lvalue_node.is_staticmethod = True
return self.check_subtype(
compare_type,
base_type,
rvalue,
message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT,
"expression has type",
f'base class "{base.name}" defined the type as',
)
return True
def lvalue_type_from_base(
self, expr_node: Var, base: TypeInfo
) -> tuple[Type | None, Node | None]:
"""For a NameExpr that is part of a class, walk all base classes and try
to find the first class that defines a Type for the same name."""
expr_name = expr_node.name
base_var = base.names.get(expr_name)
if base_var:
base_node = base_var.node
base_type = base_var.type
if isinstance(base_node, Var) and base_type is not None:
base_type = expand_self_type(base_node, base_type, fill_typevars(expr_node.info))
if isinstance(base_node, Decorator):
base_node = base_node.func
base_type = base_node.type
if base_type:
if not has_no_typevars(base_type):
self_type = self.scope.active_self_type()
assert self_type is not None, "Internal error: base lookup outside class"
if isinstance(self_type, TupleType):
instance = tuple_fallback(self_type)
else:
instance = self_type
itype = map_instance_to_supertype(instance, base)
base_type = expand_type_by_instance(base_type, itype)
base_type = get_proper_type(base_type)
if isinstance(base_type, CallableType) and isinstance(base_node, FuncDef):
# If we are a property, return the Type of the return
# value, not the Callable
if base_node.is_property:
base_type = get_proper_type(base_type.ret_type)
if isinstance(base_type, FunctionLike) and isinstance(
base_node, OverloadedFuncDef
):
# Same for properties with setter
if base_node.is_property:
base_type = base_type.items[0].ret_type
return base_type, base_node
return None, None
def check_compatibility_classvar_super(
self, node: Var, base: TypeInfo, base_node: Node | None
) -> bool:
if not isinstance(base_node, Var):
return True
if node.is_classvar and not base_node.is_classvar:
self.fail(message_registry.CANNOT_OVERRIDE_INSTANCE_VAR.format(base.name), node)
return False
elif not node.is_classvar and base_node.is_classvar:
self.fail(message_registry.CANNOT_OVERRIDE_CLASS_VAR.format(base.name), node)
return False
return True
def check_compatibility_final_super(
self, node: Var, base: TypeInfo, base_node: Node | None
) -> bool:
"""Check if an assignment overrides a final attribute in a base class.
This only checks situations where either a node in base class is not a variable
but a final method, or where override is explicitly declared as final.
In these cases we give a more detailed error message. In addition, we check that
a final variable doesn't override writeable attribute, which is not safe.
Other situations are checked in `check_final()`.
"""
if not isinstance(base_node, (Var, FuncBase, Decorator)):
return True
if base_node.is_final and (node.is_final or not isinstance(base_node, Var)):
# Give this error only for explicit override attempt with `Final`, or
# if we are overriding a final method with variable.
# Other override attempts will be flagged as assignment to constant
# in `check_final()`.
self.msg.cant_override_final(node.name, base.name, node)
return False
if node.is_final:
if base.fullname in ENUM_BASES or node.name in ENUM_SPECIAL_PROPS:
return True
self.check_if_final_var_override_writable(node.name, base_node, node)
return True
def check_if_final_var_override_writable(
self, name: str, base_node: Node | None, ctx: Context
) -> None:
"""Check that a final variable doesn't override writeable attribute.
This is done to prevent situations like this:
class C:
attr = 1
class D(C):
attr: Final = 2
x: C = D()
x.attr = 3 # Oops!
"""
writable = True
if base_node:
writable = self.is_writable_attribute(base_node)
if writable:
self.msg.final_cant_override_writable(name, ctx)
def get_final_context(self) -> bool:
"""Check whether we a currently checking a final declaration."""
return self._is_final_def
@contextmanager
def enter_final_context(self, is_final_def: bool) -> Iterator[None]:
"""Store whether the current checked assignment is a final declaration."""
old_ctx = self._is_final_def
self._is_final_def = is_final_def
try:
yield
finally:
self._is_final_def = old_ctx
def check_final(self, s: AssignmentStmt | OperatorAssignmentStmt | AssignmentExpr) -> None:
"""Check if this assignment does not assign to a final attribute.
This function performs the check only for name assignments at module
and class scope. The assignments to `obj.attr` and `Cls.attr` are checked
in checkmember.py.
"""
if isinstance(s, AssignmentStmt):
lvs = self.flatten_lvalues(s.lvalues)
elif isinstance(s, AssignmentExpr):
lvs = [s.target]
else:
lvs = [s.lvalue]
is_final_decl = s.is_final_def if isinstance(s, AssignmentStmt) else False
if is_final_decl and self.scope.active_class():
lv = lvs[0]
assert isinstance(lv, RefExpr)
if lv.node is not None:
assert isinstance(lv.node, Var)
if (
lv.node.final_unset_in_class
and not lv.node.final_set_in_init
and not self.is_stub
and # It is OK to skip initializer in stub files.
# Avoid extra error messages, if there is no type in Final[...],
# then we already reported the error about missing r.h.s.
isinstance(s, AssignmentStmt)
and s.type is not None
):
self.msg.final_without_value(s)
for lv in lvs:
if isinstance(lv, RefExpr) and isinstance(lv.node, Var):
name = lv.node.name
cls = self.scope.active_class()
if cls is not None:
# These additional checks exist to give more error messages
# even if the final attribute was overridden with a new symbol
# (which is itself an error)...
for base in cls.mro[1:]:
sym = base.names.get(name)
# We only give this error if base node is variable,
# overriding final method will be caught in
# `check_compatibility_final_super()`.
if sym and isinstance(sym.node, Var):
if sym.node.is_final and not is_final_decl:
self.msg.cant_assign_to_final(name, sym.node.info is None, s)
# ...but only once
break
if lv.node.is_final and not is_final_decl:
self.msg.cant_assign_to_final(name, lv.node.info is None, s)
def check_assignment_to_slots(self, lvalue: Lvalue) -> None:
if not isinstance(lvalue, MemberExpr):
return
inst = get_proper_type(self.expr_checker.accept(lvalue.expr))
if not isinstance(inst, Instance):
return
if inst.type.slots is None:
return # Slots do not exist, we can allow any assignment
if lvalue.name in inst.type.slots:
return # We are assigning to an existing slot
for base_info in inst.type.mro[:-1]:
if base_info.names.get("__setattr__") is not None:
# When type has `__setattr__` defined,
# we can assign any dynamic value.
# We exclude object, because it always has `__setattr__`.
return
definition = inst.type.get(lvalue.name)
if definition is None:
# We don't want to duplicate
# `"SomeType" has no attribute "some_attr"`
# error twice.
return
if self.is_assignable_slot(lvalue, definition.type):
return
self.fail(
message_registry.NAME_NOT_IN_SLOTS.format(lvalue.name, inst.type.fullname), lvalue
)
def is_assignable_slot(self, lvalue: Lvalue, typ: Type | None) -> bool:
if getattr(lvalue, "node", None):
return False # This is a definition
typ = get_proper_type(typ)
if typ is None or isinstance(typ, AnyType):
return True # Any can be literally anything, like `@propery`
if isinstance(typ, Instance):
# When working with instances, we need to know if they contain
# `__set__` special method. Like `@property` does.
# This makes assigning to properties possible,
# even without extra slot spec.
return typ.type.get("__set__") is not None
if isinstance(typ, FunctionLike):
return True # Can be a property, or some other magic
if isinstance(typ, UnionType):
return all(self.is_assignable_slot(lvalue, u) for u in typ.items)
return False
def check_assignment_to_multiple_lvalues(
self,
lvalues: list[Lvalue],
rvalue: Expression,
context: Context,
infer_lvalue_type: bool = True,
) -> None:
if isinstance(rvalue, (TupleExpr, ListExpr)):
# Recursively go into Tuple or List expression rhs instead of
# using the type of rhs, because this allowed more fine grained
# control in cases like: a, b = [int, str] where rhs would get
# type List[object]
rvalues: list[Expression] = []
iterable_type: Type | None = None
last_idx: int | None = None
for idx_rval, rval in enumerate(rvalue.items):
if isinstance(rval, StarExpr):
typs = get_proper_type(self.expr_checker.accept(rval.expr))
if isinstance(typs, TupleType):
rvalues.extend([TempNode(typ) for typ in typs.items])
elif self.type_is_iterable(typs) and isinstance(typs, Instance):
if iterable_type is not None and iterable_type != self.iterable_item_type(
typs, rvalue
):
self.fail(message_registry.CONTIGUOUS_ITERABLE_EXPECTED, context)
else:
if last_idx is None or last_idx + 1 == idx_rval:
rvalues.append(rval)
last_idx = idx_rval
iterable_type = self.iterable_item_type(typs, rvalue)
else:
self.fail(message_registry.CONTIGUOUS_ITERABLE_EXPECTED, context)
else:
self.fail(message_registry.ITERABLE_TYPE_EXPECTED.format(typs), context)
else:
rvalues.append(rval)
iterable_start: int | None = None
iterable_end: int | None = None
for i, rval in enumerate(rvalues):
if isinstance(rval, StarExpr):
typs = get_proper_type(self.expr_checker.accept(rval.expr))
if self.type_is_iterable(typs) and isinstance(typs, Instance):
if iterable_start is None:
iterable_start = i
iterable_end = i
if (
iterable_start is not None
and iterable_end is not None
and iterable_type is not None
):
iterable_num = iterable_end - iterable_start + 1
rvalue_needed = len(lvalues) - (len(rvalues) - iterable_num)
if rvalue_needed > 0:
rvalues = (
rvalues[0:iterable_start]
+ [TempNode(iterable_type) for i in range(rvalue_needed)]
+ rvalues[iterable_end + 1 :]
)
if self.check_rvalue_count_in_assignment(lvalues, len(rvalues), context):
star_index = next(
(i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues)
)
left_lvs = lvalues[:star_index]
star_lv = (
cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None
)
right_lvs = lvalues[star_index + 1 :]
left_rvs, star_rvs, right_rvs = self.split_around_star(
rvalues, star_index, len(lvalues)
)
lr_pairs = list(zip(left_lvs, left_rvs))
if star_lv:
rv_list = ListExpr(star_rvs)
rv_list.set_line(rvalue)
lr_pairs.append((star_lv.expr, rv_list))
lr_pairs.extend(zip(right_lvs, right_rvs))
for lv, rv in lr_pairs:
self.check_assignment(lv, rv, infer_lvalue_type)
else:
self.check_multi_assignment(lvalues, rvalue, context, infer_lvalue_type)
def check_rvalue_count_in_assignment(
self, lvalues: list[Lvalue], rvalue_count: int, context: Context
) -> bool:
if any(isinstance(lvalue, StarExpr) for lvalue in lvalues):
if len(lvalues) - 1 > rvalue_count:
self.msg.wrong_number_values_to_unpack(rvalue_count, len(lvalues) - 1, context)
return False
elif rvalue_count != len(lvalues):
self.msg.wrong_number_values_to_unpack(rvalue_count, len(lvalues), context)
return False
return True
def check_multi_assignment(
self,
lvalues: list[Lvalue],
rvalue: Expression,
context: Context,
infer_lvalue_type: bool = True,
rv_type: Type | None = None,
undefined_rvalue: bool = False,
) -> None:
"""Check the assignment of one rvalue to a number of lvalues."""
# Infer the type of an ordinary rvalue expression.
# TODO: maybe elsewhere; redundant.
rvalue_type = get_proper_type(rv_type or self.expr_checker.accept(rvalue))
if isinstance(rvalue_type, TypeVarLikeType):
rvalue_type = get_proper_type(rvalue_type.upper_bound)
if isinstance(rvalue_type, UnionType):
# If this is an Optional type in non-strict Optional code, unwrap it.
relevant_items = rvalue_type.relevant_items()
if len(relevant_items) == 1:
rvalue_type = get_proper_type(relevant_items[0])
if isinstance(rvalue_type, AnyType):
for lv in lvalues:
if isinstance(lv, StarExpr):
lv = lv.expr
temp_node = self.temp_node(
AnyType(TypeOfAny.from_another_any, source_any=rvalue_type), context
)
self.check_assignment(lv, temp_node, infer_lvalue_type)
elif isinstance(rvalue_type, TupleType):
self.check_multi_assignment_from_tuple(
lvalues, rvalue, rvalue_type, context, undefined_rvalue, infer_lvalue_type
)
elif isinstance(rvalue_type, UnionType):
self.check_multi_assignment_from_union(
lvalues, rvalue, rvalue_type, context, infer_lvalue_type
)
elif isinstance(rvalue_type, Instance) and rvalue_type.type.fullname == "builtins.str":
self.msg.unpacking_strings_disallowed(context)
else:
self.check_multi_assignment_from_iterable(
lvalues, rvalue_type, context, infer_lvalue_type
)
def check_multi_assignment_from_union(
self,
lvalues: list[Expression],
rvalue: Expression,
rvalue_type: UnionType,
context: Context,
infer_lvalue_type: bool,
) -> None:
"""Check assignment to multiple lvalue targets when rvalue type is a Union[...].
For example:
t: Union[Tuple[int, int], Tuple[str, str]]
x, y = t
reveal_type(x) # Union[int, str]
The idea in this case is to process the assignment for every item of the union.
Important note: the types are collected in two places, 'union_types' contains
inferred types for first assignments, 'assignments' contains the narrowed types
for binder.
"""
self.no_partial_types = True
transposed: tuple[list[Type], ...] = tuple([] for _ in self.flatten_lvalues(lvalues))
# Notify binder that we want to defer bindings and instead collect types.
with self.binder.accumulate_type_assignments() as assignments:
for item in rvalue_type.items:
# Type check the assignment separately for each union item and collect
# the inferred lvalue types for each union item.
self.check_multi_assignment(
lvalues,
rvalue,
context,
infer_lvalue_type=infer_lvalue_type,
rv_type=item,
undefined_rvalue=True,
)
for t, lv in zip(transposed, self.flatten_lvalues(lvalues)):
# We can access _type_maps directly since temporary type maps are
# only created within expressions.
t.append(self._type_maps[0].pop(lv, AnyType(TypeOfAny.special_form)))
union_types = tuple(make_simplified_union(col) for col in transposed)
for expr, items in assignments.items():
# Bind a union of types collected in 'assignments' to every expression.
if isinstance(expr, StarExpr):
expr = expr.expr
# TODO: See todo in binder.py, ConditionalTypeBinder.assign_type
# It's unclear why the 'declared_type' param is sometimes 'None'
clean_items: list[tuple[Type, Type]] = []
for type, declared_type in items:
assert declared_type is not None
clean_items.append((type, declared_type))
types, declared_types = zip(*clean_items)
self.binder.assign_type(
expr,
make_simplified_union(list(types)),
make_simplified_union(list(declared_types)),
False,
)
for union, lv in zip(union_types, self.flatten_lvalues(lvalues)):
# Properly store the inferred types.
_1, _2, inferred = self.check_lvalue(lv)
if inferred:
self.set_inferred_type(inferred, lv, union)
else:
self.store_type(lv, union)
self.no_partial_types = False
def flatten_lvalues(self, lvalues: list[Expression]) -> list[Expression]:
res: list[Expression] = []
for lv in lvalues:
if isinstance(lv, (TupleExpr, ListExpr)):
res.extend(self.flatten_lvalues(lv.items))
if isinstance(lv, StarExpr):
# Unwrap StarExpr, since it is unwrapped by other helpers.
lv = lv.expr
res.append(lv)
return res
def check_multi_assignment_from_tuple(
self,
lvalues: list[Lvalue],
rvalue: Expression,
rvalue_type: TupleType,
context: Context,
undefined_rvalue: bool,
infer_lvalue_type: bool = True,
) -> None:
if self.check_rvalue_count_in_assignment(lvalues, len(rvalue_type.items), context):
star_index = next(
(i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues)
)
left_lvs = lvalues[:star_index]
star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None
right_lvs = lvalues[star_index + 1 :]
if not undefined_rvalue:
# Infer rvalue again, now in the correct type context.
lvalue_type = self.lvalue_type_for_inference(lvalues, rvalue_type)
reinferred_rvalue_type = get_proper_type(
self.expr_checker.accept(rvalue, lvalue_type)
)
if isinstance(reinferred_rvalue_type, UnionType):
# If this is an Optional type in non-strict Optional code, unwrap it.
relevant_items = reinferred_rvalue_type.relevant_items()
if len(relevant_items) == 1:
reinferred_rvalue_type = get_proper_type(relevant_items[0])
if isinstance(reinferred_rvalue_type, UnionType):
self.check_multi_assignment_from_union(
lvalues, rvalue, reinferred_rvalue_type, context, infer_lvalue_type
)
return
if isinstance(reinferred_rvalue_type, AnyType):
# We can get Any if the current node is
# deferred. Doing more inference in deferred nodes
# is hard, so give up for now. We can also get
# here if reinferring types above changes the
# inferred return type for an overloaded function
# to be ambiguous.
return
assert isinstance(reinferred_rvalue_type, TupleType)
rvalue_type = reinferred_rvalue_type
left_rv_types, star_rv_types, right_rv_types = self.split_around_star(
rvalue_type.items, star_index, len(lvalues)
)
for lv, rv_type in zip(left_lvs, left_rv_types):
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)
if star_lv:
list_expr = ListExpr(
[self.temp_node(rv_type, context) for rv_type in star_rv_types]
)
list_expr.set_line(context)
self.check_assignment(star_lv.expr, list_expr, infer_lvalue_type)
for lv, rv_type in zip(right_lvs, right_rv_types):
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)
def lvalue_type_for_inference(self, lvalues: list[Lvalue], rvalue_type: TupleType) -> Type:
star_index = next(
(i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues)
)
left_lvs = lvalues[:star_index]
star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None
right_lvs = lvalues[star_index + 1 :]
left_rv_types, star_rv_types, right_rv_types = self.split_around_star(
rvalue_type.items, star_index, len(lvalues)
)
type_parameters: list[Type] = []
def append_types_for_inference(lvs: list[Expression], rv_types: list[Type]) -> None:
for lv, rv_type in zip(lvs, rv_types):
sub_lvalue_type, index_expr, inferred = self.check_lvalue(lv)
if sub_lvalue_type and not isinstance(sub_lvalue_type, PartialType):
type_parameters.append(sub_lvalue_type)
else: # index lvalue
# TODO Figure out more precise type context, probably
# based on the type signature of the _set method.
type_parameters.append(rv_type)
append_types_for_inference(left_lvs, left_rv_types)
if star_lv:
sub_lvalue_type, index_expr, inferred = self.check_lvalue(star_lv.expr)
if sub_lvalue_type and not isinstance(sub_lvalue_type, PartialType):
type_parameters.extend([sub_lvalue_type] * len(star_rv_types))
else: # index lvalue
# TODO Figure out more precise type context, probably
# based on the type signature of the _set method.
type_parameters.extend(star_rv_types)
append_types_for_inference(right_lvs, right_rv_types)
return TupleType(type_parameters, self.named_type("builtins.tuple"))
def split_around_star(
self, items: list[T], star_index: int, length: int
) -> tuple[list[T], list[T], list[T]]:
"""Splits a list of items in three to match another list of length 'length'
that contains a starred expression at 'star_index' in the following way:
star_index = 2, length = 5 (i.e., [a,b,*,c,d]), items = [1,2,3,4,5,6,7]
returns in: ([1,2], [3,4,5], [6,7])
"""
nr_right_of_star = length - star_index - 1
right_index = -nr_right_of_star if nr_right_of_star != 0 else len(items)
left = items[:star_index]
star = items[star_index:right_index]
right = items[right_index:]
return left, star, right
def type_is_iterable(self, type: Type) -> bool:
type = get_proper_type(type)
if isinstance(type, CallableType) and type.is_type_obj():
type = type.fallback
return is_subtype(
type, self.named_generic_type("typing.Iterable", [AnyType(TypeOfAny.special_form)])
)
def check_multi_assignment_from_iterable(
self,
lvalues: list[Lvalue],
rvalue_type: Type,
context: Context,
infer_lvalue_type: bool = True,
) -> None:
rvalue_type = get_proper_type(rvalue_type)
if self.type_is_iterable(rvalue_type) and isinstance(
rvalue_type, (Instance, CallableType, TypeType, Overloaded)
):
item_type = self.iterable_item_type(rvalue_type, context)
for lv in lvalues:
if isinstance(lv, StarExpr):
items_type = self.named_generic_type("builtins.list", [item_type])
self.check_assignment(
lv.expr, self.temp_node(items_type, context), infer_lvalue_type
)
else:
self.check_assignment(
lv, self.temp_node(item_type, context), infer_lvalue_type
)
else:
self.msg.type_not_iterable(rvalue_type, context)
def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, Var | None]:
lvalue_type = None
index_lvalue = None
inferred = None
if self.is_definition(lvalue) and (
not isinstance(lvalue, NameExpr) or isinstance(lvalue.node, Var)
):
if isinstance(lvalue, NameExpr):
assert isinstance(lvalue.node, Var)
inferred = lvalue.node
else:
assert isinstance(lvalue, MemberExpr)
self.expr_checker.accept(lvalue.expr)
inferred = lvalue.def_var
elif isinstance(lvalue, IndexExpr):
index_lvalue = lvalue
elif isinstance(lvalue, MemberExpr):
lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True)
self.store_type(lvalue, lvalue_type)
elif isinstance(lvalue, NameExpr):
lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True)
self.store_type(lvalue, lvalue_type)
elif isinstance(lvalue, (TupleExpr, ListExpr)):
types = [
self.check_lvalue(sub_expr)[0] or
# This type will be used as a context for further inference of rvalue,
# we put Uninhabited if there is no information available from lvalue.
UninhabitedType()
for sub_expr in lvalue.items
]
lvalue_type = TupleType(types, self.named_type("builtins.tuple"))
elif isinstance(lvalue, StarExpr):
lvalue_type, _, _ = self.check_lvalue(lvalue.expr)
else:
lvalue_type = self.expr_checker.accept(lvalue)
return lvalue_type, index_lvalue, inferred
def is_definition(self, s: Lvalue) -> bool:
if isinstance(s, NameExpr):
if s.is_inferred_def:
return True
# If the node type is not defined, this must the first assignment
# that we process => this is a definition, even though the semantic
# analyzer did not recognize this as such. This can arise in code
# that uses isinstance checks, if type checking of the primary
# definition is skipped due to an always False type check.
node = s.node
if isinstance(node, Var):
return node.type is None
elif isinstance(s, MemberExpr):
return s.is_inferred_def
return False
def infer_variable_type(
self, name: Var, lvalue: Lvalue, init_type: Type, context: Context
) -> None:
"""Infer the type of initialized variables from initializer type."""
if isinstance(init_type, DeletedType):
self.msg.deleted_as_rvalue(init_type, context)
elif (
not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final)
and not self.no_partial_types
):
# We cannot use the type of the initialization expression for full type
# inference (it's not specific enough), but we might be able to give
# partial type which will be made more specific later. A partial type
# gets generated in assignment like 'x = []' where item type is not known.
if not self.infer_partial_type(name, lvalue, init_type):
self.msg.need_annotation_for_var(name, context, self.options.python_version)
self.set_inference_error_fallback_type(name, lvalue, init_type)
elif (
isinstance(lvalue, MemberExpr)
and self.inferred_attribute_types is not None
and lvalue.def_var
and lvalue.def_var in self.inferred_attribute_types
and not is_same_type(self.inferred_attribute_types[lvalue.def_var], init_type)
):
# Multiple, inconsistent types inferred for an attribute.
self.msg.need_annotation_for_var(name, context, self.options.python_version)
name.type = AnyType(TypeOfAny.from_error)
else:
# Infer type of the target.
# Make the type more general (strip away function names etc.).
init_type = strip_type(init_type)
self.set_inferred_type(name, lvalue, init_type)
def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool:
init_type = get_proper_type(init_type)
if isinstance(init_type, NoneType):
partial_type = PartialType(None, name)
elif isinstance(init_type, Instance):
fullname = init_type.type.fullname
is_ref = isinstance(lvalue, RefExpr)
if (
is_ref
and (
fullname == "builtins.list"
or fullname == "builtins.set"
or fullname == "builtins.dict"
or fullname == "collections.OrderedDict"
)
and all(
isinstance(t, (NoneType, UninhabitedType))
for t in get_proper_types(init_type.args)
)
):
partial_type = PartialType(init_type.type, name)
elif is_ref and fullname == "collections.defaultdict":
arg0 = get_proper_type(init_type.args[0])
arg1 = get_proper_type(init_type.args[1])
if isinstance(
arg0, (NoneType, UninhabitedType)
) and self.is_valid_defaultdict_partial_value_type(arg1):
arg1 = erase_type(arg1)
assert isinstance(arg1, Instance)
partial_type = PartialType(init_type.type, name, arg1)
else:
return False
else:
return False
else:
return False
self.set_inferred_type(name, lvalue, partial_type)
self.partial_types[-1].map[name] = lvalue
return True
def is_valid_defaultdict_partial_value_type(self, t: ProperType) -> bool:
"""Check if t can be used as the basis for a partial defaultdict value type.
Examples:
* t is 'int' --> True
* t is 'list[<nothing>]' --> True
* t is 'dict[...]' --> False (only generic types with a single type
argument supported)
"""
if not isinstance(t, Instance):
return False
if len(t.args) == 0:
return True
if len(t.args) == 1:
arg = get_proper_type(t.args[0])
if self.options.new_type_inference:
allowed = isinstance(arg, (UninhabitedType, NoneType))
else:
# Allow leaked TypeVars for legacy inference logic.
allowed = isinstance(arg, (UninhabitedType, NoneType, TypeVarType))
if allowed:
return True
return False
def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
"""Store inferred variable type.
Store the type to both the variable node and the expression node that
refers to the variable (lvalue). If var is None, do nothing.
"""
if var and not self.current_node_deferred:
var.type = type
var.is_inferred = True
if var not in self.var_decl_frames:
# Used for the hack to improve optional type inference in conditionals
self.var_decl_frames[var] = {frame.id for frame in self.binder.frames}
if isinstance(lvalue, MemberExpr) and self.inferred_attribute_types is not None:
# Store inferred attribute type so that we can check consistency afterwards.
if lvalue.def_var is not None:
self.inferred_attribute_types[lvalue.def_var] = type
self.store_type(lvalue, type)
def set_inference_error_fallback_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
"""Store best known type for variable if type inference failed.
If a program ignores error on type inference error, the variable should get some
inferred type so that if can used later on in the program. Example:
x = [] # type: ignore
x.append(1) # Should be ok!
We implement this here by giving x a valid type (replacing inferred <nothing> with Any).
"""
fallback = self.inference_error_fallback_type(type)
self.set_inferred_type(var, lvalue, fallback)
def inference_error_fallback_type(self, type: Type) -> Type:
fallback = type.accept(SetNothingToAny())
# Type variables may leak from inference, see https://github.com/python/mypy/issues/5738,
# we therefore need to erase them.
return erase_typevars(fallback)
def simple_rvalue(self, rvalue: Expression) -> bool:
"""Returns True for expressions for which inferred type should not depend on context.
Note that this function can still return False for some expressions where inferred type
does not depend on context. It only exists for performance optimizations.
"""
if isinstance(rvalue, (IntExpr, StrExpr, BytesExpr, FloatExpr, RefExpr)):
return True
if isinstance(rvalue, CallExpr):
if isinstance(rvalue.callee, RefExpr) and isinstance(rvalue.callee.node, FuncBase):
typ = rvalue.callee.node.type
if isinstance(typ, CallableType):
return not typ.variables
elif isinstance(typ, Overloaded):
return not any(item.variables for item in typ.items)
return False
def check_simple_assignment(
self,
lvalue_type: Type | None,
rvalue: Expression,
context: Context,
msg: ErrorMessage = message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT,
lvalue_name: str = "variable",
rvalue_name: str = "expression",
*,
notes: list[str] | None = None,
) -> Type:
if self.is_stub and isinstance(rvalue, EllipsisExpr):
# '...' is always a valid initializer in a stub.
return AnyType(TypeOfAny.special_form)
else:
always_allow_any = lvalue_type is not None and not isinstance(
get_proper_type(lvalue_type), AnyType
)
rvalue_type = self.expr_checker.accept(
rvalue, lvalue_type, always_allow_any=always_allow_any
)
if (
isinstance(get_proper_type(lvalue_type), UnionType)
# Skip literal types, as they have special logic (for better errors).
and not isinstance(get_proper_type(rvalue_type), LiteralType)
and not self.simple_rvalue(rvalue)
):
# Try re-inferring r.h.s. in empty context, and use that if it
# results in a narrower type. We don't do this always because this
# may cause some perf impact, plus we want to partially preserve
# the old behavior. This helps with various practical examples, see
# e.g. testOptionalTypeNarrowedByGenericCall.
with self.msg.filter_errors() as local_errors, self.local_type_map() as type_map:
alt_rvalue_type = self.expr_checker.accept(
rvalue, None, always_allow_any=always_allow_any
)
if (
not local_errors.has_new_errors()
# Skip Any type, since it is special cased in binder.
and not isinstance(get_proper_type(alt_rvalue_type), AnyType)
and is_valid_inferred_type(alt_rvalue_type)
and is_proper_subtype(alt_rvalue_type, rvalue_type)
):
rvalue_type = alt_rvalue_type
self.store_types(type_map)
if isinstance(rvalue_type, DeletedType):
self.msg.deleted_as_rvalue(rvalue_type, context)
if isinstance(lvalue_type, DeletedType):
self.msg.deleted_as_lvalue(lvalue_type, context)
elif lvalue_type:
self.check_subtype(
# Preserve original aliases for error messages when possible.
rvalue_type,
lvalue_type,
context,
msg,
f"{rvalue_name} has type",
f"{lvalue_name} has type",
notes=notes,
)
return rvalue_type
def check_member_assignment(
self, instance_type: Type, attribute_type: Type, rvalue: Expression, context: Context
) -> tuple[Type, Type, bool]:
"""Type member assignment.
This defers to check_simple_assignment, unless the member expression
is a descriptor, in which case this checks descriptor semantics as well.
Return the inferred rvalue_type, inferred lvalue_type, and whether to use the binder
for this assignment.
Note: this method exists here and not in checkmember.py, because we need to take
care about interaction between binder and __set__().
"""
instance_type = get_proper_type(instance_type)
attribute_type = get_proper_type(attribute_type)
# Descriptors don't participate in class-attribute access
if (isinstance(instance_type, FunctionLike) and instance_type.is_type_obj()) or isinstance(
instance_type, TypeType
):
rvalue_type = self.check_simple_assignment(attribute_type, rvalue, context)
return rvalue_type, attribute_type, True
if not isinstance(attribute_type, Instance):
# TODO: support __set__() for union types.
rvalue_type = self.check_simple_assignment(attribute_type, rvalue, context)
return rvalue_type, attribute_type, True
mx = MemberContext(
is_lvalue=False,
is_super=False,
is_operator=False,
original_type=instance_type,
context=context,
self_type=None,
msg=self.msg,
chk=self,
)
get_type = analyze_descriptor_access(attribute_type, mx)
if not attribute_type.type.has_readable_member("__set__"):
# If there is no __set__, we type-check that the assigned value matches
# the return type of __get__. This doesn't match the python semantics,
# (which allow you to override the descriptor with any value), but preserves
# the type of accessing the attribute (even after the override).
rvalue_type = self.check_simple_assignment(get_type, rvalue, context)
return rvalue_type, get_type, True
dunder_set = attribute_type.type.get_method("__set__")
if dunder_set is None:
self.fail(
message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format(
attribute_type.str_with_options(self.options)
),
context,
)
return AnyType(TypeOfAny.from_error), get_type, False
bound_method = analyze_decorator_or_funcbase_access(
defn=dunder_set,
itype=attribute_type,
info=attribute_type.type,
self_type=attribute_type,
name="__set__",
mx=mx,
)
typ = map_instance_to_supertype(attribute_type, dunder_set.info)
dunder_set_type = expand_type_by_instance(bound_method, typ)
callable_name = self.expr_checker.method_fullname(attribute_type, "__set__")
dunder_set_type = self.expr_checker.transform_callee_type(
callable_name,
dunder_set_type,
[TempNode(instance_type, context=context), rvalue],
[nodes.ARG_POS, nodes.ARG_POS],
context,
object_type=attribute_type,
)
# For non-overloaded setters, the result should be type-checked like a regular assignment.
# Hence, we first only try to infer the type by using the rvalue as type context.
type_context = rvalue
with self.msg.filter_errors():
_, inferred_dunder_set_type = self.expr_checker.check_call(
dunder_set_type,
[TempNode(instance_type, context=context), type_context],
[nodes.ARG_POS, nodes.ARG_POS],
context,
object_type=attribute_type,
callable_name=callable_name,
)
# And now we in fact type check the call, to show errors related to wrong arguments
# count, etc., replacing the type context for non-overloaded setters only.
inferred_dunder_set_type = get_proper_type(inferred_dunder_set_type)
if isinstance(inferred_dunder_set_type, CallableType):
type_context = TempNode(AnyType(TypeOfAny.special_form), context=context)
self.expr_checker.check_call(
dunder_set_type,
[TempNode(instance_type, context=context), type_context],
[nodes.ARG_POS, nodes.ARG_POS],
context,
object_type=attribute_type,
callable_name=callable_name,
)
# In the following cases, a message already will have been recorded in check_call.
if (not isinstance(inferred_dunder_set_type, CallableType)) or (
len(inferred_dunder_set_type.arg_types) < 2
):
return AnyType(TypeOfAny.from_error), get_type, False
set_type = inferred_dunder_set_type.arg_types[1]
# Special case: if the rvalue_type is a subtype of both '__get__' and '__set__' types,
# and '__get__' type is narrower than '__set__', then we invoke the binder to narrow type
# by this assignment. Technically, this is not safe, but in practice this is
# what a user expects.
rvalue_type = self.check_simple_assignment(set_type, rvalue, context)
infer = is_subtype(rvalue_type, get_type) and is_subtype(get_type, set_type)
return rvalue_type if infer else set_type, get_type, infer
def check_indexed_assignment(
self, lvalue: IndexExpr, rvalue: Expression, context: Context
) -> None:
"""Type check indexed assignment base[index] = rvalue.
The lvalue argument is the base[index] expression.
"""
self.try_infer_partial_type_from_indexed_assignment(lvalue, rvalue)
basetype = get_proper_type(self.expr_checker.accept(lvalue.base))
method_type = self.expr_checker.analyze_external_member_access(
"__setitem__", basetype, lvalue
)
lvalue.method_type = method_type
res_type, _ = self.expr_checker.check_method_call(
"__setitem__",
basetype,
method_type,
[lvalue.index, rvalue],
[nodes.ARG_POS, nodes.ARG_POS],
context,
)
res_type = get_proper_type(res_type)
if isinstance(res_type, UninhabitedType) and not res_type.ambiguous:
self.binder.unreachable()
def try_infer_partial_type_from_indexed_assignment(
self, lvalue: IndexExpr, rvalue: Expression
) -> None:
# TODO: Should we share some of this with try_infer_partial_type?
var = None
if isinstance(lvalue.base, RefExpr) and isinstance(lvalue.base.node, Var):
var = lvalue.base.node
elif isinstance(lvalue.base, MemberExpr):
var = self.expr_checker.get_partial_self_var(lvalue.base)
if isinstance(var, Var):
if isinstance(var.type, PartialType):
type_type = var.type.type
if type_type is None:
return # The partial type is None.
partial_types = self.find_partial_types(var)
if partial_types is None:
return
typename = type_type.fullname
if (
typename == "builtins.dict"
or typename == "collections.OrderedDict"
or typename == "collections.defaultdict"
):
# TODO: Don't infer things twice.
key_type = self.expr_checker.accept(lvalue.index)
value_type = self.expr_checker.accept(rvalue)
if (
is_valid_inferred_type(key_type)
and is_valid_inferred_type(value_type)
and not self.current_node_deferred
and not (
typename == "collections.defaultdict"
and var.type.value_type is not None
and not is_equivalent(value_type, var.type.value_type)
)
):
var.type = self.named_generic_type(typename, [key_type, value_type])
del partial_types[var]
def type_requires_usage(self, typ: Type) -> tuple[str, ErrorCode] | None:
"""Some types require usage in all cases. The classic example is
an unused coroutine.
In the case that it does require usage, returns a note to attach
to the error message.
"""
proper_type = get_proper_type(typ)
if isinstance(proper_type, Instance):
# We use different error codes for generic awaitable vs coroutine.
# Coroutines are on by default, whereas generic awaitables are not.
if proper_type.type.fullname == "typing.Coroutine":
return ("Are you missing an await?", UNUSED_COROUTINE)
if proper_type.type.get("__await__") is not None:
return ("Are you missing an await?", UNUSED_AWAITABLE)
return None
def visit_expression_stmt(self, s: ExpressionStmt) -> None:
expr_type = self.expr_checker.accept(s.expr, allow_none_return=True, always_allow_any=True)
error_note_and_code = self.type_requires_usage(expr_type)
if error_note_and_code:
error_note, code = error_note_and_code
self.fail(
message_registry.TYPE_MUST_BE_USED.format(format_type(expr_type, self.options)),
s,
code=code,
)
self.note(error_note, s, code=code)
def visit_return_stmt(self, s: ReturnStmt) -> None:
"""Type check a return statement."""
self.check_return_stmt(s)
self.binder.unreachable()
def check_return_stmt(self, s: ReturnStmt) -> None:
defn = self.scope.top_function()
if defn is not None:
if defn.is_generator:
return_type = self.get_generator_return_type(
self.return_types[-1], defn.is_coroutine
)
elif defn.is_coroutine:
return_type = self.get_coroutine_return_type(self.return_types[-1])
else:
return_type = self.return_types[-1]
return_type = get_proper_type(return_type)
is_lambda = isinstance(self.scope.top_function(), LambdaExpr)
if isinstance(return_type, UninhabitedType):
# Avoid extra error messages for failed inference in lambdas
if not is_lambda or not return_type.ambiguous:
self.fail(message_registry.NO_RETURN_EXPECTED, s)
return
if s.expr:
declared_none_return = isinstance(return_type, NoneType)
declared_any_return = isinstance(return_type, AnyType)
# This controls whether or not we allow a function call that
# returns None as the expression of this return statement.
# E.g. `return f()` for some `f` that returns None. We allow
# this only if we're in a lambda or in a function that returns
# `None` or `Any`.
allow_none_func_call = is_lambda or declared_none_return or declared_any_return
# Return with a value.
typ = get_proper_type(
self.expr_checker.accept(
s.expr, return_type, allow_none_return=allow_none_func_call
)
)
if defn.is_async_generator:
self.fail(message_registry.RETURN_IN_ASYNC_GENERATOR, s)
return
# Returning a value of type Any is always fine.
if isinstance(typ, AnyType):
# (Unless you asked to be warned in that case, and the
# function is not declared to return Any)
if (
self.options.warn_return_any
and not self.current_node_deferred
and not is_proper_subtype(AnyType(TypeOfAny.special_form), return_type)
and not (
defn.name in BINARY_MAGIC_METHODS
and is_literal_not_implemented(s.expr)
)
and not (
isinstance(return_type, Instance)
and return_type.type.fullname == "builtins.object"
)
and not is_lambda
):
self.msg.incorrectly_returning_any(return_type, s)
return
# Disallow return expressions in functions declared to return
# None, subject to two exceptions below.
if declared_none_return:
# Lambdas are allowed to have None returns.
# Functions returning a value of type None are allowed to have a None return.
if is_lambda or isinstance(typ, NoneType):
return
self.fail(message_registry.NO_RETURN_VALUE_EXPECTED, s)
else:
self.check_subtype(
subtype_label="got",
subtype=typ,
supertype_label="expected",
supertype=return_type,
context=s.expr,
outer_context=s,
msg=message_registry.INCOMPATIBLE_RETURN_VALUE_TYPE,
)
else:
# Empty returns are valid in Generators with Any typed returns, but not in
# coroutines.
if (
defn.is_generator
and not defn.is_coroutine
and isinstance(return_type, AnyType)
):
return
if isinstance(return_type, (NoneType, AnyType)):
return
if self.in_checked_function():
self.fail(message_registry.RETURN_VALUE_EXPECTED, s)
def visit_if_stmt(self, s: IfStmt) -> None:
"""Type check an if statement."""
# This frame records the knowledge from previous if/elif clauses not being taken.
# Fall-through to the original frame is handled explicitly in each block.
with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=0):
for e, b in zip(s.expr, s.body):
t = get_proper_type(self.expr_checker.accept(e))
if isinstance(t, DeletedType):
self.msg.deleted_as_rvalue(t, s)
if_map, else_map = self.find_isinstance_check(e)
# XXX Issue a warning if condition is always False?
with self.binder.frame_context(can_skip=True, fall_through=2):
self.push_type_map(if_map)
self.accept(b)
# XXX Issue a warning if condition is always True?
self.push_type_map(else_map)
with self.binder.frame_context(can_skip=False, fall_through=2):
if s.else_body:
self.accept(s.else_body)
def visit_while_stmt(self, s: WhileStmt) -> None:
"""Type check a while statement."""
if_stmt = IfStmt([s.expr], [s.body], None)
if_stmt.set_line(s)
self.accept_loop(if_stmt, s.else_body, exit_condition=s.expr)
def visit_operator_assignment_stmt(self, s: OperatorAssignmentStmt) -> None:
"""Type check an operator assignment statement, e.g. x += 1."""
self.try_infer_partial_generic_type_from_assignment(s.lvalue, s.rvalue, s.op)
if isinstance(s.lvalue, MemberExpr):
# Special case, some additional errors may be given for
# assignments to read-only or final attributes.
lvalue_type = self.expr_checker.visit_member_expr(s.lvalue, True)
else:
lvalue_type = self.expr_checker.accept(s.lvalue)
inplace, method = infer_operator_assignment_method(lvalue_type, s.op)
if inplace:
# There is __ifoo__, treat as x = x.__ifoo__(y)
rvalue_type, method_type = self.expr_checker.check_op(method, lvalue_type, s.rvalue, s)
if not is_subtype(rvalue_type, lvalue_type):
self.msg.incompatible_operator_assignment(s.op, s)
else:
# There is no __ifoo__, treat as x = x <foo> y
expr = OpExpr(s.op, s.lvalue, s.rvalue)
expr.set_line(s)
self.check_assignment(
lvalue=s.lvalue, rvalue=expr, infer_lvalue_type=True, new_syntax=False
)
self.check_final(s)
def visit_assert_stmt(self, s: AssertStmt) -> None:
self.expr_checker.accept(s.expr)
if isinstance(s.expr, TupleExpr) and len(s.expr.items) > 0:
self.fail(message_registry.MALFORMED_ASSERT, s)
# If this is asserting some isinstance check, bind that type in the following code
true_map, else_map = self.find_isinstance_check(s.expr)
if s.msg is not None:
self.expr_checker.analyze_cond_branch(else_map, s.msg, None)
self.push_type_map(true_map)
def visit_raise_stmt(self, s: RaiseStmt) -> None:
"""Type check a raise statement."""
if s.expr:
self.type_check_raise(s.expr, s)
if s.from_expr:
self.type_check_raise(s.from_expr, s, optional=True)
self.binder.unreachable()
def type_check_raise(self, e: Expression, s: RaiseStmt, optional: bool = False) -> None:
typ = get_proper_type(self.expr_checker.accept(e))
if isinstance(typ, DeletedType):
self.msg.deleted_as_rvalue(typ, e)
return
exc_type = self.named_type("builtins.BaseException")
expected_type_items = [exc_type, TypeType(exc_type)]
if optional:
# This is used for `x` part in a case like `raise e from x`,
# where we allow `raise e from None`.
expected_type_items.append(NoneType())
self.check_subtype(
typ, UnionType.make_union(expected_type_items), s, message_registry.INVALID_EXCEPTION
)
if isinstance(typ, FunctionLike):
# https://github.com/python/mypy/issues/11089
self.expr_checker.check_call(typ, [], [], e)
def visit_try_stmt(self, s: TryStmt) -> None:
"""Type check a try statement."""
# Our enclosing frame will get the result if the try/except falls through.
# This one gets all possible states after the try block exited abnormally
# (by exception, return, break, etc.)
with self.binder.frame_context(can_skip=False, fall_through=0):
# Not only might the body of the try statement exit
# abnormally, but so might an exception handler or else
# clause. The finally clause runs in *all* cases, so we
# need an outer try frame to catch all intermediate states
# in case an exception is raised during an except or else
# clause. As an optimization, only create the outer try
# frame when there actually is a finally clause.
self.visit_try_without_finally(s, try_frame=bool(s.finally_body))
if s.finally_body:
# First we check finally_body is type safe on all abnormal exit paths
self.accept(s.finally_body)
if s.finally_body:
# Then we try again for the more restricted set of options
# that can fall through. (Why do we need to check the
# finally clause twice? Depending on whether the finally
# clause was reached by the try clause falling off the end
# or exiting abnormally, after completing the finally clause
# either flow will continue to after the entire try statement
# or the exception/return/etc. will be processed and control
# flow will escape. We need to check that the finally clause
# type checks in both contexts, but only the resulting types
# from the latter context affect the type state in the code
# that follows the try statement.)
if not self.binder.is_unreachable():
self.accept(s.finally_body)
def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
"""Type check a try statement, ignoring the finally block.
On entry, the top frame should receive all flow that exits the
try block abnormally (i.e., such that the else block does not
execute), and its parent should receive all flow that exits
the try block normally.
"""
# This frame will run the else block if the try fell through.
# In that case, control flow continues to the parent of what
# was the top frame on entry.
with self.binder.frame_context(can_skip=False, fall_through=2, try_frame=try_frame):
# This frame receives exit via exception, and runs exception handlers
with self.binder.frame_context(can_skip=False, conditional_frame=True, fall_through=2):
# Finally, the body of the try statement
with self.binder.frame_context(can_skip=False, fall_through=2, try_frame=True):
self.accept(s.body)
for i in range(len(s.handlers)):
with self.binder.frame_context(can_skip=True, fall_through=4):
typ = s.types[i]
if typ:
t = self.check_except_handler_test(typ, s.is_star)
var = s.vars[i]
if var:
# To support local variables, we make this a definition line,
# causing assignment to set the variable's type.
var.is_inferred_def = True
self.check_assignment(var, self.temp_node(t, var))
self.accept(s.handlers[i])
var = s.vars[i]
if var:
# Exception variables are deleted.
# Unfortunately, this doesn't let us detect usage before the
# try/except block.
source = var.name
if isinstance(var.node, Var):
var.node.type = DeletedType(source=source)
self.binder.cleanse(var)
if s.else_body:
self.accept(s.else_body)
def check_except_handler_test(self, n: Expression, is_star: bool) -> Type:
"""Type check an exception handler test clause."""
typ = self.expr_checker.accept(n)
all_types: list[Type] = []
test_types = self.get_types_from_except_handler(typ, n)
for ttype in get_proper_types(test_types):
if isinstance(ttype, AnyType):
all_types.append(ttype)
continue
if isinstance(ttype, FunctionLike):
item = ttype.items[0]
if not item.is_type_obj():
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return self.default_exception_type(is_star)
exc_type = erase_typevars(item.ret_type)
elif isinstance(ttype, TypeType):
exc_type = ttype.item
else:
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return self.default_exception_type(is_star)
if not is_subtype(exc_type, self.named_type("builtins.BaseException")):
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
return self.default_exception_type(is_star)
all_types.append(exc_type)
if is_star:
new_all_types: list[Type] = []
for typ in all_types:
if is_proper_subtype(typ, self.named_type("builtins.BaseExceptionGroup")):
self.fail(message_registry.INVALID_EXCEPTION_GROUP, n)
new_all_types.append(AnyType(TypeOfAny.from_error))
else:
new_all_types.append(typ)
return self.wrap_exception_group(new_all_types)
return make_simplified_union(all_types)
def default_exception_type(self, is_star: bool) -> Type:
"""Exception type to return in case of a previous type error."""
any_type = AnyType(TypeOfAny.from_error)
if is_star:
return self.named_generic_type("builtins.ExceptionGroup", [any_type])
return any_type
def wrap_exception_group(self, types: Sequence[Type]) -> Type:
"""Transform except* variable type into an appropriate exception group."""
arg = make_simplified_union(types)
if is_subtype(arg, self.named_type("builtins.Exception")):
base = "builtins.ExceptionGroup"
else:
base = "builtins.BaseExceptionGroup"
return self.named_generic_type(base, [arg])
def get_types_from_except_handler(self, typ: Type, n: Expression) -> list[Type]:
"""Helper for check_except_handler_test to retrieve handler types."""
typ = get_proper_type(typ)
if isinstance(typ, TupleType):
return typ.items
elif isinstance(typ, UnionType):
return [
union_typ
for item in typ.relevant_items()
for union_typ in self.get_types_from_except_handler(item, n)
]
elif is_named_instance(typ, "builtins.tuple"):
# variadic tuple
return [typ.args[0]]
else:
return [typ]
def visit_for_stmt(self, s: ForStmt) -> None:
"""Type check a for statement."""
if s.is_async:
iterator_type, item_type = self.analyze_async_iterable_item_type(s.expr)
else:
iterator_type, item_type = self.analyze_iterable_item_type(s.expr)
s.inferred_item_type = item_type
s.inferred_iterator_type = iterator_type
self.analyze_index_variables(s.index, item_type, s.index_type is None, s)
self.accept_loop(s.body, s.else_body)
def analyze_async_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
"""Analyse async iterable expression and return iterator and iterator item types."""
echk = self.expr_checker
iterable = echk.accept(expr)
iterator = echk.check_method_call_by_name("__aiter__", iterable, [], [], expr)[0]
awaitable = echk.check_method_call_by_name("__anext__", iterator, [], [], expr)[0]
item_type = echk.check_awaitable_expr(
awaitable, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_FOR
)
return iterator, item_type
def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
"""Analyse iterable expression and return iterator and iterator item types."""
echk = self.expr_checker
iterable = get_proper_type(echk.accept(expr))
iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], expr)[0]
int_type = self.analyze_range_native_int_type(expr)
if int_type:
return iterator, int_type
if (
isinstance(iterable, TupleType)
and iterable.partial_fallback.type.fullname == "builtins.tuple"
):
joined: Type = UninhabitedType()
for item in iterable.items:
joined = join_types(joined, item)
return iterator, joined
else:
# Non-tuple iterable.
return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0]
def analyze_iterable_item_type_without_expression(
self, type: Type, context: Context
) -> tuple[Type, Type]:
"""Analyse iterable type and return iterator and iterator item types."""
echk = self.expr_checker
iterable = get_proper_type(type)
iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0]
if isinstance(iterable, TupleType):
joined: Type = UninhabitedType()
for item in iterable.items:
joined = join_types(joined, item)
return iterator, joined
else:
# Non-tuple iterable.
return (
iterator,
echk.check_method_call_by_name("__next__", iterator, [], [], context)[0],
)
def analyze_range_native_int_type(self, expr: Expression) -> Type | None:
"""Try to infer native int item type from arguments to range(...).
For example, return i64 if the expression is "range(0, i64(n))".
Return None if unsuccessful.
"""
if (
isinstance(expr, CallExpr)
and isinstance(expr.callee, RefExpr)
and expr.callee.fullname == "builtins.range"
and 1 <= len(expr.args) <= 3
and all(kind == ARG_POS for kind in expr.arg_kinds)
):
native_int: Type | None = None
ok = True
for arg in expr.args:
argt = get_proper_type(self.lookup_type(arg))
if isinstance(argt, Instance) and argt.type.fullname in MYPYC_NATIVE_INT_NAMES:
if native_int is None:
native_int = argt
elif argt != native_int:
ok = False
if ok and native_int:
return native_int
return None
def analyze_container_item_type(self, typ: Type) -> Type | None:
"""Check if a type is a nominal container of a union of such.
Return the corresponding container item type.
"""
typ = get_proper_type(typ)
if isinstance(typ, UnionType):
types: list[Type] = []
for item in typ.items:
c_type = self.analyze_container_item_type(item)
if c_type:
types.append(c_type)
return UnionType.make_union(types)
if isinstance(typ, Instance) and typ.type.has_base("typing.Container"):
supertype = self.named_type("typing.Container").type
super_instance = map_instance_to_supertype(typ, supertype)
assert len(super_instance.args) == 1
return super_instance.args[0]
if isinstance(typ, TupleType):
return self.analyze_container_item_type(tuple_fallback(typ))
return None
def analyze_index_variables(
self, index: Expression, item_type: Type, infer_lvalue_type: bool, context: Context
) -> None:
"""Type check or infer for loop or list comprehension index vars."""
self.check_assignment(index, self.temp_node(item_type, context), infer_lvalue_type)
def visit_del_stmt(self, s: DelStmt) -> None:
if isinstance(s.expr, IndexExpr):
e = s.expr
m = MemberExpr(e.base, "__delitem__")
m.line = s.line
m.column = s.column
c = CallExpr(m, [e.index], [nodes.ARG_POS], [None])
c.line = s.line
c.column = s.column
self.expr_checker.accept(c, allow_none_return=True)
else:
s.expr.accept(self.expr_checker)
for elt in flatten(s.expr):
if isinstance(elt, NameExpr):
self.binder.assign_type(
elt, DeletedType(source=elt.name), get_declaration(elt), False
)
def visit_decorator(self, e: Decorator) -> None:
for d in e.decorators:
if isinstance(d, RefExpr):
if d.fullname == "typing.no_type_check":
e.var.type = AnyType(TypeOfAny.special_form)
e.var.is_ready = True
return
self.visit_decorator_inner(e)
def visit_decorator_inner(self, e: Decorator, allow_empty: bool = False) -> None:
if self.recurse_into_functions:
with self.tscope.function_scope(e.func):
self.check_func_item(e.func, name=e.func.name, allow_empty=allow_empty)
# Process decorators from the inside out to determine decorated signature, which
# may be different from the declared signature.
sig: Type = self.function_type(e.func)
for d in reversed(e.decorators):
if refers_to_fullname(d, OVERLOAD_NAMES):
if not allow_empty:
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
continue
dec = self.expr_checker.accept(d)
temp = self.temp_node(sig, context=e)
fullname = None
if isinstance(d, RefExpr):
fullname = d.fullname or None
# if this is a expression like @b.a where b is an object, get the type of b
# so we can pass it the method hook in the plugins
object_type: Type | None = None
if fullname is None and isinstance(d, MemberExpr) and self.has_type(d.expr):
object_type = self.lookup_type(d.expr)
fullname = self.expr_checker.method_fullname(object_type, d.name)
self.check_for_untyped_decorator(e.func, dec, d)
sig, t2 = self.expr_checker.check_call(
dec, [temp], [nodes.ARG_POS], e, callable_name=fullname, object_type=object_type
)
self.check_untyped_after_decorator(sig, e.func)
sig = set_callable_name(sig, e.func)
e.var.type = sig
e.var.is_ready = True
if e.func.is_property:
if isinstance(sig, CallableType):
if len([k for k in sig.arg_kinds if k.is_required()]) > 1:
self.msg.fail("Too many arguments for property", e)
self.check_incompatible_property_override(e)
# For overloaded functions we already checked override for overload as a whole.
if allow_empty:
return
if e.func.info and not e.func.is_dynamic() and not e.is_overload:
found_method_base_classes = self.check_method_override(e)
if (
e.func.is_explicit_override
and not found_method_base_classes
and found_method_base_classes is not None
):
self.msg.no_overridable_method(e.func.name, e.func)
self.check_explicit_override_decorator(e.func, found_method_base_classes)
if e.func.info and e.func.name in ("__init__", "__new__"):
if e.type and not isinstance(get_proper_type(e.type), (FunctionLike, AnyType)):
self.fail(message_registry.BAD_CONSTRUCTOR_TYPE, e)
def check_for_untyped_decorator(
self, func: FuncDef, dec_type: Type, dec_expr: Expression
) -> None:
if (
self.options.disallow_untyped_decorators
and is_typed_callable(func.type)
and is_untyped_decorator(dec_type)
):
self.msg.typed_function_untyped_decorator(func.name, dec_expr)
def check_incompatible_property_override(self, e: Decorator) -> None:
if not e.var.is_settable_property and e.func.info:
name = e.func.name
for base in e.func.info.mro[1:]:
base_attr = base.names.get(name)
if not base_attr:
continue
if (
isinstance(base_attr.node, OverloadedFuncDef)
and base_attr.node.is_property
and cast(Decorator, base_attr.node.items[0]).var.is_settable_property
):
self.fail(message_registry.READ_ONLY_PROPERTY_OVERRIDES_READ_WRITE, e)
def visit_with_stmt(self, s: WithStmt) -> None:
exceptions_maybe_suppressed = False
for expr, target in zip(s.expr, s.target):
if s.is_async:
exit_ret_type = self.check_async_with_item(expr, target, s.unanalyzed_type is None)
else:
exit_ret_type = self.check_with_item(expr, target, s.unanalyzed_type is None)
# Based on the return type, determine if this context manager 'swallows'
# exceptions or not. We determine this using a heuristic based on the
# return type of the __exit__ method -- see the discussion in
# https://github.com/python/mypy/issues/7214 and the section about context managers
# in https://github.com/python/typeshed/blob/main/CONTRIBUTING.md#conventions
# for more details.
exit_ret_type = get_proper_type(exit_ret_type)
if is_literal_type(exit_ret_type, "builtins.bool", False):
continue
if is_literal_type(exit_ret_type, "builtins.bool", True) or (
isinstance(exit_ret_type, Instance)
and exit_ret_type.type.fullname == "builtins.bool"
and state.strict_optional
):
# Note: if strict-optional is disabled, this bool instance
# could actually be an Optional[bool].
exceptions_maybe_suppressed = True
if exceptions_maybe_suppressed:
# Treat this 'with' block in the same way we'd treat a 'try: BODY; except: pass'
# block. This means control flow can continue after the 'with' even if the 'with'
# block immediately returns.
with self.binder.frame_context(can_skip=True, try_frame=True):
self.accept(s.body)
else:
self.accept(s.body)
def check_untyped_after_decorator(self, typ: Type, func: FuncDef) -> None:
if not self.options.disallow_any_decorated or self.is_stub:
return
if mypy.checkexpr.has_any_type(typ):
self.msg.untyped_decorated_function(typ, func)
def check_async_with_item(
self, expr: Expression, target: Expression | None, infer_lvalue_type: bool
) -> Type:
echk = self.expr_checker
ctx = echk.accept(expr)
obj = echk.check_method_call_by_name("__aenter__", ctx, [], [], expr)[0]
obj = echk.check_awaitable_expr(
obj, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER
)
if target:
self.check_assignment(target, self.temp_node(obj, expr), infer_lvalue_type)
arg = self.temp_node(AnyType(TypeOfAny.special_form), expr)
res, _ = echk.check_method_call_by_name(
"__aexit__", ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr
)
return echk.check_awaitable_expr(
res, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT
)
def check_with_item(
self, expr: Expression, target: Expression | None, infer_lvalue_type: bool
) -> Type:
echk = self.expr_checker
ctx = echk.accept(expr)
obj = echk.check_method_call_by_name("__enter__", ctx, [], [], expr)[0]
if target:
self.check_assignment(target, self.temp_node(obj, expr), infer_lvalue_type)
arg = self.temp_node(AnyType(TypeOfAny.special_form), expr)
res, _ = echk.check_method_call_by_name(
"__exit__", ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr
)
return res
def visit_break_stmt(self, s: BreakStmt) -> None:
self.binder.handle_break()
def visit_continue_stmt(self, s: ContinueStmt) -> None:
self.binder.handle_continue()
return None
def visit_match_stmt(self, s: MatchStmt) -> None:
with self.binder.frame_context(can_skip=False, fall_through=0):
subject_type = get_proper_type(self.expr_checker.accept(s.subject))
if isinstance(subject_type, DeletedType):
self.msg.deleted_as_rvalue(subject_type, s)
# We infer types of patterns twice. The first pass is used
# to infer the types of capture variables. The type of a
# capture variable may depend on multiple patterns (it
# will be a union of all capture types). This pass ignores
# guard expressions.
pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns]
type_maps: list[TypeMap] = [t.captures for t in pattern_types]
inferred_types = self.infer_variable_types_from_type_maps(type_maps)
# The second pass narrows down the types and type checks bodies.
for p, g, b in zip(s.patterns, s.guards, s.bodies):
current_subject_type = self.expr_checker.narrow_type_from_binder(
s.subject, subject_type
)
pattern_type = self.pattern_checker.accept(p, current_subject_type)
with self.binder.frame_context(can_skip=True, fall_through=2):
if b.is_unreachable or isinstance(
get_proper_type(pattern_type.type), UninhabitedType
):
self.push_type_map(None)
else_map: TypeMap = {}
else:
pattern_map, else_map = conditional_types_to_typemaps(
s.subject, pattern_type.type, pattern_type.rest_type
)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
self.push_type_map(pattern_map)
self.push_type_map(pattern_type.captures)
if g is not None:
with self.binder.frame_context(can_skip=False, fall_through=3):
gt = get_proper_type(self.expr_checker.accept(g))
if isinstance(gt, DeletedType):
self.msg.deleted_as_rvalue(gt, s)
guard_map, guard_else_map = self.find_isinstance_check(g)
else_map = or_conditional_maps(else_map, guard_else_map)
# If the guard narrowed the subject, copy the narrowed types over
if isinstance(p, AsPattern):
case_target = p.pattern or p.name
if isinstance(case_target, NameExpr):
for type_map in (guard_map, else_map):
if not type_map:
continue
for expr in list(type_map):
if not (
isinstance(expr, NameExpr)
and expr.fullname == case_target.fullname
):
continue
type_map[s.subject] = type_map[expr]
self.push_type_map(guard_map)
self.accept(b)
else:
self.accept(b)
self.push_type_map(else_map)
# This is needed due to a quirk in frame_context. Without it types will stay narrowed
# after the match.
with self.binder.frame_context(can_skip=False, fall_through=2):
pass
def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[Var, Type]:
all_captures: dict[Var, list[tuple[NameExpr, Type]]] = defaultdict(list)
for tm in type_maps:
if tm is not None:
for expr, typ in tm.items():
if isinstance(expr, NameExpr):
node = expr.node
assert isinstance(node, Var)
all_captures[node].append((expr, typ))
inferred_types: dict[Var, Type] = {}
for var, captures in all_captures.items():
already_exists = False
types: list[Type] = []
for expr, typ in captures:
types.append(typ)
previous_type, _, _ = self.check_lvalue(expr)
if previous_type is not None:
already_exists = True
if self.check_subtype(
typ,
previous_type,
expr,
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
subtype_label="pattern captures type",
supertype_label="variable has type",
):
inferred_types[var] = previous_type
if not already_exists:
new_type = UnionType.make_union(types)
# Infer the union type at the first occurrence
first_occurrence, _ = captures[0]
inferred_types[var] = new_type
self.infer_variable_type(var, first_occurrence, new_type, first_occurrence)
return inferred_types
def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var, Type]) -> None:
if type_map:
for expr, typ in list(type_map.items()):
if isinstance(expr, NameExpr):
node = expr.node
assert isinstance(node, Var)
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
del type_map[expr]
def make_fake_typeinfo(
self,
curr_module_fullname: str,
class_gen_name: str,
class_short_name: str,
bases: list[Instance],
) -> tuple[ClassDef, TypeInfo]:
# Build the fake ClassDef and TypeInfo together.
# The ClassDef is full of lies and doesn't actually contain a body.
# Use format_bare to generate a nice name for error messages.
# We skip fully filling out a handful of TypeInfo fields because they
# should be irrelevant for a generated type like this:
# is_protocol, protocol_members, is_abstract
cdef = ClassDef(class_short_name, Block([]))
cdef.fullname = curr_module_fullname + "." + class_gen_name
info = TypeInfo(SymbolTable(), cdef, curr_module_fullname)
cdef.info = info
info.bases = bases
calculate_mro(info)
info.metaclass_type = info.calculate_metaclass_type()
return cdef, info
def intersect_instances(
self, instances: tuple[Instance, Instance], errors: list[tuple[str, str]]
) -> Instance | None:
"""Try creating an ad-hoc intersection of the given instances.
Note that this function does *not* try and create a full-fledged
intersection type. Instead, it returns an instance of a new ad-hoc
subclass of the given instances.
This is mainly useful when you need a way of representing some
theoretical subclass of the instances the user may be trying to use
the generated intersection can serve as a placeholder.
This function will create a fresh subclass every time you call it,
even if you pass in the exact same arguments. So this means calling
`self.intersect_intersection([inst_1, inst_2], ctx)` twice will result
in instances of two distinct subclasses of inst_1 and inst_2.
This is by design: we want each ad-hoc intersection to be unique since
they're supposed represent some other unknown subclass.
Returns None if creating the subclass is impossible (e.g. due to
MRO errors or incompatible signatures). If we do successfully create
a subclass, its TypeInfo will automatically be added to the global scope.
"""
curr_module = self.scope.stack[0]
assert isinstance(curr_module, MypyFile)
# First, retry narrowing while allowing promotions (they are disabled by default
# for isinstance() checks, etc). This way we will still type-check branches like
# x: complex = 1
# if isinstance(x, int):
# ...
left, right = instances
if is_proper_subtype(left, right, ignore_promotions=False):
return left
if is_proper_subtype(right, left, ignore_promotions=False):
return right
def _get_base_classes(instances_: tuple[Instance, Instance]) -> list[Instance]:
base_classes_ = []
for inst in instances_:
if inst.type.is_intersection:
expanded = inst.type.bases
else:
expanded = [inst]
for expanded_inst in expanded:
base_classes_.append(expanded_inst)
return base_classes_
def _make_fake_typeinfo_and_full_name(
base_classes_: list[Instance], curr_module_: MypyFile
) -> tuple[TypeInfo, str]:
names_list = pretty_seq([x.type.name for x in base_classes_], "and")
short_name = f"<subclass of {names_list}>"
full_name_ = gen_unique_name(short_name, curr_module_.names)
cdef, info_ = self.make_fake_typeinfo(
curr_module_.fullname, full_name_, short_name, base_classes_
)
return info_, full_name_
base_classes = _get_base_classes(instances)
# We use the pretty_names_list for error messages but can't
# use it for the real name that goes into the symbol table
# because it can have dots in it.
pretty_names_list = pretty_seq(
format_type_distinctly(*base_classes, options=self.options, bare=True), "and"
)
try:
info, full_name = _make_fake_typeinfo_and_full_name(base_classes, curr_module)
with self.msg.filter_errors() as local_errors:
self.check_multiple_inheritance(info)
if local_errors.has_new_errors():
# "class A(B, C)" unsafe, now check "class A(C, B)":
base_classes = _get_base_classes(instances[::-1])
info, full_name = _make_fake_typeinfo_and_full_name(base_classes, curr_module)
with self.msg.filter_errors() as local_errors:
self.check_multiple_inheritance(info)
info.is_intersection = True
except MroError:
errors.append((pretty_names_list, "inconsistent method resolution order"))
return None
if local_errors.has_new_errors():
errors.append((pretty_names_list, "incompatible method signatures"))
return None
curr_module.names[full_name] = SymbolTableNode(GDEF, info)
return Instance(info, [], extra_attrs=instances[0].extra_attrs or instances[1].extra_attrs)
def intersect_instance_callable(self, typ: Instance, callable_type: CallableType) -> Instance:
"""Creates a fake type that represents the intersection of an Instance and a CallableType.
It operates by creating a bare-minimum dummy TypeInfo that
subclasses type and adds a __call__ method matching callable_type.
"""
# In order for this to work in incremental mode, the type we generate needs to
# have a valid fullname and a corresponding entry in a symbol table. We generate
# a unique name inside the symbol table of the current module.
cur_module = self.scope.stack[0]
assert isinstance(cur_module, MypyFile)
gen_name = gen_unique_name(f"<callable subtype of {typ.type.name}>", cur_module.names)
# Synthesize a fake TypeInfo
short_name = format_type_bare(typ, self.options)
cdef, info = self.make_fake_typeinfo(cur_module.fullname, gen_name, short_name, [typ])
# Build up a fake FuncDef so we can populate the symbol table.
func_def = FuncDef("__call__", [], Block([]), callable_type)
func_def._fullname = cdef.fullname + ".__call__"
func_def.info = info
info.names["__call__"] = SymbolTableNode(MDEF, func_def)
cur_module.names[gen_name] = SymbolTableNode(GDEF, info)
return Instance(info, [], extra_attrs=typ.extra_attrs)
def make_fake_callable(self, typ: Instance) -> Instance:
"""Produce a new type that makes type Callable with a generic callable type."""
fallback = self.named_type("builtins.function")
callable_type = CallableType(
[AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)],
[nodes.ARG_STAR, nodes.ARG_STAR2],
[None, None],
ret_type=AnyType(TypeOfAny.explicit),
fallback=fallback,
is_ellipsis_args=True,
)
return self.intersect_instance_callable(typ, callable_type)
def partition_by_callable(
self, typ: Type, unsound_partition: bool
) -> tuple[list[Type], list[Type]]:
"""Partitions a type into callable subtypes and uncallable subtypes.
Thus, given:
`callables, uncallables = partition_by_callable(type)`
If we assert `callable(type)` then `type` has type Union[*callables], and
If we assert `not callable(type)` then `type` has type Union[*uncallables]
If unsound_partition is set, assume that anything that is not
clearly callable is in fact not callable. Otherwise we generate a
new subtype that *is* callable.
Guaranteed to not return [], [].
"""
typ = get_proper_type(typ)
if isinstance(typ, (FunctionLike, TypeType)):
return [typ], []
if isinstance(typ, AnyType):
return [typ], [typ]
if isinstance(typ, NoneType):
return [], [typ]
if isinstance(typ, UnionType):
callables = []
uncallables = []
for subtype in typ.items:
# Use unsound_partition when handling unions in order to
# allow the expected type discrimination.
subcallables, subuncallables = self.partition_by_callable(
subtype, unsound_partition=True
)
callables.extend(subcallables)
uncallables.extend(subuncallables)
return callables, uncallables
if isinstance(typ, TypeVarType):
# We could do better probably?
# Refine the the type variable's bound as our type in the case that
# callable() is true. This unfortunately loses the information that
# the type is a type variable in that branch.
# This matches what is done for isinstance, but it may be possible to
# do better.
# If it is possible for the false branch to execute, return the original
# type to avoid losing type information.
callables, uncallables = self.partition_by_callable(
erase_to_union_or_bound(typ), unsound_partition
)
uncallables = [typ] if uncallables else []
return callables, uncallables
# A TupleType is callable if its fallback is, but needs special handling
# when we dummy up a new type.
ityp = typ
if isinstance(typ, TupleType):
ityp = tuple_fallback(typ)
if isinstance(ityp, Instance):
method = ityp.type.get_method("__call__")
if method and method.type:
callables, uncallables = self.partition_by_callable(
method.type, unsound_partition=False
)
if callables and not uncallables:
# Only consider the type callable if its __call__ method is
# definitely callable.
return [typ], []
if not unsound_partition:
fake = self.make_fake_callable(ityp)
if isinstance(typ, TupleType):
fake.type.tuple_type = TupleType(typ.items, fake)
return [fake.type.tuple_type], [typ]
return [fake], [typ]
if unsound_partition:
return [], [typ]
else:
# We don't know how properly make the type callable.
return [typ], [typ]
def conditional_callable_type_map(
self, expr: Expression, current_type: Type | None
) -> tuple[TypeMap, TypeMap]:
"""Takes in an expression and the current type of the expression.
Returns a 2-tuple: The first element is a map from the expression to
the restricted type if it were callable. The second element is a
map from the expression to the type it would hold if it weren't
callable.
"""
if not current_type:
return {}, {}
if isinstance(get_proper_type(current_type), AnyType):
return {}, {}
callables, uncallables = self.partition_by_callable(current_type, unsound_partition=False)
if callables and uncallables:
callable_map = {expr: UnionType.make_union(callables)} if callables else None
uncallable_map = {expr: UnionType.make_union(uncallables)} if uncallables else None
return callable_map, uncallable_map
elif callables:
return {}, None
return None, {}
def conditional_types_for_iterable(
self, item_type: Type, iterable_type: Type
) -> tuple[Type | None, Type | None]:
"""
Narrows the type of `iterable_type` based on the type of `item_type`.
For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s).
"""
if_types: list[Type] = []
else_types: list[Type] = []
iterable_type = get_proper_type(iterable_type)
if isinstance(iterable_type, UnionType):
possible_iterable_types = get_proper_types(iterable_type.relevant_items())
else:
possible_iterable_types = [iterable_type]
item_str_literals = try_getting_str_literals_from_type(item_type)
for possible_iterable_type in possible_iterable_types:
if item_str_literals and isinstance(possible_iterable_type, TypedDictType):
for key in item_str_literals:
if key in possible_iterable_type.required_keys:
if_types.append(possible_iterable_type)
elif (
key in possible_iterable_type.items or not possible_iterable_type.is_final
):
if_types.append(possible_iterable_type)
else_types.append(possible_iterable_type)
else:
else_types.append(possible_iterable_type)
else:
if_types.append(possible_iterable_type)
else_types.append(possible_iterable_type)
return (
UnionType.make_union(if_types) if if_types else None,
UnionType.make_union(else_types) if else_types else None,
)
def _is_truthy_type(self, t: ProperType) -> bool:
return (
(
isinstance(t, Instance)
and bool(t.type)
and not t.type.has_readable_member("__bool__")
and not t.type.has_readable_member("__len__")
and t.type.fullname != "builtins.object"
)
or isinstance(t, FunctionLike)
or (
isinstance(t, UnionType)
and all(self._is_truthy_type(t) for t in get_proper_types(t.items))
)
)
def _check_for_truthy_type(self, t: Type, expr: Expression) -> None:
if not state.strict_optional:
return # if everything can be None, all bets are off
t = get_proper_type(t)
if not self._is_truthy_type(t):
return
def format_expr_type() -> str:
typ = format_type(t, self.options)
if isinstance(expr, MemberExpr):
return f'Member "{expr.name}" has type {typ}'
elif isinstance(expr, RefExpr) and expr.fullname:
return f'"{expr.fullname}" has type {typ}'
elif isinstance(expr, CallExpr):
if isinstance(expr.callee, MemberExpr):
return f'"{expr.callee.name}" returns {typ}'
elif isinstance(expr.callee, RefExpr) and expr.callee.fullname:
return f'"{expr.callee.fullname}" returns {typ}'
return f"Call returns {typ}"
else:
return f"Expression has type {typ}"
def get_expr_name() -> str:
if isinstance(expr, (NameExpr, MemberExpr)):
return f'"{expr.name}"'
else:
# return type if expr has no name
return format_type(t, self.options)
if isinstance(t, FunctionLike):
self.fail(message_registry.FUNCTION_ALWAYS_TRUE.format(get_expr_name()), expr)
elif isinstance(t, UnionType):
self.fail(message_registry.TYPE_ALWAYS_TRUE_UNIONTYPE.format(format_expr_type()), expr)
elif isinstance(t, Instance) and t.type.fullname == "typing.Iterable":
_, info = self.make_fake_typeinfo("typing", "Collection", "Collection", [])
self.fail(
message_registry.ITERABLE_ALWAYS_TRUE.format(
format_expr_type(), format_type(Instance(info, t.args), self.options)
),
expr,
)
else:
self.fail(message_registry.TYPE_ALWAYS_TRUE.format(format_expr_type()), expr)
def find_type_equals_check(
self, node: ComparisonExpr, expr_indices: list[int]
) -> tuple[TypeMap, TypeMap]:
"""Narrow types based on any checks of the type ``type(x) == T``
Args:
node: The node that might contain the comparison
expr_indices: The list of indices of expressions in ``node`` that are being
compared
"""
def is_type_call(expr: CallExpr) -> bool:
"""Is expr a call to type with one argument?"""
return refers_to_fullname(expr.callee, "builtins.type") and len(expr.args) == 1
# exprs that are being passed into type
exprs_in_type_calls: list[Expression] = []
# type that is being compared to type(expr)
type_being_compared: list[TypeRange] | None = None
# whether the type being compared to is final
is_final = False
for index in expr_indices:
expr = node.operands[index]
if isinstance(expr, CallExpr) and is_type_call(expr):
exprs_in_type_calls.append(expr.args[0])
else:
current_type = self.get_isinstance_type(expr)
if current_type is None:
continue
if type_being_compared is not None:
# It doesn't really make sense to have several types being
# compared to the output of type (like type(x) == int == str)
# because whether that's true is solely dependent on what the
# types being compared are, so we don't try to narrow types any
# further because we can't really get any information about the
# type of x from that check
return {}, {}
else:
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo):
is_final = expr.node.is_final
type_being_compared = current_type
if not exprs_in_type_calls:
return {}, {}
if_maps: list[TypeMap] = []
else_maps: list[TypeMap] = []
for expr in exprs_in_type_calls:
current_if_type, current_else_type = self.conditional_types_with_intersection(
self.lookup_type(expr), type_being_compared, expr
)
current_if_map, current_else_map = conditional_types_to_typemaps(
expr, current_if_type, current_else_type
)
if_maps.append(current_if_map)
else_maps.append(current_else_map)
def combine_maps(list_maps: list[TypeMap]) -> TypeMap:
"""Combine all typemaps in list_maps into one typemap"""
result_map = {}
for d in list_maps:
if d is not None:
result_map.update(d)
return result_map
if_map = combine_maps(if_maps)
# type(x) == T is only true when x has the same type as T, meaning
# that it can be false if x is an instance of a subclass of T. That means
# we can't do any narrowing in the else case unless T is final, in which
# case T can't be subclassed
if is_final:
else_map = combine_maps(else_maps)
else:
else_map = {}
return if_map, else_map
def find_isinstance_check(
self, node: Expression, *, in_boolean_context: bool = True
) -> tuple[TypeMap, TypeMap]:
"""Find any isinstance checks (within a chain of ands). Includes
implicit and explicit checks for None and calls to callable.
Also includes TypeGuard functions.
Return value is a map of variables to their types if the condition
is true and a map of variables to their types if the condition is false.
If either of the values in the tuple is None, then that particular
branch can never occur.
May return {}, {}.
Can return None, None in situations involving NoReturn.
"""
if_map, else_map = self.find_isinstance_check_helper(
node, in_boolean_context=in_boolean_context
)
new_if_map = self.propagate_up_typemap_info(if_map)
new_else_map = self.propagate_up_typemap_info(else_map)
return new_if_map, new_else_map
def find_isinstance_check_helper(
self, node: Expression, *, in_boolean_context: bool = True
) -> tuple[TypeMap, TypeMap]:
if is_true_literal(node):
return {}, None
if is_false_literal(node):
return None, {}
if isinstance(node, CallExpr) and len(node.args) != 0:
expr = collapse_walrus(node.args[0])
if refers_to_fullname(node.callee, "builtins.isinstance"):
if len(node.args) != 2: # the error will be reported elsewhere
return {}, {}
if literal(expr) == LITERAL_TYPE:
return conditional_types_to_typemaps(
expr,
*self.conditional_types_with_intersection(
self.lookup_type(expr), self.get_isinstance_type(node.args[1]), expr
),
)
elif refers_to_fullname(node.callee, "builtins.issubclass"):
if len(node.args) != 2: # the error will be reported elsewhere
return {}, {}
if literal(expr) == LITERAL_TYPE:
return self.infer_issubclass_maps(node, expr)
elif refers_to_fullname(node.callee, "builtins.callable"):
if len(node.args) != 1: # the error will be reported elsewhere
return {}, {}
if literal(expr) == LITERAL_TYPE:
vartype = self.lookup_type(expr)
return self.conditional_callable_type_map(expr, vartype)
elif refers_to_fullname(node.callee, "builtins.hasattr"):
if len(node.args) != 2: # the error will be reported elsewhere
return {}, {}
attr = try_getting_str_literals(node.args[1], self.lookup_type(node.args[1]))
if literal(expr) == LITERAL_TYPE and attr and len(attr) == 1:
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
elif isinstance(node.callee, RefExpr):
if node.callee.type_guard is not None:
# TODO: Follow *args, **kwargs
if node.arg_kinds[0] != nodes.ARG_POS:
# the first argument might be used as a kwarg
called_type = get_proper_type(self.lookup_type(node.callee))
assert isinstance(called_type, (CallableType, Overloaded))
# *assuming* the overloaded function is correct, there's a couple cases:
# 1) The first argument has different names, but is pos-only. We don't
# care about this case, the argument must be passed positionally.
# 2) The first argument allows keyword reference, therefore must be the
# same between overloads.
name = called_type.items[0].arg_names[0]
if name in node.arg_names:
idx = node.arg_names.index(name)
# we want the idx-th variable to be narrowed
expr = collapse_walrus(node.args[idx])
else:
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
return {}, {}
if literal(expr) == LITERAL_TYPE:
# Note: we wrap the target type, so that we can special case later.
# Namely, for isinstance() we use a normal meet, while TypeGuard is
# considered "always right" (i.e. even if the types are not overlapping).
# Also note that a care must be taken to unwrap this back at read places
# where we use this to narrow down declared type.
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
elif isinstance(node, ComparisonExpr):
# Step 1: Obtain the types of each operand and whether or not we can
# narrow their types. (For example, we shouldn't try narrowing the
# types of literal string or enum expressions).
operands = [collapse_walrus(x) for x in node.operands]
operand_types = []
narrowable_operand_index_to_hash = {}
for i, expr in enumerate(operands):
if not self.has_type(expr):
return {}, {}
expr_type = self.lookup_type(expr)
operand_types.append(expr_type)
if (
literal(expr) == LITERAL_TYPE
and not is_literal_none(expr)
and not self.is_literal_enum(expr)
):
h = literal_hash(expr)
if h is not None:
narrowable_operand_index_to_hash[i] = h
# Step 2: Group operands chained by either the 'is' or '==' operands
# together. For all other operands, we keep them in groups of size 2.
# So the expression:
#
# x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
#
# ...is converted into the simplified operator list:
#
# [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
# ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
#
# We group identity/equality expressions so we can propagate information
# we discover about one operand across the entire chain. We don't bother
# handling 'is not' and '!=' chains in a special way: those are very rare
# in practice.
simplified_operator_list = group_comparison_operands(
node.pairwise(), narrowable_operand_index_to_hash, {"==", "is"}
)
# Step 3: Analyze each group and infer more precise type maps for each
# assignable operand, if possible. We combine these type maps together
# in the final step.
partial_type_maps = []
for operator, expr_indices in simplified_operator_list:
if operator in {"is", "is not", "==", "!="}:
# is_valid_target:
# Controls which types we're allowed to narrow exprs to. Note that
# we cannot use 'is_literal_type_like' in both cases since doing
# 'x = 10000 + 1; x is 10001' is not always True in all Python
# implementations.
#
# coerce_only_in_literal_context:
# If true, coerce types into literal types only if one or more of
# the provided exprs contains an explicit Literal type. This could
# technically be set to any arbitrary value, but it seems being liberal
# with narrowing when using 'is' and conservative when using '==' seems
# to break the least amount of real-world code.
#
# should_narrow_by_identity:
# Set to 'false' only if the user defines custom __eq__ or __ne__ methods
# that could cause identity-based narrowing to produce invalid results.
if operator in {"is", "is not"}:
is_valid_target: Callable[[Type], bool] = is_singleton_type
coerce_only_in_literal_context = False
should_narrow_by_identity = True
else:
def is_exactly_literal_type(t: Type) -> bool:
return isinstance(get_proper_type(t), LiteralType)
def has_no_custom_eq_checks(t: Type) -> bool:
return not custom_special_method(
t, "__eq__", check_all=False
) and not custom_special_method(t, "__ne__", check_all=False)
is_valid_target = is_exactly_literal_type
coerce_only_in_literal_context = True
expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types))
if_map: TypeMap = {}
else_map: TypeMap = {}
if should_narrow_by_identity:
if_map, else_map = self.refine_identity_comparison_expression(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
is_valid_target,
coerce_only_in_literal_context,
)
# Strictly speaking, we should also skip this check if the objects in the expr
# chain have custom __eq__ or __ne__ methods. But we (maybe optimistically)
# assume nobody would actually create a custom objects that considers itself
# equal to None.
if if_map == {} and else_map == {}:
if_map, else_map = self.refine_away_none_in_comparison(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
)
# If we haven't been able to narrow types yet, we might be dealing with a
# explicit type(x) == some_type check
if if_map == {} and else_map == {}:
if_map, else_map = self.find_type_equals_check(node, expr_indices)
elif operator in {"in", "not in"}:
assert len(expr_indices) == 2
left_index, right_index = expr_indices
item_type = operand_types[left_index]
iterable_type = operand_types[right_index]
if_map, else_map = {}, {}
if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_overlapping_none(item_type):
collection_item_type = get_proper_type(
builtin_item_type(iterable_type)
)
if (
collection_item_type is not None
and not is_overlapping_none(collection_item_type)
and not (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
)
and is_overlapping_erased_types(item_type, collection_item_type)
):
if_map[operands[left_index]] = remove_optional(item_type)
if right_index in narrowable_operand_index_to_hash:
if_type, else_type = self.conditional_types_for_iterable(
item_type, iterable_type
)
expr = operands[right_index]
if if_type is None:
if_map = None
else:
if_map[expr] = if_type
if else_type is None:
else_map = None
else:
else_map[expr] = else_type
else:
if_map = {}
else_map = {}
if operator in {"is not", "!=", "not in"}:
if_map, else_map = else_map, if_map
partial_type_maps.append((if_map, else_map))
return reduce_conditional_maps(partial_type_maps)
elif isinstance(node, AssignmentExpr):
if_map = {}
else_map = {}
if_assignment_map, else_assignment_map = self.find_isinstance_check(node.target)
if if_assignment_map is not None:
if_map.update(if_assignment_map)
if else_assignment_map is not None:
else_map.update(else_assignment_map)
if_condition_map, else_condition_map = self.find_isinstance_check(
node.value, in_boolean_context=False
)
if if_condition_map is not None:
if_map.update(if_condition_map)
if else_condition_map is not None:
else_map.update(else_condition_map)
return (
(None if if_assignment_map is None or if_condition_map is None else if_map),
(None if else_assignment_map is None or else_condition_map is None else else_map),
)
elif isinstance(node, OpExpr) and node.op == "and":
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
# (e1 and e2) is true if both e1 and e2 are true,
# and false if at least one of e1 and e2 is false.
return (
and_conditional_maps(left_if_vars, right_if_vars),
or_conditional_maps(left_else_vars, right_else_vars),
)
elif isinstance(node, OpExpr) and node.op == "or":
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
# (e1 or e2) is true if at least one of e1 or e2 is true,
# and false if both e1 and e2 are false.
return (
or_conditional_maps(left_if_vars, right_if_vars),
and_conditional_maps(left_else_vars, right_else_vars),
)
elif isinstance(node, UnaryExpr) and node.op == "not":
left, right = self.find_isinstance_check(node.expr)
return right, left
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
# respectively
original_vartype = self.lookup_type(node)
if in_boolean_context:
# We don't check `:=` values in expressions like `(a := A())`,
# because they produce two error messages.
self._check_for_truthy_type(original_vartype, node)
vartype = try_expanding_sum_type_to_union(original_vartype, "builtins.bool")
if_type = true_only(vartype)
else_type = false_only(vartype)
if_map = {node: if_type} if not isinstance(if_type, UninhabitedType) else None
else_map = {node: else_type} if not isinstance(else_type, UninhabitedType) else None
return if_map, else_map
def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap:
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
Specifically, this function accepts two mappings of expression to original types:
the original mapping (existing_types), and a new mapping (new_types) intended to
update the original.
This function iterates through new_types and attempts to use the information to try
refining any parent types that happen to be unions.
For example, suppose there are two types "A = Tuple[int, int]" and "B = Tuple[str, str]".
Next, suppose that 'new_types' specifies the expression 'foo[0]' has a refined type
of 'int' and that 'foo' was previously deduced to be of type Union[A, B].
Then, this function will observe that since A[0] is an int and B[0] is not, the type of
'foo' can be further refined from Union[A, B] into just B.
We perform this kind of "parent narrowing" for member lookup expressions and indexing
expressions into tuples, namedtuples, and typeddicts. We repeat this narrowing
recursively if the parent is also a "lookup expression". So for example, if we have
the expression "foo['bar'].baz[0]", we'd potentially end up refining types for the
expressions "foo", "foo['bar']", and "foo['bar'].baz".
We return the newly refined map. This map is guaranteed to be a superset of 'new_types'.
"""
if new_types is None:
return None
output_map = {}
for expr, expr_type in new_types.items():
# The original inferred type should always be present in the output map, of course
output_map[expr] = expr_type
# Next, try using this information to refine the parent types, if applicable.
new_mapping = self.refine_parent_types(expr, expr_type)
for parent_expr, proposed_parent_type in new_mapping.items():
# We don't try inferring anything if we've already inferred something for
# the parent expression.
# TODO: Consider picking the narrower type instead of always discarding this?
if parent_expr in new_types:
continue
output_map[parent_expr] = proposed_parent_type
return output_map
def refine_parent_types(self, expr: Expression, expr_type: Type) -> Mapping[Expression, Type]:
"""Checks if the given expr is a 'lookup operation' into a union and iteratively refines
the parent types based on the 'expr_type'.
For example, if 'expr' is an expression like 'a.b.c.d', we'll potentially return refined
types for expressions 'a', 'a.b', and 'a.b.c'.
For more details about what a 'lookup operation' is and how we use the expr_type to refine
the parent types of lookup_expr, see the docstring in 'propagate_up_typemap_info'.
"""
output: dict[Expression, Type] = {}
# Note: parent_expr and parent_type are progressively refined as we crawl up the
# parent lookup chain.
while True:
# First, check if this expression is one that's attempting to
# "lookup" some key in the parent type. If so, save the parent type
# and create function that will try replaying the same lookup
# operation against arbitrary types.
if isinstance(expr, MemberExpr):
parent_expr = collapse_walrus(expr.expr)
parent_type = self.lookup_type_or_none(parent_expr)
member_name = expr.name
def replay_lookup(new_parent_type: ProperType) -> Type | None:
with self.msg.filter_errors() as w:
member_type = analyze_member_access(
name=member_name,
typ=new_parent_type,
context=parent_expr,
is_lvalue=False,
is_super=False,
is_operator=False,
msg=self.msg,
original_type=new_parent_type,
chk=self,
in_literal_context=False,
)
if w.has_new_errors():
return None
else:
return member_type
elif isinstance(expr, IndexExpr):
parent_expr = collapse_walrus(expr.base)
parent_type = self.lookup_type_or_none(parent_expr)
index_type = self.lookup_type_or_none(expr.index)
if index_type is None:
return output
str_literals = try_getting_str_literals_from_type(index_type)
if str_literals is not None:
# Refactoring these two indexing replay functions is surprisingly
# tricky -- see https://github.com/python/mypy/pull/7917, which
# was blocked by https://github.com/mypyc/mypyc/issues/586
def replay_lookup(new_parent_type: ProperType) -> Type | None:
if not isinstance(new_parent_type, TypedDictType):
return None
try:
assert str_literals is not None
member_types = [new_parent_type.items[key] for key in str_literals]
except KeyError:
return None
return make_simplified_union(member_types)
else:
int_literals = try_getting_int_literals_from_type(index_type)
if int_literals is not None:
def replay_lookup(new_parent_type: ProperType) -> Type | None:
if not isinstance(new_parent_type, TupleType):
return None
try:
assert int_literals is not None
member_types = [new_parent_type.items[key] for key in int_literals]
except IndexError:
return None
return make_simplified_union(member_types)
else:
return output
else:
return output
# If we somehow didn't previously derive the parent type, abort completely
# with what we have so far: something went wrong at an earlier stage.
if parent_type is None:
return output
# We currently only try refining the parent type if it's a Union.
# If not, there's no point in trying to refine any further parents
# since we have no further information we can use to refine the lookup
# chain, so we end early as an optimization.
parent_type = get_proper_type(parent_type)
if not isinstance(parent_type, UnionType):
return output
# Take each element in the parent union and replay the original lookup procedure
# to figure out which parents are compatible.
new_parent_types = []
for item in flatten_nested_unions(parent_type.items):
member_type = replay_lookup(get_proper_type(item))
if member_type is None:
# We were unable to obtain the member type. So, we give up on refining this
# parent type entirely and abort.
return output
if is_overlapping_types(member_type, expr_type):
new_parent_types.append(item)
# If none of the parent types overlap (if we derived an empty union), something
# went wrong. We should never hit this case, but deriving the uninhabited type or
# reporting an error both seem unhelpful. So we abort.
if not new_parent_types:
return output
expr = parent_expr
expr_type = output[parent_expr] = make_simplified_union(new_parent_types)
def refine_identity_comparison_expression(
self,
operands: list[Expression],
operand_types: list[Type],
chain_indices: list[int],
narrowable_operand_indices: AbstractSet[int],
is_valid_target: Callable[[ProperType], bool],
coerce_only_in_literal_context: bool,
) -> tuple[TypeMap, TypeMap]:
"""Produce conditional type maps refining expressions by an identity/equality comparison.
The 'operands' and 'operand_types' lists should be the full list of operands used
in the overall comparison expression. The 'chain_indices' list is the list of indices
actually used within this identity comparison chain.
So if we have the expression:
a <= b is c is d <= e
...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices'
would be the list [1, 2, 3].
The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
to refine the types of: that is, all operands that will potentially be a part of
the output TypeMaps.
Although this function could theoretically try setting the types of the operands
in the chains to the meet, doing that causes too many issues in real-world code.
Instead, we use 'is_valid_target' to identify which of the given chain types
we could plausibly use as the refined type for the expressions in the chain.
Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing
expressions in the chain to a Literal type. Performing this coercion is sometimes
too aggressive of a narrowing, depending on context.
"""
should_coerce = True
if coerce_only_in_literal_context:
def should_coerce_inner(typ: Type) -> bool:
typ = get_proper_type(typ)
return is_literal_type_like(typ) or (
isinstance(typ, Instance) and typ.type.is_enum
)
should_coerce = any(should_coerce_inner(operand_types[i]) for i in chain_indices)
target: Type | None = None
possible_target_indices = []
for i in chain_indices:
expr_type = operand_types[i]
if should_coerce:
expr_type = coerce_to_literal(expr_type)
if not is_valid_target(get_proper_type(expr_type)):
continue
if target and not is_same_type(target, expr_type):
# We have multiple disjoint target types. So the 'if' branch
# must be unreachable.
return None, {}
target = expr_type
possible_target_indices.append(i)
# There's nothing we can currently infer if none of the operands are valid targets,
# so we end early and infer nothing.
if target is None:
return {}, {}
# If possible, use an unassignable expression as the target.
# We skip refining the type of the target below, so ideally we'd
# want to pick an expression we were going to skip anyways.
singleton_index = -1
for i in possible_target_indices:
if i not in narrowable_operand_indices:
singleton_index = i
# But if none of the possible singletons are unassignable ones, we give up
# and arbitrarily pick the last item, mostly because other parts of the
# type narrowing logic bias towards picking the rightmost item and it'd be
# nice to stay consistent.
#
# That said, it shouldn't matter which index we pick. For example, suppose we
# have this if statement, where 'x' and 'y' both have singleton types:
#
# if x is y:
# reveal_type(x)
# reveal_type(y)
# else:
# reveal_type(x)
# reveal_type(y)
#
# At this point, 'x' and 'y' *must* have the same singleton type: we would have
# ended early in the first for-loop in this function if they weren't.
#
# So, we should always get the same result in the 'if' case no matter which
# index we pick. And while we do end up getting different results in the 'else'
# case depending on the index (e.g. if we pick 'y', then its type stays the same
# while 'x' is narrowed to '<uninhabited>'), this distinction is also moot: mypy
# currently will just mark the whole branch as unreachable if either operand is
# narrowed to <uninhabited>.
if singleton_index == -1:
singleton_index = possible_target_indices[-1]
sum_type_name = None
target = get_proper_type(target)
if isinstance(target, LiteralType) and (
target.is_enum_literal() or isinstance(target.value, bool)
):
sum_type_name = target.fallback.type.fullname
target_type = [TypeRange(target, is_upper_bound=False)]
partial_type_maps = []
for i in chain_indices:
# If we try refining a type against itself, conditional_type_map
# will end up assuming that the 'else' branch is unreachable. This is
# typically not what we want: generally the user will intend for the
# target type to be some fixed 'sentinel' value and will want to refine
# the other exprs against this one instead.
if i == singleton_index:
continue
# Naturally, we can't refine operands which are not permitted to be refined.
if i not in narrowable_operand_indices:
continue
expr = operands[i]
expr_type = coerce_to_literal(operand_types[i])
if sum_type_name is not None:
expr_type = try_expanding_sum_type_to_union(expr_type, sum_type_name)
# We intentionally use 'conditional_types' directly here instead of
# 'self.conditional_types_with_intersection': we only compute ad-hoc
# intersections when working with pure instances.
types = conditional_types(expr_type, target_type)
partial_type_maps.append(conditional_types_to_typemaps(expr, *types))
return reduce_conditional_maps(partial_type_maps)
def refine_away_none_in_comparison(
self,
operands: list[Expression],
operand_types: list[Type],
chain_indices: list[int],
narrowable_operand_indices: AbstractSet[int],
) -> tuple[TypeMap, TypeMap]:
"""Produces conditional type maps refining away None in an identity/equality chain.
For more details about what the different arguments mean, see the
docstring of 'refine_identity_comparison_expression' up above.
"""
non_optional_types = []
for i in chain_indices:
typ = operand_types[i]
if not is_overlapping_none(typ):
non_optional_types.append(typ)
# Make sure we have a mixture of optional and non-optional types.
if len(non_optional_types) == 0 or len(non_optional_types) == len(chain_indices):
return {}, {}
if_map = {}
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if not is_overlapping_none(expr_type):
continue
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
if_map[operands[i]] = remove_optional(expr_type)
return if_map, {}
#
# Helpers
#
@overload
def check_subtype(
self,
subtype: Type,
supertype: Type,
context: Context,
msg: str,
subtype_label: str | None = None,
supertype_label: str | None = None,
*,
notes: list[str] | None = None,
code: ErrorCode | None = None,
outer_context: Context | None = None,
) -> bool:
...
@overload
def check_subtype(
self,
subtype: Type,
supertype: Type,
context: Context,
msg: ErrorMessage,
subtype_label: str | None = None,
supertype_label: str | None = None,
*,
notes: list[str] | None = None,
outer_context: Context | None = None,
) -> bool:
...
def check_subtype(
self,
subtype: Type,
supertype: Type,
context: Context,
msg: str | ErrorMessage,
subtype_label: str | None = None,
supertype_label: str | None = None,
*,
notes: list[str] | None = None,
code: ErrorCode | None = None,
outer_context: Context | None = None,
) -> bool:
"""Generate an error if the subtype is not compatible with supertype."""
if is_subtype(subtype, supertype, options=self.options):
return True
if isinstance(msg, str):
msg = ErrorMessage(msg, code=code)
if self.msg.prefer_simple_messages():
self.fail(msg, context) # Fast path -- skip all fancy logic
return False
orig_subtype = subtype
subtype = get_proper_type(subtype)
orig_supertype = supertype
supertype = get_proper_type(supertype)
if self.msg.try_report_long_tuple_assignment_error(
subtype, supertype, context, msg, subtype_label, supertype_label
):
return False
extra_info: list[str] = []
note_msg = ""
notes = notes or []
if subtype_label is not None or supertype_label is not None:
subtype_str, supertype_str = format_type_distinctly(
orig_subtype, orig_supertype, options=self.options
)
if subtype_label is not None:
extra_info.append(subtype_label + " " + subtype_str)
if supertype_label is not None:
extra_info.append(supertype_label + " " + supertype_str)
note_msg = make_inferred_type_note(
outer_context or context, subtype, supertype, supertype_str
)
if isinstance(subtype, Instance) and isinstance(supertype, Instance):
notes = append_invariance_notes(notes, subtype, supertype)
if extra_info:
msg = msg.with_additional_msg(" (" + ", ".join(extra_info) + ")")
self.fail(msg, context)
for note in notes:
self.msg.note(note, context, code=msg.code)
if note_msg:
self.note(note_msg, context, code=msg.code)
self.msg.maybe_note_concatenate_pos_args(subtype, supertype, context, code=msg.code)
if (
isinstance(supertype, Instance)
and supertype.type.is_protocol
and isinstance(subtype, (CallableType, Instance, TupleType, TypedDictType))
):
self.msg.report_protocol_problems(subtype, supertype, context, code=msg.code)
if isinstance(supertype, CallableType) and isinstance(subtype, Instance):
call = find_member("__call__", subtype, subtype, is_operator=True)
if call:
self.msg.note_call(subtype, call, context, code=msg.code)
if isinstance(subtype, (CallableType, Overloaded)) and isinstance(supertype, Instance):
if supertype.type.is_protocol and "__call__" in supertype.type.protocol_members:
call = find_member("__call__", supertype, subtype, is_operator=True)
assert call is not None
if not is_subtype(subtype, call, options=self.options):
self.msg.note_call(supertype, call, context, code=msg.code)
self.check_possible_missing_await(subtype, supertype, context)
return False
def get_precise_awaitable_type(self, typ: Type, local_errors: ErrorWatcher) -> Type | None:
"""If type implements Awaitable[X] with non-Any X, return X.
In all other cases return None. This method must be called in context
of local_errors.
"""
if isinstance(get_proper_type(typ), PartialType):
# Partial types are special, ignore them here.
return None
try:
aw_type = self.expr_checker.check_awaitable_expr(
typ, Context(), "", ignore_binder=True
)
except KeyError:
# This is a hack to speed up tests by not including Awaitable in all typing stubs.
return None
if local_errors.has_new_errors():
return None
if isinstance(get_proper_type(aw_type), (AnyType, UnboundType)):
return None
return aw_type
@contextmanager
def checking_await_set(self) -> Iterator[None]:
self.checking_missing_await = True
try:
yield
finally:
self.checking_missing_await = False
def check_possible_missing_await(
self, subtype: Type, supertype: Type, context: Context
) -> None:
"""Check if the given type becomes a subtype when awaited."""
if self.checking_missing_await:
# Avoid infinite recursion.
return
with self.checking_await_set(), self.msg.filter_errors() as local_errors:
aw_type = self.get_precise_awaitable_type(subtype, local_errors)
if aw_type is None:
return
if not self.check_subtype(
aw_type, supertype, context, msg=message_registry.INCOMPATIBLE_TYPES
):
return
self.msg.possible_missing_await(context)
def contains_none(self, t: Type) -> bool:
t = get_proper_type(t)
return (
isinstance(t, NoneType)
or (isinstance(t, UnionType) and any(self.contains_none(ut) for ut in t.items))
or (isinstance(t, TupleType) and any(self.contains_none(tt) for tt in t.items))
or (
isinstance(t, Instance)
and bool(t.args)
and any(self.contains_none(it) for it in t.args)
)
)
def named_type(self, name: str) -> Instance:
"""Return an instance type with given name and implicit Any type args.
For example, named_type('builtins.object') produces the 'object' type.
"""
# Assume that the name refers to a type.
sym = self.lookup_qualified(name)
node = sym.node
if isinstance(node, TypeAlias):
assert isinstance(node.target, Instance) # type: ignore[misc]
node = node.target.type
assert isinstance(node, TypeInfo)
any_type = AnyType(TypeOfAny.from_omitted_generics)
return Instance(node, [any_type] * len(node.defn.type_vars))
def named_generic_type(self, name: str, args: list[Type]) -> Instance:
"""Return an instance with the given name and type arguments.
Assume that the number of arguments is correct. Assume that
the name refers to a compatible generic type.
"""
info = self.lookup_typeinfo(name)
args = [remove_instance_last_known_values(arg) for arg in args]
# TODO: assert len(args) == len(info.defn.type_vars)
return Instance(info, args)
def lookup_typeinfo(self, fullname: str) -> TypeInfo:
# Assume that the name refers to a class.
sym = self.lookup_qualified(fullname)
node = sym.node
assert isinstance(node, TypeInfo)
return node
def type_type(self) -> Instance:
"""Return instance type 'type'."""
return self.named_type("builtins.type")
def str_type(self) -> Instance:
"""Return instance type 'str'."""
return self.named_type("builtins.str")
def store_type(self, node: Expression, typ: Type) -> None:
"""Store the type of a node in the type map."""
self._type_maps[-1][node] = typ
def has_type(self, node: Expression) -> bool:
return any(node in m for m in reversed(self._type_maps))
def lookup_type_or_none(self, node: Expression) -> Type | None:
for m in reversed(self._type_maps):
if node in m:
return m[node]
return None
def lookup_type(self, node: Expression) -> Type:
for m in reversed(self._type_maps):
t = m.get(node)
if t is not None:
return t
raise KeyError(node)
def store_types(self, d: dict[Expression, Type]) -> None:
self._type_maps[-1].update(d)
@contextmanager
def local_type_map(self) -> Iterator[dict[Expression, Type]]:
"""Store inferred types into a temporary type map (returned).
This can be used to perform type checking "experiments" without
affecting exported types (which are used by mypyc).
"""
temp_type_map: dict[Expression, Type] = {}
self._type_maps.append(temp_type_map)
yield temp_type_map
self._type_maps.pop()
def in_checked_function(self) -> bool:
"""Should we type-check the current function?
- Yes if --check-untyped-defs is set.
- Yes outside functions.
- Yes in annotated functions.
- No otherwise.
"""
return (
self.options.check_untyped_defs or not self.dynamic_funcs or not self.dynamic_funcs[-1]
)
def lookup(self, name: str) -> SymbolTableNode:
"""Look up a definition from the symbol table with the given name."""
if name in self.globals:
return self.globals[name]
else:
b = self.globals.get("__builtins__", None)
if b:
assert isinstance(b.node, MypyFile)
table = b.node.names
if name in table:
return table[name]
raise KeyError(f"Failed lookup: {name}")
def lookup_qualified(self, name: str) -> SymbolTableNode:
if "." not in name:
return self.lookup(name)
else:
parts = name.split(".")
n = self.modules[parts[0]]
for i in range(1, len(parts) - 1):
sym = n.names.get(parts[i])
assert sym is not None, "Internal error: attempted lookup of unknown name"
assert isinstance(sym.node, MypyFile)
n = sym.node
last = parts[-1]
if last in n.names:
return n.names[last]
elif len(parts) == 2 and parts[0] in ("builtins", "typing"):
fullname = ".".join(parts)
if fullname in SUGGESTED_TEST_FIXTURES:
suggestion = ", e.g. add '[{} fixtures/{}]' to your test".format(
parts[0], SUGGESTED_TEST_FIXTURES[fullname]
)
else:
suggestion = ""
raise KeyError(
"Could not find builtin symbol '{}' (If you are running a "
"test case, use a fixture that "
"defines this symbol{})".format(last, suggestion)
)
else:
msg = "Failed qualified lookup: '{}' (fullname = '{}')."
raise KeyError(msg.format(last, name))
@contextmanager
def enter_partial_types(
self, *, is_function: bool = False, is_class: bool = False
) -> Iterator[None]:
"""Enter a new scope for collecting partial types.
Also report errors for (some) variables which still have partial
types, i.e. we couldn't infer a complete type.
"""
is_local = (self.partial_types and self.partial_types[-1].is_local) or is_function
self.partial_types.append(PartialTypeScope({}, is_function, is_local))
yield
# Don't complain about not being able to infer partials if it is
# at the toplevel (with allow_untyped_globals) or if it is in an
# untyped function being checked with check_untyped_defs.
permissive = (self.options.allow_untyped_globals and not is_local) or (
self.options.check_untyped_defs and self.dynamic_funcs and self.dynamic_funcs[-1]
)
partial_types, _, _ = self.partial_types.pop()
if not self.current_node_deferred:
for var, context in partial_types.items():
# If we require local partial types, there are a few exceptions where
# we fall back to inferring just "None" as the type from a None initializer:
#
# 1. If all happens within a single function this is acceptable, since only
# the topmost function is a separate target in fine-grained incremental mode.
# We primarily want to avoid "splitting" partial types across targets.
#
# 2. A None initializer in the class body if the attribute is defined in a base
# class is fine, since the attribute is already defined and it's currently okay
# to vary the type of an attribute covariantly. The None type will still be
# checked for compatibility with base classes elsewhere. Without this exception
# mypy could require an annotation for an attribute that already has been
# declared in a base class, which would be bad.
allow_none = (
not self.options.local_partial_types
or is_function
or (is_class and self.is_defined_in_base_class(var))
)
if (
allow_none
and isinstance(var.type, PartialType)
and var.type.type is None
and not permissive
):
var.type = NoneType()
else:
if var not in self.partial_reported and not permissive:
self.msg.need_annotation_for_var(var, context, self.options.python_version)
self.partial_reported.add(var)
if var.type:
fixed = fixup_partial_type(var.type)
var.invalid_partial_type = fixed != var.type
var.type = fixed
def handle_partial_var_type(
self, typ: PartialType, is_lvalue: bool, node: Var, context: Context
) -> Type:
"""Handle a reference to a partial type through a var.
(Used by checkexpr and checkmember.)
"""
in_scope, is_local, partial_types = self.find_partial_types_in_all_scopes(node)
if typ.type is None and in_scope:
# 'None' partial type. It has a well-defined type. In an lvalue context
# we want to preserve the knowledge of it being a partial type.
if not is_lvalue:
return NoneType()
else:
return typ
else:
if partial_types is not None and not self.current_node_deferred:
if in_scope:
context = partial_types[node]
if is_local or not self.options.allow_untyped_globals:
self.msg.need_annotation_for_var(
node, context, self.options.python_version
)
self.partial_reported.add(node)
else:
# Defer the node -- we might get a better type in the outer scope
self.handle_cannot_determine_type(node.name, context)
return fixup_partial_type(typ)
def is_defined_in_base_class(self, var: Var) -> bool:
if not var.info:
return False
return var.info.fallback_to_any or any(
base.get(var.name) is not None for base in var.info.mro[1:]
)
def find_partial_types(self, var: Var) -> dict[Var, Context] | None:
"""Look for an active partial type scope containing variable.
A scope is active if assignments in the current context can refine a partial
type originally defined in the scope. This is affected by the local_partial_types
configuration option.
"""
in_scope, _, partial_types = self.find_partial_types_in_all_scopes(var)
if in_scope:
return partial_types
return None
def find_partial_types_in_all_scopes(
self, var: Var
) -> tuple[bool, bool, dict[Var, Context] | None]:
"""Look for partial type scope containing variable.
Return tuple (is the scope active, is the scope a local scope, scope).
"""
for scope in reversed(self.partial_types):
if var in scope.map:
# All scopes within the outermost function are active. Scopes out of
# the outermost function are inactive to allow local reasoning (important
# for fine-grained incremental mode).
disallow_other_scopes = self.options.local_partial_types
if isinstance(var.type, PartialType) and var.type.type is not None and var.info:
# This is an ugly hack to make partial generic self attributes behave
# as if --local-partial-types is always on (because it used to be like this).
disallow_other_scopes = True
scope_active = (
not disallow_other_scopes or scope.is_local == self.partial_types[-1].is_local
)
return scope_active, scope.is_local, scope.map
return False, False, None
def temp_node(self, t: Type, context: Context | None = None) -> TempNode:
"""Create a temporary node with the given, fixed type."""
return TempNode(t, context=context)
def fail(
self, msg: str | ErrorMessage, context: Context, *, code: ErrorCode | None = None
) -> None:
"""Produce an error message."""
if isinstance(msg, ErrorMessage):
self.msg.fail(msg.value, context, code=msg.code)
return
self.msg.fail(msg, context, code=code)
def note(
self,
msg: str | ErrorMessage,
context: Context,
offset: int = 0,
*,
code: ErrorCode | None = None,
) -> None:
"""Produce a note."""
if isinstance(msg, ErrorMessage):
self.msg.note(msg.value, context, code=msg.code)
return
self.msg.note(msg, context, offset=offset, code=code)
def iterable_item_type(
self, it: Instance | CallableType | TypeType | Overloaded, context: Context
) -> Type:
if isinstance(it, Instance):
iterable = map_instance_to_supertype(it, self.lookup_typeinfo("typing.Iterable"))
item_type = iterable.args[0]
if not isinstance(get_proper_type(item_type), AnyType):
# This relies on 'map_instance_to_supertype' returning 'Iterable[Any]'
# in case there is no explicit base class.
return item_type
# Try also structural typing.
return self.analyze_iterable_item_type_without_expression(it, context)[1]
def function_type(self, func: FuncBase) -> FunctionLike:
return function_type(func, self.named_type("builtins.function"))
def push_type_map(self, type_map: TypeMap) -> None:
if type_map is None:
self.binder.unreachable()
else:
for expr, type in type_map.items():
self.binder.put(expr, type)
def infer_issubclass_maps(self, node: CallExpr, expr: Expression) -> tuple[TypeMap, TypeMap]:
"""Infer type restrictions for an expression in issubclass call."""
vartype = self.lookup_type(expr)
type = self.get_isinstance_type(node.args[1])
if isinstance(vartype, TypeVarType):
vartype = vartype.upper_bound
vartype = get_proper_type(vartype)
if isinstance(vartype, UnionType):
union_list = []
for t in get_proper_types(vartype.items):
if isinstance(t, TypeType):
union_list.append(t.item)
else:
# This is an error that should be reported earlier
# if we reach here, we refuse to do any type inference.
return {}, {}
vartype = UnionType(union_list)
elif isinstance(vartype, TypeType):
vartype = vartype.item
elif isinstance(vartype, Instance) and vartype.type.is_metaclass():
vartype = self.named_type("builtins.object")
else:
# Any other object whose type we don't know precisely
# for example, Any or a custom metaclass.
return {}, {} # unknown type
yes_type, no_type = self.conditional_types_with_intersection(vartype, type, expr)
yes_map, no_map = conditional_types_to_typemaps(expr, yes_type, no_type)
yes_map, no_map = map(convert_to_typetype, (yes_map, no_map))
return yes_map, no_map
@overload
def conditional_types_with_intersection(
self,
expr_type: Type,
type_ranges: list[TypeRange] | None,
ctx: Context,
default: None = None,
) -> tuple[Type | None, Type | None]:
...
@overload
def conditional_types_with_intersection(
self, expr_type: Type, type_ranges: list[TypeRange] | None, ctx: Context, default: Type
) -> tuple[Type, Type]:
...
def conditional_types_with_intersection(
self,
expr_type: Type,
type_ranges: list[TypeRange] | None,
ctx: Context,
default: Type | None = None,
) -> tuple[Type | None, Type | None]:
initial_types = conditional_types(expr_type, type_ranges, default)
# For some reason, doing "yes_map, no_map = conditional_types_to_typemaps(...)"
# doesn't work: mypyc will decide that 'yes_map' is of type None if we try.
yes_type: Type | None = initial_types[0]
no_type: Type | None = initial_types[1]
if not isinstance(get_proper_type(yes_type), UninhabitedType) or type_ranges is None:
return yes_type, no_type
# If conditional_types was unable to successfully narrow the expr_type
# using the type_ranges and concluded if-branch is unreachable, we try
# computing it again using a different algorithm that tries to generate
# an ad-hoc intersection between the expr_type and the type_ranges.
proper_type = get_proper_type(expr_type)
if isinstance(proper_type, UnionType):
possible_expr_types = get_proper_types(proper_type.relevant_items())
else:
possible_expr_types = [proper_type]
possible_target_types = []
for tr in type_ranges:
item = get_proper_type(tr.item)
if not isinstance(item, Instance) or tr.is_upper_bound:
return yes_type, no_type
possible_target_types.append(item)
out = []
errors: list[tuple[str, str]] = []
for v in possible_expr_types:
if not isinstance(v, Instance):
return yes_type, no_type
for t in possible_target_types:
intersection = self.intersect_instances((v, t), errors)
if intersection is None:
continue
out.append(intersection)
if not out:
# Only report errors if no element in the union worked.
if self.should_report_unreachable_issues():
for types, reason in errors:
self.msg.impossible_intersection(types, reason, ctx)
return UninhabitedType(), expr_type
new_yes_type = make_simplified_union(out)
return new_yes_type, expr_type
def is_writable_attribute(self, node: Node) -> bool:
"""Check if an attribute is writable"""
if isinstance(node, Var):
if node.is_property and not node.is_settable_property:
return False
return True
elif isinstance(node, OverloadedFuncDef) and node.is_property:
first_item = node.items[0]
assert isinstance(first_item, Decorator)
return first_item.var.is_settable_property
return False
def get_isinstance_type(self, expr: Expression) -> list[TypeRange] | None:
if isinstance(expr, OpExpr) and expr.op == "|":
left = self.get_isinstance_type(expr.left)
right = self.get_isinstance_type(expr.right)
if left is None or right is None:
return None
return left + right
all_types = get_proper_types(flatten_types(self.lookup_type(expr)))
types: list[TypeRange] = []
for typ in all_types:
if isinstance(typ, FunctionLike) and typ.is_type_obj():
# Type variables may be present -- erase them, which is the best
# we can do (outside disallowing them here).
erased_type = erase_typevars(typ.items[0].ret_type)
types.append(TypeRange(erased_type, is_upper_bound=False))
elif isinstance(typ, TypeType):
# Type[A] means "any type that is a subtype of A" rather than "precisely type A"
# we indicate this by setting is_upper_bound flag
types.append(TypeRange(typ.item, is_upper_bound=True))
elif isinstance(typ, Instance) and typ.type.fullname == "builtins.type":
object_type = Instance(typ.type.mro[-1], [])
types.append(TypeRange(object_type, is_upper_bound=True))
elif isinstance(typ, AnyType):
types.append(TypeRange(typ, is_upper_bound=False))
else: # we didn't see an actual type, but rather a variable with unknown value
return None
if not types:
# this can happen if someone has empty tuple as 2nd argument to isinstance
# strictly speaking, we should return UninhabitedType but for simplicity we will simply
# refuse to do any type inference for now
return None
return types
def is_literal_enum(self, n: Expression) -> bool:
"""Returns true if this expression (with the given type context) is an Enum literal.
For example, if we had an enum:
class Foo(Enum):
A = 1
B = 2
...and if the expression 'Foo' referred to that enum within the current type context,
then the expression 'Foo.A' would be a literal enum. However, if we did 'a = Foo.A',
then the variable 'a' would *not* be a literal enum.
We occasionally special-case expressions like 'Foo.A' and treat them as a single primitive
unit for the same reasons we sometimes treat 'True', 'False', or 'None' as a single
primitive unit.
"""
if not isinstance(n, MemberExpr) or not isinstance(n.expr, NameExpr):
return False
parent_type = self.lookup_type_or_none(n.expr)
member_type = self.lookup_type_or_none(n)
if member_type is None or parent_type is None:
return False
parent_type = get_proper_type(parent_type)
member_type = get_proper_type(coerce_to_literal(member_type))
if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType):
return False
if not parent_type.is_type_obj():
return False
return (
member_type.is_enum_literal()
and member_type.fallback.type == parent_type.type_object()
)
def add_any_attribute_to_type(self, typ: Type, name: str) -> Type:
"""Inject an extra attribute with Any type using fallbacks."""
orig_typ = typ
typ = get_proper_type(typ)
any_type = AnyType(TypeOfAny.unannotated)
if isinstance(typ, Instance):
result = typ.copy_with_extra_attr(name, any_type)
# For instances, we erase the possible module name, so that restrictions
# become anonymous types.ModuleType instances, allowing hasattr() to
# have effect on modules.
assert result.extra_attrs is not None
result.extra_attrs.mod_name = None
return result
if isinstance(typ, TupleType):
fallback = typ.partial_fallback.copy_with_extra_attr(name, any_type)
return typ.copy_modified(fallback=fallback)
if isinstance(typ, CallableType):
fallback = typ.fallback.copy_with_extra_attr(name, any_type)
return typ.copy_modified(fallback=fallback)
if isinstance(typ, TypeType) and isinstance(typ.item, Instance):
return TypeType.make_normalized(self.add_any_attribute_to_type(typ.item, name))
if isinstance(typ, TypeVarType):
return typ.copy_modified(
upper_bound=self.add_any_attribute_to_type(typ.upper_bound, name),
values=[self.add_any_attribute_to_type(v, name) for v in typ.values],
)
if isinstance(typ, UnionType):
with_attr, without_attr = self.partition_union_by_attr(typ, name)
return make_simplified_union(
with_attr + [self.add_any_attribute_to_type(typ, name) for typ in without_attr]
)
return orig_typ
def hasattr_type_maps(
self, expr: Expression, source_type: Type, name: str
) -> tuple[TypeMap, TypeMap]:
"""Simple support for hasattr() checks.
Essentially the logic is following:
* In the if branch, keep types that already has a valid attribute as is,
for other inject an attribute with `Any` type.
* In the else branch, remove types that already have a valid attribute,
while keeping the rest.
"""
if self.has_valid_attribute(source_type, name):
return {expr: source_type}, {}
source_type = get_proper_type(source_type)
if isinstance(source_type, UnionType):
_, without_attr = self.partition_union_by_attr(source_type, name)
yes_map = {expr: self.add_any_attribute_to_type(source_type, name)}
return yes_map, {expr: make_simplified_union(without_attr)}
type_with_attr = self.add_any_attribute_to_type(source_type, name)
if type_with_attr != source_type:
return {expr: type_with_attr}, {}
return {}, {}
def partition_union_by_attr(
self, source_type: UnionType, name: str
) -> tuple[list[Type], list[Type]]:
with_attr = []
without_attr = []
for item in source_type.items:
if self.has_valid_attribute(item, name):
with_attr.append(item)
else:
without_attr.append(item)
return with_attr, without_attr
def has_valid_attribute(self, typ: Type, name: str) -> bool:
p_typ = get_proper_type(typ)
if isinstance(p_typ, AnyType):
return False
if isinstance(p_typ, Instance) and p_typ.extra_attrs and p_typ.extra_attrs.mod_name:
# Presence of module_symbol_table means this check will skip ModuleType.__getattr__
module_symbol_table = p_typ.type.names
else:
module_symbol_table = None
with self.msg.filter_errors() as watcher:
analyze_member_access(
name,
typ,
TempNode(AnyType(TypeOfAny.special_form)),
False,
False,
False,
self.msg,
original_type=typ,
chk=self,
# This is not a real attribute lookup so don't mess with deferring nodes.
no_deferral=True,
module_symbol_table=module_symbol_table,
)
return not watcher.has_new_errors()
def get_expression_type(self, node: Expression, type_context: Type | None = None) -> Type:
return self.expr_checker.accept(node, type_context=type_context)
class CollectArgTypeVarTypes(TypeTraverserVisitor):
"""Collects the non-nested argument types in a set."""
def __init__(self) -> None:
self.arg_types: set[TypeVarType] = set()
def visit_type_var(self, t: TypeVarType) -> None:
self.arg_types.add(t)
@overload
def conditional_types(
current_type: Type, proposed_type_ranges: list[TypeRange] | None, default: None = None
) -> tuple[Type | None, Type | None]:
...
@overload
def conditional_types(
current_type: Type, proposed_type_ranges: list[TypeRange] | None, default: Type
) -> tuple[Type, Type]:
...
def conditional_types(
current_type: Type, proposed_type_ranges: list[TypeRange] | None, default: Type | None = None
) -> tuple[Type | None, Type | None]:
"""Takes in the current type and a proposed type of an expression.
Returns a 2-tuple: The first element is the proposed type, if the expression
can be the proposed type. The second element is the type it would hold
if it was not the proposed type, if any. UninhabitedType means unreachable.
None means no new information can be inferred. If default is set it is returned
instead."""
if proposed_type_ranges:
if len(proposed_type_ranges) == 1:
target = proposed_type_ranges[0].item
target = get_proper_type(target)
if isinstance(target, LiteralType) and (
target.is_enum_literal() or isinstance(target.value, bool)
):
enum_name = target.fallback.type.fullname
current_type = try_expanding_sum_type_to_union(current_type, enum_name)
proposed_items = [type_range.item for type_range in proposed_type_ranges]
proposed_type = make_simplified_union(proposed_items)
if isinstance(proposed_type, AnyType):
# We don't really know much about the proposed type, so we shouldn't
# attempt to narrow anything. Instead, we broaden the expr to Any to
# avoid false positives
return proposed_type, default
elif not any(
type_range.is_upper_bound for type_range in proposed_type_ranges
) and is_proper_subtype(current_type, proposed_type, ignore_promotions=True):
# Expression is always of one of the types in proposed_type_ranges
return default, UninhabitedType()
elif not is_overlapping_types(
current_type, proposed_type, prohibit_none_typevar_overlap=True, ignore_promotions=True
):
# Expression is never of any type in proposed_type_ranges
return UninhabitedType(), default
else:
# we can only restrict when the type is precise, not bounded
proposed_precise_type = UnionType.make_union(
[
type_range.item
for type_range in proposed_type_ranges
if not type_range.is_upper_bound
]
)
remaining_type = restrict_subtype_away(current_type, proposed_precise_type)
return proposed_type, remaining_type
else:
# An isinstance check, but we don't understand the type
return current_type, default
def conditional_types_to_typemaps(
expr: Expression, yes_type: Type | None, no_type: Type | None
) -> tuple[TypeMap, TypeMap]:
expr = collapse_walrus(expr)
maps: list[TypeMap] = []
for typ in (yes_type, no_type):
proper_type = get_proper_type(typ)
if isinstance(proper_type, UninhabitedType):
maps.append(None)
elif proper_type is None:
maps.append({})
else:
assert typ is not None
maps.append({expr: typ})
return cast(Tuple[TypeMap, TypeMap], tuple(maps))
def gen_unique_name(base: str, table: SymbolTable) -> str:
"""Generate a name that does not appear in table by appending numbers to base."""
if base not in table:
return base
i = 1
while base + str(i) in table:
i += 1
return base + str(i)
def is_true_literal(n: Expression) -> bool:
"""Returns true if this expression is the 'True' literal/keyword."""
return refers_to_fullname(n, "builtins.True") or isinstance(n, IntExpr) and n.value != 0
def is_false_literal(n: Expression) -> bool:
"""Returns true if this expression is the 'False' literal/keyword."""
return refers_to_fullname(n, "builtins.False") or isinstance(n, IntExpr) and n.value == 0
def is_literal_none(n: Expression) -> bool:
"""Returns true if this expression is the 'None' literal/keyword."""
return isinstance(n, NameExpr) and n.fullname == "builtins.None"
def is_literal_not_implemented(n: Expression) -> bool:
return isinstance(n, NameExpr) and n.fullname == "builtins.NotImplemented"
def _is_empty_generator_function(func: FuncItem) -> bool:
"""
Checks whether a function's body is 'return; yield' (the yield being added only
to promote the function into a generator function).
"""
body = func.body.body
return (
len(body) == 2
and isinstance(ret_stmt := body[0], ReturnStmt)
and (ret_stmt.expr is None or is_literal_none(ret_stmt.expr))
and isinstance(expr_stmt := body[1], ExpressionStmt)
and isinstance(yield_expr := expr_stmt.expr, YieldExpr)
and (yield_expr.expr is None or is_literal_none(yield_expr.expr))
)
def builtin_item_type(tp: Type) -> Type | None:
"""Get the item type of a builtin container.
If 'tp' is not one of the built containers (these includes NamedTuple and TypedDict)
or if the container is not parameterized (like List or List[Any])
return None. This function is used to narrow optional types in situations like this:
x: Optional[int]
if x in (1, 2, 3):
x + 42 # OK
Note: this is only OK for built-in containers, where we know the behavior
of __contains__.
"""
tp = get_proper_type(tp)
if isinstance(tp, Instance):
if tp.type.fullname in [
"builtins.list",
"builtins.tuple",
"builtins.dict",
"builtins.set",
"builtins.frozenset",
"_collections_abc.dict_keys",
"typing.KeysView",
]:
if not tp.args:
# TODO: fix tuple in lib-stub/builtins.pyi (it should be generic).
return None
if not isinstance(get_proper_type(tp.args[0]), AnyType):
return tp.args[0]
elif isinstance(tp, TupleType) and all(
not isinstance(it, AnyType) for it in get_proper_types(tp.items)
):
return make_simplified_union(tp.items) # this type is not externally visible
elif isinstance(tp, TypedDictType):
# TypedDict always has non-optional string keys. Find the key type from the Mapping
# base class.
for base in tp.fallback.type.mro:
if base.fullname == "typing.Mapping":
return map_instance_to_supertype(tp.fallback, base).args[0]
assert False, "No Mapping base class found for TypedDict fallback"
return None
def and_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
"""Calculate what information we can learn from the truth of (e1 and e2)
in terms of the information that we can learn from the truth of e1 and
the truth of e2.
"""
if m1 is None or m2 is None:
# One of the conditions can never be true.
return None
# Both conditions can be true; combine the information. Anything
# we learn from either conditions's truth is valid. If the same
# expression's type is refined by both conditions, we somewhat
# arbitrarily give precedence to m2. (In the future, we could use
# an intersection type.)
result = m2.copy()
m2_keys = {literal_hash(n2) for n2 in m2}
for n1 in m1:
if literal_hash(n1) not in m2_keys:
result[n1] = m1[n1]
return result
def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
"""Calculate what information we can learn from the truth of (e1 or e2)
in terms of the information that we can learn from the truth of e1 and
the truth of e2.
"""
if m1 is None:
return m2
if m2 is None:
return m1
# Both conditions can be true. Combine information about
# expressions whose type is refined by both conditions. (We do not
# learn anything about expressions whose type is refined by only
# one condition.)
result: dict[Expression, Type] = {}
for n1 in m1:
for n2 in m2:
if literal_hash(n1) == literal_hash(n2):
result[n1] = make_simplified_union([m1[n1], m2[n2]])
return result
def reduce_conditional_maps(type_maps: list[tuple[TypeMap, TypeMap]]) -> tuple[TypeMap, TypeMap]:
"""Reduces a list containing pairs of if/else TypeMaps into a single pair.
We "and" together all of the if TypeMaps and "or" together the else TypeMaps. So
for example, if we had the input:
[
({x: TypeIfX, shared: TypeIfShared1}, {x: TypeElseX, shared: TypeElseShared1}),
({y: TypeIfY, shared: TypeIfShared2}, {y: TypeElseY, shared: TypeElseShared2}),
]
...we'd return the output:
(
{x: TypeIfX, y: TypeIfY, shared: PseudoIntersection[TypeIfShared1, TypeIfShared2]},
{shared: Union[TypeElseShared1, TypeElseShared2]},
)
...where "PseudoIntersection[X, Y] == Y" because mypy actually doesn't understand intersections
yet, so we settle for just arbitrarily picking the right expr's type.
We only retain the shared expression in the 'else' case because we don't actually know
whether x was refined or y was refined -- only just that one of the two was refined.
"""
if len(type_maps) == 0:
return {}, {}
elif len(type_maps) == 1:
return type_maps[0]
else:
final_if_map, final_else_map = type_maps[0]
for if_map, else_map in type_maps[1:]:
final_if_map = and_conditional_maps(final_if_map, if_map)
final_else_map = or_conditional_maps(final_else_map, else_map)
return final_if_map, final_else_map
def convert_to_typetype(type_map: TypeMap) -> TypeMap:
converted_type_map: dict[Expression, Type] = {}
if type_map is None:
return None
for expr, typ in type_map.items():
t = typ
if isinstance(t, TypeVarType):
t = t.upper_bound
# TODO: should we only allow unions of instances as per PEP 484?
if not isinstance(get_proper_type(t), (UnionType, Instance)):
# unknown type; error was likely reported earlier
return {}
converted_type_map[expr] = TypeType.make_normalized(typ)
return converted_type_map
def flatten(t: Expression) -> list[Expression]:
"""Flatten a nested sequence of tuples/lists into one list of nodes."""
if isinstance(t, (TupleExpr, ListExpr)):
return [b for a in t.items for b in flatten(a)]
elif isinstance(t, StarExpr):
return flatten(t.expr)
else:
return [t]
def flatten_types(t: Type) -> list[Type]:
"""Flatten a nested sequence of tuples into one list of nodes."""
t = get_proper_type(t)
if isinstance(t, TupleType):
return [b for a in t.items for b in flatten_types(a)]
elif is_named_instance(t, "builtins.tuple"):
return [t.args[0]]
else:
return [t]
def expand_func(defn: FuncItem, map: dict[TypeVarId, Type]) -> FuncItem:
visitor = TypeTransformVisitor(map)
ret = visitor.node(defn)
assert isinstance(ret, FuncItem)
return ret
class TypeTransformVisitor(TransformVisitor):
def __init__(self, map: dict[TypeVarId, Type]) -> None:
super().__init__()
self.map = map
def type(self, type: Type) -> Type:
return expand_type(type, self.map)
def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool:
"""Can a single call match both t and s, based just on positional argument counts?"""
min_args = max(t.min_args, s.min_args)
max_args = min(t.max_possible_positional_args(), s.max_possible_positional_args())
return min_args <= max_args
def is_unsafe_overlapping_overload_signatures(
signature: CallableType, other: CallableType, class_type_vars: list[TypeVarLikeType]
) -> bool:
"""Check if two overloaded signatures are unsafely overlapping or partially overlapping.
We consider two functions 's' and 't' to be unsafely overlapping if both
of the following are true:
1. s's parameters are all more precise or partially overlapping with t's
2. s's return type is NOT a subtype of t's.
Assumes that 'signature' appears earlier in the list of overload
alternatives then 'other' and that their argument counts are overlapping.
"""
# Try detaching callables from the containing class so that all TypeVars
# are treated as being free.
#
# This lets us identify cases where the two signatures use completely
# incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars
# test case.
signature = detach_callable(signature, class_type_vars)
other = detach_callable(other, class_type_vars)
# Note: We repeat this check twice in both directions due to a slight
# asymmetry in 'is_callable_compatible'. When checking for partial overlaps,
# we attempt to unify 'signature' and 'other' both against each other.
#
# If 'signature' cannot be unified with 'other', we end early. However,
# if 'other' cannot be modified with 'signature', the function continues
# using the older version of 'other'.
#
# This discrepancy is unfortunately difficult to get rid of, so we repeat the
# checks twice in both directions for now.
#
# Note that we ignore possible overlap between type variables and None. This
# is technically unsafe, but unsafety is tiny and this prevents some common
# use cases like:
# @overload
# def foo(x: None) -> None: ..
# @overload
# def foo(x: T) -> Foo[T]: ...
return is_callable_compatible(
signature,
other,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_compat_return=lambda l, r: not is_subtype_no_promote(l, r),
ignore_return=False,
check_args_covariantly=True,
allow_partial_overlap=True,
no_unify_none=True,
) or is_callable_compatible(
other,
signature,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_compat_return=lambda l, r: not is_subtype_no_promote(r, l),
ignore_return=False,
check_args_covariantly=False,
allow_partial_overlap=True,
no_unify_none=True,
)
def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) -> CallableType:
"""Ensures that the callable's type variables are 'detached' and independent of the context.
A callable normally keeps track of the type variables it uses within its 'variables' field.
However, if the callable is from a method and that method is using a class type variable,
the callable will not keep track of that type variable since it belongs to the class.
This function will traverse the callable and find all used type vars and add them to the
variables field if it isn't already present.
The caller can then unify on all type variables whether the callable is originally from
the class or not."""
if not class_type_vars:
# Fast path, nothing to update.
return typ
seen_type_vars = set()
for t in typ.arg_types + [typ.ret_type]:
seen_type_vars |= set(get_type_vars(t))
return typ.copy_modified(
variables=list(typ.variables) + [tv for tv in class_type_vars if tv in seen_type_vars]
)
def overload_can_never_match(signature: CallableType, other: CallableType) -> bool:
"""Check if the 'other' method can never be matched due to 'signature'.
This can happen if signature's parameters are all strictly broader then
other's parameters.
Assumes that both signatures have overlapping argument counts.
"""
# The extra erasure is needed to prevent spurious errors
# in situations where an `Any` overload is used as a fallback
# for an overload with type variables. The spurious error appears
# because the type variables turn into `Any` during unification in
# the below subtype check and (surprisingly?) `is_proper_subtype(Any, Any)`
# returns `True`.
# TODO: find a cleaner solution instead of this ad-hoc erasure.
exp_signature = expand_type(
signature, {tvar.id: erase_def_to_union_or_bound(tvar) for tvar in signature.variables}
)
return is_callable_compatible(
exp_signature, other, is_compat=is_more_precise, ignore_return=True
)
def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool:
"""Does t have wider arguments than s?"""
# TODO should an overload with additional items be allowed to be more
# general than one with fewer items (or just one item)?
if isinstance(t, CallableType):
if isinstance(s, CallableType):
return is_callable_compatible(t, s, is_compat=is_proper_subtype, ignore_return=True)
elif isinstance(t, FunctionLike):
if isinstance(s, FunctionLike):
if len(t.items) == len(s.items):
return all(
is_same_arg_prefix(items, itemt) for items, itemt in zip(t.items, s.items)
)
return False
def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool:
return is_callable_compatible(
t,
s,
is_compat=is_same_type,
ignore_return=True,
check_args_covariantly=True,
ignore_pos_arg_names=True,
)
def infer_operator_assignment_method(typ: Type, operator: str) -> tuple[bool, str]:
"""Determine if operator assignment on given value type is in-place, and the method name.
For example, if operator is '+', return (True, '__iadd__') or (False, '__add__')
depending on which method is supported by the type.
"""
typ = get_proper_type(typ)
method = operators.op_methods[operator]
if isinstance(typ, Instance):
if operator in operators.ops_with_inplace_method:
inplace_method = "__i" + method[2:]
if typ.type.has_readable_member(inplace_method):
return True, inplace_method
return False, method
def is_valid_inferred_type(typ: Type, is_lvalue_final: bool = False) -> bool:
"""Is an inferred type valid and needs no further refinement?
Examples of invalid types include the None type (when we are not assigning
None to a final lvalue) or List[<uninhabited>].
When not doing strict Optional checking, all types containing None are
invalid. When doing strict Optional checking, only None and types that are
incompletely defined (i.e. contain UninhabitedType) are invalid.
"""
proper_type = get_proper_type(typ)
if isinstance(proper_type, NoneType):
# If the lvalue is final, we may immediately infer NoneType when the
# initializer is None.
#
# If not, we want to defer making this decision. The final inferred
# type could either be NoneType or an Optional type, depending on
# the context. This resolution happens in leave_partial_types when
# we pop a partial types scope.
return is_lvalue_final
elif isinstance(proper_type, UninhabitedType):
return False
return not typ.accept(InvalidInferredTypes())
class InvalidInferredTypes(BoolTypeQuery):
"""Find type components that are not valid for an inferred type.
These include <Erased> type, and any <nothing> types resulting from failed
(ambiguous) type inference.
"""
def __init__(self) -> None:
super().__init__(ANY_STRATEGY)
def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
return t.ambiguous
def visit_erased_type(self, t: ErasedType) -> bool:
# This can happen inside a lambda.
return True
def visit_type_var(self, t: TypeVarType) -> bool:
# This is needed to prevent leaking into partial types during
# multi-step type inference.
return t.id.is_meta_var()
class SetNothingToAny(TypeTranslator):
"""Replace all ambiguous <nothing> types with Any (to avoid spurious extra errors)."""
def visit_uninhabited_type(self, t: UninhabitedType) -> Type:
if t.ambiguous:
return AnyType(TypeOfAny.from_error)
return t
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Target of the alias cannot be an ambiguous <nothing>, so we just
# replace the arguments.
return t.copy_modified(args=[a.accept(self) for a in t.args])
def is_node_static(node: Node | None) -> bool | None:
"""Find out if a node describes a static function method."""
if isinstance(node, FuncDef):
return node.is_static
if isinstance(node, Var):
return node.is_staticmethod
return None
class CheckerScope:
# We keep two stacks combined, to maintain the relative order
stack: list[TypeInfo | FuncItem | MypyFile]
def __init__(self, module: MypyFile) -> None:
self.stack = [module]
def top_function(self) -> FuncItem | None:
for e in reversed(self.stack):
if isinstance(e, FuncItem):
return e
return None
def top_non_lambda_function(self) -> FuncItem | None:
for e in reversed(self.stack):
if isinstance(e, FuncItem) and not isinstance(e, LambdaExpr):
return e
return None
def active_class(self) -> TypeInfo | None:
if isinstance(self.stack[-1], TypeInfo):
return self.stack[-1]
return None
def enclosing_class(self) -> TypeInfo | None:
"""Is there a class *directly* enclosing this function?"""
top = self.top_function()
assert top, "This method must be called from inside a function"
index = self.stack.index(top)
assert index, "CheckerScope stack must always start with a module"
enclosing = self.stack[index - 1]
if isinstance(enclosing, TypeInfo):
return enclosing
return None
def active_self_type(self) -> Instance | TupleType | None:
"""An instance or tuple type representing the current class.
This returns None unless we are in class body or in a method.
In particular, inside a function nested in method this returns None.
"""
info = self.active_class()
if not info and self.top_function():
info = self.enclosing_class()
if info:
return fill_typevars(info)
return None
@contextmanager
def push_function(self, item: FuncItem) -> Iterator[None]:
self.stack.append(item)
yield
self.stack.pop()
@contextmanager
def push_class(self, info: TypeInfo) -> Iterator[None]:
self.stack.append(info)
yield
self.stack.pop()
TKey = TypeVar("TKey")
TValue = TypeVar("TValue")
class DisjointDict(Generic[TKey, TValue]):
"""An variation of the union-find algorithm/data structure where instead of keeping
track of just disjoint sets, we keep track of disjoint dicts -- keep track of multiple
Set[Key] -> Set[Value] mappings, where each mapping's keys are guaranteed to be disjoint.
This data structure is currently used exclusively by 'group_comparison_operands' below
to merge chains of '==' and 'is' comparisons when two or more chains use the same expression
in best-case O(n), where n is the number of operands.
Specifically, the `add_mapping()` function and `items()` functions will take on average
O(k + v) and O(n) respectively, where k and v are the number of keys and values we're adding
for a given chain. Note that k <= n and v <= n.
We hit these average/best-case scenarios for most user code: e.g. when the user has just
a single chain like 'a == b == c == d == ...' or multiple disjoint chains like
'a==b < c==d < e==f < ...'. (Note that a naive iterative merging would be O(n^2) for
the latter case).
In comparison, this data structure will make 'group_comparison_operands' have a worst-case
runtime of O(n*log(n)): 'add_mapping()' and 'items()' are worst-case O(k*log(n) + v) and
O(k*log(n)) respectively. This happens only in the rare case where the user keeps repeatedly
making disjoint mappings before merging them in a way that persistently dodges the path
compression optimization in '_lookup_root_id', which would end up constructing a single
tree of height log_2(n). This makes root lookups no longer amoritized constant time when we
finally call 'items()'.
"""
def __init__(self) -> None:
# Each key maps to a unique ID
self._key_to_id: dict[TKey, int] = {}
# Each id points to the parent id, forming a forest of upwards-pointing trees. If the
# current id already is the root, it points to itself. We gradually flatten these trees
# as we perform root lookups: eventually all nodes point directly to its root.
self._id_to_parent_id: dict[int, int] = {}
# Each root id in turn maps to the set of values.
self._root_id_to_values: dict[int, set[TValue]] = {}
def add_mapping(self, keys: set[TKey], values: set[TValue]) -> None:
"""Adds a 'Set[TKey] -> Set[TValue]' mapping. If there already exists a mapping
containing one or more of the given keys, we merge the input mapping with the old one.
Note that the given set of keys must be non-empty -- otherwise, nothing happens.
"""
if not keys:
return
subtree_roots = [self._lookup_or_make_root_id(key) for key in keys]
new_root = subtree_roots[0]
root_values = self._root_id_to_values[new_root]
root_values.update(values)
for subtree_root in subtree_roots[1:]:
if subtree_root == new_root or subtree_root not in self._root_id_to_values:
continue
self._id_to_parent_id[subtree_root] = new_root
root_values.update(self._root_id_to_values.pop(subtree_root))
def items(self) -> list[tuple[set[TKey], set[TValue]]]:
"""Returns all disjoint mappings in key-value pairs."""
root_id_to_keys: dict[int, set[TKey]] = {}
for key in self._key_to_id:
root_id = self._lookup_root_id(key)
if root_id not in root_id_to_keys:
root_id_to_keys[root_id] = set()
root_id_to_keys[root_id].add(key)
output = []
for root_id, keys in root_id_to_keys.items():
output.append((keys, self._root_id_to_values[root_id]))
return output
def _lookup_or_make_root_id(self, key: TKey) -> int:
if key in self._key_to_id:
return self._lookup_root_id(key)
else:
new_id = len(self._key_to_id)
self._key_to_id[key] = new_id
self._id_to_parent_id[new_id] = new_id
self._root_id_to_values[new_id] = set()
return new_id
def _lookup_root_id(self, key: TKey) -> int:
i = self._key_to_id[key]
while i != self._id_to_parent_id[i]:
# Optimization: make keys directly point to their grandparents to speed up
# future traversals. This prevents degenerate trees of height n from forming.
new_parent = self._id_to_parent_id[self._id_to_parent_id[i]]
self._id_to_parent_id[i] = new_parent
i = new_parent
return i
def group_comparison_operands(
pairwise_comparisons: Iterable[tuple[str, Expression, Expression]],
operand_to_literal_hash: Mapping[int, Key],
operators_to_group: set[str],
) -> list[tuple[str, list[int]]]:
"""Group a series of comparison operands together chained by any operand
in the 'operators_to_group' set. All other pairwise operands are kept in
groups of size 2.
For example, suppose we have the input comparison expression:
x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
If we get these expressions in a pairwise way (e.g. by calling ComparisionExpr's
'pairwise()' method), we get the following as input:
[('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('<', x3, x4),
('is', x4, x5), ('is', x5, x6), ('is not', x6, x7), ('is not', x7, x8)]
If `operators_to_group` is the set {'==', 'is'}, this function will produce
the following "simplified operator list":
[("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
Note that (a) we yield *indices* to the operands rather then the operand
expressions themselves and that (b) operands used in a consecutive chain
of '==' or 'is' are grouped together.
If two of these chains happen to contain operands with the same underlying
literal hash (e.g. are assignable and correspond to the same expression),
we combine those chains together. For example, if we had:
same == x < y == same
...and if 'operand_to_literal_hash' contained the same values for the indices
0 and 3, we'd produce the following output:
[("==", [0, 1, 2, 3]), ("<", [1, 2])]
But if the 'operand_to_literal_hash' did *not* contain an entry, we'd instead
default to returning:
[("==", [0, 1]), ("<", [1, 2]), ("==", [2, 3])]
This function is currently only used to assist with type-narrowing refinements
and is extracted out to a helper function so we can unit test it.
"""
groups: dict[str, DisjointDict[Key, int]] = {op: DisjointDict() for op in operators_to_group}
simplified_operator_list: list[tuple[str, list[int]]] = []
last_operator: str | None = None
current_indices: set[int] = set()
current_hashes: set[Key] = set()
for i, (operator, left_expr, right_expr) in enumerate(pairwise_comparisons):
if last_operator is None:
last_operator = operator
if current_indices and (operator != last_operator or operator not in operators_to_group):
# If some of the operands in the chain are assignable, defer adding it: we might
# end up needing to merge it with other chains that appear later.
if not current_hashes:
simplified_operator_list.append((last_operator, sorted(current_indices)))
else:
groups[last_operator].add_mapping(current_hashes, current_indices)
last_operator = operator
current_indices = set()
current_hashes = set()
# Note: 'i' corresponds to the left operand index, so 'i + 1' is the
# right operand.
current_indices.add(i)
current_indices.add(i + 1)
# We only ever want to combine operands/combine chains for these operators
if operator in operators_to_group:
left_hash = operand_to_literal_hash.get(i)
if left_hash is not None:
current_hashes.add(left_hash)
right_hash = operand_to_literal_hash.get(i + 1)
if right_hash is not None:
current_hashes.add(right_hash)
if last_operator is not None:
if not current_hashes:
simplified_operator_list.append((last_operator, sorted(current_indices)))
else:
groups[last_operator].add_mapping(current_hashes, current_indices)
# Now that we know which chains happen to contain the same underlying expressions
# and can be merged together, add in this info back to the output.
for operator, disjoint_dict in groups.items():
for keys, indices in disjoint_dict.items():
simplified_operator_list.append((operator, sorted(indices)))
# For stability, reorder list by the first operand index to appear
simplified_operator_list.sort(key=lambda item: item[1][0])
return simplified_operator_list
def is_typed_callable(c: Type | None) -> bool:
c = get_proper_type(c)
if not c or not isinstance(c, CallableType):
return False
return not all(
isinstance(t, AnyType) and t.type_of_any == TypeOfAny.unannotated
for t in get_proper_types(c.arg_types + [c.ret_type])
)
def is_untyped_decorator(typ: Type | None) -> bool:
typ = get_proper_type(typ)
if not typ:
return True
elif isinstance(typ, CallableType):
return not is_typed_callable(typ)
elif isinstance(typ, Instance):
method = typ.type.get_method("__call__")
if method:
if isinstance(method, Decorator):
return is_untyped_decorator(method.func.type) or is_untyped_decorator(
method.var.type
)
if isinstance(method.type, Overloaded):
return any(is_untyped_decorator(item) for item in method.type.items)
else:
return not is_typed_callable(method.type)
else:
return False
elif isinstance(typ, Overloaded):
return any(is_untyped_decorator(item) for item in typ.items)
return True
def is_static(func: FuncBase | Decorator) -> bool:
if isinstance(func, Decorator):
return is_static(func.func)
elif isinstance(func, FuncBase):
return func.is_static
assert False, f"Unexpected func type: {type(func)}"
def is_property(defn: SymbolNode) -> bool:
if isinstance(defn, Decorator):
return defn.func.is_property
if isinstance(defn, OverloadedFuncDef):
if defn.items and isinstance(defn.items[0], Decorator):
return defn.items[0].func.is_property
return False
def get_property_type(t: ProperType) -> ProperType:
if isinstance(t, CallableType):
return get_proper_type(t.ret_type)
if isinstance(t, Overloaded):
return get_proper_type(t.items[0].ret_type)
return t
def is_subtype_no_promote(left: Type, right: Type) -> bool:
return is_subtype(left, right, ignore_promotions=True)
def is_overlapping_types_no_promote_no_uninhabited_no_none(left: Type, right: Type) -> bool:
# For the purpose of unsafe overload checks we consider list[<nothing>] and list[int]
# non-overlapping. This is consistent with how we treat list[int] and list[str] as
# non-overlapping, despite [] belongs to both. Also this will prevent false positives
# for failed type inference during unification.
return is_overlapping_types(
left,
right,
ignore_promotions=True,
ignore_uninhabited=True,
prohibit_none_typevar_overlap=True,
)
def is_private(node_name: str) -> bool:
"""Check if node is private to class definition."""
return node_name.startswith("__") and not node_name.endswith("__")
def is_string_literal(typ: Type) -> bool:
strs = try_getting_str_literals_from_type(typ)
return strs is not None and len(strs) == 1
def has_bool_item(typ: ProperType) -> bool:
"""Return True if type is 'bool' or a union with a 'bool' item."""
if is_named_instance(typ, "builtins.bool"):
return True
if isinstance(typ, UnionType):
return any(is_named_instance(item, "builtins.bool") for item in typ.items)
return False
def collapse_walrus(e: Expression) -> Expression:
"""If an expression is an AssignmentExpr, pull out the assignment target.
We don't make any attempt to pull out all the targets in code like `x := (y := z)`.
We could support narrowing those if that sort of code turns out to be common.
"""
if isinstance(e, AssignmentExpr):
return e.target
return e
def find_last_var_assignment_line(n: Node, v: Var) -> int:
"""Find the highest line number of a potential assignment to variable within node.
This supports local and global variables.
Return -1 if no assignment was found.
"""
visitor = VarAssignVisitor(v)
n.accept(visitor)
return visitor.last_line
class VarAssignVisitor(TraverserVisitor):
def __init__(self, v: Var) -> None:
self.last_line = -1
self.lvalue = False
self.var_node = v
def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
self.lvalue = True
for lv in s.lvalues:
lv.accept(self)
self.lvalue = False
def visit_name_expr(self, e: NameExpr) -> None:
if self.lvalue and e.node is self.var_node:
self.last_line = max(self.last_line, e.line)
def visit_member_expr(self, e: MemberExpr) -> None:
old_lvalue = self.lvalue
self.lvalue = False
super().visit_member_expr(e)
self.lvalue = old_lvalue
def visit_index_expr(self, e: IndexExpr) -> None:
old_lvalue = self.lvalue
self.lvalue = False
super().visit_index_expr(e)
self.lvalue = old_lvalue
def visit_with_stmt(self, s: WithStmt) -> None:
self.lvalue = True
for lv in s.target:
if lv is not None:
lv.accept(self)
self.lvalue = False
s.body.accept(self)
def visit_for_stmt(self, s: ForStmt) -> None:
self.lvalue = True
s.index.accept(self)
self.lvalue = False
s.body.accept(self)
if s.else_body:
s.else_body.accept(self)
def visit_assignment_expr(self, e: AssignmentExpr) -> None:
self.lvalue = True
e.target.accept(self)
self.lvalue = False
e.value.accept(self)
def visit_as_pattern(self, p: AsPattern) -> None:
if p.pattern is not None:
p.pattern.accept(self)
if p.name is not None:
self.lvalue = True
p.name.accept(self)
self.lvalue = False
def visit_starred_pattern(self, p: StarredPattern) -> None:
if p.capture is not None:
self.lvalue = True
p.capture.accept(self)
self.lvalue = False