Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support exporting > 2Gb transformer models #1514

Merged
merged 14 commits into from
May 11, 2023
18 changes: 15 additions & 3 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@
__all__ = ["export_transformer_to_onnx", "load_task_model"]

MODEL_ONNX_NAME = "model.onnx"
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
DEPLOYMENT_FILES: List[str] = [
EXTERNAL_ONNX_DATA_NAME = "model.data"
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
MANDATORY_DEPLOYMENT_FILES: List[str] = [
MODEL_ONNX_NAME,
"tokenizer.json",
"tokenizer_config.json",
"config.json",
]
OPTIONAL_DEPLOYMENT_FILES: List[str] = [EXTERNAL_ONNX_DATA_NAME, "tokenizer.json"]

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -403,7 +404,9 @@ def create_deployment_folder(

if deployment_files is None:
# set deployment files to default values
deployment_files = copy.deepcopy(DEPLOYMENT_FILES)
deployment_files = copy.deepcopy(
MANDATORY_DEPLOYMENT_FILES + OPTIONAL_DEPLOYMENT_FILES
)
if onnx_file_name != MODEL_ONNX_NAME:
# replace the default onnx model name with the custom one
deployment_files[deployment_files.index(MODEL_ONNX_NAME)] = onnx_file_name
Expand All @@ -418,6 +421,12 @@ def create_deployment_folder(
expected_file_path = os.path.join(training_directory, file_name)
deployment_file_path = os.path.join(deployment_folder_dir, file_name)
if not os.path.exists(expected_file_path):
if file_name in OPTIONAL_DEPLOYMENT_FILES:
_LOGGER.warning(
f"Optional file {file_name} not found in {training_directory}. "
f"Skipping copying to deployment folder."
)
continue
raise ValueError(
f"Attempting to copy {file_name} file from {expected_file_path},"
f"but the file does not exits. Make sure that {training_directory} "
Expand All @@ -426,6 +435,9 @@ def create_deployment_folder(
if file_name == MODEL_ONNX_NAME:
# moving onnx file from training to deployment directory
shutil.move(expected_file_path, deployment_file_path)
elif file_name == EXTERNAL_ONNX_DATA_NAME:
# moving external onnx tensors from training to deployment directory
shutil.move(expected_file_path, deployment_file_path)
else:
# copying remaining `deployment_files` from training to deployment directory
shutil.copyfile(expected_file_path, deployment_file_path)
Expand Down