Skip to content

Commit

Permalink
Merge pull request #373 from Instagram/annotation-access
Browse files Browse the repository at this point in the history
Handle string annotations in ScopeProvider
  • Loading branch information
Kronuz authored Aug 14, 2020
2 parents eda30f4 + fa15c98 commit c935fcb
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 6 deletions.
39 changes: 38 additions & 1 deletion libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,22 +699,59 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:

def visit_Call(self, node: cst.Call) -> Optional[bool]:
self.__top_level_attribute_stack.append(None)
if any(
qn.name == "typing.TypeVar"
for qn in self.scope.get_qualified_names_for(node)
):
node.func.visit(self)
self.__in_annotation.add(node)
for arg in node.args[1:]:
arg.visit(self)
return False
return True

def leave_Call(self, original_node: cst.Call) -> None:
self.__top_level_attribute_stack.pop()
self.__in_annotation.discard(original_node)

def visit_Annotation(self, node: cst.Annotation) -> Optional[bool]:
self.__in_annotation.add(node)

def leave_Annotation(self, original_node: cst.Annotation) -> None:
self.__in_annotation.discard(original_node)

def visit_SimpleString(self, node: cst.SimpleString) -> Optional[bool]:
self._handle_string_annotation(node)
return False

def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> Optional[bool]:
self._handle_string_annotation(node)
return False

def _handle_string_annotation(
self, node: Union[cst.SimpleString, cst.ConcatenatedString]
) -> None:
if self.__in_annotation:
value = node.evaluated_value
if value:
mod = cst.parse_module(value)
mod.visit(self)

def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]:
if any(
qn.name == "typing.Literal"
for qn in self.scope.get_qualified_names_for(node.value)
):
node.value.visit(self)
return False
return True

def visit_Name(self, node: cst.Name) -> Optional[bool]:
# not all Name have ExpressionContext
context = self.provider.get_metadata(ExpressionContextProvider, node, None)
if context == ExpressionContext.STORE:
self.scope.record_assignment(node.value, node)
elif context in (ExpressionContext.LOAD, ExpressionContext.DEL):
elif context in (ExpressionContext.LOAD, ExpressionContext.DEL, None):
access = Access(node, self.scope, is_annotation=bool(self.__in_annotation))
self.__deferred_accesses.append(
(access, self.__top_level_attribute_stack[-1])
Expand Down
44 changes: 39 additions & 5 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,22 +1013,56 @@ def g():
def test_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from t import T
def f(t: T):
from typing import Literal, TypeVar
from a import A, B, C, D, E, F
def x(a: A):
pass
def y(b: B):
pass
def z(c: Literal["C"]):
pass
DType = TypeVar("DType", bound=D)
EType = TypeVar("EType", bound="E")
FType = TypeVar("F")
"""
)
imp = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
scope = scopes[imp]
assignments = list(scope["T"])
assignment = assignments[0]

assignment = list(scope["A"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)

assignment = list(scope["B"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)

assignment = list(scope["C"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

assignment = list(scope["D"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)

assignment = list(scope["E"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)

assignment = list(scope["F"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
Expand Down

0 comments on commit c935fcb

Please sign in to comment.