Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[apply-annotations] Add argument for ignoring existing annotations. #291

Merged
merged 1 commit into from
May 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,15 @@ class ApplyTypeAnnotationsVisitor(ContextAwareTransformer):
This is one of the transforms that is available automatically to you when
running a codemod. To use it in this manner, import
:class:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor` and then call the static
:meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.add_stub_to_context` method,
:meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.store_stub_in_context` method,
giving it the current context (found as ``self.context`` for all subclasses of
:class:`~libcst.codemod.Codemod`), the stub module from which you wish to add annotations.

For example, you can store the type annotation ``int`` for ``x`` using::

stub_module = parse_module("x: int = ...")

ApplyTypeAnnotationsVisitor.add_stub_to_context(self.context, stub_module)
ApplyTypeAnnotationsVisitor.store_stub_in_context(self.context, stub_module)

You can apply the type annotation using::

Expand All @@ -223,33 +223,52 @@ class ApplyTypeAnnotationsVisitor(ContextAwareTransformer):
x: int = 1

If the function or attribute already has a type annotation, it will not be overwritten.

To overwrite existing annotations when applying annotations from a stub,
use the keyword argument ``overwrite_existing_annotations=True`` when
constructing the codemod or when calling ``store_stub_in_context``.
pradeep90 marked this conversation as resolved.
Show resolved Hide resolved
"""

CONTEXT_KEY = "ApplyTypeAnnotationsVisitor"

def __init__(
self, context: CodemodContext, annotations: Optional[Annotations] = None
self,
context: CodemodContext,
annotations: Optional[Annotations] = None,
overwrite_existing_annotations: bool = False,
) -> None:
super().__init__(context)
# Qualifier for storing the canonical name of the current function.
self.qualifier: List[str] = []
self.annotations: Annotations = annotations or Annotations()
self.toplevel_annotations: Dict[str, cst.Annotation] = {}
self.visited_classes: Set[str] = set()
self.overwrite_existing_annotations = overwrite_existing_annotations

# We use this to determine the end of the import block so that we can
# insert top-level annotations.
self.import_statements: List[cst.ImportFrom] = []

@staticmethod
def add_stub_to_context(context: CodemodContext, stub: cst.Module) -> None:
def store_stub_in_context(
context: CodemodContext,
stub: cst.Module,
overwrite_existing_annotations: bool = False,
) -> None:
"""
Add a stub module to the :class:`~libcst.codemod.CodemodContext` so
Store a stub module in the :class:`~libcst.codemod.CodemodContext` so
that type annotations from the stub can be applied in a later
invocation of this class.

If the ``overwrite_existing_annotations`` flag is ``True``, the
codemod will overwrite any existing annotations.

If you call this function multiple times, only the last values of
``stub`` and ``overwrite_existing_annotations`` will take effect.
"""
context.scratch.setdefault(ApplyTypeAnnotationsVisitor.CONTEXT_KEY, []).append(
stub
context.scratch[ApplyTypeAnnotationsVisitor.CONTEXT_KEY] = (
stub,
overwrite_existing_annotations,
)

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
Expand All @@ -262,8 +281,14 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
tree.visit(import_gatherer)
existing_import_names = _get_import_names(import_gatherer.all_imports)

stubs = self.context.scratch.get(ApplyTypeAnnotationsVisitor.CONTEXT_KEY, [])
for stub in stubs:
context_contents = self.context.scratch.get(
ApplyTypeAnnotationsVisitor.CONTEXT_KEY
)
if context_contents:
stub, overwrite_existing_annotations = context_contents
self.overwrite_existing_annotations = (
self.overwrite_existing_annotations or overwrite_existing_annotations
)
visitor = TypeCollector(existing_import_names, self.context)
stub.visit(visitor)
self.annotations.function_annotations.update(visitor.function_annotations)
Expand Down Expand Up @@ -339,7 +364,8 @@ def _update_parameters(
self, annotations: FunctionAnnotation, updated_node: cst.FunctionDef
) -> cst.Parameters:
# Update params and default params with annotations
# don't override existing annotations or default values
# Don't override existing annotations or default values unless asked
# to overwrite existing annotations.
def update_annotation(
parameters: Sequence[cst.Param], annotations: Sequence[cst.Param]
) -> List[cst.Param]:
Expand All @@ -350,7 +376,9 @@ def update_annotation(
parameter_annotations[parameter.name.value] = parameter.annotation
for parameter in parameters:
key = parameter.name.value
if key in parameter_annotations and not parameter.annotation:
if key in parameter_annotations and (
self.overwrite_existing_annotations or not parameter.annotation
):
parameter = parameter.with_changes(
annotation=parameter_annotations[key]
)
Expand Down Expand Up @@ -409,8 +437,9 @@ def leave_FunctionDef(
self.qualifier.pop()
if key in self.annotations.function_annotations:
function_annotation = self.annotations.function_annotations[key]
# Only add new annotation if one doesn't already exist
if not updated_node.returns:
# Only add new annotation if explicitly told to overwrite existing
# annotations or if one doesn't already exist.
if self.overwrite_existing_annotations or not updated_node.returns:
updated_node = updated_node.with_changes(
returns=function_annotation.returns
)
Expand Down
40 changes: 39 additions & 1 deletion libcst/codemod/visitors/tests/test_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,45 @@ def foo() -> typing.Sequence[int]:
)
def test_annotate_functions(self, stub: str, before: str, after: str) -> None:
context = CodemodContext()
ApplyTypeAnnotationsVisitor.add_stub_to_context(
ApplyTypeAnnotationsVisitor.store_stub_in_context(
context, parse_module(textwrap.dedent(stub.rstrip()))
)
self.assertCodemod(before, after, context_override=context)

@data_provider(
(
(
"""
def fully_annotated_with_different_stub(a: bool, b: bool) -> str: ...
""",
"""
def fully_annotated_with_different_stub(a: int, b: str) -> bool:
return 'hello'
""",
"""
def fully_annotated_with_different_stub(a: bool, b: bool) -> str:
return 'hello'
""",
),
)
)
def test_annotate_functions_with_existing_annotations(
self, stub: str, before: str, after: str
) -> None:
context = CodemodContext()
ApplyTypeAnnotationsVisitor.store_stub_in_context(
context, parse_module(textwrap.dedent(stub.rstrip()))
)
# Test setting the overwrite flag on the codemod instance.
self.assertCodemod(
before, after, context_override=context, overwrite_existing_annotations=True
)

# Test setting the flag when storing the stub in the context.
context = CodemodContext()
ApplyTypeAnnotationsVisitor.store_stub_in_context(
context,
parse_module(textwrap.dedent(stub.rstrip())),
overwrite_existing_annotations=True,
)
self.assertCodemod(before, after, context_override=context)