diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index bcea03c2e..3eb79d359 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -699,9 +699,20 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: def visit_Call(self, node: cst.Call) -> Optional[bool]: self.__top_level_attribute_stack.append(None) + if any( + qn.name == "typing.TypeVar" + for qn in self.scope.get_qualified_names_for(node) + ): + node.func.visit(self) + self.__in_annotation.add(node) + for arg in node.args[1:]: + arg.visit(self) + return False + return True def leave_Call(self, original_node: cst.Call) -> None: self.__top_level_attribute_stack.pop() + self.__in_annotation.discard(original_node) def visit_Annotation(self, node: cst.Annotation) -> Optional[bool]: self.__in_annotation.add(node) @@ -709,12 +720,38 @@ def visit_Annotation(self, node: cst.Annotation) -> Optional[bool]: def leave_Annotation(self, original_node: cst.Annotation) -> None: self.__in_annotation.discard(original_node) + def visit_SimpleString(self, node: cst.SimpleString) -> Optional[bool]: + self._handle_string_annotation(node) + return False + + def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> Optional[bool]: + self._handle_string_annotation(node) + return False + + def _handle_string_annotation( + self, node: Union[cst.SimpleString, cst.ConcatenatedString] + ) -> None: + if self.__in_annotation: + value = node.evaluated_value + if value: + mod = cst.parse_module(value) + mod.visit(self) + + def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]: + if any( + qn.name == "typing.Literal" + for qn in self.scope.get_qualified_names_for(node.value) + ): + node.value.visit(self) + return False + return True + def visit_Name(self, node: cst.Name) -> Optional[bool]: # not all Name have ExpressionContext context = self.provider.get_metadata(ExpressionContextProvider, node, None) if context == ExpressionContext.STORE: self.scope.record_assignment(node.value, node) - elif context in (ExpressionContext.LOAD, ExpressionContext.DEL): + elif context in (ExpressionContext.LOAD, ExpressionContext.DEL, None): access = Access(node, self.scope, is_annotation=bool(self.__in_annotation)) self.__deferred_accesses.append( (access, self.__top_level_attribute_stack[-1]) diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index bee13d505..b05ee8314 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -1013,22 +1013,56 @@ def g(): def test_annotation_access(self) -> None: m, scopes = get_scope_metadata_provider( """ - from t import T - def f(t: T): + from typing import Literal, TypeVar + from a import A, B, C, D, E, F + def x(a: A): pass + def y(b: B): + pass + def z(c: Literal["C"]): + pass + DType = TypeVar("DType", bound=D) + EType = TypeVar("EType", bound="E") + FType = TypeVar("F") """ ) imp = ensure_type( - ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom + ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom ) scope = scopes[imp] - assignments = list(scope["T"]) - assignment = assignments[0] + + assignment = list(scope["A"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 1) + references = list(assignment.references) + self.assertTrue(references[0].is_annotation) + + assignment = list(scope["B"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 1) + references = list(assignment.references) + self.assertTrue(references[0].is_annotation) + + assignment = list(scope["C"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 0) + + assignment = list(scope["D"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 1) references = list(assignment.references) self.assertTrue(references[0].is_annotation) + assignment = list(scope["E"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 1) + references = list(assignment.references) + self.assertTrue(references[0].is_annotation) + + assignment = list(scope["F"])[0] + self.assertIsInstance(assignment, Assignment) + self.assertEqual(len(assignment.references), 0) + def test_node_of_scopes(self) -> None: m, scopes = get_scope_metadata_provider( """