Skip to content

Commit

Permalink
Use QualifiedNameProvider to handle stub types (#536)
Browse files Browse the repository at this point in the history
The existing TypeCollector visitor logic attempted to
fold actual imports from stubs together with the module
we were annotating, and separately do nice things with the
names of types so that we could parse stubs written either
with various sorts of proper imports *or* stubs written
using bare fully-qualified type names (which isn't
actually legal python, but is easy to produce from automated
tools like `pyre infer`).

In this commit I simplify things in principle - meaning the
data flow is simpler, although the code is still similarly
complex - by using `QualifiedNameProvider` plus a fallback
to `get_full_name_for_node` to handle all cases via
fully-qualified names, so that the way a stub chooses to
lay out its imports is no longer relevant to how we will
understand it.

As a result, we can scrap a whole test suite where we
were understanding edge cases in the import handling, and
moreover one of the weird unsupported edge cases is now
well supported.

The tests got simpler because some edge cases no longer
matter (the whole imports test is no longer relevant),
and a couple of weird edge cases were fixed.

I ran tests with
```
python -m unittest libcst.codemod.visitors.tests.test_apply_type_annotations.TestApplyAnnotationsVisitor
```

I tried to make this change minimal in that I preserve the
existing data flow, so that it's easy to review. But it's worth
considering whether to follow up with a diff where we change
the TypeAnnotationCollector into a *transform* rather than a
*visitor*, because that would allow us to scrap quite a bit
of logic - all we would need to know is a couple of bits
of context from higher up in the tree and we could process
Names and Attributes without needing all this recursion.
  • Loading branch information
stroxler authored Oct 28, 2021
1 parent 1f169b8 commit 3743c70
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 107 deletions.
190 changes: 124 additions & 66 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

import libcst as cst
from libcst import matchers as m
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._add_imports import AddImportsVisitor
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.helpers import get_full_name_for_node
from libcst.metadata import QualifiedNameProvider, PositionProvider


NameOrAttribute = Union[cst.Name, cst.Attribute]
NAME_OR_ATTRIBUTE = (cst.Name, cst.Attribute)


def _get_import_alias_names(import_aliases: Sequence[cst.ImportAlias]) -> Set[str]:
Expand Down Expand Up @@ -50,6 +54,11 @@ class TypeCollector(cst.CSTVisitor):
Collect type annotations from a stub module.
"""

METADATA_DEPENDENCIES = (
PositionProvider,
QualifiedNameProvider,
)

def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None:
# Qualifier for storing the canonical name of the current function.
self.qualifier: List[str] = []
Expand All @@ -62,7 +71,23 @@ def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None:

def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.qualifier.append(node.name.value)
self.class_definitions[node.name.value] = node
new_bases = []
for base in node.bases:
value = base.value
if isinstance(value, NAME_OR_ATTRIBUTE):
new_value = self._handle_NameOrAttribute(value)
elif isinstance(base.value, cst.Subscript):
new_value = self._handle_Subscript(value)
else:
start = self.get_metadata(PositionProvider, node).start
raise ValueError(
"Invalid type used as base class in stub file at "
+ f"{start.line}:{start.column}. Only subscripts, names, and "
+ "attributes are valid base classes for static typing."
)
new_bases.append(base.with_changes(value=new_value))

self.class_definitions[node.name.value] = node.with_changes(bases=new_bases)

def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.qualifier.pop()
Expand All @@ -71,11 +96,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
self.qualifier.append(node.name.value)
returns = node.returns
return_annotation = (
self._create_import_from_annotation(returns)
if returns is not None
else None
self._handle_Annotation(annotation=returns) if returns is not None else None
)
parameter_annotations = self._import_parameter_annotations(node.params)
parameter_annotations = self._handle_Parameters(node.params)
self.function_annotations[".".join(self.qualifier)] = FunctionAnnotation(
parameters=parameter_annotations, returns=return_annotation
)
Expand All @@ -90,102 +113,137 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
name = get_full_name_for_node(node.target)
if name is not None:
self.qualifier.append(name)
annotation_value = self._create_import_from_annotation(node.annotation)
annotation_value = self._handle_Annotation(annotation=node.annotation)
self.attribute_annotations[".".join(self.qualifier)] = annotation_value
return True

def leave_AnnAssign(self, original_node: cst.AnnAssign) -> None:
self.qualifier.pop()

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
module = node.module
names = node.names

# module is None for relative imports like `from .. import foo`.
# We ignore these for now.
if module is None or isinstance(names, cst.ImportStar):
return
module_name = get_full_name_for_node(module)
if module_name is not None:
for import_name in _get_import_alias_names(names):
AddImportsVisitor.add_needed_import(
self.context, module_name, import_name
)
def _get_unique_qualified_name(self, node: cst.CSTNode) -> str:
names = [q.name for q in self.get_metadata(QualifiedNameProvider, node)]
if len(names) == 0:
# we hit this branch if the stub is directly using a fully
# qualified name, which is not technically valid python but is
# convenient to allow.
return get_full_name_for_node(node)
elif len(names) == 1:
return names[0]
else:
start = self.get_metadata(PositionProvider, node).start
raise ValueError(
"Could not resolve a unique qualified name for type "
+ f"{get_full_name_for_node(node)} at {start.line}:{start.column}. "
+ f"Candidate names were: {names!r}"
)

def _add_annotation_to_imports(
self, annotation: cst.Attribute
def _get_qualified_name_and_dequalified_node(
self,
node: Union[cst.Name, cst.Attribute],
) -> Tuple[str, Union[cst.Name, cst.Attribute]]:
qualified_name = self._get_unique_qualified_name(node)
dequalified_node = node.attr if isinstance(node, cst.Attribute) else node
return qualified_name, dequalified_node

def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool:
"""
Basd on a qualified name and the existing module imports, record that
we need to add an import if necessary and return whether or not we
should use the qualified name due to a preexisting import.
"""
split_name = qualified_name.split(".")
if len(split_name) > 1 and qualified_name not in self.existing_imports:
module, target = ".".join(split_name[:-1]), split_name[-1]
if module == "builtins":
return False
elif module in self.existing_imports:
return True
else:
AddImportsVisitor.add_needed_import(self.context, module, target)
return False

# Handler functions.
#
# Each of these does one of two things, possibly recursively, over some
# valid CST node for a static type:
# - process the qualified name and ensure we will add necessary imports
# - dequalify the node

def _handle_NameOrAttribute(
self,
node: NameOrAttribute,
) -> Union[cst.Name, cst.Attribute]:
key = get_full_name_for_node(annotation.value)
if key is not None:
# Don't attempt to re-import existing imports.
if key in self.existing_imports:
return annotation
import_name = get_full_name_for_node(annotation.attr)
if import_name is not None:
AddImportsVisitor.add_needed_import(self.context, key, import_name)
return annotation.attr
(
qualified_name,
dequalified_node,
) = self._get_qualified_name_and_dequalified_node(node)
should_qualify = self._handle_qualification_and_should_qualify(qualified_name)
if should_qualify:
return node
else:
return dequalified_node

def _handle_Index(self, slice: cst.Index, node: cst.Subscript) -> cst.Subscript:
value = slice.value
if isinstance(value, cst.Subscript):
new_slice = slice.with_changes(value=self._handle_Subscript(value))
return node.with_changes(slice=new_slice)
return slice.with_changes(value=self._handle_Subscript(value))
elif isinstance(value, cst.Attribute):
new_slice = slice.with_changes(value=self._add_annotation_to_imports(value))
return node.with_changes(slice=new_slice)
return slice.with_changes(value=self._handle_NameOrAttribute(value))
else:
return node
return slice

def _handle_Subscript(self, node: cst.Subscript) -> cst.Subscript:
value = node.value
if isinstance(value, NAME_OR_ATTRIBUTE):
new_node = node.with_changes(value=self._handle_NameOrAttribute(value))
else:
raise ValueError("Expected any indexed type to have")
if self._get_unique_qualified_name(node) in ("Type", "typing.Type"):
# Note: we are intentionally not handling qualification of
# anything inside `Type` because it's common to have nested
# classes, which we cannot currently distinguish from classes
# coming from other modules, appear here.
return new_node
slice = node.slice
if m.matches(node.value, m.Name(value="Type")):
return node
if isinstance(slice, list):
if isinstance(slice, tuple):
new_slice = []
for item in slice:
value = item.slice.value
if isinstance(value, cst.Attribute):
name = self._add_annotation_to_imports(item.slice.value)
if isinstance(value, NAME_OR_ATTRIBUTE):
name = self._handle_NameOrAttribute(item.slice.value)
new_index = item.slice.with_changes(value=name)
new_slice.append(item.with_changes(slice=new_index))
else:
if isinstance(item.slice, cst.Index) and not isinstance(
item.slice.value, cst.Name
):
if isinstance(item.slice, cst.Index):
new_index = item.slice.with_changes(
value=self._handle_Index(item.slice, item)
)
item = item.with_changes(slice=new_index, comma=None)
item = item.with_changes(slice=new_index)
new_slice.append(item)
return node.with_changes(slice=new_slice)
return new_node.with_changes(slice=tuple(new_slice))
elif isinstance(slice, cst.Index):
return self._handle_Index(slice, node)
new_slice = self._handle_Index(slice)
return new_node.with_changes(slice=new_slice)
else:
return node

def _create_import_from_annotation(self, returns: cst.Annotation) -> cst.Annotation:
annotation = returns.annotation
if isinstance(annotation, cst.Attribute):
attr = self._add_annotation_to_imports(annotation)
return cst.Annotation(annotation=attr)
if isinstance(annotation, cst.Subscript):
value = annotation.value
if m.matches(value, m.Name(value="Type")):
return returns
return cst.Annotation(annotation=self._handle_Subscript(annotation))
return new_node

def _handle_Annotation(self, annotation: cst.Annotation) -> cst.Annotation:
node = annotation.annotation
if isinstance(node, cst.SimpleString):
return annotation
elif isinstance(node, cst.Subscript):
return cst.Annotation(annotation=self._handle_Subscript(node))
else:
return returns
return cst.Annotation(annotation=self._handle_NameOrAttribute(node))

def _import_parameter_annotations(
self, parameters: cst.Parameters
) -> cst.Parameters:
def _handle_Parameters(self, parameters: cst.Parameters) -> cst.Parameters:
def update_annotations(parameters: Sequence[cst.Param]) -> List[cst.Param]:
updated_parameters = []
for parameter in list(parameters):
annotation = parameter.annotation
if annotation is not None:
parameter = parameter.with_changes(
annotation=self._create_import_from_annotation(annotation)
annotation=self._handle_Annotation(annotation=annotation)
)
updated_parameters.append(parameter)
return updated_parameters
Expand Down Expand Up @@ -321,7 +379,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
self.overwrite_existing_annotations or overwrite_existing_annotations
)
visitor = TypeCollector(existing_import_names, self.context)
stub.visit(visitor)
cst.MetadataWrapper(stub).visit(visitor)
self.annotations.function_annotations.update(visitor.function_annotations)
self.annotations.attribute_annotations.update(visitor.attribute_annotations)
self.annotations.class_definitions.update(visitor.class_definitions)
Expand Down
45 changes: 4 additions & 41 deletions libcst/codemod/visitors/tests/test_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,45 +56,6 @@ def run_test_case_with_flags(
)
self.assertCodemod(before, after, context_override=context)

@data_provider(
{
"supported_cases": (
"""
from __future__ import annotations
from foo import Foo
from baz import Baz
""",
"""
from foo import Bar
import bar
""",
"""
from __future__ import annotations
from foo import Foo, Bar
import bar
from baz import Baz
""",
),
"unsupported_cases": (
"""
from Foo import foo as bar
import foo
from .. import baz
from boo import *
""",
"""
""",
# This is a bug, it would be better to just ignor aliased
# imports than to add them incorrectly.
"""
from Foo import bar
""",
),
}
)
def test_merge_module_imports(self, stub: str, before: str, after: str) -> None:
self.run_simple_test_case(stub=stub, before=before, after=after)

@data_provider(
{
"simple": (
Expand Down Expand Up @@ -361,7 +322,7 @@ def foo(x: int) -> Optional[Example]:
pass
""",
),
"UNSUPPORTED_add_imports_for_generics": (
"add_imports_for_generics": (
"""
def foo(x: int) -> typing.Optional[Example]: ...
""",
Expand All @@ -370,7 +331,9 @@ def foo(x: int):
pass
""",
"""
def foo(x: int) -> typing.Optional[Example]:
from typing import Optional
def foo(x: int) -> Optional[Example]:
pass
""",
),
Expand Down

0 comments on commit 3743c70

Please sign in to comment.