Skip to content

Commit

Permalink
[suggest] Support refining existing type annotations (#7838)
Browse files Browse the repository at this point in the history
  • Loading branch information
msullivan authored Nov 2, 2019
1 parent 38e0f5d commit a4f4ffe
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 6 deletions.
121 changes: 115 additions & 6 deletions mypy/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType:
AnyType(TypeOfAny.special_form),
self.builtin_type('builtins.function'))

def get_starting_type(self, fdef: FuncDef) -> CallableType:
if isinstance(fdef.type, CallableType):
return fdef.type
else:
return self.get_trivial_type(fdef)

def get_args(self, is_method: bool,
base: CallableType, defaults: List[Optional[Type]],
callsites: List[Callsite],
Expand Down Expand Up @@ -356,11 +362,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option
"""
options = self.get_args(is_method, base, defaults, callsites, uses)
options = [self.add_adjustments(tps) for tps in options]
return [base.copy_modified(arg_types=list(x)) for x in itertools.product(*options)]
return [refine_callable(base, base.copy_modified(arg_types=list(x)))
for x in itertools.product(*options)]

def get_callsites(self, func: FuncDef) -> Tuple[List[Callsite], List[str]]:
"""Find all call sites of a function."""
new_type = self.get_trivial_type(func)
new_type = self.get_starting_type(func)

collector_plugin = SuggestionPlugin(func.fullname())

Expand Down Expand Up @@ -413,7 +420,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
with strict_optional_set(graph[mod].options.strict_optional):
guesses = self.get_guesses(
is_method,
self.get_trivial_type(node),
self.get_starting_type(node),
self.get_default_arg_types(graph[mod], node),
callsites,
uses,
Expand All @@ -432,7 +439,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
else:
ret_types = [NoneType()]

guesses = [best.copy_modified(ret_type=t) for t in ret_types]
guesses = [best.copy_modified(ret_type=refine_type(best.ret_type, t)) for t in ret_types]
guesses = self.filter_options(guesses, is_method)
best, errors = self.find_best(node, guesses)

Expand Down Expand Up @@ -593,8 +600,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]:
"""
old = func.unanalyzed_type
# During reprocessing, unanalyzed_type gets copied to type (by aststrip).
# We don't modify type because it isn't necessary and it
# would mess up the snapshotting.
# We set type to None to ensure that the type always changes during
# reprocessing.
func.type = None
func.unanalyzed_type = typ
try:
res = self.fgmanager.trigger(func.fullname())
Expand Down Expand Up @@ -682,6 +690,8 @@ def score_type(self, t: Type, arg_pos: bool) -> int:
if isinstance(t, UnionType):
if any(isinstance(x, AnyType) for x in t.items):
return 20
if any(has_any_type(x) for x in t.items):
return 15
if not is_optional(t):
return 10
if isinstance(t, CallableType) and (has_any_type(t) or is_tricky_callable(t)):
Expand Down Expand Up @@ -840,6 +850,105 @@ def count_errors(msgs: List[str]) -> int:
return len([x for x in msgs if ' error: ' in x])


def refine_type(ti: Type, si: Type) -> Type:
"""Refine `ti` by replacing Anys in it with information taken from `si`
This basically works by, when the types have the same structure,
traversing both of them in parallel and replacing Any on the left
with whatever the type on the right is. If the types don't have the
same structure (or aren't supported), the left type is chosen.
For example:
refine(Any, T) = T, for all T
refine(float, int) = float
refine(List[Any], List[int]) = List[int]
refine(Dict[int, Any], Dict[Any, int]) = Dict[int, int]
refine(Tuple[int, Any], Tuple[Any, int]) = Tuple[int, int]
refine(Callable[[Any], Any], Callable[[int], int]) = Callable[[int], int]
refine(Callable[..., int], Callable[[int, float], Any]) = Callable[[int, float], int]
refine(Optional[Any], int) = Optional[int]
refine(Optional[Any], Optional[int]) = Optional[int]
refine(Optional[Any], Union[int, str]) = Optional[Union[int, str]]
refine(Optional[List[Any]], List[int]) = List[int]
"""
t = get_proper_type(ti)
s = get_proper_type(si)

if isinstance(t, AnyType):
return s

if isinstance(t, Instance) and isinstance(s, Instance) and t.type == s.type:
return t.copy_modified(args=[refine_type(ta, sa) for ta, sa in zip(t.args, s.args)])

if (
isinstance(t, TupleType)
and isinstance(s, TupleType)
and t.partial_fallback == s.partial_fallback
and len(t.items) == len(s.items)
):
return t.copy_modified(items=[refine_type(ta, sa) for ta, sa in zip(t.items, s.items)])

if isinstance(t, CallableType) and isinstance(s, CallableType):
return refine_callable(t, s)

if isinstance(t, UnionType):
return refine_union(t, s)

# TODO: Refining of builtins.tuple, Type?

return t


def refine_union(t: UnionType, s: ProperType) -> Type:
"""Refine a union type based on another type.
This is done by refining every component of the union against the
right hand side type (or every component of its union if it is
one). If an element of the union is succesfully refined, we drop it
from the union in favor of the refined versions.
"""
rhs_items = s.items if isinstance(s, UnionType) else [s]

new_items = []
for lhs in t.items:
refined = False
for rhs in rhs_items:
new = refine_type(lhs, rhs)
if new != lhs:
new_items.append(new)
refined = True
if not refined:
new_items.append(lhs)

# Turn strict optional on when simplifying the union since we
# don't want to drop Nones.
with strict_optional_set(True):
return make_simplified_union(new_items)


def refine_callable(t: CallableType, s: CallableType) -> CallableType:
"""Refine a callable based on another.
See comments for refine_type.
"""
if t.fallback != s.fallback:
return t

if t.is_ellipsis_args and not is_tricky_callable(s):
return s.copy_modified(ret_type=refine_type(t.ret_type, s.ret_type))

if is_tricky_callable(t) or t.arg_kinds != s.arg_kinds:
return t

return t.copy_modified(
arg_types=[refine_type(ta, sa) for ta, sa in zip(t.arg_types, s.arg_types)],
ret_type=refine_type(t.ret_type, s.ret_type),
)


T = TypeVar('T')


Expand Down
141 changes: 141 additions & 0 deletions test-data/unit/fine-grained-suggest.test
Original file line number Diff line number Diff line change
Expand Up @@ -906,3 +906,144 @@ Command 'suggest' is only valid after a 'check' command (that produces no parse
==
foo.py:4: error: unexpected EOF while parsing
-- )

[case testSuggestRefine]
# suggest: foo.foo
# suggest: foo.spam
# suggest: foo.eggs
# suggest: foo.take_l
# suggest: foo.union
# suggest: foo.callable1
# suggest: foo.callable2
# suggest: foo.optional1
# suggest: foo.optional2
# suggest: foo.optional3
# suggest: foo.optional4
# suggest: foo.optional5
# suggest: foo.dict1
# suggest: foo.tuple1
[file foo.py]
from typing import Any, List, Union, Callable, Optional, Set, Dict, Tuple

def bar():
return 10

def foo(x: int, y):
return x + y

foo(bar(), 10)

def spam(x: int, y: Any) -> Any:
return x + y

spam(bar(), 20)

def eggs(x: int) -> List[Any]:
a = [x]
return a

def take_l(x: List[Any]) -> Any:
return x[0]

test = [10, 20]
take_l(test)

def union(x: Union[int, str]):
pass

union(10)

def add1(x: float) -> int:
pass

def callable1(f: Callable[[int], Any]):
return f(10)

callable1(add1)

def callable2(f: Callable[..., Any]):
return f(10)

callable2(add1)

def optional1(x: Optional[Any]):
pass

optional1(10)

def optional2(x: Union[None, int, Any]):
if x is None:
pass
elif isinstance(x, str):
pass
else:
add1(x)

optional2(10)
optional2('test')

def optional3(x: Optional[List[Any]]):
assert not x
return x[0]

optional3(test)

set_test = {1, 2}

def optional4(x: Union[Set[Any], List[Any]]):
pass

optional4(test)
optional4(set_test)

def optional5(x: Optional[Any]):
pass

optional5(10)
optional5(None)

def dict1(d: Dict[int, Any]):
pass

d: Dict[Any, int]
dict1(d)

def tuple1(d: Tuple[int, Any]):
pass

t: Tuple[Any, int]
tuple1(t)

[builtins fixtures/isinstancelist.pyi]
[out]
(int, int) -> int
(int, int) -> int
(int) -> foo.List[int]
(foo.List[int]) -> int
(Union[int, str]) -> None
(Callable[[int], int]) -> int
(Callable[[float], int]) -> int
(Optional[int]) -> None
(Union[None, int, str]) -> None
(Optional[foo.List[int]]) -> int
(Union[foo.Set[int], foo.List[int]]) -> None
(Optional[int]) -> None
(foo.Dict[int, int]) -> None
(Tuple[int, int]) -> None
==

[case testSuggestRefine2]
# suggest: foo.optional5
[file foo.py]
from typing import Optional, Any

def optional5(x: Optional[Any]):
pass

optional5(10)
optional5(None)

[builtins fixtures/isinstancelist.pyi]
[out]
(Optional[int]) -> None
==

0 comments on commit a4f4ffe

Please sign in to comment.