Skip to content

Commit

Permalink
Avoid appending to external data when running onnx_save (#320)
Browse files Browse the repository at this point in the history
* initial commit

* fix docstrings

* fix blunders in logic

---------

Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
dbogunowicz and mgoin committed Jun 2, 2023
1 parent 960b213 commit 4486de9
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions src/sparsezoo/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import os
from typing import Any, Dict, Optional, Tuple, Union

import numpy
Expand Down Expand Up @@ -48,8 +49,11 @@
"group_four_block",
"extract_node_id",
"get_node_attributes",
"EXTERNAL_ONNX_DATA_NAME",
]

EXTERNAL_ONNX_DATA_NAME = "model.data"


def onnx_includes_external_data(model: ModelProto) -> bool:
"""
Expand All @@ -75,7 +79,10 @@ def onnx_includes_external_data(model: ModelProto) -> bool:


def save_onnx(
model: ModelProto, model_path: str, external_data_file: Optional[str] = None
model: ModelProto,
model_path: str,
large_model_external_data_file: str = EXTERNAL_ONNX_DATA_NAME,
external_data_file: Optional[str] = None,
) -> bool:
"""
Save model to the given path.
Expand All @@ -88,11 +95,21 @@ def save_onnx(
:param model: The model to save.
:param model_path: The path to save the model to.
:param large_model_external_data_file: The default name to save the external
data to if the model is too large to be saved as a single protobuf.
If:
- the model is too large to be saved as a single protobuf, AND
- `external_data_file` is specified,
then the external data of the model will be saved to `external_data_file`
instead of `large_model_external_data_name`.
:param external_data_file: The optional name save the external data to.
:return True if the model was saved with external data, False otherwise.
"""
if external_data_file is not None:
_LOGGER.debug(f"Saving with external data: {external_data_file}")
_check_for_old_external_data(
model_path=model_path, external_data_file=external_data_file
)
onnx.save(
model,
model_path,
Expand All @@ -104,16 +121,18 @@ def save_onnx(

if model.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF:
_LOGGER.warning(
"The ONNX model is too large to be saved as a single protobuf."
"Saving with external data: 'model.data'"
"The ONNX model is too large to be saved as a single protobuf. "
f"Saving with external data: {large_model_external_data_file}"
)
_check_for_old_external_data(
model_path=model_path, external_data_file=large_model_external_data_file
)

onnx.save(
model,
model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="model.data",
location=large_model_external_data_file,
)
return True

Expand Down Expand Up @@ -566,3 +585,19 @@ def _get_node_input(
return node.input[index]
else:
return default


def _check_for_old_external_data(model_path: str, external_data_file: str):
old_external_data_file = os.path.join(
os.path.dirname(model_path), external_data_file
)
if os.path.exists(old_external_data_file):
_LOGGER.warning(
f"Attempting to save external data for a model: {model_path} "
f"to a directory:{os.path.dirname(model_path)} "
f"that already contains external data file: {external_data_file}. "
"The external data file will be overwritten."
)
os.remove(old_external_data_file)

return

0 comments on commit 4486de9

Please sign in to comment.