Consolidate descriptor handling in checkmember.py (#18831) This is not a pure refactoring, but almost. Right now we are in a weird situation where we have two inconsistencies: * `__set__()` is handled in `checker.py` while `__get__()` is handled in `checkmember.py` * rules for when to use binder are slightly different between descriptors and settable properties. This PR fixes these two things. As a nice bonus we should get free support for unions in `__set__()`.
diff --git a/mypy/checker.py b/mypy/checker.py index ac4b247..fb43587 100644 --- a/mypy/checker.py +++ b/mypy/checker.py
@@ -3170,7 +3170,7 @@ ) else: self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=") - lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue) + lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue, rvalue) # If we're assigning to __getattr__ or similar methods, check that the signature is # valid. if isinstance(lvalue, NameExpr) and lvalue.node: @@ -4263,7 +4263,9 @@ else: self.msg.type_not_iterable(rvalue_type, context) - def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, Var | None]: + def check_lvalue( + self, lvalue: Lvalue, rvalue: Expression | None = None + ) -> tuple[Type | None, IndexExpr | None, Var | None]: lvalue_type = None index_lvalue = None inferred = None @@ -4281,7 +4283,7 @@ elif isinstance(lvalue, IndexExpr): index_lvalue = lvalue elif isinstance(lvalue, MemberExpr): - lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True) + lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True, rvalue) self.store_type(lvalue, lvalue_type) elif isinstance(lvalue, NameExpr): lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True) @@ -4552,12 +4554,8 @@ Return the inferred rvalue_type, inferred lvalue_type, and whether to use the binder for this assignment. - - Note: this method exists here and not in checkmember.py, because we need to take - care about interaction between binder and __set__(). """ instance_type = get_proper_type(instance_type) - attribute_type = get_proper_type(attribute_type) # Descriptors don't participate in class-attribute access if (isinstance(instance_type, FunctionLike) and instance_type.is_type_obj()) or isinstance( instance_type, TypeType @@ -4569,8 +4567,8 @@ get_lvalue_type = self.expr_checker.analyze_ordinary_member_access( lvalue, is_lvalue=False ) - use_binder = is_same_type(get_lvalue_type, attribute_type) +<<<<<<< HEAD if not isinstance(attribute_type, Instance): # TODO: support __set__() for union types. rvalue_type = self.check_simple_assignment(attribute_type, rvalue, context) @@ -4664,13 +4662,23 @@ return AnyType(TypeOfAny.from_error), get_type, False set_type = inferred_dunder_set_type.arg_types[1] +======= +>>>>>>> df9ddfcac (Consolidate descriptor handling in checkmember.py (#18831)) # Special case: if the rvalue_type is a subtype of both '__get__' and '__set__' types, # and '__get__' type is narrower than '__set__', then we invoke the binder to narrow type # by this assignment. Technically, this is not safe, but in practice this is # what a user expects. +<<<<<<< HEAD rvalue_type = self.check_simple_assignment(set_type, rvalue, context) infer = is_subtype(rvalue_type, get_type) and is_subtype(get_type, set_type) return rvalue_type if infer else set_type, get_type, infer +======= + rvalue_type, _ = self.check_simple_assignment(attribute_type, rvalue, context) + infer = is_subtype(rvalue_type, get_lvalue_type) and is_subtype( + get_lvalue_type, attribute_type + ) + return rvalue_type if infer else attribute_type, attribute_type, infer +>>>>>>> df9ddfcac (Consolidate descriptor handling in checkmember.py (#18831)) def check_indexed_assignment( self, lvalue: IndexExpr, rvalue: Expression, context: Context
diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 1017009..3fa4df2 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py
@@ -3327,8 +3327,13 @@ self.chk.warn_deprecated(e.node, e) return narrowed - def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type: - """Analyse member expression or member lvalue.""" + def analyze_ordinary_member_access( + self, e: MemberExpr, is_lvalue: bool, rvalue: Expression | None = None + ) -> Type: + """Analyse member expression or member lvalue. + + An rvalue can be provided optionally to infer better setter type when is_lvalue is True. + """ if e.kind is not None: # This is a reference to a module attribute. return self.analyze_ref_expr(e) @@ -3360,6 +3365,7 @@ in_literal_context=self.is_literal_context(), module_symbol_table=module_symbol_table, is_self=is_self, + rvalue=rvalue, ) return member_type
diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 0994d0d..9b7c802 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py
@@ -23,6 +23,7 @@ ArgKind, Context, Decorator, + Expression, FuncBase, FuncDef, IndexExpr, @@ -101,6 +102,7 @@ module_symbol_table: SymbolTable | None = None, no_deferral: bool = False, is_self: bool = False, + rvalue: Expression | None = None, ) -> None: self.is_lvalue = is_lvalue self.is_super = is_super @@ -113,6 +115,9 @@ self.module_symbol_table = module_symbol_table self.no_deferral = no_deferral self.is_self = is_self + if rvalue is not None: + assert is_lvalue + self.rvalue = rvalue def named_type(self, name: str) -> Instance: return self.chk.named_type(name) @@ -139,6 +144,7 @@ self_type=self.self_type, module_symbol_table=self.module_symbol_table, no_deferral=self.no_deferral, + rvalue=self.rvalue, ) if messages is not None: mx.msg = messages @@ -168,6 +174,7 @@ module_symbol_table: SymbolTable | None = None, no_deferral: bool = False, is_self: bool = False, + rvalue: Expression | None = None, ) -> Type: """Return the type of attribute 'name' of 'typ'. @@ -186,11 +193,14 @@ of 'original_type'. 'original_type' is always preserved as the 'typ' type used in the initial, non-recursive call. The 'self_type' is a component of 'original_type' to which generic self should be bound (a narrower type that has a fallback to instance). - Currently this is used only for union types. + Currently, this is used only for union types. - 'module_symbol_table' is passed to this function if 'typ' is actually a module + 'module_symbol_table' is passed to this function if 'typ' is actually a module, and we want to keep track of the available attributes of the module (since they are not available via the type object directly) + + 'rvalue' can be provided optionally to infer better setter type when is_lvalue is True, + most notably this helps for descriptors with overloaded __set__() method. """ mx = MemberContext( is_lvalue=is_lvalue, @@ -204,6 +214,7 @@ module_symbol_table=module_symbol_table, no_deferral=no_deferral, is_self=is_self, + rvalue=rvalue, ) result = _analyze_member_access(name, typ, mx, override_info) possible_literal = get_proper_type(result) @@ -629,9 +640,7 @@ msg.cant_assign_to_final(name, attr_assign=True, ctx=ctx) -def analyze_descriptor_access( - descriptor_type: Type, mx: MemberContext, *, assignment: bool = False -) -> Type: +def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type: """Type check descriptor access. Arguments: @@ -639,7 +648,7 @@ (the type of ``f`` in ``a.f`` when ``f`` is a descriptor). mx: The current member access context. Return: - The return type of the appropriate ``__get__`` overload for the descriptor. + The return type of the appropriate ``__get__/__set__`` overload for the descriptor. """ instance_type = get_proper_type(mx.self_type) orig_descriptor_type = descriptor_type @@ -648,15 +657,24 @@ if isinstance(descriptor_type, UnionType): # Map the access over union types return make_simplified_union( - [ - analyze_descriptor_access(typ, mx, assignment=assignment) - for typ in descriptor_type.items - ] + [analyze_descriptor_access(typ, mx) for typ in descriptor_type.items] ) elif not isinstance(descriptor_type, Instance): return orig_descriptor_type - if not descriptor_type.type.has_readable_member("__get__"): + if not mx.is_lvalue and not descriptor_type.type.has_readable_member("__get__"): + return orig_descriptor_type + + # We do this check first to accommodate for descriptors with only __set__ method. + # If there is no __set__, we type-check that the assigned value matches + # the return type of __get__. This doesn't match the python semantics, + # (which allow you to override the descriptor with any value), but preserves + # the type of accessing the attribute (even after the override). + if mx.is_lvalue and descriptor_type.type.has_readable_member("__set__"): + return analyze_descriptor_assign(descriptor_type, mx) + + if mx.is_lvalue and not descriptor_type.type.has_readable_member("__get__"): + # This turned out to be not a descriptor after all. return orig_descriptor_type dunder_get = descriptor_type.type.get_method("__get__") @@ -713,11 +731,10 @@ callable_name=callable_name, ) - if not assignment: - mx.chk.check_deprecated(dunder_get, mx.context) - mx.chk.warn_deprecated_overload_item( - dunder_get, mx.context, target=inferred_dunder_get_type, selftype=descriptor_type - ) + mx.chk.check_deprecated(dunder_get, mx.context) + mx.chk.warn_deprecated_overload_item( + dunder_get, mx.context, target=inferred_dunder_get_type, selftype=descriptor_type + ) inferred_dunder_get_type = get_proper_type(inferred_dunder_get_type) if isinstance(inferred_dunder_get_type, AnyType): @@ -736,6 +753,79 @@ return inferred_dunder_get_type.ret_type +def analyze_descriptor_assign(descriptor_type: Instance, mx: MemberContext) -> Type: + instance_type = get_proper_type(mx.self_type) + dunder_set = descriptor_type.type.get_method("__set__") + if dunder_set is None: + mx.chk.fail( + message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format( + descriptor_type.str_with_options(mx.msg.options) + ), + mx.context, + ) + return AnyType(TypeOfAny.from_error) + + bound_method = analyze_decorator_or_funcbase_access( + defn=dunder_set, + itype=descriptor_type, + name="__set__", + mx=mx.copy_modified(is_lvalue=False, self_type=descriptor_type), + ) + typ = map_instance_to_supertype(descriptor_type, dunder_set.info) + dunder_set_type = expand_type_by_instance(bound_method, typ) + + callable_name = mx.chk.expr_checker.method_fullname(descriptor_type, "__set__") + rvalue = mx.rvalue or TempNode(AnyType(TypeOfAny.special_form), context=mx.context) + dunder_set_type = mx.chk.expr_checker.transform_callee_type( + callable_name, + dunder_set_type, + [TempNode(instance_type, context=mx.context), rvalue], + [ARG_POS, ARG_POS], + mx.context, + object_type=descriptor_type, + ) + + # For non-overloaded setters, the result should be type-checked like a regular assignment. + # Hence, we first only try to infer the type by using the rvalue as type context. + type_context = rvalue + with mx.msg.filter_errors(): + _, inferred_dunder_set_type = mx.chk.expr_checker.check_call( + dunder_set_type, + [TempNode(instance_type, context=mx.context), type_context], + [ARG_POS, ARG_POS], + mx.context, + object_type=descriptor_type, + callable_name=callable_name, + ) + + # And now we in fact type check the call, to show errors related to wrong arguments + # count, etc., replacing the type context for non-overloaded setters only. + inferred_dunder_set_type = get_proper_type(inferred_dunder_set_type) + if isinstance(inferred_dunder_set_type, CallableType): + type_context = TempNode(AnyType(TypeOfAny.special_form), context=mx.context) + mx.chk.expr_checker.check_call( + dunder_set_type, + [TempNode(instance_type, context=mx.context), type_context], + [ARG_POS, ARG_POS], + mx.context, + object_type=descriptor_type, + callable_name=callable_name, + ) + + # Search for possible deprecations: + mx.chk.check_deprecated(dunder_set, mx.context) + mx.chk.warn_deprecated_overload_item( + dunder_set, mx.context, target=inferred_dunder_set_type, selftype=descriptor_type + ) + + # In the following cases, a message already will have been recorded in check_call. + if (not isinstance(inferred_dunder_set_type, CallableType)) or ( + len(inferred_dunder_set_type.arg_types) < 2 + ): + return AnyType(TypeOfAny.from_error) + return inferred_dunder_set_type.arg_types[1] + + def is_instance_var(var: Var) -> bool: """Return if var is an instance variable according to PEP 526.""" return ( @@ -820,6 +910,7 @@ # A property cannot have an overloaded type => the cast is fine. assert isinstance(expanded_signature, CallableType) if var.is_settable_property and mx.is_lvalue and var.setter_type is not None: + # TODO: use check_call() to infer better type, same as for __set__(). result = expanded_signature.arg_types[0] else: result = expanded_signature.ret_type @@ -832,7 +923,7 @@ result = AnyType(TypeOfAny.special_form) fullname = f"{var.info.fullname}.{name}" hook = mx.chk.plugin.get_attribute_hook(fullname) - if result and not mx.is_lvalue and not implicit: + if result and not (implicit or var.info.is_protocol and is_instance_var(var)): result = analyze_descriptor_access(result, mx) if hook: result = hook( @@ -1106,6 +1197,7 @@ result = add_class_tvars( t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars ) + # __set__ is not called on class objects. if not mx.is_lvalue: result = analyze_descriptor_access(result, mx)