Consolidate descriptor handling in checkmember.py (#18831)

This is not a pure refactoring, but almost. Right now we are in a weird
situation where we have two inconsistencies:
* `__set__()` is handled in `checker.py` while `__get__()` is handled in
`checkmember.py`
* rules for when to use binder are slightly different between
descriptors and settable properties.

This PR fixes these two things. As a nice bonus we should get free
support for unions in `__set__()`.
diff --git a/mypy/checker.py b/mypy/checker.py
index ac4b247..fb43587 100644
--- a/mypy/checker.py
+++ b/mypy/checker.py
@@ -3170,7 +3170,7 @@
             )
         else:
             self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=")
-            lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue)
+            lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue, rvalue)
             # If we're assigning to __getattr__ or similar methods, check that the signature is
             # valid.
             if isinstance(lvalue, NameExpr) and lvalue.node:
@@ -4263,7 +4263,9 @@
         else:
             self.msg.type_not_iterable(rvalue_type, context)
 
-    def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, Var | None]:
+    def check_lvalue(
+        self, lvalue: Lvalue, rvalue: Expression | None = None
+    ) -> tuple[Type | None, IndexExpr | None, Var | None]:
         lvalue_type = None
         index_lvalue = None
         inferred = None
@@ -4281,7 +4283,7 @@
         elif isinstance(lvalue, IndexExpr):
             index_lvalue = lvalue
         elif isinstance(lvalue, MemberExpr):
-            lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True)
+            lvalue_type = self.expr_checker.analyze_ordinary_member_access(lvalue, True, rvalue)
             self.store_type(lvalue, lvalue_type)
         elif isinstance(lvalue, NameExpr):
             lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True)
@@ -4552,12 +4554,8 @@
 
         Return the inferred rvalue_type, inferred lvalue_type, and whether to use the binder
         for this assignment.
-
-        Note: this method exists here and not in checkmember.py, because we need to take
-        care about interaction between binder and __set__().
         """
         instance_type = get_proper_type(instance_type)
-        attribute_type = get_proper_type(attribute_type)
         # Descriptors don't participate in class-attribute access
         if (isinstance(instance_type, FunctionLike) and instance_type.is_type_obj()) or isinstance(
             instance_type, TypeType
@@ -4569,8 +4567,8 @@
             get_lvalue_type = self.expr_checker.analyze_ordinary_member_access(
                 lvalue, is_lvalue=False
             )
-            use_binder = is_same_type(get_lvalue_type, attribute_type)
 
+<<<<<<< HEAD
         if not isinstance(attribute_type, Instance):
             # TODO: support __set__() for union types.
             rvalue_type = self.check_simple_assignment(attribute_type, rvalue, context)
@@ -4664,13 +4662,23 @@
             return AnyType(TypeOfAny.from_error), get_type, False
 
         set_type = inferred_dunder_set_type.arg_types[1]
+=======
+>>>>>>> df9ddfcac (Consolidate descriptor handling in checkmember.py (#18831))
         # Special case: if the rvalue_type is a subtype of both '__get__' and '__set__' types,
         # and '__get__' type is narrower than '__set__', then we invoke the binder to narrow type
         # by this assignment. Technically, this is not safe, but in practice this is
         # what a user expects.
+<<<<<<< HEAD
         rvalue_type = self.check_simple_assignment(set_type, rvalue, context)
         infer = is_subtype(rvalue_type, get_type) and is_subtype(get_type, set_type)
         return rvalue_type if infer else set_type, get_type, infer
+=======
+        rvalue_type, _ = self.check_simple_assignment(attribute_type, rvalue, context)
+        infer = is_subtype(rvalue_type, get_lvalue_type) and is_subtype(
+            get_lvalue_type, attribute_type
+        )
+        return rvalue_type if infer else attribute_type, attribute_type, infer
+>>>>>>> df9ddfcac (Consolidate descriptor handling in checkmember.py (#18831))
 
     def check_indexed_assignment(
         self, lvalue: IndexExpr, rvalue: Expression, context: Context
diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py
index 1017009..3fa4df2 100644
--- a/mypy/checkexpr.py
+++ b/mypy/checkexpr.py
@@ -3327,8 +3327,13 @@
         self.chk.warn_deprecated(e.node, e)
         return narrowed
 
-    def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type:
-        """Analyse member expression or member lvalue."""
+    def analyze_ordinary_member_access(
+        self, e: MemberExpr, is_lvalue: bool, rvalue: Expression | None = None
+    ) -> Type:
+        """Analyse member expression or member lvalue.
+
+        An rvalue can be provided optionally to infer better setter type when is_lvalue is True.
+        """
         if e.kind is not None:
             # This is a reference to a module attribute.
             return self.analyze_ref_expr(e)
@@ -3360,6 +3365,7 @@
                 in_literal_context=self.is_literal_context(),
                 module_symbol_table=module_symbol_table,
                 is_self=is_self,
+                rvalue=rvalue,
             )
 
             return member_type
diff --git a/mypy/checkmember.py b/mypy/checkmember.py
index 0994d0d..9b7c802 100644
--- a/mypy/checkmember.py
+++ b/mypy/checkmember.py
@@ -23,6 +23,7 @@
     ArgKind,
     Context,
     Decorator,
+    Expression,
     FuncBase,
     FuncDef,
     IndexExpr,
@@ -101,6 +102,7 @@
         module_symbol_table: SymbolTable | None = None,
         no_deferral: bool = False,
         is_self: bool = False,
+        rvalue: Expression | None = None,
     ) -> None:
         self.is_lvalue = is_lvalue
         self.is_super = is_super
@@ -113,6 +115,9 @@
         self.module_symbol_table = module_symbol_table
         self.no_deferral = no_deferral
         self.is_self = is_self
+        if rvalue is not None:
+            assert is_lvalue
+        self.rvalue = rvalue
 
     def named_type(self, name: str) -> Instance:
         return self.chk.named_type(name)
@@ -139,6 +144,7 @@
             self_type=self.self_type,
             module_symbol_table=self.module_symbol_table,
             no_deferral=self.no_deferral,
+            rvalue=self.rvalue,
         )
         if messages is not None:
             mx.msg = messages
@@ -168,6 +174,7 @@
     module_symbol_table: SymbolTable | None = None,
     no_deferral: bool = False,
     is_self: bool = False,
+    rvalue: Expression | None = None,
 ) -> Type:
     """Return the type of attribute 'name' of 'typ'.
 
@@ -186,11 +193,14 @@
     of 'original_type'. 'original_type' is always preserved as the 'typ' type used in
     the initial, non-recursive call. The 'self_type' is a component of 'original_type'
     to which generic self should be bound (a narrower type that has a fallback to instance).
-    Currently this is used only for union types.
+    Currently, this is used only for union types.
 
-    'module_symbol_table' is passed to this function if 'typ' is actually a module
+    'module_symbol_table' is passed to this function if 'typ' is actually a module,
     and we want to keep track of the available attributes of the module (since they
     are not available via the type object directly)
+
+    'rvalue' can be provided optionally to infer better setter type when is_lvalue is True,
+    most notably this helps for descriptors with overloaded __set__() method.
     """
     mx = MemberContext(
         is_lvalue=is_lvalue,
@@ -204,6 +214,7 @@
         module_symbol_table=module_symbol_table,
         no_deferral=no_deferral,
         is_self=is_self,
+        rvalue=rvalue,
     )
     result = _analyze_member_access(name, typ, mx, override_info)
     possible_literal = get_proper_type(result)
@@ -629,9 +640,7 @@
             msg.cant_assign_to_final(name, attr_assign=True, ctx=ctx)
 
 
-def analyze_descriptor_access(
-    descriptor_type: Type, mx: MemberContext, *, assignment: bool = False
-) -> Type:
+def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
     """Type check descriptor access.
 
     Arguments:
@@ -639,7 +648,7 @@
             (the type of ``f`` in ``a.f`` when ``f`` is a descriptor).
         mx: The current member access context.
     Return:
-        The return type of the appropriate ``__get__`` overload for the descriptor.
+        The return type of the appropriate ``__get__/__set__`` overload for the descriptor.
     """
     instance_type = get_proper_type(mx.self_type)
     orig_descriptor_type = descriptor_type
@@ -648,15 +657,24 @@
     if isinstance(descriptor_type, UnionType):
         # Map the access over union types
         return make_simplified_union(
-            [
-                analyze_descriptor_access(typ, mx, assignment=assignment)
-                for typ in descriptor_type.items
-            ]
+            [analyze_descriptor_access(typ, mx) for typ in descriptor_type.items]
         )
     elif not isinstance(descriptor_type, Instance):
         return orig_descriptor_type
 
-    if not descriptor_type.type.has_readable_member("__get__"):
+    if not mx.is_lvalue and not descriptor_type.type.has_readable_member("__get__"):
+        return orig_descriptor_type
+
+    # We do this check first to accommodate for descriptors with only __set__ method.
+    # If there is no __set__, we type-check that the assigned value matches
+    # the return type of __get__. This doesn't match the python semantics,
+    # (which allow you to override the descriptor with any value), but preserves
+    # the type of accessing the attribute (even after the override).
+    if mx.is_lvalue and descriptor_type.type.has_readable_member("__set__"):
+        return analyze_descriptor_assign(descriptor_type, mx)
+
+    if mx.is_lvalue and not descriptor_type.type.has_readable_member("__get__"):
+        # This turned out to be not a descriptor after all.
         return orig_descriptor_type
 
     dunder_get = descriptor_type.type.get_method("__get__")
@@ -713,11 +731,10 @@
         callable_name=callable_name,
     )
 
-    if not assignment:
-        mx.chk.check_deprecated(dunder_get, mx.context)
-        mx.chk.warn_deprecated_overload_item(
-            dunder_get, mx.context, target=inferred_dunder_get_type, selftype=descriptor_type
-        )
+    mx.chk.check_deprecated(dunder_get, mx.context)
+    mx.chk.warn_deprecated_overload_item(
+        dunder_get, mx.context, target=inferred_dunder_get_type, selftype=descriptor_type
+    )
 
     inferred_dunder_get_type = get_proper_type(inferred_dunder_get_type)
     if isinstance(inferred_dunder_get_type, AnyType):
@@ -736,6 +753,79 @@
     return inferred_dunder_get_type.ret_type
 
 
+def analyze_descriptor_assign(descriptor_type: Instance, mx: MemberContext) -> Type:
+    instance_type = get_proper_type(mx.self_type)
+    dunder_set = descriptor_type.type.get_method("__set__")
+    if dunder_set is None:
+        mx.chk.fail(
+            message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format(
+                descriptor_type.str_with_options(mx.msg.options)
+            ),
+            mx.context,
+        )
+        return AnyType(TypeOfAny.from_error)
+
+    bound_method = analyze_decorator_or_funcbase_access(
+        defn=dunder_set,
+        itype=descriptor_type,
+        name="__set__",
+        mx=mx.copy_modified(is_lvalue=False, self_type=descriptor_type),
+    )
+    typ = map_instance_to_supertype(descriptor_type, dunder_set.info)
+    dunder_set_type = expand_type_by_instance(bound_method, typ)
+
+    callable_name = mx.chk.expr_checker.method_fullname(descriptor_type, "__set__")
+    rvalue = mx.rvalue or TempNode(AnyType(TypeOfAny.special_form), context=mx.context)
+    dunder_set_type = mx.chk.expr_checker.transform_callee_type(
+        callable_name,
+        dunder_set_type,
+        [TempNode(instance_type, context=mx.context), rvalue],
+        [ARG_POS, ARG_POS],
+        mx.context,
+        object_type=descriptor_type,
+    )
+
+    # For non-overloaded setters, the result should be type-checked like a regular assignment.
+    # Hence, we first only try to infer the type by using the rvalue as type context.
+    type_context = rvalue
+    with mx.msg.filter_errors():
+        _, inferred_dunder_set_type = mx.chk.expr_checker.check_call(
+            dunder_set_type,
+            [TempNode(instance_type, context=mx.context), type_context],
+            [ARG_POS, ARG_POS],
+            mx.context,
+            object_type=descriptor_type,
+            callable_name=callable_name,
+        )
+
+    # And now we in fact type check the call, to show errors related to wrong arguments
+    # count, etc., replacing the type context for non-overloaded setters only.
+    inferred_dunder_set_type = get_proper_type(inferred_dunder_set_type)
+    if isinstance(inferred_dunder_set_type, CallableType):
+        type_context = TempNode(AnyType(TypeOfAny.special_form), context=mx.context)
+    mx.chk.expr_checker.check_call(
+        dunder_set_type,
+        [TempNode(instance_type, context=mx.context), type_context],
+        [ARG_POS, ARG_POS],
+        mx.context,
+        object_type=descriptor_type,
+        callable_name=callable_name,
+    )
+
+    # Search for possible deprecations:
+    mx.chk.check_deprecated(dunder_set, mx.context)
+    mx.chk.warn_deprecated_overload_item(
+        dunder_set, mx.context, target=inferred_dunder_set_type, selftype=descriptor_type
+    )
+
+    # In the following cases, a message already will have been recorded in check_call.
+    if (not isinstance(inferred_dunder_set_type, CallableType)) or (
+        len(inferred_dunder_set_type.arg_types) < 2
+    ):
+        return AnyType(TypeOfAny.from_error)
+    return inferred_dunder_set_type.arg_types[1]
+
+
 def is_instance_var(var: Var) -> bool:
     """Return if var is an instance variable according to PEP 526."""
     return (
@@ -820,6 +910,7 @@
                     # A property cannot have an overloaded type => the cast is fine.
                     assert isinstance(expanded_signature, CallableType)
                     if var.is_settable_property and mx.is_lvalue and var.setter_type is not None:
+                        # TODO: use check_call() to infer better type, same as for __set__().
                         result = expanded_signature.arg_types[0]
                     else:
                         result = expanded_signature.ret_type
@@ -832,7 +923,7 @@
         result = AnyType(TypeOfAny.special_form)
     fullname = f"{var.info.fullname}.{name}"
     hook = mx.chk.plugin.get_attribute_hook(fullname)
-    if result and not mx.is_lvalue and not implicit:
+    if result and not (implicit or var.info.is_protocol and is_instance_var(var)):
         result = analyze_descriptor_access(result, mx)
     if hook:
         result = hook(
@@ -1106,6 +1197,7 @@
         result = add_class_tvars(
             t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars
         )
+        # __set__ is not called on class objects.
         if not mx.is_lvalue:
             result = analyze_descriptor_access(result, mx)