From d9f19dda6c2b1a65e73108dca9b8e1d6650b3078 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Thu, 31 Oct 2019 13:23:57 -0700 Subject: [PATCH] [suggest] Support refining existing type annotations --- mypy/suggestions.py | 40 ++++++++++++++++++++---- test-data/unit/fine-grained-suggest.test | 37 ++++++++++++++++++++++ 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/mypy/suggestions.py b/mypy/suggestions.py index c1e5750665e3..3c7e45a2cb39 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -51,6 +51,7 @@ from mypy.checkexpr import has_any_type from mypy.join import join_type_list +from mypy.meet import meet_types from mypy.sametypes import is_same_type from mypy.typeops import make_simplified_union @@ -240,6 +241,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType: AnyType(TypeOfAny.special_form), self.builtin_type('builtins.function')) + def get_starting_type(self, fdef: FuncDef) -> CallableType: + if isinstance(fdef.type, CallableType): + return fdef.type + else: + return self.get_trivial_type(fdef) + def get_args(self, is_method: bool, base: CallableType, defaults: List[Optional[Type]], callsites: List[Callsite]) -> List[List[Type]]: @@ -294,11 +301,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option """ options = self.get_args(is_method, base, defaults, callsites) options = [self.add_adjustments(tps) for tps in options] - return [base.copy_modified(arg_types=list(x)) for x in itertools.product(*options)] + return [merge_callables(base, base.copy_modified(arg_types=list(x))) + for x in itertools.product(*options)] def get_callsites(self, func: FuncDef) -> Tuple[List[Callsite], List[str]]: """Find all call sites of a function.""" - new_type = self.get_trivial_type(func) + new_type = self.get_starting_type(func) collector_plugin = SuggestionPlugin(func.fullname()) @@ -350,7 +358,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature: with strict_optional_set(graph[mod].options.strict_optional): guesses = self.get_guesses( is_method, - self.get_trivial_type(node), + self.get_starting_type(node), self.get_default_arg_types(graph[mod], node), callsites) guesses = self.filter_options(guesses, is_method) @@ -367,7 +375,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature: else: ret_types = [NoneType()] - guesses = [best.copy_modified(ret_type=t) for t in ret_types] + guesses = [merge_callables(best, best.copy_modified(ret_type=t)) for t in ret_types] guesses = self.filter_options(guesses, is_method) best, errors = self.find_best(node, guesses) @@ -528,8 +536,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]: """ old = func.unanalyzed_type # During reprocessing, unanalyzed_type gets copied to type (by aststrip). - # We don't modify type because it isn't necessary and it - # would mess up the snapshotting. + # We set type to None to ensure that the type always changes during + # reprocessing. + func.type = None func.unanalyzed_type = typ try: res = self.fgmanager.trigger(func.fullname()) @@ -778,6 +787,25 @@ def count_errors(msgs: List[str]) -> int: T = TypeVar('T') +def merge_callables(t: CallableType, s: CallableType) -> CallableType: + """Merge two callable types in a way that prefers dropping Anys. + + This is implemented by doing a meet on both the arguments and the return type, + since meet(t, Any) == t. + + This won't do perfectly with complex compound types (like + callables nested inside), but it does pretty well. + """ + + # We don't want to ever squash away optionals while doing this, so set + # strict optional to be true always + with strict_optional_set(True): + arg_types = [] # type: List[Type] + for i in range(len(t.arg_types)): + arg_types.append(meet_types(t.arg_types[i], s.arg_types[i])) + return t.copy_modified(arg_types=arg_types, ret_type=meet_types(t.ret_type, s.ret_type)) + + def dedup(old: List[T]) -> List[T]: new = [] # type: List[T] for x in old: diff --git a/test-data/unit/fine-grained-suggest.test b/test-data/unit/fine-grained-suggest.test index 701974161a89..70680b66bbec 100644 --- a/test-data/unit/fine-grained-suggest.test +++ b/test-data/unit/fine-grained-suggest.test @@ -829,3 +829,40 @@ Command 'suggest' is only valid after a 'check' command (that produces no parse == foo.py:4: error: unexpected EOF while parsing -- ) + +[case testSuggestRefine] +# suggest: foo.foo +# suggest: foo.spam +# suggest: foo.eggs +# suggest: foo.take_l +[file foo.py] +from typing import Any, List + +def bar(): + return 10 + +def foo(x: int, y): + return x + y + +def spam(x: int, y: Any) -> Any: + return x + y + +def eggs(x: int) -> List[Any]: + a = [x] + return a + +def take_l(x: List[Any]) -> Any: + return x[0] + + +foo(bar(), 10) +spam(bar(), 20) +test = [10, 20] +take_l(test) +[builtins fixtures/isinstancelist.pyi] +[out] +(int, int) -> int +(int, int) -> int +(int) -> foo.List[int] +(foo.List[int]) -> int +==