Fix strict optional handling in dataclasses (#15571)
There were few cases when someone forgot to call `strict_optional_set()`
in dataclasses plugin, let's move the calls directly to two places where
they are needed for typeops. This may cause a tiny perf regression, but
is much more robust in terms of preventing bugs.
diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py
index 9e05449..efa1338 100644
--- a/mypy/plugins/dataclasses.py
+++ b/mypy/plugins/dataclasses.py
@@ -104,6 +104,7 @@
info: TypeInfo,
kw_only: bool,
is_neither_frozen_nor_nonfrozen: bool,
+ api: SemanticAnalyzerPluginInterface,
) -> None:
self.name = name
self.alias = alias
@@ -116,6 +117,7 @@
self.info = info
self.kw_only = kw_only
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
+ self._api = api
def to_argument(self, current_info: TypeInfo) -> Argument:
arg_kind = ARG_POS
@@ -138,7 +140,10 @@
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
- return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
+ with state.strict_optional_set(self._api.options.strict_optional):
+ return expand_type(
+ self.type, {self.info.self_type.id: fill_typevars(current_info)}
+ )
return self.type
def to_var(self, current_info: TypeInfo) -> Var:
@@ -165,13 +170,14 @@
) -> DataclassAttribute:
data = data.copy()
typ = deserialize_and_fixup_type(data.pop("type"), api)
- return cls(type=typ, info=info, **data)
+ return cls(type=typ, info=info, **data, api=api)
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type."""
if self.type is not None:
- self.type = map_type_from_supertype(self.type, sub_type, self.info)
+ with state.strict_optional_set(self._api.options.strict_optional):
+ self.type = map_type_from_supertype(self.type, sub_type, self.info)
class DataclassTransformer:
@@ -230,12 +236,11 @@
and ("__init__" not in info.names or info.names["__init__"].plugin_generated)
and attributes
):
- with state.strict_optional_set(self._api.options.strict_optional):
- args = [
- attr.to_argument(info)
- for attr in attributes
- if attr.is_in_init and not self._is_kw_only_type(attr.type)
- ]
+ args = [
+ attr.to_argument(info)
+ for attr in attributes
+ if attr.is_in_init and not self._is_kw_only_type(attr.type)
+ ]
if info.fallback_to_any:
# Make positional args optional since we don't know their order.
@@ -355,8 +360,7 @@
self._add_dataclass_fields_magic_attribute()
if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
- with state.strict_optional_set(self._api.options.strict_optional):
- self._add_internal_replace_method(attributes)
+ self._add_internal_replace_method(attributes)
if "__post_init__" in info.names:
self._add_internal_post_init_method(attributes)
@@ -546,8 +550,7 @@
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
- with state.strict_optional_set(self._api.options.strict_optional):
- attr.expand_typevar_from_subtype(cls.info)
+ attr.expand_typevar_from_subtype(cls.info)
found_attrs[name] = attr
sym_node = cls.info.names.get(name)
@@ -693,6 +696,7 @@
is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass(
cls.info
),
+ api=self._api,
)
all_attrs = list(found_attrs.values())
diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test
index e4eecc3..abc0f6a 100644
--- a/test-data/unit/pythoneval.test
+++ b/test-data/unit/pythoneval.test
@@ -2094,7 +2094,6 @@
[out]
[case testDataclassReplaceOptional]
-# flags: --strict-optional
from dataclasses import dataclass, replace
from typing import Optional
@@ -2107,5 +2106,18 @@
a2 = replace(a, x=None) # OK
reveal_type(a2)
[out]
-_testDataclassReplaceOptional.py:10: note: Revealed type is "_testDataclassReplaceOptional.A"
-_testDataclassReplaceOptional.py:12: note: Revealed type is "_testDataclassReplaceOptional.A"
+_testDataclassReplaceOptional.py:9: note: Revealed type is "_testDataclassReplaceOptional.A"
+_testDataclassReplaceOptional.py:11: note: Revealed type is "_testDataclassReplaceOptional.A"
+
+[case testDataclassStrictOptionalAlwaysSet]
+from dataclasses import dataclass
+from typing import Callable, Optional
+
+@dataclass
+class Description:
+ name_fn: Callable[[Optional[int]], Optional[str]]
+
+def f(d: Description) -> None:
+ reveal_type(d.name_fn)
+[out]
+_testDataclassStrictOptionalAlwaysSet.py:9: note: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]"