Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support exporting > 2Gb transformer models #1514

Merged
merged 14 commits into from
May 11, 2023
5 changes: 3 additions & 2 deletions src/sparseml/exporters/onnx_to_deepsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from sparseml.exporters import transforms as sparseml_transforms
from sparseml.exporters.base_exporter import BaseExporter
from sparsezoo import save_onnx


class ONNXToDeepsparse(BaseExporter):
Expand Down Expand Up @@ -109,7 +110,7 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto:

def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):
if self.export_input_model or os.getenv("SAVE_PREQAT_ONNX", False):
onnx.save(pre_transforms_model, file_path.replace(".onnx", ".preqat.onnx"))
save_onnx(pre_transforms_model, file_path.replace(".onnx", ".preqat.onnx"))

post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model)
onnx.save(post_transforms_model, file_path)
save_onnx(post_transforms_model, file_path)
9 changes: 5 additions & 4 deletions src/sparseml/exporters/transforms/onnx_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from sparseml.exporters.transforms import BaseTransform
from sparseml.exporters.transforms.utils import MatchResult
from sparseml.onnx.utils import ONNXGraph, check_load_model, validate_onnx_file
from sparseml.onnx.utils import ONNXGraph
from sparsezoo.utils import load_model, validate_onnx


__all__ = ["OnnxTransform"]
Expand Down Expand Up @@ -80,8 +81,8 @@ def pre_validate(self, model: Union[ModelProto, str]) -> ModelProto:
f"Invalid model type: {type(model)}. "
"Must be a string (path to the .onnx file) or ONNX ModelProto"
)
model = check_load_model(model)
validate_onnx_file(model)
model = load_model(model)
validate_onnx(model)
self._nodes_to_delete.clear()
self._nodes_to_add.clear()
self._num_matches = 0
Expand All @@ -102,5 +103,5 @@ def post_validate(self, model: ModelProto) -> ModelProto:
graph = ONNXGraph(model)
graph.delete_unused_initializers()
graph.sort_nodes_topologically()
validate_onnx_file(model)
validate_onnx(model)
return model
4 changes: 2 additions & 2 deletions src/sparseml/onnx/optim/analyzer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from sparseml.onnx.utils import (
NodeShape,
calculate_flops,
check_load_model,
extract_node_id,
extract_node_shapes,
get_kernel_shape,
Expand All @@ -38,6 +37,7 @@
is_prunable_node,
)
from sparseml.utils import clean_path, create_parent_dirs
from sparsezoo.utils import load_model


__all__ = ["NodeAnalyzer", "ModelAnalyzer"]
Expand Down Expand Up @@ -358,7 +358,7 @@ def __init__(
raise ValueError("model or nodes must be None, both cannot be passed")

if model is not None:
model = check_load_model(model)
model = load_model(model)
node_shapes = extract_node_shapes(model)
self._nodes = [
NodeAnalyzer(
Expand Down
9 changes: 4 additions & 5 deletions src/sparseml/onnx/optim/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import onnx

from sparseml.onnx.utils import ORTModelRunner, fold_conv_bns, get_node_output_nodes
from sparsezoo.utils import save_onnx, validate_onnx


__all__ = ["CalibrationSession"]
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(
suffix=".onnx", delete=True
)
self._augmented_model_path = self._augmented_model_tmp_file.name
onnx.save(self._model_augmented, self._augmented_model_path)
save_onnx(self._model_augmented, self._augmented_model_path)

self._sessions = {} # batch_size -> session
self._quantization_thresholds = {} # Dict[node.name, Tuple(min_val, max_val)]
Expand Down Expand Up @@ -101,13 +102,11 @@ def _optimize_model(self) -> Union[str, None]:
if model_optimized is None:
# no optimization performed, skip the rest of this block
raise Exception()
onnx.checker.check_model(
model_optimized
) # should raise exception if broken
validate_onnx(model_optimized) # should raise exception if broken
optimized_model_path = tempfile.NamedTemporaryFile(
suffix=".onnx", delete=False
)
onnx.save(model_optimized, optimized_model_path.name)
save_onnx(model_optimized, optimized_model_path.name)
self._model = model_optimized
print("Optimization successful")
return optimized_model_path.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sparseml.onnx.optim.quantization.calibration import CalibrationSession
from sparseml.onnx.optim.quantization.quantize import QuantizationMode, quantize
from sparseml.onnx.utils import DataLoader, quantize_resnet_identity_add_inputs
from sparsezoo.utils import save_onnx


__all__ = ["quantize_model_post_training"]
Expand Down Expand Up @@ -105,4 +106,4 @@ def quantize_model_post_training(
if output_model_path is None:
return calibrated_quantized_model
else:
onnx.save(calibrated_quantized_model, output_model_path)
save_onnx(calibrated_quantized_model, output_model_path)
6 changes: 3 additions & 3 deletions src/sparseml/onnx/optim/sensitivity_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
DeepSparseAnalyzeModelRunner,
DeepSparseModelRunner,
ORTModelRunner,
check_load_model,
extract_node_id,
get_node_params,
get_prunable_nodes,
Expand All @@ -46,6 +45,7 @@
default_pruning_sparsities_perf,
)
from sparseml.utils import flatten_iterable
from sparsezoo.utils import load_model


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -142,7 +142,7 @@ def pruning_loss_sens_magnitude_iter(
:return: the analysis results for the model with an additional layer at each
iteration along with a float representing the iteration progress
"""
model = check_load_model(model)
model = load_model(model)
prunable = get_prunable_nodes(model)
analysis = PruningLossSensitivityAnalysis()
num_layers = len(prunable)
Expand Down Expand Up @@ -251,7 +251,7 @@ def pruning_loss_sens_one_shot_iter(
:return: the sensitivity results for every node that is prunable,
yields update at each layer along with iteration progress
"""
model = check_load_model(model)
model = load_model(model)
prunable_nodes = get_prunable_nodes(model)
analysis = PruningLossSensitivityAnalysis()
num_updates = len(prunable_nodes) * len(sparsity_levels) + 1
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/onnx/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
from onnx import ModelProto

from sparseml.onnx.utils.helpers import (
check_load_model,
extract_shape,
get_numpy_dtype,
model_inputs,
model_outputs,
)
from sparseml.utils import NumpyArrayBatcher, load_labeled_data
from sparsezoo.utils import load_model


__all__ = ["DataLoader"]
Expand Down Expand Up @@ -171,7 +171,7 @@ def from_model_random(
and outputs, typically the batch dimension
:return: the created DataLoader instance with the random data
"""
model = check_load_model(model)
model = load_model(model)
inputs = model_inputs(model)
outputs = model_outputs(model)
data_shapes = OrderedDict(
Expand Down
54 changes: 6 additions & 48 deletions src/sparseml/onnx/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@
from onnx.helper import get_attribute_value, make_empty_tensor_value_info

from sparseml.onnx.base import require_onnxruntime
from sparseml.utils import clean_path
from sparsezoo.utils import load_model, save_onnx


_LOGGER = logging.getLogger(__name__)

__all__ = [
"validate_onnx_file",
"check_load_model",
"extract_node_id",
"get_node_by_id",
"get_nodes_by_input_id",
Expand Down Expand Up @@ -78,46 +76,6 @@
]


def validate_onnx_file(path: str):
"""
Validate that a file at a given path is a valid ONNX model

:param path: the path of the file to validate
:raise ValueError: if not a valid ONNX model
"""
try:
onnx_model = check_load_model(path)

if onnx_model.ByteSize() < onnx.checker.MAXIMUM_PROTOBUF:
onnx.checker.check_model(onnx_model)
else:
_LOGGER.warning(
"onnx check_model skipped as model exceeds maximum protobuf size of 2GB"
)

if not onnx_model.opset_import:
raise ValueError("could not parse opset_import")
except Exception as err:
raise ValueError(f"Invalid onnx model: {err}")


def check_load_model(model: Union[str, ModelProto]) -> ModelProto:
"""
Load an ONNX model from a given file path if supplied.
If already a model proto, then returns.

:param model: the model proto or path to the model ONNX file to check for loading
:return: the loaded ONNX ModelProto
"""
if isinstance(model, ModelProto):
return model

if isinstance(model, str):
return onnx.load(clean_path(model))

raise ValueError(f"unknown type given for model: {type(model)}")


def extract_node_id(node: NodeProto) -> str:
"""
Get the node id for a given node from an ONNX model.
Expand Down Expand Up @@ -915,7 +873,7 @@ def get_prunable_nodes(model: Union[str, ModelProto]) -> List[Any]:
:param model: the model proto loaded from the ONNX file
:return: a list of nodes from the model proto
"""
model = check_load_model(model)
model = load_model(model)
prunable_nodes = []

for node in model.graph.node:
Expand Down Expand Up @@ -951,7 +909,7 @@ def onnx_nodes_sparsities(
:return: a tuple containing the overall sparsity measurement for the model,
each conv or gemm node found in the model
"""
model = check_load_model(model)
model = load_model(model)
node_inp_sparsities = OrderedDict() # type: Dict[str, SparsityMeasurement]
params_count = 0
params_zero_count = 0
Expand Down Expand Up @@ -991,7 +949,7 @@ def model_inputs(model: Union[str, ModelProto]) -> List:
to get the model inputs for
:return: the input to the model
"""
model = check_load_model(model)
model = load_model(model)
inputs_all = [node.name for node in model.graph.input]
inputs_init = [node.name for node in model.graph.initializer]
input_names = list(set(inputs_all) - set(inputs_init))
Expand All @@ -1009,7 +967,7 @@ def model_outputs(model: Union[str, ModelProto]) -> List:
to get the model outputs for
:return: the output from the model
"""
model = check_load_model(model)
model = load_model(model)
outputs = [node for node in model.graph.output]

return outputs
Expand Down Expand Up @@ -1272,4 +1230,4 @@ def override_model_input_shape(model: Union[str, onnx.ModelProto], shape: List[i
set_tensor_dim_shape(model.graph.input[0], dim, dim_size)

if model_path:
onnx.save(model, model_path)
save_onnx(model, model_path)
6 changes: 3 additions & 3 deletions src/sparseml/onnx/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
from sparseml.onnx.utils.data import DataLoader
from sparseml.onnx.utils.graph_editor import override_model_batch_size
from sparseml.onnx.utils.helpers import (
check_load_model,
extract_node_id,
get_node_by_id,
get_prunable_node_from_foldable,
is_foldable_node,
)
from sparsezoo import File, Model
from sparsezoo.utils import load_model


try:
Expand Down Expand Up @@ -464,7 +464,7 @@ def __init__(
import onnxruntime # import protected by @require_onnxruntime()

super().__init__(loss)
self._model = check_load_model(model)
self._model = load_model(model)

if batch_size is not None:
override_model_batch_size(self._model, batch_size)
Expand Down Expand Up @@ -712,7 +712,7 @@ def correct_nm_analyze_model_node_ids(nm_result: Dict, model: Union[str, ModelPr
:param model: the onnx model proto or path to the onnx file that the
nm_result was for
"""
model = check_load_model(model)
model = load_model(model)

for layer in nm_result["layer_info"]:
node_id = (
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/openpifpaf/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import os
from typing import Optional

import onnx
import torch

import openpifpaf
from sparseml.pytorch.optim.manager import ScheduledModifierManager
from sparseml.pytorch.utils import ModuleExporter
from sparsezoo.utils import validate_onnx


LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -115,7 +115,7 @@ def export(
input_names=["input_batch"],
output_names=[meta.name for meta in datamodule.head_metas],
)
onnx.checker.check_model(os.path.join(save_dir, name))
validate_onnx(os.path.join(save_dir, name))
exporter.create_deployment_folder()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
swap_node_output,
update_model_param,
)
from sparsezoo.utils import save_onnx


__all__ = [
Expand Down Expand Up @@ -1593,7 +1594,7 @@ def quantize_torch_qat_export(
graph.delete_unused_initializers()

if output_file_path:
onnx.save(model, output_file_path)
save_onnx(model, output_file_path)

return model

Expand Down Expand Up @@ -1718,7 +1719,7 @@ def skip_onnx_input_quantize(
raise RuntimeError(optim_error_message)

if output_file_path:
onnx.save(model, output_file_path)
save_onnx(model, output_file_path)


def _propagate_mobilebert_embedding_quantization(model: ModelProto):
Expand Down
3 changes: 2 additions & 1 deletion src/sparseml/pytorch/torch_to_onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import tensors_module_forward, tensors_to_device
from sparseml.pytorch.utils.model import is_parallel_model
from sparsezoo.utils import save_onnx


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,7 +102,7 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto:

def export(self, pre_transforms_model: torch.nn.Module, file_path: str):
post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model)
onnx.save(post_transforms_model, file_path)
save_onnx(post_transforms_model, file_path)


class _TorchOnnxExport(BaseTransform):
Expand Down
Loading