From 8652974d8798ad08c73846c7bf8ebbd14dc02b02 Mon Sep 17 00:00:00 2001 From: Martin DeMello Date: Sat, 8 Jan 2022 02:18:10 -0800 Subject: [PATCH] Support relative imports in AddImportsVisitor. (#585) * Support relative imports in AddImportsVisitor. * Adds an Import dataclass to represent a single imported object * Refactors AddImportsVisitor to pass around Import objects * Separates out the main logic in get_absolute_module_for_import so that it can be used to resolve relative module names outside of a cst.Import node * Resolves relative module names in AddImportsVisitor if we have a current module name set. Fixes #578 --- libcst/codemod/visitors/__init__.py | 2 + libcst/codemod/visitors/_add_imports.py | 51 +-- .../visitors/_apply_type_annotations.py | 3 +- libcst/codemod/visitors/_imports.py | 43 +++ .../visitors/tests/test_add_imports.py | 313 +++++++++++++++--- libcst/helpers/__init__.py | 2 + libcst/helpers/_statement.py | 20 +- 7 files changed, 358 insertions(+), 76 deletions(-) create mode 100644 libcst/codemod/visitors/_imports.py diff --git a/libcst/codemod/visitors/__init__.py b/libcst/codemod/visitors/__init__.py index a14165052..1cbbd2c8c 100644 --- a/libcst/codemod/visitors/__init__.py +++ b/libcst/codemod/visitors/__init__.py @@ -12,6 +12,7 @@ GatherNamesFromStringAnnotationsVisitor, ) from libcst.codemod.visitors._gather_unused_imports import GatherUnusedImportsVisitor +from libcst.codemod.visitors._imports import ImportItem from libcst.codemod.visitors._remove_imports import RemoveImportsVisitor __all__ = [ @@ -22,5 +23,6 @@ "GatherImportsVisitor", "GatherNamesFromStringAnnotationsVisitor", "GatherUnusedImportsVisitor", + "ImportItem", "RemoveImportsVisitor", ] diff --git a/libcst/codemod/visitors/_add_imports.py b/libcst/codemod/visitors/_add_imports.py index 248d3838f..64131dd60 100644 --- a/libcst/codemod/visitors/_add_imports.py +++ b/libcst/codemod/visitors/_add_imports.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -# + from collections import defaultdict from typing import Dict, List, Optional, Sequence, Set, Tuple, Union @@ -11,6 +11,7 @@ from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareTransformer from libcst.codemod.visitors._gather_imports import GatherImportsVisitor +from libcst.codemod.visitors._imports import ImportItem from libcst.helpers import get_absolute_module_for_import @@ -63,7 +64,7 @@ class AddImportsVisitor(ContextAwareTransformer): @staticmethod def _get_imports_from_context( context: CodemodContext, - ) -> List[Tuple[str, Optional[str], Optional[str]]]: + ) -> List[ImportItem]: imports = context.scratch.get(AddImportsVisitor.CONTEXT_KEY, []) if not isinstance(imports, list): raise Exception("Logic error!") @@ -75,6 +76,7 @@ def add_needed_import( module: str, obj: Optional[str] = None, asname: Optional[str] = None, + relative: int = 0, ) -> None: """ Schedule an import to be added in a future invocation of this class by @@ -96,64 +98,73 @@ def add_needed_import( if module == "__future__" and obj is None: raise Exception("Cannot import __future__ directly!") imports = AddImportsVisitor._get_imports_from_context(context) - imports.append((module, obj, asname)) + imports.append(ImportItem(module, obj, asname, relative)) context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports def __init__( self, context: CodemodContext, - imports: Sequence[Tuple[str, Optional[str], Optional[str]]] = (), + imports: Sequence[ImportItem] = (), ) -> None: # Allow for instantiation from either a context (used when multiple transforms # get chained) or from a direct instantiation. super().__init__(context) - imps: List[Tuple[str, Optional[str], Optional[str]]] = [ + imps: List[ImportItem] = [ *AddImportsVisitor._get_imports_from_context(context), *imports, ] # Verify that the imports are valid - for module, obj, alias in imps: - if module == "__future__" and obj is None: + for imp in imps: + if imp.module == "__future__" and imp.obj_name is None: raise Exception("Cannot import __future__ directly!") - if module == "__future__" and alias is not None: + if imp.module == "__future__" and imp.alias is not None: raise Exception("Cannot import __future__ objects with aliases!") + # Resolve relative imports if we have a module name + imps = [imp.resolve_relative(self.context.full_module_name) for imp in imps] + # List of modules we need to ensure are imported self.module_imports: Set[str] = { - module for (module, obj, alias) in imps if obj is None and alias is None + imp.module for imp in imps if imp.obj_name is None and imp.alias is None } # List of modules we need to check for object imports on from_imports: Set[str] = { - module for (module, obj, alias) in imps if obj is not None and alias is None + imp.module for imp in imps if imp.obj_name is not None and imp.alias is None } # Mapping of modules we're adding to the object they should import self.module_mapping: Dict[str, Set[str]] = { module: { - o for (m, o, n) in imps if m == module and o is not None and n is None + imp.obj_name + for imp in imps + if imp.module == module + and imp.obj_name is not None + and imp.alias is None } for module in sorted(from_imports) } # List of aliased modules we need to ensure are imported self.module_aliases: Dict[str, str] = { - module: alias - for (module, obj, alias) in imps - if obj is None and alias is not None + imp.module: imp.alias + for imp in imps + if imp.obj_name is None and imp.alias is not None } # List of modules we need to check for object imports on from_imports_aliases: Set[str] = { - module - for (module, obj, alias) in imps - if obj is not None and alias is not None + imp.module + for imp in imps + if imp.obj_name is not None and imp.alias is not None } # Mapping of modules we're adding to the object with alias they should import self.alias_mapping: Dict[str, List[Tuple[str, str]]] = { module: [ - (o, n) - for (m, o, n) in imps - if m == module and o is not None and n is not None + (imp.obj_name, imp.alias) + for imp in imps + if imp.module == module + and imp.obj_name is not None + and imp.alias is not None ] for module in sorted(from_imports_aliases) } diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index d29b6c9f4..8a4fccfe6 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -12,6 +12,7 @@ from libcst.codemod._visitor import ContextAwareTransformer from libcst.codemod.visitors._add_imports import AddImportsVisitor from libcst.codemod.visitors._gather_imports import GatherImportsVisitor +from libcst.codemod.visitors._imports import ImportItem from libcst.helpers import get_full_name_for_node from libcst.metadata import PositionProvider, QualifiedNameProvider @@ -416,7 +417,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: tree_with_imports = AddImportsVisitor( context=self.context, imports=( - [("__future__", "annotations", None)] + [ImportItem("__future__", "annotations", None)] if self.use_future_annotations else () ), diff --git a/libcst/codemod/visitors/_imports.py b/libcst/codemod/visitors/_imports.py new file mode 100644 index 000000000..5a703112e --- /dev/null +++ b/libcst/codemod/visitors/_imports.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, replace +from typing import Optional + +from libcst.helpers import get_absolute_module + + +@dataclass(frozen=True) +class ImportItem: + """Representation of individual import items for codemods.""" + + module_name: str + obj_name: Optional[str] = None + alias: Optional[str] = None + relative: int = 0 + + def __post_init__(self) -> None: + if self.module_name is None: + object.__setattr__(self, "module_name", "") + elif self.module_name.startswith("."): + mod = self.module_name.lstrip(".") + rel = self.relative + len(self.module_name) - len(mod) + object.__setattr__(self, "module_name", mod) + object.__setattr__(self, "relative", rel) + + @property + def module(self) -> str: + return "." * self.relative + self.module_name + + def resolve_relative(self, base_module: Optional[str]) -> "ImportItem": + """Return an ImportItem with an absolute module name if possible.""" + mod = self + # `import ..a` -> `from .. import a` + if mod.relative and mod.obj_name is None: + mod = replace(mod, module_name="", obj_name=mod.module_name) + if base_module is None: + return mod + m = get_absolute_module(base_module, mod.module_name or None, self.relative) + return mod if m is None else replace(mod, module_name=m, relative=0) diff --git a/libcst/codemod/visitors/tests/test_add_imports.py b/libcst/codemod/visitors/tests/test_add_imports.py index 4e410a14b..6a88b3358 100644 --- a/libcst/codemod/visitors/tests/test_add_imports.py +++ b/libcst/codemod/visitors/tests/test_add_imports.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # from libcst.codemod import CodemodContext, CodemodTest -from libcst.codemod.visitors import AddImportsVisitor +from libcst.codemod.visitors import AddImportsVisitor, ImportItem class TestAddImportsCodemod(CodemodTest): @@ -55,7 +55,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", None, None)]) + self.assertCodemod(before, after, [ImportItem("a.b.c", None, None)]) def test_dont_add_module_simple(self) -> None: """ @@ -81,7 +81,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", None, None)]) + self.assertCodemod(before, after, [ImportItem("a.b.c", None, None)]) def test_add_module_alias_simple(self) -> None: """ @@ -105,7 +105,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", None, "d")]) + self.assertCodemod(before, after, [ImportItem("a.b.c", None, "d")]) def test_dont_add_module_alias_simple(self) -> None: """ @@ -131,7 +131,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", None, "d")]) + self.assertCodemod(before, after, [ImportItem("a.b.c", None, "d")]) def test_add_module_complex(self) -> None: """ @@ -167,11 +167,11 @@ def bar() -> int: before, after, [ - ("a.b.c", None, None), - ("defg.hi", None, None), - ("argparse", None, None), - ("jkl", None, "h"), - ("i.j", None, "k"), + ImportItem("a.b.c", None, None), + ImportItem("defg.hi", None, None), + ImportItem("argparse", None, None), + ImportItem("jkl", None, "h"), + ImportItem("i.j", None, "k"), ], ) @@ -197,7 +197,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", "D", None)]) + self.assertCodemod(before, after, [ImportItem("a.b.c", "D", None)]) def test_add_object_alias_simple(self) -> None: """ @@ -221,7 +221,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", "D", "E")]) + self.assertCodemod(before, after, [ImportItem("a.b.c", "D", "E")]) def test_add_future(self) -> None: """ @@ -250,7 +250,9 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("__future__", "dummy_feature", None)]) + self.assertCodemod( + before, after, [ImportItem("__future__", "dummy_feature", None)] + ) def test_dont_add_object_simple(self) -> None: """ @@ -276,7 +278,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", "D", None)]) + self.assertCodemod(before, after, [ImportItem("a.b.c", "D", None)]) def test_dont_add_object_alias_simple(self) -> None: """ @@ -302,7 +304,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", "D", "E")]) + self.assertCodemod(before, after, [ImportItem("a.b.c", "D", "E")]) def test_add_object_modify_simple(self) -> None: """ @@ -328,7 +330,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", "D", None)]) + self.assertCodemod(before, after, [ImportItem("a.b.c", "D", None)]) def test_add_object_alias_modify_simple(self) -> None: """ @@ -354,7 +356,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", "D", "_")]) + self.assertCodemod(before, after, [ImportItem("a.b.c", "D", "_")]) def test_add_object_modify_complex(self) -> None: """ @@ -387,17 +389,17 @@ def bar() -> int: before, after, [ - ("a.b.c", "D", None), - ("a.b.c", "F", None), - ("a.b.c", "G", "H"), - ("d.e.f", "Foo", None), - ("g.h.i", "Z", None), - ("g.h.i", "X", None), - ("d.e.f", "Bar", None), - ("d.e.f", "Baz", "Qux"), - ("g.h.i", "Y", None), - ("g.h.i", "V", "W"), - ("a.b.c", "F", None), + ImportItem("a.b.c", "D", None), + ImportItem("a.b.c", "F", None), + ImportItem("a.b.c", "G", "H"), + ImportItem("d.e.f", "Foo", None), + ImportItem("g.h.i", "Z", None), + ImportItem("g.h.i", "X", None), + ImportItem("d.e.f", "Bar", None), + ImportItem("d.e.f", "Baz", "Qux"), + ImportItem("g.h.i", "Y", None), + ImportItem("g.h.i", "V", "W"), + ImportItem("a.b.c", "F", None), ], ) @@ -440,18 +442,18 @@ def bar() -> int: before, after, [ - ("a.b.c", "D", None), - ("a.b.c", "F", None), - ("d.e.f", "Foo", None), - ("sys", None, None), - ("g.h.i", "Z", None), - ("g.h.i", "X", None), - ("d.e.f", "Bar", None), - ("g.h.i", "Y", None), - ("foo", None, None), - ("a.b.c", "F", None), - ("bar", None, "baz"), - ("qux", None, "quux"), + ImportItem("a.b.c", "D", None), + ImportItem("a.b.c", "F", None), + ImportItem("d.e.f", "Foo", None), + ImportItem("sys", None, None), + ImportItem("g.h.i", "Z", None), + ImportItem("g.h.i", "X", None), + ImportItem("d.e.f", "Bar", None), + ImportItem("g.h.i", "Y", None), + ImportItem("foo", None, None), + ImportItem("a.b.c", "F", None), + ImportItem("bar", None, "baz"), + ImportItem("qux", None, "quux"), ], ) @@ -481,7 +483,7 @@ def bar() -> int: return 5 """ - self.assertCodemod(before, after, [("a.b.c", "D", None)]) + self.assertCodemod(before, after, [ImportItem("a.b.c", "D", None)]) def test_add_import_preserve_doctring_multiples(self) -> None: """ @@ -511,7 +513,9 @@ def bar() -> int: """ self.assertCodemod( - before, after, [("a.b.c", "D", None), ("argparse", None, None)] + before, + after, + [ImportItem("a.b.c", "D", None), ImportItem("argparse", None, None)], ) def test_strict_module_no_imports(self) -> None: @@ -532,7 +536,7 @@ class Foo: pass """ - self.assertCodemod(before, after, [("argparse", None, None)]) + self.assertCodemod(before, after, [ImportItem("argparse", None, None)]) def test_strict_module_with_imports(self) -> None: """ @@ -556,7 +560,7 @@ class Foo: pass """ - self.assertCodemod(before, after, [("argparse", None, None)]) + self.assertCodemod(before, after, [ImportItem("argparse", None, None)]) def test_dont_add_relative_object_simple(self) -> None: """ @@ -585,7 +589,7 @@ def bar() -> int: self.assertCodemod( before, after, - [("a.b.c", "D", None)], + [ImportItem("a.b.c", "D", None)], context_override=CodemodContext(full_module_name="a.b.foobar"), ) @@ -616,7 +620,7 @@ def bar() -> int: self.assertCodemod( before, after, - [("a.b.c", "D", None)], + [ImportItem("a.b.c", "D", None)], context_override=CodemodContext(full_module_name="a.b.foobar"), ) @@ -634,7 +638,220 @@ def test_import_order(self) -> None: self.assertCodemod( before, after, - [("a", "f", None), ("a", "g", "y"), ("a", "c", None), ("a", "d", "x")], + [ + ImportItem("a", "f", None), + ImportItem("a", "g", "y"), + ImportItem("a", "c", None), + ImportItem("a", "d", "x"), + ], + context_override=CodemodContext(full_module_name="a.b.foobar"), + ) + + def test_add_explicit_relative(self) -> None: + """ + Should add a relative import from .. . + """ + + before = """ + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + after = """ + from .. import a + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + + self.assertCodemod( + before, + after, + [ImportItem("a", None, None, 2)], + ) + + def test_add_explicit_relative_alias(self) -> None: + """ + Should add a relative import from .. . + """ + + before = """ + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + after = """ + from .. import a as foo + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + + self.assertCodemod( + before, + after, + [ImportItem("a", None, "foo", 2)], + ) + + def test_add_explicit_relative_object_simple(self) -> None: + """ + Should add a relative import. + """ + + before = """ + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + after = """ + from ..a import B + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + + self.assertCodemod( + before, + after, + [ImportItem("a", "B", None, 2)], + ) + + def test_dont_add_explicit_relative_object_simple(self) -> None: + """ + Should not add object as an import since it exists. + """ + + before = """ + from ..c import D + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + after = """ + from ..c import D + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + + self.assertCodemod( + before, + after, + [ImportItem("c", "D", None, 2)], + context_override=CodemodContext(full_module_name="a.b.foobar"), + ) + + def test_add_object_explicit_relative_modify_simple(self) -> None: + """ + Should modify existing import to add new object. + """ + + before = """ + from ..c import E, F + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + after = """ + from ..c import D, E, F + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + + self.assertCodemod( + before, + after, + [ImportItem("c", "D", None, 2)], + context_override=CodemodContext(full_module_name="a.b.foobar"), + ) + + def test_add_object_resolve_explicit_relative_modify_simple(self) -> None: + """ + Should merge a relative new module with an absolute existing one. + """ + + before = """ + from ..c import E, F + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + after = """ + from ..c import D, E, F + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + + self.assertCodemod( + before, + after, + [ImportItem("c", "D", None, 2)], + context_override=CodemodContext(full_module_name="a.b.foobar"), + ) + + def test_add_object_resolve_dotted_relative_modify_simple(self) -> None: + """ + Should merge a relative new module with an absolute existing one. + """ + + before = """ + from ..c import E, F + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + after = """ + from ..c import D, E, F + + def foo() -> None: + pass + + def bar() -> int: + return 5 + """ + + self.assertCodemod( + before, + after, + [ImportItem("..c", "D", None)], context_override=CodemodContext(full_module_name="a.b.foobar"), ) @@ -655,6 +872,6 @@ def test_import_in_docstring_module(self) -> None: self.assertCodemod( before, after, - [("__future__", "annotations", None)], + [ImportItem("__future__", "annotations", None)], context_override=CodemodContext(full_module_name="a.b.foobar"), ) diff --git a/libcst/helpers/__init__.py b/libcst/helpers/__init__.py index 3e23a6d90..ccd12c728 100644 --- a/libcst/helpers/__init__.py +++ b/libcst/helpers/__init__.py @@ -5,6 +5,7 @@ # from libcst.helpers._statement import ( + get_absolute_module, get_absolute_module_for_import, get_absolute_module_for_import_or_raise, ) @@ -21,6 +22,7 @@ from libcst.helpers.module import insert_header_comments __all__ = [ + "get_absolute_module", "get_absolute_module_for_import", "get_absolute_module_for_import_or_raise", "get_full_name_for_node", diff --git a/libcst/helpers/_statement.py b/libcst/helpers/_statement.py index 0d21e2252..f62a5eb87 100644 --- a/libcst/helpers/_statement.py +++ b/libcst/helpers/_statement.py @@ -9,14 +9,9 @@ from libcst.helpers.expression import get_full_name_for_node -def get_absolute_module_for_import( - current_module: Optional[str], import_node: cst.ImportFrom +def get_absolute_module( + current_module: Optional[str], module_name: Optional[str], num_dots: int ) -> Optional[str]: - # First, let's try to grab the module name, regardless of relative status. - module = import_node.module - module_name = get_full_name_for_node(module) if module is not None else None - # Now, get the relative import location if it exists. - num_dots = len(import_node.relative) if num_dots == 0: # This is an absolute import, so the module is correct. return module_name @@ -43,6 +38,17 @@ def get_absolute_module_for_import( return base_module if len(base_module) > 0 else None +def get_absolute_module_for_import( + current_module: Optional[str], import_node: cst.ImportFrom +) -> Optional[str]: + # First, let's try to grab the module name, regardless of relative status. + module = import_node.module + module_name = get_full_name_for_node(module) if module is not None else None + # Now, get the relative import location if it exists. + num_dots = len(import_node.relative) + return get_absolute_module(current_module, module_name, num_dots) + + def get_absolute_module_for_import_or_raise( current_module: Optional[str], import_node: cst.ImportFrom ) -> str: