Skip to content

Commit

Permalink
Merge in TypeVars and Generic base classes in ApplyTypeAnnotationVisi…
Browse files Browse the repository at this point in the history
…tor (#596)

* Tracks TypeVars that are used in type annotations in the pyi file, and
  adds their Assign statements to the merged file.
* Adds Generic[T] as a base class if needed.
  • Loading branch information
martindemello authored Jan 14, 2022
1 parent 5f22b6c commit e03ed43
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 3 deletions.
108 changes: 105 additions & 3 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

import libcst as cst
import libcst.matchers as m

from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._add_imports import AddImportsVisitor
Expand Down Expand Up @@ -55,6 +57,18 @@ def _is_set(x: Union[None, cst.CSTNode, cst.MaybeSentinel]) -> bool:
return x is not None and x != cst.MaybeSentinel.DEFAULT


def _get_string_value(node: cst.SimpleString) -> str:
s = node.value
c = s[-1]
return s[s.index(c) : -1]


def _find_generic_base(node: cst.ClassDef) -> Optional[cst.Arg]:
for b in node.bases:
if m.matches(b.value, m.Subscript(value=m.Name("Generic"))):
return b


@dataclass(frozen=True)
class FunctionKey:
name: str
Expand All @@ -80,7 +94,7 @@ class FunctionAnnotation:
returns: Optional[cst.Annotation]


class TypeCollector(cst.CSTVisitor):
class TypeCollector(m.MatcherDecoratableVisitor):
"""
Collect type annotations from a stub module.
"""
Expand All @@ -91,6 +105,7 @@ class TypeCollector(cst.CSTVisitor):
)

def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None:
super().__init__()
# Qualifier for storing the canonical name of the current function.
self.qualifier: List[str] = []
# Store the annotations.
Expand All @@ -99,6 +114,9 @@ def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None:
self.existing_imports: Set[str] = existing_imports
self.class_definitions: Dict[str, cst.ClassDef] = {}
self.context = context
self.current_assign: Optional[cst.Assign] = None # used to collect typevars
self.typevars: Dict[str, cst.Assign] = {}
self.annotation_names: Set[str] = set()

def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.qualifier.append(node.name.value)
Expand Down Expand Up @@ -153,6 +171,29 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
def leave_AnnAssign(self, original_node: cst.AnnAssign) -> None:
self.qualifier.pop()

def visit_Assign(self, node: cst.Assign) -> None:
self.current_assign = node

def leave_Assign(self, original_node: cst.Assign) -> None:
self.current_assign = None

@m.call_if_inside(m.Assign())
@m.visit(m.Call(func=m.Name("TypeVar")))
def record_typevar(self, node: cst.Call) -> None:
# pyre-ignore current_assign is never None here
name = get_full_name_for_node(self.current_assign.targets[0].target)
if name:
# pyre-ignore current_assign is never None here
self.typevars[name] = self.current_assign
self._handle_qualification_and_should_qualify("typing.TypeVar")
self.current_assign = None

def leave_Module(self, original_node: cst.Module) -> None:
# Filter out unused typevars
self.typevars = {
k: v for k, v in self.typevars.items() if k in self.annotation_names
}

def _get_unique_qualified_name(self, node: cst.CSTNode) -> str:
name = None
names = [q.name for q in self.get_metadata(QualifiedNameProvider, node)]
Expand Down Expand Up @@ -194,7 +235,7 @@ def _module_and_target(self, qualified_name: str) -> Tuple[str, str]:

def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool:
"""
Basd on a qualified name and the existing module imports, record that
Based on a qualified name and the existing module imports, record that
we need to add an import if necessary and return whether or not we
should use the qualified name due to a preexisting import.
"""
Expand Down Expand Up @@ -227,6 +268,7 @@ def _handle_NameOrAttribute(
dequalified_node,
) = self._get_qualified_name_and_dequalified_node(node)
should_qualify = self._handle_qualification_and_should_qualify(qualified_name)
self.annotation_names.add(qualified_name)
if should_qualify:
return node
else:
Expand All @@ -239,6 +281,8 @@ def _handle_Index(self, slice: cst.Index) -> cst.Index:
elif isinstance(value, cst.Attribute):
return slice.with_changes(value=self._handle_NameOrAttribute(value))
else:
if isinstance(value, cst.SimpleString):
self.annotation_names.add(_get_string_value(value))
return slice

def _handle_Subscript(self, node: cst.Subscript) -> cst.Subscript:
Expand Down Expand Up @@ -279,6 +323,7 @@ def _handle_Subscript(self, node: cst.Subscript) -> cst.Subscript:
def _handle_Annotation(self, annotation: cst.Annotation) -> cst.Annotation:
node = annotation.annotation
if isinstance(node, cst.SimpleString):
self.annotation_names.add(_get_string_value(node))
return annotation
elif isinstance(node, cst.Subscript):
return cst.Annotation(annotation=self._handle_Subscript(node))
Expand Down Expand Up @@ -309,6 +354,7 @@ class Annotations:
)
attribute_annotations: Dict[str, cst.Annotation] = field(default_factory=dict)
class_definitions: Dict[str, cst.ClassDef] = field(default_factory=dict)
typevars: Dict[str, cst.Assign] = field(default_factory=dict)


@dataclass
Expand All @@ -318,6 +364,7 @@ class AnnotationCounts:
parameter_annotations: int = 0
return_annotations: int = 0
classes_added: int = 0
typevars_and_generics_added: int = 0

def any_changes_applied(self) -> bool:
return (
Expand All @@ -326,6 +373,7 @@ def any_changes_applied(self) -> bool:
+ self.parameter_annotations
+ self.return_annotations
+ self.classes_added
+ self.typevars_and_generics_added
) > 0


Expand Down Expand Up @@ -397,6 +445,10 @@ def __init__(
# only made changes to the imports.
self.annotation_counts: AnnotationCounts = AnnotationCounts()

# We use this to collect typevars, to avoid importing existing ones from the pyi file
self.current_assign: Optional[cst.Assign] = None
self.typevars: Dict[str, cst.Assign] = {}

@staticmethod
def store_stub_in_context(
context: CodemodContext,
Expand Down Expand Up @@ -463,6 +515,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
self.annotations.function_annotations.update(visitor.function_annotations)
self.annotations.attribute_annotations.update(visitor.attribute_annotations)
self.annotations.class_definitions.update(visitor.class_definitions)
self.annotations.typevars.update(visitor.typevars)

tree_with_imports = AddImportsVisitor(
context=self.context,
Expand Down Expand Up @@ -722,7 +775,16 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef:
cls_name = ".".join(self.qualifier)
self.qualifier.pop()
definition = self.annotations.class_definitions.get(cls_name)
if definition:
b1 = _find_generic_base(definition)
b2 = _find_generic_base(updated_node)
if b1 and not b2:
new_bases = list(updated_node.bases) + [b1]
self.annotation_counts.typevars_and_generics_added += 1
return updated_node.with_changes(bases=new_bases)
return updated_node

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
Expand Down Expand Up @@ -756,10 +818,29 @@ def leave_FunctionDef(
return updated_node.with_changes(params=new_parameters)
return updated_node

def visit_Assign(self, node: cst.Assign) -> None:
self.current_assign = node

@m.call_if_inside(m.Assign())
@m.visit(m.Call(func=m.Name("TypeVar")))
def record_typevar(self, node: cst.Call) -> None:
# pyre-ignore current_assign is never None here
name = get_full_name_for_node(self.current_assign.targets[0].target)
if name:
# Preserve the whole node, even though we currently just use the
# name, so that we can match bounds and variance at some point and
# determine if two typevars with the same name are indeed the same.

# pyre-ignore current_assign is never None here
self.typevars[name] = self.current_assign
self.current_assign = None

def leave_Assign(
self, original_node: cst.Assign, updated_node: cst.Assign
) -> Union[cst.Assign, cst.AnnAssign]:

self.current_assign = None

if len(original_node.targets) > 1:
for assign in original_node.targets:
target = assign.target
Expand Down Expand Up @@ -787,8 +868,17 @@ def leave_Module(
for name, definition in self.annotations.class_definitions.items()
if name not in self.visited_classes
]
if not self.toplevel_annotations and not fresh_class_definitions:

# NOTE: The entire change will also be abandoned if
# self.annotation_counts is all 0s, so if adding any new category make
# sure to record it there.
if not (
self.toplevel_annotations
or fresh_class_definitions
or self.annotations.typevars
):
return updated_node

toplevel_statements = []
# First, find the insertion point for imports
statements_before_imports, statements_after_imports = self._split_module(
Expand All @@ -806,6 +896,18 @@ def leave_Module(
)
toplevel_statements.append(cst.SimpleStatementLine([annotated_assign]))

# TypeVar definitions could be scattered through the file, so do not
# attempt to put new ones with existing ones, just add them at the top.
typevars = {
k: v for k, v in self.annotations.typevars.items() if k not in self.typevars
}
if typevars:
for var, stmt in typevars.items():
toplevel_statements.append(cst.Newline())
toplevel_statements.append(stmt)
self.annotation_counts.typevars_and_generics_added += 1
toplevel_statements.append(cst.Newline())

self.annotation_counts.classes_added = len(fresh_class_definitions)
toplevel_statements.extend(fresh_class_definitions)

Expand Down
114 changes: 114 additions & 0 deletions libcst/codemod/visitors/tests/test_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,120 @@ def foo() -> None:
def test_adding_typed_dicts(self, stub: str, before: str, after: str) -> None:
self.run_simple_test_case(stub=stub, before=before, after=after)

@data_provider(
{
"insert_new_TypeVar_not_in_source_file": (
"""
from typing import Dict, TypeVar
_KT = TypeVar('_KT')
_VT = TypeVar('_VT')
class UserDict(Dict[_KT, _VT]):
def __init__(self, initialdata: Dict[_KT, _VT] = ...): ...
""",
"""
class UserDict:
def __init__(self, initialdata = None):
pass
""",
"""
from typing import Dict, TypeVar
_KT = TypeVar('_KT')
_VT = TypeVar('_VT')
class UserDict:
def __init__(self, initialdata: Dict[_KT, _VT] = None):
pass
""",
),
"insert_only_used_TypeVar_not_already_in_source": (
"""
from typing import Dict, TypeVar
K = TypeVar('K')
V = TypeVar('V')
X = TypeVar('X')
class UserDict(Dict[K, V]):
def __init__(self, initialdata: Dict[K, V] = ...): ...
""",
"""
from typing import TypeVar
V = TypeVar('V')
class UserDict:
def __init__(self, initialdata = None):
pass
def f(x: V) -> V:
pass
""",
"""
from typing import Dict, TypeVar
K = TypeVar('K')
V = TypeVar('V')
class UserDict:
def __init__(self, initialdata: Dict[K, V] = None):
pass
def f(x: V) -> V:
pass
""",
),
"insert_Generic_base_class": (
"""
from typing import TypeVar
T = TypeVar('T')
X = TypeVar('X')
class B(A, Generic[T]):
def f(self, x: T) -> T: ...
""",
"""
from typing import TypeVar
V = TypeVar('V')
def f(x: V) -> V:
pass
class A:
pass
class B(A):
def f(self, x):
pass
""",
"""
from typing import TypeVar
T = TypeVar('T')
V = TypeVar('V')
def f(x: V) -> V:
pass
class A:
pass
class B(A, Generic[T]):
def f(self, x: T) -> T:
pass
""",
),
}
)
def test_adding_typevars(self, stub: str, before: str, after: str) -> None:
self.run_simple_test_case(stub=stub, before=before, after=after)

@data_provider(
{
"required_positional_only_args": (
Expand Down

0 comments on commit e03ed43

Please sign in to comment.