Skip to content

Commit

Permalink
Support relative imports in AddImportsVisitor. (#585)
Browse files Browse the repository at this point in the history
* Support relative imports in AddImportsVisitor.

* Adds an Import dataclass to represent a single imported object
* Refactors AddImportsVisitor to pass around Import objects
* Separates out the main logic in get_absolute_module_for_import so that
  it can be used to resolve relative module names outside of a cst.Import
  node
* Resolves relative module names in AddImportsVisitor if we have a
  current module name set.

Fixes #578
  • Loading branch information
martindemello authored Jan 8, 2022
1 parent 1337022 commit 8652974
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 76 deletions.
2 changes: 2 additions & 0 deletions libcst/codemod/visitors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
GatherNamesFromStringAnnotationsVisitor,
)
from libcst.codemod.visitors._gather_unused_imports import GatherUnusedImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.codemod.visitors._remove_imports import RemoveImportsVisitor

__all__ = [
Expand All @@ -22,5 +23,6 @@
"GatherImportsVisitor",
"GatherNamesFromStringAnnotationsVisitor",
"GatherUnusedImportsVisitor",
"ImportItem",
"RemoveImportsVisitor",
]
51 changes: 31 additions & 20 deletions libcst/codemod/visitors/_add_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

from collections import defaultdict
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

Expand All @@ -11,6 +11,7 @@
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_for_import


Expand Down Expand Up @@ -63,7 +64,7 @@ class AddImportsVisitor(ContextAwareTransformer):
@staticmethod
def _get_imports_from_context(
context: CodemodContext,
) -> List[Tuple[str, Optional[str], Optional[str]]]:
) -> List[ImportItem]:
imports = context.scratch.get(AddImportsVisitor.CONTEXT_KEY, [])
if not isinstance(imports, list):
raise Exception("Logic error!")
Expand All @@ -75,6 +76,7 @@ def add_needed_import(
module: str,
obj: Optional[str] = None,
asname: Optional[str] = None,
relative: int = 0,
) -> None:
"""
Schedule an import to be added in a future invocation of this class by
Expand All @@ -96,64 +98,73 @@ def add_needed_import(
if module == "__future__" and obj is None:
raise Exception("Cannot import __future__ directly!")
imports = AddImportsVisitor._get_imports_from_context(context)
imports.append((module, obj, asname))
imports.append(ImportItem(module, obj, asname, relative))
context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports

def __init__(
self,
context: CodemodContext,
imports: Sequence[Tuple[str, Optional[str], Optional[str]]] = (),
imports: Sequence[ImportItem] = (),
) -> None:
# Allow for instantiation from either a context (used when multiple transforms
# get chained) or from a direct instantiation.
super().__init__(context)
imps: List[Tuple[str, Optional[str], Optional[str]]] = [
imps: List[ImportItem] = [
*AddImportsVisitor._get_imports_from_context(context),
*imports,
]

# Verify that the imports are valid
for module, obj, alias in imps:
if module == "__future__" and obj is None:
for imp in imps:
if imp.module == "__future__" and imp.obj_name is None:
raise Exception("Cannot import __future__ directly!")
if module == "__future__" and alias is not None:
if imp.module == "__future__" and imp.alias is not None:
raise Exception("Cannot import __future__ objects with aliases!")

# Resolve relative imports if we have a module name
imps = [imp.resolve_relative(self.context.full_module_name) for imp in imps]

# List of modules we need to ensure are imported
self.module_imports: Set[str] = {
module for (module, obj, alias) in imps if obj is None and alias is None
imp.module for imp in imps if imp.obj_name is None and imp.alias is None
}

# List of modules we need to check for object imports on
from_imports: Set[str] = {
module for (module, obj, alias) in imps if obj is not None and alias is None
imp.module for imp in imps if imp.obj_name is not None and imp.alias is None
}
# Mapping of modules we're adding to the object they should import
self.module_mapping: Dict[str, Set[str]] = {
module: {
o for (m, o, n) in imps if m == module and o is not None and n is None
imp.obj_name
for imp in imps
if imp.module == module
and imp.obj_name is not None
and imp.alias is None
}
for module in sorted(from_imports)
}

# List of aliased modules we need to ensure are imported
self.module_aliases: Dict[str, str] = {
module: alias
for (module, obj, alias) in imps
if obj is None and alias is not None
imp.module: imp.alias
for imp in imps
if imp.obj_name is None and imp.alias is not None
}
# List of modules we need to check for object imports on
from_imports_aliases: Set[str] = {
module
for (module, obj, alias) in imps
if obj is not None and alias is not None
imp.module
for imp in imps
if imp.obj_name is not None and imp.alias is not None
}
# Mapping of modules we're adding to the object with alias they should import
self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {
module: [
(o, n)
for (m, o, n) in imps
if m == module and o is not None and n is not None
(imp.obj_name, imp.alias)
for imp in imps
if imp.module == module
and imp.obj_name is not None
and imp.alias is not None
]
for module in sorted(from_imports_aliases)
}
Expand Down
3 changes: 2 additions & 1 deletion libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._add_imports import AddImportsVisitor
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_full_name_for_node
from libcst.metadata import PositionProvider, QualifiedNameProvider

Expand Down Expand Up @@ -416,7 +417,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
tree_with_imports = AddImportsVisitor(
context=self.context,
imports=(
[("__future__", "annotations", None)]
[ImportItem("__future__", "annotations", None)]
if self.use_future_annotations
else ()
),
Expand Down
43 changes: 43 additions & 0 deletions libcst/codemod/visitors/_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, replace
from typing import Optional

from libcst.helpers import get_absolute_module


@dataclass(frozen=True)
class ImportItem:
"""Representation of individual import items for codemods."""

module_name: str
obj_name: Optional[str] = None
alias: Optional[str] = None
relative: int = 0

def __post_init__(self) -> None:
if self.module_name is None:
object.__setattr__(self, "module_name", "")
elif self.module_name.startswith("."):
mod = self.module_name.lstrip(".")
rel = self.relative + len(self.module_name) - len(mod)
object.__setattr__(self, "module_name", mod)
object.__setattr__(self, "relative", rel)

@property
def module(self) -> str:
return "." * self.relative + self.module_name

def resolve_relative(self, base_module: Optional[str]) -> "ImportItem":
"""Return an ImportItem with an absolute module name if possible."""
mod = self
# `import ..a` -> `from .. import a`
if mod.relative and mod.obj_name is None:
mod = replace(mod, module_name="", obj_name=mod.module_name)
if base_module is None:
return mod
m = get_absolute_module(base_module, mod.module_name or None, self.relative)
return mod if m is None else replace(mod, module_name=m, relative=0)
Loading

0 comments on commit 8652974

Please sign in to comment.