diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index 43f5ed03f..24cd50a7e 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -8,6 +8,8 @@ from typing import Dict, List, Optional, Sequence, Set, Tuple, Union import libcst as cst +import libcst.matchers as m + from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareTransformer from libcst.codemod.visitors._add_imports import AddImportsVisitor @@ -55,6 +57,18 @@ def _is_set(x: Union[None, cst.CSTNode, cst.MaybeSentinel]) -> bool: return x is not None and x != cst.MaybeSentinel.DEFAULT +def _get_string_value(node: cst.SimpleString) -> str: + s = node.value + c = s[-1] + return s[s.index(c) : -1] + + +def _find_generic_base(node: cst.ClassDef) -> Optional[cst.Arg]: + for b in node.bases: + if m.matches(b.value, m.Subscript(value=m.Name("Generic"))): + return b + + @dataclass(frozen=True) class FunctionKey: name: str @@ -80,7 +94,7 @@ class FunctionAnnotation: returns: Optional[cst.Annotation] -class TypeCollector(cst.CSTVisitor): +class TypeCollector(m.MatcherDecoratableVisitor): """ Collect type annotations from a stub module. """ @@ -91,6 +105,7 @@ class TypeCollector(cst.CSTVisitor): ) def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None: + super().__init__() # Qualifier for storing the canonical name of the current function. self.qualifier: List[str] = [] # Store the annotations. @@ -99,6 +114,9 @@ def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None: self.existing_imports: Set[str] = existing_imports self.class_definitions: Dict[str, cst.ClassDef] = {} self.context = context + self.current_assign: Optional[cst.Assign] = None # used to collect typevars + self.typevars: Dict[str, cst.Assign] = {} + self.annotation_names: Set[str] = set() def visit_ClassDef(self, node: cst.ClassDef) -> None: self.qualifier.append(node.name.value) @@ -153,6 +171,29 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: def leave_AnnAssign(self, original_node: cst.AnnAssign) -> None: self.qualifier.pop() + def visit_Assign(self, node: cst.Assign) -> None: + self.current_assign = node + + def leave_Assign(self, original_node: cst.Assign) -> None: + self.current_assign = None + + @m.call_if_inside(m.Assign()) + @m.visit(m.Call(func=m.Name("TypeVar"))) + def record_typevar(self, node: cst.Call) -> None: + # pyre-ignore current_assign is never None here + name = get_full_name_for_node(self.current_assign.targets[0].target) + if name: + # pyre-ignore current_assign is never None here + self.typevars[name] = self.current_assign + self._handle_qualification_and_should_qualify("typing.TypeVar") + self.current_assign = None + + def leave_Module(self, original_node: cst.Module) -> None: + # Filter out unused typevars + self.typevars = { + k: v for k, v in self.typevars.items() if k in self.annotation_names + } + def _get_unique_qualified_name(self, node: cst.CSTNode) -> str: name = None names = [q.name for q in self.get_metadata(QualifiedNameProvider, node)] @@ -194,7 +235,7 @@ def _module_and_target(self, qualified_name: str) -> Tuple[str, str]: def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool: """ - Basd on a qualified name and the existing module imports, record that + Based on a qualified name and the existing module imports, record that we need to add an import if necessary and return whether or not we should use the qualified name due to a preexisting import. """ @@ -227,6 +268,7 @@ def _handle_NameOrAttribute( dequalified_node, ) = self._get_qualified_name_and_dequalified_node(node) should_qualify = self._handle_qualification_and_should_qualify(qualified_name) + self.annotation_names.add(qualified_name) if should_qualify: return node else: @@ -239,6 +281,8 @@ def _handle_Index(self, slice: cst.Index) -> cst.Index: elif isinstance(value, cst.Attribute): return slice.with_changes(value=self._handle_NameOrAttribute(value)) else: + if isinstance(value, cst.SimpleString): + self.annotation_names.add(_get_string_value(value)) return slice def _handle_Subscript(self, node: cst.Subscript) -> cst.Subscript: @@ -279,6 +323,7 @@ def _handle_Subscript(self, node: cst.Subscript) -> cst.Subscript: def _handle_Annotation(self, annotation: cst.Annotation) -> cst.Annotation: node = annotation.annotation if isinstance(node, cst.SimpleString): + self.annotation_names.add(_get_string_value(node)) return annotation elif isinstance(node, cst.Subscript): return cst.Annotation(annotation=self._handle_Subscript(node)) @@ -309,6 +354,7 @@ class Annotations: ) attribute_annotations: Dict[str, cst.Annotation] = field(default_factory=dict) class_definitions: Dict[str, cst.ClassDef] = field(default_factory=dict) + typevars: Dict[str, cst.Assign] = field(default_factory=dict) @dataclass @@ -318,6 +364,7 @@ class AnnotationCounts: parameter_annotations: int = 0 return_annotations: int = 0 classes_added: int = 0 + typevars_and_generics_added: int = 0 def any_changes_applied(self) -> bool: return ( @@ -326,6 +373,7 @@ def any_changes_applied(self) -> bool: + self.parameter_annotations + self.return_annotations + self.classes_added + + self.typevars_and_generics_added ) > 0 @@ -397,6 +445,10 @@ def __init__( # only made changes to the imports. self.annotation_counts: AnnotationCounts = AnnotationCounts() + # We use this to collect typevars, to avoid importing existing ones from the pyi file + self.current_assign: Optional[cst.Assign] = None + self.typevars: Dict[str, cst.Assign] = {} + @staticmethod def store_stub_in_context( context: CodemodContext, @@ -463,6 +515,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.annotations.function_annotations.update(visitor.function_annotations) self.annotations.attribute_annotations.update(visitor.attribute_annotations) self.annotations.class_definitions.update(visitor.class_definitions) + self.annotations.typevars.update(visitor.typevars) tree_with_imports = AddImportsVisitor( context=self.context, @@ -722,7 +775,16 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: def leave_ClassDef( self, original_node: cst.ClassDef, updated_node: cst.ClassDef ) -> cst.ClassDef: + cls_name = ".".join(self.qualifier) self.qualifier.pop() + definition = self.annotations.class_definitions.get(cls_name) + if definition: + b1 = _find_generic_base(definition) + b2 = _find_generic_base(updated_node) + if b1 and not b2: + new_bases = list(updated_node.bases) + [b1] + self.annotation_counts.typevars_and_generics_added += 1 + return updated_node.with_changes(bases=new_bases) return updated_node def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: @@ -756,10 +818,29 @@ def leave_FunctionDef( return updated_node.with_changes(params=new_parameters) return updated_node + def visit_Assign(self, node: cst.Assign) -> None: + self.current_assign = node + + @m.call_if_inside(m.Assign()) + @m.visit(m.Call(func=m.Name("TypeVar"))) + def record_typevar(self, node: cst.Call) -> None: + # pyre-ignore current_assign is never None here + name = get_full_name_for_node(self.current_assign.targets[0].target) + if name: + # Preserve the whole node, even though we currently just use the + # name, so that we can match bounds and variance at some point and + # determine if two typevars with the same name are indeed the same. + + # pyre-ignore current_assign is never None here + self.typevars[name] = self.current_assign + self.current_assign = None + def leave_Assign( self, original_node: cst.Assign, updated_node: cst.Assign ) -> Union[cst.Assign, cst.AnnAssign]: + self.current_assign = None + if len(original_node.targets) > 1: for assign in original_node.targets: target = assign.target @@ -787,8 +868,17 @@ def leave_Module( for name, definition in self.annotations.class_definitions.items() if name not in self.visited_classes ] - if not self.toplevel_annotations and not fresh_class_definitions: + + # NOTE: The entire change will also be abandoned if + # self.annotation_counts is all 0s, so if adding any new category make + # sure to record it there. + if not ( + self.toplevel_annotations + or fresh_class_definitions + or self.annotations.typevars + ): return updated_node + toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( @@ -806,6 +896,18 @@ def leave_Module( ) toplevel_statements.append(cst.SimpleStatementLine([annotated_assign])) + # TypeVar definitions could be scattered through the file, so do not + # attempt to put new ones with existing ones, just add them at the top. + typevars = { + k: v for k, v in self.annotations.typevars.items() if k not in self.typevars + } + if typevars: + for var, stmt in typevars.items(): + toplevel_statements.append(cst.Newline()) + toplevel_statements.append(stmt) + self.annotation_counts.typevars_and_generics_added += 1 + toplevel_statements.append(cst.Newline()) + self.annotation_counts.classes_added = len(fresh_class_definitions) toplevel_statements.extend(fresh_class_definitions) diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py index 150e996aa..8689afbd2 100644 --- a/libcst/codemod/visitors/tests/test_apply_type_annotations.py +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -821,6 +821,120 @@ def foo() -> None: def test_adding_typed_dicts(self, stub: str, before: str, after: str) -> None: self.run_simple_test_case(stub=stub, before=before, after=after) + @data_provider( + { + "insert_new_TypeVar_not_in_source_file": ( + """ + from typing import Dict, TypeVar + + _KT = TypeVar('_KT') + _VT = TypeVar('_VT') + + class UserDict(Dict[_KT, _VT]): + def __init__(self, initialdata: Dict[_KT, _VT] = ...): ... + """, + """ + class UserDict: + def __init__(self, initialdata = None): + pass + """, + """ + from typing import Dict, TypeVar + + _KT = TypeVar('_KT') + _VT = TypeVar('_VT') + + class UserDict: + def __init__(self, initialdata: Dict[_KT, _VT] = None): + pass + """, + ), + "insert_only_used_TypeVar_not_already_in_source": ( + """ + from typing import Dict, TypeVar + + K = TypeVar('K') + V = TypeVar('V') + X = TypeVar('X') + + class UserDict(Dict[K, V]): + def __init__(self, initialdata: Dict[K, V] = ...): ... + """, + """ + from typing import TypeVar + + V = TypeVar('V') + + class UserDict: + def __init__(self, initialdata = None): + pass + + def f(x: V) -> V: + pass + """, + """ + from typing import Dict, TypeVar + + K = TypeVar('K') + + V = TypeVar('V') + + class UserDict: + def __init__(self, initialdata: Dict[K, V] = None): + pass + + def f(x: V) -> V: + pass + """, + ), + "insert_Generic_base_class": ( + """ + from typing import TypeVar + + T = TypeVar('T') + X = TypeVar('X') + + class B(A, Generic[T]): + def f(self, x: T) -> T: ... + """, + """ + from typing import TypeVar + + V = TypeVar('V') + + def f(x: V) -> V: + pass + + class A: + pass + + class B(A): + def f(self, x): + pass + """, + """ + from typing import TypeVar + + T = TypeVar('T') + + V = TypeVar('V') + + def f(x: V) -> V: + pass + + class A: + pass + + class B(A, Generic[T]): + def f(self, x: T) -> T: + pass + """, + ), + } + ) + def test_adding_typevars(self, stub: str, before: str, after: str) -> None: + self.run_simple_test_case(stub=stub, before=before, after=after) + @data_provider( { "required_positional_only_args": (