Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use namespaces for function type variables #17311

Merged
merged 9 commits into from
Jun 5, 2024
4 changes: 3 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,7 +2171,9 @@ def bind_and_map_method(
def get_op_other_domain(self, tp: FunctionLike) -> Type | None:
if isinstance(tp, CallableType):
if tp.arg_kinds and tp.arg_kinds[0] == ARG_POS:
return tp.arg_types[0]
# For generic methods, domain comparison is tricky, as a first
# approximation erase all remaining type variables to bounds.
return erase_typevars(tp.arg_types[0], {v.id for v in tp.variables})
return None
elif isinstance(tp, Overloaded):
raw_items = [self.get_op_other_domain(it) for it in tp.items]
Expand Down
13 changes: 7 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
Expand Down Expand Up @@ -4933,7 +4934,7 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
tv = TypeVarType(
"T",
"T",
id=-1,
id=TypeVarId(-1, namespace="<lst>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
Expand Down Expand Up @@ -5164,15 +5165,15 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
kt = TypeVarType(
"KT",
"KT",
id=-1,
id=TypeVarId(-1, namespace="<dict>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
vt = TypeVarType(
"VT",
"VT",
id=-2,
id=TypeVarId(-2, namespace="<dict>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
Expand Down Expand Up @@ -5564,7 +5565,7 @@ def check_generator_or_comprehension(
tv = TypeVarType(
"T",
"T",
id=-1,
id=TypeVarId(-1, namespace="<genexp>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
Expand All @@ -5591,15 +5592,15 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:
ktdef = TypeVarType(
"KT",
"KT",
id=-1,
id=TypeVarId(-1, namespace="<dict>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
)
vtdef = TypeVarType(
"VT",
"VT",
id=-2,
id=TypeVarId(-2, namespace="<dict>"),
values=[],
upper_bound=self.object_type(),
default=AnyType(TypeOfAny.from_omitted_generics),
Expand Down
2 changes: 1 addition & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l
new_unpack: Type
if isinstance(var_arg_type, Instance):
# we have something like Unpack[Tuple[Any, ...]]
new_unpack = var_arg
new_unpack = UnpackType(var_arg.type.accept(self))
elif isinstance(var_arg_type, TupleType):
# We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
expanded_tuple = var_arg_type.accept(self)
Expand Down
31 changes: 31 additions & 0 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Sequence, overload

import mypy.typeops
from mypy.expandtype import expand_type
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY
from mypy.state import state
Expand Down Expand Up @@ -36,6 +37,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
Expand Down Expand Up @@ -718,7 +720,35 @@ def is_similar_callables(t: CallableType, s: CallableType) -> bool:
)


def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType:
tv_map = {}
tvs = []
for tv, new_id in zip(c.variables, ids):
new_tv = tv.copy_modified(id=new_id)
tvs.append(new_tv)
tv_map[tv.id] = new_tv
return expand_type(c, tv_map).copy_modified(variables=tvs)


def match_generic_callables(t: CallableType, s: CallableType) -> tuple[CallableType, CallableType]:
# The case where we combine/join/meet similar callables, situation where both are generic
# requires special care. A more principled solution may involve unify_generic_callable(),
# but it would have two problems:
# * This adds risk of infinite recursion: e.g. join -> unification -> solver -> join
# * Using unification is an incorrect thing for meets, as it "widens" the types
# Finally, this effectively falls back to an old behaviour before namespaces were added to
# type variables, and it worked relatively well.
max_len = max(len(t.variables), len(s.variables))
min_len = min(len(t.variables), len(s.variables))
if min_len == 0:
return t, s
new_ids = [TypeVarId.new(meta_level=0) for _ in range(max_len)]
# Note: this relies on variables being in order they appear in function definition.
return update_callable_ids(t, new_ids), update_callable_ids(s, new_ids)


def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_meet(t.arg_types[i], s.arg_types[i]))
Expand Down Expand Up @@ -771,6 +801,7 @@ def safe_meet(t: Type, s: Type) -> Type:


def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
Expand Down
3 changes: 2 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,9 @@ def default(self, typ: Type) -> ProperType:


def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType:
from mypy.join import safe_join
from mypy.join import match_generic_callables, safe_join

t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
Expand Down
76 changes: 57 additions & 19 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
TypeOfAny,
TypeStrVisitor,
TypeType,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
Expand Down Expand Up @@ -2502,14 +2503,16 @@ def format_literal_value(typ: LiteralType) -> str:
return typ.value_repr()

if isinstance(typ, TypeAliasType) and typ.is_recursive:
# TODO: find balance here, str(typ) doesn't support custom verbosity, and may be
# too verbose for user messages, OTOH it nicely shows structure of recursive types.
if verbosity < 2:
type_str = typ.alias.name if typ.alias else "<alias (unfixed)>"
if typ.alias is None:
type_str = "<alias (unfixed)>"
else:
if verbosity >= 2 or (fullnames and typ.alias.fullname in fullnames):
type_str = typ.alias.fullname
else:
type_str = typ.alias.name
if typ.args:
type_str += f"[{format_list(typ.args)}]"
return type_str
return str(typ)
return type_str

# TODO: always mention type alias names in errors.
typ = get_proper_type(typ)
Expand Down Expand Up @@ -2550,9 +2553,15 @@ def format_literal_value(typ: LiteralType) -> str:
return f"Unpack[{format(typ.type)}]"
elif isinstance(typ, TypeVarType):
# This is similar to non-generic instance types.
fullname = scoped_type_var_name(typ)
if verbosity >= 2 or (fullnames and fullname in fullnames):
return fullname
return typ.name
elif isinstance(typ, TypeVarTupleType):
# This is similar to non-generic instance types.
fullname = scoped_type_var_name(typ)
if verbosity >= 2 or (fullnames and fullname in fullnames):
return fullname
return typ.name
elif isinstance(typ, ParamSpecType):
# Concatenate[..., P]
Expand All @@ -2563,6 +2572,7 @@ def format_literal_value(typ: LiteralType) -> str:

return f"[{args}, **{typ.name_with_suffix()}]"
else:
# TODO: better disambiguate ParamSpec name clashes.
return typ.name_with_suffix()
elif isinstance(typ, TupleType):
# Prefer the name of the fallback class (if not tuple), as it's more informative.
Expand Down Expand Up @@ -2680,29 +2690,51 @@ def format_literal_value(typ: LiteralType) -> str:
return "object"


def collect_all_instances(t: Type) -> list[Instance]:
"""Return all instances that `t` contains (including `t`).
def collect_all_named_types(t: Type) -> list[Type]:
"""Return all instances/aliases/type variables that `t` contains (including `t`).

This is similar to collect_all_inner_types from typeanal but only
returns instances and will recurse into fallbacks.
"""
visitor = CollectAllInstancesQuery()
visitor = CollectAllNamedTypesQuery()
t.accept(visitor)
return visitor.instances
return visitor.types


class CollectAllInstancesQuery(TypeTraverserVisitor):
class CollectAllNamedTypesQuery(TypeTraverserVisitor):
def __init__(self) -> None:
self.instances: list[Instance] = []
self.types: list[Type] = []

def visit_instance(self, t: Instance) -> None:
self.instances.append(t)
self.types.append(t)
super().visit_instance(t)

def visit_type_alias_type(self, t: TypeAliasType) -> None:
if t.alias and not t.is_recursive:
t.alias.target.accept(self)
super().visit_type_alias_type(t)
get_proper_type(t).accept(self)
else:
self.types.append(t)
super().visit_type_alias_type(t)

def visit_type_var(self, t: TypeVarType) -> None:
self.types.append(t)
super().visit_type_var(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
self.types.append(t)
super().visit_type_var_tuple(t)

def visit_param_spec(self, t: ParamSpecType) -> None:
self.types.append(t)
super().visit_param_spec(t)


def scoped_type_var_name(t: TypeVarLikeType) -> str:
if not t.id.namespace:
return t.name
# TODO: support rare cases when both TypeVar name and namespace suffix coincide.
*_, suffix = t.id.namespace.split(".")
return f"{t.name}@{suffix}"


def find_type_overlaps(*types: Type) -> set[str]:
Expand All @@ -2713,8 +2745,14 @@ def find_type_overlaps(*types: Type) -> set[str]:
"""
d: dict[str, set[str]] = {}
for type in types:
for inst in collect_all_instances(type):
d.setdefault(inst.type.name, set()).add(inst.type.fullname)
for t in collect_all_named_types(type):
if isinstance(t, ProperType) and isinstance(t, Instance):
d.setdefault(t.type.name, set()).add(t.type.fullname)
elif isinstance(t, TypeAliasType) and t.alias:
d.setdefault(t.alias.name, set()).add(t.alias.fullname)
else:
assert isinstance(t, TypeVarLikeType)
d.setdefault(t.name, set()).add(scoped_type_var_name(t))
for shortname in d.keys():
if f"typing.{shortname}" in TYPES_FOR_UNIMPORTED_HINTS:
d[shortname].add(f"typing.{shortname}")
Expand All @@ -2732,7 +2770,7 @@ def format_type(
"""
Convert a type to a relatively short string suitable for error messages.

`verbosity` is a coarse grained control on the verbosity of the type
`verbosity` is a coarse-grained control on the verbosity of the type

This function returns a string appropriate for unmodified use in error
messages; this means that it will be quoted in most cases. If
Expand All @@ -2748,7 +2786,7 @@ def format_type_bare(
"""
Convert a type to a relatively short string suitable for error messages.

`verbosity` is a coarse grained control on the verbosity of the type
`verbosity` is a coarse-grained control on the verbosity of the type
`fullnames` specifies a set of names that should be printed in full

This function will return an unquoted string. If a caller doesn't need to
Expand Down
17 changes: 9 additions & 8 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
Type,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarType,
UninhabitedType,
UnionType,
Expand Down Expand Up @@ -807,25 +808,25 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
# AT = TypeVar('AT')
# def __lt__(self: AT, other: AT) -> bool
# This way comparisons with subclasses will work correctly.
fullname = f"{ctx.cls.info.fullname}.{SELF_TVAR_NAME}"
tvd = TypeVarType(
SELF_TVAR_NAME,
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
id=-1,
fullname,
# Namespace is patched per-method below.
id=TypeVarId(-1, namespace=""),
values=[],
upper_bound=object_type,
default=AnyType(TypeOfAny.from_omitted_generics),
)
self_tvar_expr = TypeVarExpr(
SELF_TVAR_NAME,
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
[],
object_type,
AnyType(TypeOfAny.from_omitted_generics),
SELF_TVAR_NAME, fullname, [], object_type, AnyType(TypeOfAny.from_omitted_generics)
)
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)

args = [Argument(Var("other", tvd), tvd, None, ARG_POS)]
for method in ["__lt__", "__le__", "__gt__", "__ge__"]:
namespace = f"{ctx.cls.info.fullname}.{method}"
tvd = tvd.copy_modified(id=TypeVarId(tvd.id.raw_id, namespace=namespace))
args = [Argument(Var("other", tvd), tvd, None, ARG_POS)]
adder.add_method(method, args, bool_type, self_type=tvd, tvd=tvd)


Expand Down
5 changes: 3 additions & 2 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
TupleType,
Type,
TypeOfAny,
TypeVarId,
TypeVarType,
UninhabitedType,
UnionType,
Expand Down Expand Up @@ -314,8 +315,8 @@ def transform(self) -> bool:
obj_type = self._api.named_type("builtins.object")
order_tvar_def = TypeVarType(
SELF_TVAR_NAME,
info.fullname + "." + SELF_TVAR_NAME,
id=-1,
f"{info.fullname}.{SELF_TVAR_NAME}",
id=TypeVarId(-1, namespace=f"{info.fullname}.{method_name}"),
values=[],
upper_bound=obj_type,
default=AnyType(TypeOfAny.from_omitted_generics),
Expand Down
14 changes: 12 additions & 2 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import mypy.checker
import mypy.plugin
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import (
AnyType,
Expand Down Expand Up @@ -151,12 +151,22 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
actual_types = [a for param in ctx.arg_types[1:] for a in param]

# Create a valid context for various ad-hoc inspections in check_call().
call_expr = CallExpr(
callee=ctx.args[0][0],
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
analyzed=ctx.context.analyzed if isinstance(ctx.context, CallExpr) else None,
)
call_expr.set_line(ctx.context)

_, bound = ctx.api.expr_checker.check_call(
callee=defaulted,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=defaulted,
context=call_expr,
)
bound = get_proper_type(bound)
if not isinstance(bound, CallableType):
Expand Down
Loading
Loading