Skip to content

Commit

Permalink
Support exporting > 2Gb transformer models (#1514)
Browse files Browse the repository at this point in the history
* initial commit

* initial commit

* Delete helpers.py

* cleanup

* fix an error in the logic

* focus on opt when it comes to tasks

* initial commit

* Delete model.py

* cleanup

* Apply suggestions from code review
  • Loading branch information
dbogunowicz committed May 11, 2023
1 parent 5a8a333 commit b96a89a
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@
__all__ = ["export_transformer_to_onnx", "load_task_model"]

MODEL_ONNX_NAME = "model.onnx"
DEPLOYMENT_FILES: List[str] = [
EXTERNAL_ONNX_DATA_NAME = "model.data"
MANDATORY_DEPLOYMENT_FILES: List[str] = [
MODEL_ONNX_NAME,
"tokenizer.json",
"tokenizer_config.json",
"config.json",
]
OPTIONAL_DEPLOYMENT_FILES: List[str] = [EXTERNAL_ONNX_DATA_NAME, "tokenizer.json"]

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -403,7 +404,9 @@ def create_deployment_folder(

if deployment_files is None:
# set deployment files to default values
deployment_files = copy.deepcopy(DEPLOYMENT_FILES)
deployment_files = copy.deepcopy(
MANDATORY_DEPLOYMENT_FILES + OPTIONAL_DEPLOYMENT_FILES
)
if onnx_file_name != MODEL_ONNX_NAME:
# replace the default onnx model name with the custom one
deployment_files[deployment_files.index(MODEL_ONNX_NAME)] = onnx_file_name
Expand All @@ -418,6 +421,12 @@ def create_deployment_folder(
expected_file_path = os.path.join(training_directory, file_name)
deployment_file_path = os.path.join(deployment_folder_dir, file_name)
if not os.path.exists(expected_file_path):
if file_name in OPTIONAL_DEPLOYMENT_FILES:
_LOGGER.warning(
f"Optional file {file_name} not found in {training_directory}. "
f"Skipping copying to deployment folder."
)
continue
raise ValueError(
f"Attempting to copy {file_name} file from {expected_file_path},"
f"but the file does not exits. Make sure that {training_directory} "
Expand All @@ -426,6 +435,9 @@ def create_deployment_folder(
if file_name == MODEL_ONNX_NAME:
# moving onnx file from training to deployment directory
shutil.move(expected_file_path, deployment_file_path)
elif file_name == EXTERNAL_ONNX_DATA_NAME:
# moving external onnx tensors from training to deployment directory
shutil.move(expected_file_path, deployment_file_path)
else:
# copying remaining `deployment_files` from training to deployment directory
shutil.copyfile(expected_file_path, deployment_file_path)
Expand Down

0 comments on commit b96a89a

Please sign in to comment.