From a4f4ffed8b6dc20366a5069ee168be1565486f22 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 1 Nov 2019 21:41:40 -0700 Subject: [PATCH] [suggest] Support refining existing type annotations (#7838) --- mypy/suggestions.py | 121 ++++++++++++++++++- test-data/unit/fine-grained-suggest.test | 141 +++++++++++++++++++++++ 2 files changed, 256 insertions(+), 6 deletions(-) diff --git a/mypy/suggestions.py b/mypy/suggestions.py index f01dc9bdf7c6..7bb4a583d0cf 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -288,6 +288,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType: AnyType(TypeOfAny.special_form), self.builtin_type('builtins.function')) + def get_starting_type(self, fdef: FuncDef) -> CallableType: + if isinstance(fdef.type, CallableType): + return fdef.type + else: + return self.get_trivial_type(fdef) + def get_args(self, is_method: bool, base: CallableType, defaults: List[Optional[Type]], callsites: List[Callsite], @@ -356,11 +362,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option """ options = self.get_args(is_method, base, defaults, callsites, uses) options = [self.add_adjustments(tps) for tps in options] - return [base.copy_modified(arg_types=list(x)) for x in itertools.product(*options)] + return [refine_callable(base, base.copy_modified(arg_types=list(x))) + for x in itertools.product(*options)] def get_callsites(self, func: FuncDef) -> Tuple[List[Callsite], List[str]]: """Find all call sites of a function.""" - new_type = self.get_trivial_type(func) + new_type = self.get_starting_type(func) collector_plugin = SuggestionPlugin(func.fullname()) @@ -413,7 +420,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature: with strict_optional_set(graph[mod].options.strict_optional): guesses = self.get_guesses( is_method, - self.get_trivial_type(node), + self.get_starting_type(node), self.get_default_arg_types(graph[mod], node), callsites, uses, @@ -432,7 +439,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature: else: ret_types = [NoneType()] - guesses = [best.copy_modified(ret_type=t) for t in ret_types] + guesses = [best.copy_modified(ret_type=refine_type(best.ret_type, t)) for t in ret_types] guesses = self.filter_options(guesses, is_method) best, errors = self.find_best(node, guesses) @@ -593,8 +600,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]: """ old = func.unanalyzed_type # During reprocessing, unanalyzed_type gets copied to type (by aststrip). - # We don't modify type because it isn't necessary and it - # would mess up the snapshotting. + # We set type to None to ensure that the type always changes during + # reprocessing. + func.type = None func.unanalyzed_type = typ try: res = self.fgmanager.trigger(func.fullname()) @@ -682,6 +690,8 @@ def score_type(self, t: Type, arg_pos: bool) -> int: if isinstance(t, UnionType): if any(isinstance(x, AnyType) for x in t.items): return 20 + if any(has_any_type(x) for x in t.items): + return 15 if not is_optional(t): return 10 if isinstance(t, CallableType) and (has_any_type(t) or is_tricky_callable(t)): @@ -840,6 +850,105 @@ def count_errors(msgs: List[str]) -> int: return len([x for x in msgs if ' error: ' in x]) +def refine_type(ti: Type, si: Type) -> Type: + """Refine `ti` by replacing Anys in it with information taken from `si` + + This basically works by, when the types have the same structure, + traversing both of them in parallel and replacing Any on the left + with whatever the type on the right is. If the types don't have the + same structure (or aren't supported), the left type is chosen. + + For example: + refine(Any, T) = T, for all T + refine(float, int) = float + refine(List[Any], List[int]) = List[int] + refine(Dict[int, Any], Dict[Any, int]) = Dict[int, int] + refine(Tuple[int, Any], Tuple[Any, int]) = Tuple[int, int] + + refine(Callable[[Any], Any], Callable[[int], int]) = Callable[[int], int] + refine(Callable[..., int], Callable[[int, float], Any]) = Callable[[int, float], int] + + refine(Optional[Any], int) = Optional[int] + refine(Optional[Any], Optional[int]) = Optional[int] + refine(Optional[Any], Union[int, str]) = Optional[Union[int, str]] + refine(Optional[List[Any]], List[int]) = List[int] + + """ + t = get_proper_type(ti) + s = get_proper_type(si) + + if isinstance(t, AnyType): + return s + + if isinstance(t, Instance) and isinstance(s, Instance) and t.type == s.type: + return t.copy_modified(args=[refine_type(ta, sa) for ta, sa in zip(t.args, s.args)]) + + if ( + isinstance(t, TupleType) + and isinstance(s, TupleType) + and t.partial_fallback == s.partial_fallback + and len(t.items) == len(s.items) + ): + return t.copy_modified(items=[refine_type(ta, sa) for ta, sa in zip(t.items, s.items)]) + + if isinstance(t, CallableType) and isinstance(s, CallableType): + return refine_callable(t, s) + + if isinstance(t, UnionType): + return refine_union(t, s) + + # TODO: Refining of builtins.tuple, Type? + + return t + + +def refine_union(t: UnionType, s: ProperType) -> Type: + """Refine a union type based on another type. + + This is done by refining every component of the union against the + right hand side type (or every component of its union if it is + one). If an element of the union is succesfully refined, we drop it + from the union in favor of the refined versions. + """ + rhs_items = s.items if isinstance(s, UnionType) else [s] + + new_items = [] + for lhs in t.items: + refined = False + for rhs in rhs_items: + new = refine_type(lhs, rhs) + if new != lhs: + new_items.append(new) + refined = True + if not refined: + new_items.append(lhs) + + # Turn strict optional on when simplifying the union since we + # don't want to drop Nones. + with strict_optional_set(True): + return make_simplified_union(new_items) + + +def refine_callable(t: CallableType, s: CallableType) -> CallableType: + """Refine a callable based on another. + + See comments for refine_type. + """ + if t.fallback != s.fallback: + return t + + if t.is_ellipsis_args and not is_tricky_callable(s): + return s.copy_modified(ret_type=refine_type(t.ret_type, s.ret_type)) + + if is_tricky_callable(t) or t.arg_kinds != s.arg_kinds: + return t + + return t.copy_modified( + arg_types=[refine_type(ta, sa) for ta, sa in zip(t.arg_types, s.arg_types)], + ret_type=refine_type(t.ret_type, s.ret_type), + ) + + T = TypeVar('T') diff --git a/test-data/unit/fine-grained-suggest.test b/test-data/unit/fine-grained-suggest.test index 8968641eaaaf..c2b5a4cba562 100644 --- a/test-data/unit/fine-grained-suggest.test +++ b/test-data/unit/fine-grained-suggest.test @@ -906,3 +906,144 @@ Command 'suggest' is only valid after a 'check' command (that produces no parse == foo.py:4: error: unexpected EOF while parsing -- ) + +[case testSuggestRefine] +# suggest: foo.foo +# suggest: foo.spam +# suggest: foo.eggs +# suggest: foo.take_l +# suggest: foo.union +# suggest: foo.callable1 +# suggest: foo.callable2 +# suggest: foo.optional1 +# suggest: foo.optional2 +# suggest: foo.optional3 +# suggest: foo.optional4 +# suggest: foo.optional5 +# suggest: foo.dict1 +# suggest: foo.tuple1 +[file foo.py] +from typing import Any, List, Union, Callable, Optional, Set, Dict, Tuple + +def bar(): + return 10 + +def foo(x: int, y): + return x + y + +foo(bar(), 10) + +def spam(x: int, y: Any) -> Any: + return x + y + +spam(bar(), 20) + +def eggs(x: int) -> List[Any]: + a = [x] + return a + +def take_l(x: List[Any]) -> Any: + return x[0] + +test = [10, 20] +take_l(test) + +def union(x: Union[int, str]): + pass + +union(10) + +def add1(x: float) -> int: + pass + +def callable1(f: Callable[[int], Any]): + return f(10) + +callable1(add1) + +def callable2(f: Callable[..., Any]): + return f(10) + +callable2(add1) + +def optional1(x: Optional[Any]): + pass + +optional1(10) + +def optional2(x: Union[None, int, Any]): + if x is None: + pass + elif isinstance(x, str): + pass + else: + add1(x) + +optional2(10) +optional2('test') + +def optional3(x: Optional[List[Any]]): + assert not x + return x[0] + +optional3(test) + +set_test = {1, 2} + +def optional4(x: Union[Set[Any], List[Any]]): + pass + +optional4(test) +optional4(set_test) + +def optional5(x: Optional[Any]): + pass + +optional5(10) +optional5(None) + +def dict1(d: Dict[int, Any]): + pass + +d: Dict[Any, int] +dict1(d) + +def tuple1(d: Tuple[int, Any]): + pass + +t: Tuple[Any, int] +tuple1(t) + +[builtins fixtures/isinstancelist.pyi] +[out] +(int, int) -> int +(int, int) -> int +(int) -> foo.List[int] +(foo.List[int]) -> int +(Union[int, str]) -> None +(Callable[[int], int]) -> int +(Callable[[float], int]) -> int +(Optional[int]) -> None +(Union[None, int, str]) -> None +(Optional[foo.List[int]]) -> int +(Union[foo.Set[int], foo.List[int]]) -> None +(Optional[int]) -> None +(foo.Dict[int, int]) -> None +(Tuple[int, int]) -> None +== + +[case testSuggestRefine2] +# suggest: foo.optional5 +[file foo.py] +from typing import Optional, Any + +def optional5(x: Optional[Any]): + pass + +optional5(10) +optional5(None) + +[builtins fixtures/isinstancelist.pyi] +[out] +(Optional[int]) -> None +==