Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make HF conversion automatically add missing imports #1241

Merged
merged 7 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
except Exception as e:
raise e

import logging

from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
Expand All @@ -62,31 +64,23 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.mpt.configuration_mpt import MPTConfig
from llmfoundry.models.utils.act_ckpt import (
build_act_ckpt_mod_to_blocks,
check_mapping_blocks_overlap,
pass_on_block_idx,
)
from llmfoundry.models.utils.config_moe_args import config_moe_args
from llmfoundry.models.utils.mpt_param_count import (
mpt_get_active_params,
mpt_get_total_params,
)

# NOTE: All utils are imported directly even if unused so that
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# Import the fcs and param_init_fns here so that the recursive code creating the files for hf checkpoints can find them
# These are the exceptions because fc.py and param_init_fns.py are not imported in any other place in the import tree
# isort: off
from llmfoundry.models.utils.meta_init_context import \
init_empty_weights # type: ignore (see note)
from llmfoundry.models.utils.param_init_fns import (
generic_param_init_fn_, # type: ignore (see note)
)
from llmfoundry.models.layers.ffn import resolve_ffn_act_fn # type: ignore (see note)
from llmfoundry.models.layers.fc import fcs # type: ignore (see note)

from llmfoundry.models.utils.act_ckpt import (
pass_on_block_idx,
build_act_ckpt_mod_to_blocks,
check_mapping_blocks_overlap,
)

import logging
from llmfoundry.models.layers.fc import fcs # type: ignore
from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ # type: ignore
# isort: on

log = logging.getLogger(__name__)

Expand Down
114 changes: 113 additions & 1 deletion llmfoundry/utils/huggingface_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import ast
import importlib
import json
import os
from typing import Optional, Sequence

Expand Down Expand Up @@ -139,6 +140,80 @@ def process_file(
return new_files_to_process


def get_all_relative_imports(file_path: str) -> set[str]:
"""Get all relative imports from a file.

Args:
file_path (str): The file to process.

Returns:
set[str]: The relative imports.
"""
with open(file_path, 'r', encoding='utf-8') as f:
source = f.read()

tree = ast.parse(source)
relative_imports = set()
for node in ast.walk(tree):
if isinstance(
node,
ast.ImportFrom,
) and node.module is not None and node.level == 1:
relative_imports.add(node.module)

return relative_imports


def add_relative_imports(
file_path: str,
relative_imports_to_add: set[str],
) -> None:
"""Add relative imports to a file.

Args:
file_path (str): The file to add to.
relative_imports_to_add (set[str]): The set of relative imports to add
"""
# Get the directory name where all the files are located
dir_name = os.path.dirname(file_path)

with open(file_path, 'r', encoding='utf-8') as f:
source = f.read()

tree = ast.parse(source)

for relative_import in relative_imports_to_add:
import_path = os.path.join(dir_name, relative_import + '.py')
# Open up the file we are adding an import to find something to import from it
with open(import_path, 'r', encoding='utf-8') as f:
import_source = f.read()
import_tree = ast.parse(import_source)
item_to_import = None
for node in ast.walk(import_tree):
# Find the first function or class
if isinstance(node,
ast.FunctionDef) or isinstance(node, ast.ClassDef):
# Get its name to import it
item_to_import = node.name
break

if item_to_import is None:
item_to_import = '*'

# This will look like `from .relative_import import item_to_import`
import_node = ast.ImportFrom(
module=relative_import,
names=[ast.alias(name=item_to_import, asname=None)],
level=1,
)

# Insert near the top of the file, but past the from __future__ import
tree.body.insert(2, import_node)

with open(file_path, 'w', encoding='utf-8') as f:
f.write(ast.unparse(tree))


def edit_files_for_hf_compatibility(
folder: str,
flatten_imports_prefix: Sequence[str] = ('llmfoundry',),
Expand All @@ -158,9 +233,27 @@ def edit_files_for_hf_compatibility(
remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening.
Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics', 'llmfoundry.utils.builders').
"""
listed_dir = os.listdir(folder)

# Try to acquire the config file to determine which python file is the entrypoint file
config_file_exists = 'config.json' in listed_dir
with open(os.path.join(folder, 'config.json'), 'r') as _f:
config = json.load(_f)

# If the config file exists, the entrypoint files would be specified in the auto map
entrypoint_files = set()
if config_file_exists:
for key, value in config.get('auto_map', {}).items():
# Only keep the modeling entrypoints, e.g. AutoModelForCausalLM
if 'model' not in key.lower():
continue
split_path = value.split('.')
if len(split_path) > 1:
entrypoint_files.add(split_path[0] + '.py')

files_to_process = [
os.path.join(folder, filename)
for filename in os.listdir(folder)
for filename in listed_dir
if filename.endswith('.py')
]
files_processed_and_queued = set(files_to_process)
Expand All @@ -178,3 +271,22 @@ def edit_files_for_hf_compatibility(
if file not in files_processed_and_queued:
files_to_process.append(file)
files_processed_and_queued.add(file)

# For each entrypoint, determine which imports are missing, and add them
# This is because HF does not recursively search imports when determining
# which files to copy into its modules cache
all_relative_imports = {
os.path.splitext(os.path.basename(f))[0]
for f in files_processed_and_queued
}
for entrypoint in entrypoint_files:
existing_relative_imports = get_all_relative_imports(
os.path.join(folder, entrypoint),
)
# Add in self so we don't create a circular import
existing_relative_imports.add(os.path.splitext(entrypoint)[0])
missing_relative_imports = all_relative_imports - existing_relative_imports
add_relative_imports(
os.path.join(folder, entrypoint),
missing_relative_imports,
)
Loading