Skip to content

Commit

Permalink
ONNX export for custom architectures & models with custom modeling co…
Browse files Browse the repository at this point in the history
…de (#1166)

* support custom architectures

* fix doc

* add test

* fix doc

* fix typo

* fix tests

* fix style

* fix tests

* fix test
  • Loading branch information
fxmarty authored Jul 6, 2023
1 parent d5b4636 commit 94afbdf
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 94 deletions.
108 changes: 106 additions & 2 deletions docs/source/exporters/onnx/usage_guides/export_a_model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,19 @@ You can then pass one of these tasks to the `--task` argument in the `optimum-cl

## Custom export of Transformers models

### Customize the export of official Transformers models

Optimum allows for advanced users a finer-grained control over the configuration for the ONNX export. This is especially useful if you would like to export models with different keyword arguments, for example using `output_attentions=True` or `output_hidden_states=True`.

To support these use cases, [~exporters.__main__.main_export] supports two arguments: `model_kwargs` and `custom_onnx_configs`, which are used in the following fashion:
To support these use cases, [`~exporters.main_export`] supports two arguments: `model_kwargs` and `custom_onnx_configs`, which are used in the following fashion:

* `model_kwargs` allows to override some of the default arguments to the models `forward`, in practice as `model(**reference_model_inputs, **model_kwargs)`.
* `custom_onnx_configs` should be a `Dict[str, OnnxConfig]`, mapping from the submodel name (usually `model`, `encoder_model`, `decoder_model`, or `decoder_model_with_past` - [reference](https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/constants.py)) to a custom ONNX configuration for the given submodel.

A complete example is given below, allowing to export models with `output_attentions=True`.

```python
from optimum.exporters.onnx.__main__ import main_export
from optimum.exporters.onnx import main_export
from optimum.exporters.onnx.model_configs import WhisperOnnxConfig
from transformers import AutoConfig

Expand Down Expand Up @@ -317,3 +319,105 @@ main_export(
custom_onnx_configs=custom_onnx_configs
)
```

### Customize the export of Transformers models with custom modeling

Optimum supports the export of Transformers models with custom modeling that use [`trust_remote_code=True`](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoModel.from_pretrained.trust_remote_code), not officially supported in the Transormers library but usable with its functionality as [pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) and [generation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.generate).

Examples of such models are [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) and [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b).

To export custom models, a dictionary `custom_onnx_configs` needs to be passed to [`~optimum.exporters.onnx.main_export`], with the ONNX config definition for all the subparts of the model to export (for example, encoder and decoder subparts). The example below allows to export `mosaicml/mpt-7b` model:

```python
from optimum.exporters.onnx import main_export

from transformers import AutoConfig

from optimum.exporters.onnx.config import TextDecoderOnnxConfig
from optimum.utils import NormalizedTextConfig, DummyPastKeyValuesGenerator
from typing import Dict


class MPTDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
"""
MPT swaps the two last dimensions for the key cache compared to usual transformers
decoder models, thus the redefinition here.
"""
def generate(self, input_name: str, framework: str = "pt"):
past_key_shape = (
self.batch_size,
self.num_attention_heads,
self.hidden_size // self.num_attention_heads,
self.sequence_length,
)
past_value_shape = (
self.batch_size,
self.num_attention_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(past_key_shape, framework=framework),
self.random_float_tensor(past_value_shape, framework=framework),
)
for _ in range(self.num_layers)
]

class CustomMPTOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (MPTDummyPastKeyValuesGenerator,) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = MPTDummyPastKeyValuesGenerator

DEFAULT_ONNX_OPSET = 14 # aten::tril operator requires opset>=14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
hidden_size="d_model",
num_layers="n_layers",
num_attention_heads="n_heads"
)

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Adapted from https://github.com/huggingface/optimum/blob/v1.9.0/optimum/exporters/onnx/base.py#L625
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 3: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name}


model_id = "/home/fxmarty/hf_internship/optimum/tiny-mpt-random-remote-code"
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)

onnx_config = CustomMPTOnnxConfig(
config=config,
task="text-generation",
use_past_in_inputs=False,
use_present_in_outputs=True,
)
onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True)

custom_onnx_configs = {
"decoder_model": onnx_config,
"decoder_with_past_model": onnx_config_with_past,
}

main_export(
model_id,
output="mpt_onnx",
task="text-generation-with-past",
trust_remote_code=True,
custom_onnx_configs=custom_onnx_configs,
no_post_process=True,
)
```

Moreover, the advanced argument `fn_get_submodels` to `main_export` allows to customize how the submodels are extracted in case the model needs to be exported in several submodels. Examples of such functions can be [consulted here](link to utils.py relevant code once merged).
Loading

0 comments on commit 94afbdf

Please sign in to comment.