Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment w/ the improved torch onnx exporter #1940

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

if is_torch_available():
import torch
import torch_onnx
torch_onnx.patch_torch(error_report=True, profile=True, dump_exported_program=True)

from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def remap(value):
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamix_axes,
# dynamic_axes=dynamix_axes,
do_constant_folding=do_constant_folding,
opset_version=opset,
)
Expand Down
45 changes: 31 additions & 14 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,18 @@ def __init__(

allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past

# Workaround https://github.com/pytorch/pytorch/issues/122649.
@torch._dynamo.assume_constant_result
def _config_outputs():
return config.outputs

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

outputs = self.orig_forward(*args, **kwargs)

config_outputs = _config_outputs()
# This code block handles different cases of the filterd_outputs input to align it with the expected
# format of outputs. It is common for the output type of a model to vary, such as tensor, list,
# tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
Expand All @@ -159,25 +164,25 @@ def patched_forward(*args, **kwargs):
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
onnx_output_name in config_outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
or any(key.startswith(onnx_output_name) for key in config_outputs.keys())
):
filterd_outputs[name] = value
elif isinstance(outputs, (list, tuple)):
outputs_list = list(config.outputs.keys())
outputs_list = list(config_outputs.keys())
dict(zip(outputs_list, outputs))
else:
if len(config.outputs) > 1:
num_outputs = len(config.outputs)
outputs_str = ", ".join(config.outputs.keys())
if len(config_outputs) > 1:
num_outputs = len(config_outputs)
outputs_str = ", ".join(config_outputs.keys())
raise ValueError(
f"config.outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}"
f"config_outputs should have only one outputs, but it has {num_outputs} keys: {outputs_str}"
)
else:
name = list(config.outputs.keys())[0]
name = list(config_outputs.keys())[0]
filterd_outputs[name] = outputs
name = list(config.outputs.keys())[0]
name = list(config_outputs.keys())[0]
filterd_outputs[name] = outputs
return filterd_outputs

Expand Down Expand Up @@ -223,21 +228,27 @@ def __init__(
if model.config.model_type == "pix2struct" and allow_past_in_outputs:
model.config.text_config.use_cache = True

# Workaround https://github.com/pytorch/pytorch/issues/122649.
@torch._dynamo.assume_constant_result
def _config_outputs():
return config.outputs

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

outputs = self.orig_forward(*args, **kwargs)
config_outputs = _config_outputs()

# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filterd_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
onnx_output_name in config_outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
or any(key.startswith(onnx_output_name) for key in config_outputs.keys())
):
if name != "past_key_values":
if self.real_config._behavior == "decoder" and name == "encoder_last_hidden_state":
Expand Down Expand Up @@ -473,6 +484,11 @@ def __init__(

allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past

# Workaround https://github.com/pytorch/pytorch/issues/122649.
@torch._dynamo.assume_constant_result
def _config_outputs():
return config.outputs

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
model_kwargs = self.model_kwargs
Expand All @@ -484,14 +500,15 @@ def patched_forward(*args, **kwargs):
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=model_kwargs)

outputs = self.orig_forward(*args, **kwargs)
config_outputs = _config_outputs()

filterd_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
onnx_output_name in config_outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
or any(key.startswith(onnx_output_name) for key in config_outputs.keys())
):
filterd_outputs[name] = value
return filterd_outputs
Expand Down