Skip to content

Commit

Permalink
Audit the onnx pathways to make them robust against >2Gb models (#1540)
Browse files Browse the repository at this point in the history
* initial commit

* fix the bad rebase
  • Loading branch information
dbogunowicz committed May 11, 2023
1 parent 2f86f08 commit 5a8a333
Show file tree
Hide file tree
Showing 43 changed files with 154 additions and 173 deletions.
5 changes: 3 additions & 2 deletions src/sparseml/exporters/onnx_to_deepsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from sparseml.exporters import transforms as sparseml_transforms
from sparseml.exporters.base_exporter import BaseExporter
from sparsezoo import save_onnx


class ONNXToDeepsparse(BaseExporter):
Expand Down Expand Up @@ -109,7 +110,7 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto:

def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):
if self.export_input_model or os.getenv("SAVE_PREQAT_ONNX", False):
onnx.save(pre_transforms_model, file_path.replace(".onnx", ".preqat.onnx"))
save_onnx(pre_transforms_model, file_path.replace(".onnx", ".preqat.onnx"))

post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model)
onnx.save(post_transforms_model, file_path)
save_onnx(post_transforms_model, file_path)
9 changes: 5 additions & 4 deletions src/sparseml/exporters/transforms/onnx_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from sparseml.exporters.transforms import BaseTransform
from sparseml.exporters.transforms.utils import MatchResult
from sparseml.onnx.utils import ONNXGraph, check_load_model, validate_onnx_file
from sparseml.onnx.utils import ONNXGraph
from sparsezoo.utils import load_model, validate_onnx


__all__ = ["OnnxTransform"]
Expand Down Expand Up @@ -80,8 +81,8 @@ def pre_validate(self, model: Union[ModelProto, str]) -> ModelProto:
f"Invalid model type: {type(model)}. "
"Must be a string (path to the .onnx file) or ONNX ModelProto"
)
model = check_load_model(model)
validate_onnx_file(model)
model = load_model(model)
validate_onnx(model)
self._nodes_to_delete.clear()
self._nodes_to_add.clear()
self._num_matches = 0
Expand All @@ -102,5 +103,5 @@ def post_validate(self, model: ModelProto) -> ModelProto:
graph = ONNXGraph(model)
graph.delete_unused_initializers()
graph.sort_nodes_topologically()
validate_onnx_file(model)
validate_onnx(model)
return model
4 changes: 2 additions & 2 deletions src/sparseml/onnx/optim/analyzer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from sparseml.onnx.utils import (
NodeShape,
calculate_flops,
check_load_model,
extract_node_id,
extract_node_shapes,
get_kernel_shape,
Expand All @@ -38,6 +37,7 @@
is_prunable_node,
)
from sparseml.utils import clean_path, create_parent_dirs
from sparsezoo.utils import load_model


__all__ = ["NodeAnalyzer", "ModelAnalyzer"]
Expand Down Expand Up @@ -358,7 +358,7 @@ def __init__(
raise ValueError("model or nodes must be None, both cannot be passed")

if model is not None:
model = check_load_model(model)
model = load_model(model)
node_shapes = extract_node_shapes(model)
self._nodes = [
NodeAnalyzer(
Expand Down
9 changes: 4 additions & 5 deletions src/sparseml/onnx/optim/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import onnx

from sparseml.onnx.utils import ORTModelRunner, fold_conv_bns, get_node_output_nodes
from sparsezoo.utils import save_onnx, validate_onnx


__all__ = ["CalibrationSession"]
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(
suffix=".onnx", delete=True
)
self._augmented_model_path = self._augmented_model_tmp_file.name
onnx.save(self._model_augmented, self._augmented_model_path)
save_onnx(self._model_augmented, self._augmented_model_path)

self._sessions = {} # batch_size -> session
self._quantization_thresholds = {} # Dict[node.name, Tuple(min_val, max_val)]
Expand Down Expand Up @@ -101,13 +102,11 @@ def _optimize_model(self) -> Union[str, None]:
if model_optimized is None:
# no optimization performed, skip the rest of this block
raise Exception()
onnx.checker.check_model(
model_optimized
) # should raise exception if broken
validate_onnx(model_optimized) # should raise exception if broken
optimized_model_path = tempfile.NamedTemporaryFile(
suffix=".onnx", delete=False
)
onnx.save(model_optimized, optimized_model_path.name)
save_onnx(model_optimized, optimized_model_path.name)
self._model = model_optimized
print("Optimization successful")
return optimized_model_path.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sparseml.onnx.optim.quantization.calibration import CalibrationSession
from sparseml.onnx.optim.quantization.quantize import QuantizationMode, quantize
from sparseml.onnx.utils import DataLoader, quantize_resnet_identity_add_inputs
from sparsezoo.utils import save_onnx


__all__ = ["quantize_model_post_training"]
Expand Down Expand Up @@ -105,4 +106,4 @@ def quantize_model_post_training(
if output_model_path is None:
return calibrated_quantized_model
else:
onnx.save(calibrated_quantized_model, output_model_path)
save_onnx(calibrated_quantized_model, output_model_path)
6 changes: 3 additions & 3 deletions src/sparseml/onnx/optim/sensitivity_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
DeepSparseAnalyzeModelRunner,
DeepSparseModelRunner,
ORTModelRunner,
check_load_model,
extract_node_id,
get_node_params,
get_prunable_nodes,
Expand All @@ -46,6 +45,7 @@
default_pruning_sparsities_perf,
)
from sparseml.utils import flatten_iterable
from sparsezoo.utils import load_model


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -142,7 +142,7 @@ def pruning_loss_sens_magnitude_iter(
:return: the analysis results for the model with an additional layer at each
iteration along with a float representing the iteration progress
"""
model = check_load_model(model)
model = load_model(model)
prunable = get_prunable_nodes(model)
analysis = PruningLossSensitivityAnalysis()
num_layers = len(prunable)
Expand Down Expand Up @@ -251,7 +251,7 @@ def pruning_loss_sens_one_shot_iter(
:return: the sensitivity results for every node that is prunable,
yields update at each layer along with iteration progress
"""
model = check_load_model(model)
model = load_model(model)
prunable_nodes = get_prunable_nodes(model)
analysis = PruningLossSensitivityAnalysis()
num_updates = len(prunable_nodes) * len(sparsity_levels) + 1
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/onnx/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
from onnx import ModelProto

from sparseml.onnx.utils.helpers import (
check_load_model,
extract_shape,
get_numpy_dtype,
model_inputs,
model_outputs,
)
from sparseml.utils import NumpyArrayBatcher, load_labeled_data
from sparsezoo.utils import load_model


__all__ = ["DataLoader"]
Expand Down Expand Up @@ -171,7 +171,7 @@ def from_model_random(
and outputs, typically the batch dimension
:return: the created DataLoader instance with the random data
"""
model = check_load_model(model)
model = load_model(model)
inputs = model_inputs(model)
outputs = model_outputs(model)
data_shapes = OrderedDict(
Expand Down
54 changes: 6 additions & 48 deletions src/sparseml/onnx/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@
from onnx.helper import get_attribute_value, make_empty_tensor_value_info

from sparseml.onnx.base import require_onnxruntime
from sparseml.utils import clean_path
from sparsezoo.utils import load_model, save_onnx


_LOGGER = logging.getLogger(__name__)

__all__ = [
"validate_onnx_file",
"check_load_model",
"extract_node_id",
"get_node_by_id",
"get_nodes_by_input_id",
Expand Down Expand Up @@ -78,46 +76,6 @@
]


def validate_onnx_file(path: str):
"""
Validate that a file at a given path is a valid ONNX model
:param path: the path of the file to validate
:raise ValueError: if not a valid ONNX model
"""
try:
onnx_model = check_load_model(path)

if onnx_model.ByteSize() < onnx.checker.MAXIMUM_PROTOBUF:
onnx.checker.check_model(onnx_model)
else:
_LOGGER.warning(
"onnx check_model skipped as model exceeds maximum protobuf size of 2GB"
)

if not onnx_model.opset_import:
raise ValueError("could not parse opset_import")
except Exception as err:
raise ValueError(f"Invalid onnx model: {err}")


def check_load_model(model: Union[str, ModelProto]) -> ModelProto:
"""
Load an ONNX model from a given file path if supplied.
If already a model proto, then returns.
:param model: the model proto or path to the model ONNX file to check for loading
:return: the loaded ONNX ModelProto
"""
if isinstance(model, ModelProto):
return model

if isinstance(model, str):
return onnx.load(clean_path(model))

raise ValueError(f"unknown type given for model: {type(model)}")


def extract_node_id(node: NodeProto) -> str:
"""
Get the node id for a given node from an ONNX model.
Expand Down Expand Up @@ -915,7 +873,7 @@ def get_prunable_nodes(model: Union[str, ModelProto]) -> List[Any]:
:param model: the model proto loaded from the ONNX file
:return: a list of nodes from the model proto
"""
model = check_load_model(model)
model = load_model(model)
prunable_nodes = []

for node in model.graph.node:
Expand Down Expand Up @@ -951,7 +909,7 @@ def onnx_nodes_sparsities(
:return: a tuple containing the overall sparsity measurement for the model,
each conv or gemm node found in the model
"""
model = check_load_model(model)
model = load_model(model)
node_inp_sparsities = OrderedDict() # type: Dict[str, SparsityMeasurement]
params_count = 0
params_zero_count = 0
Expand Down Expand Up @@ -991,7 +949,7 @@ def model_inputs(model: Union[str, ModelProto]) -> List:
to get the model inputs for
:return: the input to the model
"""
model = check_load_model(model)
model = load_model(model)
inputs_all = [node.name for node in model.graph.input]
inputs_init = [node.name for node in model.graph.initializer]
input_names = list(set(inputs_all) - set(inputs_init))
Expand All @@ -1009,7 +967,7 @@ def model_outputs(model: Union[str, ModelProto]) -> List:
to get the model outputs for
:return: the output from the model
"""
model = check_load_model(model)
model = load_model(model)
outputs = [node for node in model.graph.output]

return outputs
Expand Down Expand Up @@ -1272,4 +1230,4 @@ def override_model_input_shape(model: Union[str, onnx.ModelProto], shape: List[i
set_tensor_dim_shape(model.graph.input[0], dim, dim_size)

if model_path:
onnx.save(model, model_path)
save_onnx(model, model_path)
6 changes: 3 additions & 3 deletions src/sparseml/onnx/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
from sparseml.onnx.utils.data import DataLoader
from sparseml.onnx.utils.graph_editor import override_model_batch_size
from sparseml.onnx.utils.helpers import (
check_load_model,
extract_node_id,
get_node_by_id,
get_prunable_node_from_foldable,
is_foldable_node,
)
from sparsezoo import File, Model
from sparsezoo.utils import load_model


try:
Expand Down Expand Up @@ -464,7 +464,7 @@ def __init__(
import onnxruntime # import protected by @require_onnxruntime()

super().__init__(loss)
self._model = check_load_model(model)
self._model = load_model(model)

if batch_size is not None:
override_model_batch_size(self._model, batch_size)
Expand Down Expand Up @@ -712,7 +712,7 @@ def correct_nm_analyze_model_node_ids(nm_result: Dict, model: Union[str, ModelPr
:param model: the onnx model proto or path to the onnx file that the
nm_result was for
"""
model = check_load_model(model)
model = load_model(model)

for layer in nm_result["layer_info"]:
node_id = (
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/openpifpaf/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import os
from typing import Optional

import onnx
import torch

import openpifpaf
from sparseml.pytorch.optim.manager import ScheduledModifierManager
from sparseml.pytorch.utils import ModuleExporter
from sparsezoo.utils import validate_onnx


LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -115,7 +115,7 @@ def export(
input_names=["input_batch"],
output_names=[meta.name for meta in datamodule.head_metas],
)
onnx.checker.check_model(os.path.join(save_dir, name))
validate_onnx(os.path.join(save_dir, name))
exporter.create_deployment_folder()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
swap_node_output,
update_model_param,
)
from sparsezoo.utils import save_onnx


__all__ = [
Expand Down Expand Up @@ -1593,7 +1594,7 @@ def quantize_torch_qat_export(
graph.delete_unused_initializers()

if output_file_path:
onnx.save(model, output_file_path)
save_onnx(model, output_file_path)

return model

Expand Down Expand Up @@ -1718,7 +1719,7 @@ def skip_onnx_input_quantize(
raise RuntimeError(optim_error_message)

if output_file_path:
onnx.save(model, output_file_path)
save_onnx(model, output_file_path)


def _propagate_mobilebert_embedding_quantization(model: ModelProto):
Expand Down
3 changes: 2 additions & 1 deletion src/sparseml/pytorch/torch_to_onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import tensors_module_forward, tensors_to_device
from sparseml.pytorch.utils.model import is_parallel_model
from sparsezoo.utils import save_onnx


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,7 +102,7 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto:

def export(self, pre_transforms_model: torch.nn.Module, file_path: str):
post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model)
onnx.save(post_transforms_model, file_path)
save_onnx(post_transforms_model, file_path)


class _TorchOnnxExport(BaseTransform):
Expand Down
Loading

0 comments on commit 5a8a333

Please sign in to comment.