diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9990caaeb7a1..c8746602922a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -126,6 +126,7 @@ is_same_type, is_subtype, non_method_protocol_members, + solve_as_subtype, ) from mypy.traverser import ( all_name_and_member_expressions, @@ -6143,29 +6144,45 @@ def is_valid_var_arg(self, typ: Type) -> bool: def is_valid_keyword_var_arg(self, typ: Type) -> bool: """Is a type valid as a **kwargs argument?""" typ = get_proper_type(typ) - return ( - ( - # This is a little ad hoc, ideally we would have a map_instance_to_supertype - # that worked for protocols - isinstance(typ, Instance) - and typ.type.fullname == "builtins.dict" - and is_subtype(typ.args[0], self.named_type("builtins.str")) - ) - or isinstance(typ, ParamSpecType) - or is_subtype( - typ, - self.chk.named_generic_type( - "_typeshed.SupportsKeysAndGetItem", - [self.named_type("builtins.str"), AnyType(TypeOfAny.special_form)], - ), - ) - or is_subtype( - typ, - self.chk.named_generic_type( - "_typeshed.SupportsKeysAndGetItem", [UninhabitedType(), UninhabitedType()] - ), - ) + + # factorize over unions + if isinstance(typ, UnionType): + return all(self.is_valid_keyword_var_arg(item) for item in typ.items) + + if isinstance(typ, AnyType): + return True + + if isinstance(typ, ParamSpecType): + return typ.flavor == ParamSpecFlavor.KWARGS + + # fast path for builtins.dict + if isinstance(typ, Instance) and typ.type.fullname == "builtins.dict": + return is_subtype(typ.args[0], self.named_type("builtins.str")) + + # fast fail if not SupportsKeysAndGetItem[Any, Any] + any_type = AnyType(TypeOfAny.from_omitted_generics) + if not is_subtype( + typ, + self.chk.named_generic_type("_typeshed.SupportsKeysAndGetItem", [any_type, any_type]), + ): + return False + + # Check if 'typ' is a SupportsKeysAndGetItem[T, Any] for some T <: str + # Note: is_subtype(typ, SupportsKeysAndGetItem[str, Any])` is too harsh + # since SupportsKeysAndGetItem is invariant in the key type parameter. + + # create a TypeVar and template type + T = TypeVarType( + "T", + "T", + id=TypeVarId(-1, namespace=""), + values=[], + upper_bound=self.named_type("builtins.str"), + default=any_type, ) + template = self.chk.named_generic_type("_typeshed.SupportsKeysAndGetItem", [T, any_type]) + + return solve_as_subtype(typ, template) is not None def not_ready_callback(self, name: str, context: Context) -> None: """Called when we can't infer the type of a variable because it's not ready yet. diff --git a/mypy/messages.py b/mypy/messages.py index bbcc93ebfb25..3fc876dbbc07 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1401,7 +1401,7 @@ def invalid_var_arg(self, typ: Type, context: Context) -> None: def invalid_keyword_var_arg(self, typ: Type, is_mapping: bool, context: Context) -> None: typ = get_proper_type(typ) - if isinstance(typ, Instance) and is_mapping: + if is_mapping: self.fail("Argument after ** must have string keys", context, code=codes.ARG_TYPE) else: self.fail( diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 350d57a7e4ad..66b25785b4a3 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -8,6 +8,7 @@ import mypy.constraints import mypy.typeops from mypy.checker_state import checker_state +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints from mypy.erasetype import erase_type from mypy.expandtype import ( expand_self_type, @@ -35,6 +36,7 @@ ) from mypy.options import Options from mypy.state import state +from mypy.typeops import get_all_type_vars from mypy.types import ( MYPYC_NATIVE_INT_NAMES, TUPLE_LIKE_INSTANCE_NAMES, @@ -2311,3 +2313,45 @@ def is_erased_instance(t: Instance) -> bool: elif not isinstance(get_proper_type(arg), AnyType): return False return True + + +def solve_as_subtype(typ: Type, target: Type) -> Type | None: + """Solves type variables in `target` so that `typ` becomes a subtype of `target`. + + Returns: + None: if the mapping is not possible. + Type: the mapped type if the mapping is possible. + + Examples: + (list[int], Iterable[T]) -> Iterable[int] + (list[list[int]], Iterable[list[T]]) -> Iterable[list[int]] + (dict[str, int], Mapping[K, int]) -> Mapping[str, int] + (list[int], Mapping[K, V]) -> None + """ + + # 1. get type vars of target + tvars = get_all_type_vars(target) + + # fast path: if no type vars, just check subtype + if not tvars: + return target if is_subtype(typ, target) else None + + from mypy.solve import solve_constraints + + # 2. determine constraints + constraints: list[Constraint] = infer_constraints(target, typ, SUPERTYPE_OF) + for tvar in tvars: + # need to manually include these because solve_constraints ignores them + # apparently + constraints.append(Constraint(tvar, SUBTYPE_OF, tvar.upper_bound)) + + # 3. solve constraints + solution, _ = solve_constraints(tvars, constraints) + + if None in solution: + return None + + # 4. build resulting Type by substituting type vars with solution + env = {tvar.id: s for tvar, s in zip(tvars, cast("list[Type]", solution))} + target = expand_type(target, env) + return target if is_subtype(typ, target) else None diff --git a/test-data/unit/check-kwargs.test b/test-data/unit/check-kwargs.test index 4099716bcf6b..3caba651f2a1 100644 --- a/test-data/unit/check-kwargs.test +++ b/test-data/unit/check-kwargs.test @@ -572,11 +572,44 @@ main:41: error: Argument 1 to "foo" has incompatible type "**dict[str, str]"; ex [builtins fixtures/dict.pyi] [case testLiteralKwargs] -from typing import Any, Literal -kw: dict[Literal["a", "b"], Any] -def func(a, b): ... -func(**kw) - -badkw: dict[Literal["one", 1], Any] -func(**badkw) # E: Argument after ** must have string keys +from typing import Literal, Mapping, Iterable, Union +def func(a: int, b: int) -> None: ... + +class GOOD_KW: + def keys(self) -> Iterable[Literal["a", "b"]]: ... + def __getitem__(self, key: str) -> int: ... + +class BAD_KW: + def keys(self) -> Iterable[Literal["one", 1]]: ... + def __getitem__(self, key: str) -> int: ... + +def test( + good_kw: GOOD_KW, + bad_kw: BAD_KW, + good_dict: dict[Literal["a", "b"], int], + bad_dict: dict[Literal["one", 1], int], + good_mapping: Mapping[Literal["a", "b"], int], + bad_mapping: Mapping[Literal["one", 1], int], +) -> None: + func(**good_kw) + func(**bad_kw) # E: Argument after ** must have string keys + func(**good_dict) + func(**bad_dict) # E: Argument after ** must have string keys + func(**good_mapping) + func(**bad_mapping) # E: Argument after ** must have string keys + +def test_union( + good_kw: Union[GOOD_KW, dict[str, int]], + bad_kw: Union[BAD_KW, dict[str, int]], + good_dict: Union[dict[Literal["a", "b"], int], dict[str, int]], + bad_dict: Union[dict[Literal["one", 1], int], dict[str, int]], + good_mapping: Union[Mapping[Literal["a", "b"], int], dict[str, int]], + bad_mapping: Union[Mapping[Literal["one", 1], int], dict[str, int]], +) -> None: + func(**good_kw) + func(**bad_kw) # E: Argument after ** must have string keys + func(**good_dict) + func(**bad_dict) # E: Argument after ** must have string keys + func(**good_mapping) + func(**bad_mapping) # E: Argument after ** must have string keys [builtins fixtures/dict.pyi] diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 41a6c5b33cb9..9c11f4041092 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -461,14 +461,16 @@ class C(Generic[P, P2]): self.m1(*args, **kwargs) self.m2(*args, **kwargs) # E: Argument 1 to "m2" of "C" has incompatible type "*P.args"; expected "P2.args" \ # E: Argument 2 to "m2" of "C" has incompatible type "**P.kwargs"; expected "P2.kwargs" - self.m1(*kwargs, **args) # E: Argument 1 to "m1" of "C" has incompatible type "*P.kwargs"; expected "P.args" \ - # E: Argument 2 to "m1" of "C" has incompatible type "**P.args"; expected "P.kwargs" + self.m1(*kwargs, **args) # E: Argument after ** must be a mapping, not "P.args" \ + # E: Argument 1 to "m1" of "C" has incompatible type "*P.kwargs"; expected "P.args" \ + # E: Argument 2 to "m1" of "C" has incompatible type "**P.args"; expected "P.kwargs" self.m3(*args, **kwargs) # E: Argument 1 to "m3" of "C" has incompatible type "*P.args"; expected "int" \ # E: Argument 2 to "m3" of "C" has incompatible type "**P.kwargs"; expected "int" self.m4(*args, **kwargs) # E: Argument 1 to "m4" of "C" has incompatible type "*P.args"; expected "int" \ # E: Argument 2 to "m4" of "C" has incompatible type "**P.kwargs"; expected "int" - self.m1(*args, **args) # E: Argument 2 to "m1" of "C" has incompatible type "**P.args"; expected "P.kwargs" + self.m1(*args, **args) # E: Argument after ** must be a mapping, not "P.args" \ + # E: Argument 2 to "m1" of "C" has incompatible type "**P.args"; expected "P.kwargs" self.m1(*kwargs, **kwargs) # E: Argument 1 to "m1" of "C" has incompatible type "*P.kwargs"; expected "P.args" def m2(self, *args: P2.args, **kwargs: P2.kwargs) -> None: @@ -479,6 +481,10 @@ class C(Generic[P, P2]): def m4(self, x: int) -> None: pass + + + + [builtins fixtures/dict.pyi] [case testParamSpecOverUnannotatedDecorator]