Skip to content

Commit

Permalink
[refactor] break ONNX utils/onnx.py into multiple files (#353)
Browse files Browse the repository at this point in the history
* [refactor] break ONNX utils/onnx.py into multiple files

breaking out `onnx.py` into separate util files for analysis and
external data saving as the file is getting quite long and the
functionalities are separate.  Keeping under same `sparsezoo.utils.onnx`
namespace so all existing imports should be in tact

**test_plan:**
existing unit tests

* fix utils/onnx init file
  • Loading branch information
bfineran committed Aug 29, 2023
1 parent 1f94aa9 commit dc603b1
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 164 deletions.
19 changes: 19 additions & 0 deletions src/sparsezoo/utils/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
# isort: skip_file

from .analysis import *
from .external_data import *
165 changes: 1 addition & 164 deletions src/sparsezoo/utils/onnx.py → src/sparsezoo/utils/onnx/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import Any, Dict, Optional, Tuple, Union

import numpy
import onnx
from onnx import ModelProto, NodeProto, numpy_helper
from onnx import NodeProto, numpy_helper
from onnx.helper import get_attribute_value

from sparsezoo.utils import ONNXGraph
from sparsezoo.utils.helpers import clean_path


_LOGGER = logging.getLogger(__name__)

__all__ = [
"onnx_includes_external_data",
"save_onnx",
"validate_onnx",
"load_model",
"get_layer_and_op_counts",
"get_node_four_block_sparsity",
"get_node_num_four_block_zeros_and_size",
Expand All @@ -49,145 +39,8 @@
"group_four_block",
"extract_node_id",
"get_node_attributes",
"EXTERNAL_ONNX_DATA_NAME",
]

EXTERNAL_ONNX_DATA_NAME = "model.data"


def onnx_includes_external_data(model: ModelProto) -> bool:
"""
Check whether the ModelProto in memory includes the external
data or not.
If the model.onnx does not contain the external data, then the
initializers of the model are pointing to the external data file
(they are not empty)
:param model: the ModelProto to check
:return True if the model was loaded with external data, False otherwise.
"""

initializers = model.graph.initializer

is_data_saved_to_disk = any(
initializer.external_data for initializer in initializers
)
is_data_included_in_model = not is_data_saved_to_disk

return is_data_included_in_model


def save_onnx(
model: ModelProto,
model_path: str,
external_data_file: Optional[str] = None,
) -> bool:
"""
Save model to the given path.
If the model's size is larger than the maximum protobuf size:
- it will be saved with external data
If the model's size is smaller than the maximum protobuf size:
- and the user nevertheless specifies 'external_data_file',
the model will be saved with external data.
:param model: The model to save.
:param model_path: The path to save the model to.
:param external_data_file: The optional name save the external data to. Must be
relative to the directory `model` is saved to. If the model is too
large to be saved as a single protobuf, and this argument is None,
the external data file will be coerced to take the default name
specified in the variable EXTERNAL_ONNX_DATA_NAME
:return True if the model was saved with external data, False otherwise.
"""
if external_data_file is not None:
_LOGGER.debug(f"Saving with external data: {external_data_file}")
_check_for_old_external_data(
model_path=model_path, external_data_file=external_data_file
)
onnx.save(
model,
model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_data_file,
)
return True

if model.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF:
external_data_file = external_data_file or EXTERNAL_ONNX_DATA_NAME
_LOGGER.warning(
"The ONNX model is too large to be saved as a single protobuf. "
f"Saving with external data: {external_data_file}"
)
_check_for_old_external_data(
model_path=model_path, external_data_file=external_data_file
)
onnx.save(
model,
model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_data_file,
)
return True

onnx.save(model, model_path)
return False


def validate_onnx(model: Union[str, ModelProto]):
"""
Validate that a file at a given path is a valid ONNX model.
Raises a ValueError if not a valid ONNX model.
:param model: the model proto or path to the model
ONNX file to check for validation
"""
try:
onnx_model = load_model(model)
if onnx_model.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF:
if isinstance(model, str):
onnx.checker.check_model(model)
else:
_LOGGER.warning(
"Attempting to validate an in-memory ONNX model with "
f"size > {onnx.checker.MAXIMUM_PROTOBUF} bytes."
"`validate_onnx` skipped, as large ONNX models cannot "
"be validated in-memory. To validate this model, save "
"it to disk and call `validate_onnx` on the file path."
)
return
onnx.checker.check_model(onnx_model)
except Exception as err:
if not onnx_includes_external_data(model):
_LOGGER.warning(
"Attempting to validate an in-memory ONNX model "
"that has been loaded without external data. "
"This is currently not supported by the ONNX checker. "
"The validation will be skipped."
)
return
raise ValueError(f"Invalid onnx model: {err}")


def load_model(model: Union[str, ModelProto]) -> ModelProto:
"""
Load an ONNX model from an onnx model file path. If a ModelProto
is given, then it is returned.
:param model: the model proto or path to the model ONNX file to check for loading
:return: the loaded ONNX ModelProto
"""
if isinstance(model, ModelProto):
return model

if isinstance(model, str):
return onnx.load(clean_path(model))

raise ValueError(f"unknown type given for model: {type(model)}")


def get_node_attributes(node: NodeProto) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -585,19 +438,3 @@ def _get_node_input(
return node.input[index]
else:
return default


def _check_for_old_external_data(model_path: str, external_data_file: str):
old_external_data_file = os.path.join(
os.path.dirname(model_path), external_data_file
)
if os.path.exists(old_external_data_file):
_LOGGER.warning(
f"Attempting to save external data for a model: {model_path} "
f"to a directory:{os.path.dirname(model_path)} "
f"that already contains external data file: {external_data_file}. "
"The external data file will be overwritten."
)
os.remove(old_external_data_file)

return
Loading

0 comments on commit dc603b1

Please sign in to comment.