Skip to content

Commit

Permalink
[suggest] Support refining existing type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
msullivan committed Oct 31, 2019
1 parent 22a5a4f commit d9f19dd
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
40 changes: 34 additions & 6 deletions mypy/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from mypy.checkexpr import has_any_type

from mypy.join import join_type_list
from mypy.meet import meet_types
from mypy.sametypes import is_same_type
from mypy.typeops import make_simplified_union

Expand Down Expand Up @@ -240,6 +241,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType:
AnyType(TypeOfAny.special_form),
self.builtin_type('builtins.function'))

def get_starting_type(self, fdef: FuncDef) -> CallableType:
if isinstance(fdef.type, CallableType):
return fdef.type
else:
return self.get_trivial_type(fdef)

def get_args(self, is_method: bool,
base: CallableType, defaults: List[Optional[Type]],
callsites: List[Callsite]) -> List[List[Type]]:
Expand Down Expand Up @@ -294,11 +301,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option
"""
options = self.get_args(is_method, base, defaults, callsites)
options = [self.add_adjustments(tps) for tps in options]
return [base.copy_modified(arg_types=list(x)) for x in itertools.product(*options)]
return [merge_callables(base, base.copy_modified(arg_types=list(x)))
for x in itertools.product(*options)]

def get_callsites(self, func: FuncDef) -> Tuple[List[Callsite], List[str]]:
"""Find all call sites of a function."""
new_type = self.get_trivial_type(func)
new_type = self.get_starting_type(func)

collector_plugin = SuggestionPlugin(func.fullname())

Expand Down Expand Up @@ -350,7 +358,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
with strict_optional_set(graph[mod].options.strict_optional):
guesses = self.get_guesses(
is_method,
self.get_trivial_type(node),
self.get_starting_type(node),
self.get_default_arg_types(graph[mod], node),
callsites)
guesses = self.filter_options(guesses, is_method)
Expand All @@ -367,7 +375,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
else:
ret_types = [NoneType()]

guesses = [best.copy_modified(ret_type=t) for t in ret_types]
guesses = [merge_callables(best, best.copy_modified(ret_type=t)) for t in ret_types]
guesses = self.filter_options(guesses, is_method)
best, errors = self.find_best(node, guesses)

Expand Down Expand Up @@ -528,8 +536,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]:
"""
old = func.unanalyzed_type
# During reprocessing, unanalyzed_type gets copied to type (by aststrip).
# We don't modify type because it isn't necessary and it
# would mess up the snapshotting.
# We set type to None to ensure that the type always changes during
# reprocessing.
func.type = None
func.unanalyzed_type = typ
try:
res = self.fgmanager.trigger(func.fullname())
Expand Down Expand Up @@ -778,6 +787,25 @@ def count_errors(msgs: List[str]) -> int:
T = TypeVar('T')


def merge_callables(t: CallableType, s: CallableType) -> CallableType:
"""Merge two callable types in a way that prefers dropping Anys.
This is implemented by doing a meet on both the arguments and the return type,
since meet(t, Any) == t.
This won't do perfectly with complex compound types (like
callables nested inside), but it does pretty well.
"""

# We don't want to ever squash away optionals while doing this, so set
# strict optional to be true always
with strict_optional_set(True):
arg_types = [] # type: List[Type]
for i in range(len(t.arg_types)):
arg_types.append(meet_types(t.arg_types[i], s.arg_types[i]))
return t.copy_modified(arg_types=arg_types, ret_type=meet_types(t.ret_type, s.ret_type))


def dedup(old: List[T]) -> List[T]:
new = [] # type: List[T]
for x in old:
Expand Down
37 changes: 37 additions & 0 deletions test-data/unit/fine-grained-suggest.test
Original file line number Diff line number Diff line change
Expand Up @@ -829,3 +829,40 @@ Command 'suggest' is only valid after a 'check' command (that produces no parse
==
foo.py:4: error: unexpected EOF while parsing
-- )

[case testSuggestRefine]
# suggest: foo.foo
# suggest: foo.spam
# suggest: foo.eggs
# suggest: foo.take_l
[file foo.py]
from typing import Any, List

def bar():
return 10

def foo(x: int, y):
return x + y

def spam(x: int, y: Any) -> Any:
return x + y

def eggs(x: int) -> List[Any]:
a = [x]
return a

def take_l(x: List[Any]) -> Any:
return x[0]


foo(bar(), 10)
spam(bar(), 20)
test = [10, 20]
take_l(test)
[builtins fixtures/isinstancelist.pyi]
[out]
(int, int) -> int
(int, int) -> int
(int) -> foo.List[int]
(foo.List[int]) -> int
==

0 comments on commit d9f19dd

Please sign in to comment.