Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support relative imports in AddImportsVisitor. #585

Merged
merged 8 commits into from
Jan 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libcst/codemod/visitors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
GatherNamesFromStringAnnotationsVisitor,
)
from libcst.codemod.visitors._gather_unused_imports import GatherUnusedImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.codemod.visitors._remove_imports import RemoveImportsVisitor

__all__ = [
Expand All @@ -22,5 +23,6 @@
"GatherImportsVisitor",
"GatherNamesFromStringAnnotationsVisitor",
"GatherUnusedImportsVisitor",
"ImportItem",
"RemoveImportsVisitor",
]
51 changes: 31 additions & 20 deletions libcst/codemod/visitors/_add_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

from collections import defaultdict
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

Expand All @@ -11,6 +11,7 @@
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_for_import


Expand Down Expand Up @@ -63,7 +64,7 @@ class AddImportsVisitor(ContextAwareTransformer):
@staticmethod
def _get_imports_from_context(
context: CodemodContext,
) -> List[Tuple[str, Optional[str], Optional[str]]]:
) -> List[ImportItem]:
imports = context.scratch.get(AddImportsVisitor.CONTEXT_KEY, [])
if not isinstance(imports, list):
raise Exception("Logic error!")
Expand All @@ -75,6 +76,7 @@ def add_needed_import(
module: str,
obj: Optional[str] = None,
asname: Optional[str] = None,
relative: int = 0,
) -> None:
"""
Schedule an import to be added in a future invocation of this class by
Expand All @@ -96,64 +98,73 @@ def add_needed_import(
if module == "__future__" and obj is None:
raise Exception("Cannot import __future__ directly!")
imports = AddImportsVisitor._get_imports_from_context(context)
imports.append((module, obj, asname))
imports.append(ImportItem(module, obj, asname, relative))
context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports

def __init__(
self,
context: CodemodContext,
imports: Sequence[Tuple[str, Optional[str], Optional[str]]] = (),
imports: Sequence[ImportItem] = (),
) -> None:
# Allow for instantiation from either a context (used when multiple transforms
# get chained) or from a direct instantiation.
super().__init__(context)
imps: List[Tuple[str, Optional[str], Optional[str]]] = [
imps: List[ImportItem] = [
*AddImportsVisitor._get_imports_from_context(context),
*imports,
]

# Verify that the imports are valid
for module, obj, alias in imps:
if module == "__future__" and obj is None:
for imp in imps:
if imp.module == "__future__" and imp.obj_name is None:
raise Exception("Cannot import __future__ directly!")
if module == "__future__" and alias is not None:
if imp.module == "__future__" and imp.alias is not None:
raise Exception("Cannot import __future__ objects with aliases!")

# Resolve relative imports if we have a module name
imps = [imp.resolve_relative(self.context.full_module_name) for imp in imps]

# List of modules we need to ensure are imported
self.module_imports: Set[str] = {
module for (module, obj, alias) in imps if obj is None and alias is None
imp.module for imp in imps if imp.obj_name is None and imp.alias is None
}

# List of modules we need to check for object imports on
from_imports: Set[str] = {
module for (module, obj, alias) in imps if obj is not None and alias is None
imp.module for imp in imps if imp.obj_name is not None and imp.alias is None
}
# Mapping of modules we're adding to the object they should import
self.module_mapping: Dict[str, Set[str]] = {
module: {
o for (m, o, n) in imps if m == module and o is not None and n is None
imp.obj_name
for imp in imps
if imp.module == module
and imp.obj_name is not None
and imp.alias is None
}
for module in sorted(from_imports)
}

# List of aliased modules we need to ensure are imported
self.module_aliases: Dict[str, str] = {
module: alias
for (module, obj, alias) in imps
if obj is None and alias is not None
imp.module: imp.alias
for imp in imps
if imp.obj_name is None and imp.alias is not None
}
# List of modules we need to check for object imports on
from_imports_aliases: Set[str] = {
module
for (module, obj, alias) in imps
if obj is not None and alias is not None
imp.module
for imp in imps
if imp.obj_name is not None and imp.alias is not None
}
# Mapping of modules we're adding to the object with alias they should import
self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {
module: [
(o, n)
for (m, o, n) in imps
if m == module and o is not None and n is not None
(imp.obj_name, imp.alias)
for imp in imps
if imp.module == module
and imp.obj_name is not None
and imp.alias is not None
]
for module in sorted(from_imports_aliases)
}
Expand Down
3 changes: 2 additions & 1 deletion libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._add_imports import AddImportsVisitor
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_full_name_for_node
from libcst.metadata import PositionProvider, QualifiedNameProvider

Expand Down Expand Up @@ -416,7 +417,7 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
tree_with_imports = AddImportsVisitor(
context=self.context,
imports=(
[("__future__", "annotations", None)]
[ImportItem("__future__", "annotations", None)]
if self.use_future_annotations
else ()
),
Expand Down
43 changes: 43 additions & 0 deletions libcst/codemod/visitors/_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, replace
from typing import Optional

from libcst.helpers import get_absolute_module


@dataclass(frozen=True)
class ImportItem:
"""Representation of individual import items for codemods."""

module_name: str
obj_name: Optional[str] = None
alias: Optional[str] = None
relative: int = 0

def __post_init__(self) -> None:
if self.module_name is None:
object.__setattr__(self, "module_name", "")
elif self.module_name.startswith("."):
mod = self.module_name.lstrip(".")
rel = self.relative + len(self.module_name) - len(mod)
object.__setattr__(self, "module_name", mod)
object.__setattr__(self, "relative", rel)

@property
def module(self) -> str:
return "." * self.relative + self.module_name
martindemello marked this conversation as resolved.
Show resolved Hide resolved

def resolve_relative(self, base_module: Optional[str]) -> "ImportItem":
"""Return an ImportItem with an absolute module name if possible."""
mod = self
# `import ..a` -> `from .. import a`
if mod.relative and mod.obj_name is None:
mod = replace(mod, module_name="", obj_name=mod.module_name)
if base_module is None:
return mod
m = get_absolute_module(base_module, mod.module_name or None, self.relative)
return mod if m is None else replace(mod, module_name=m, relative=0)
Loading