Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx granite #2043

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Onnx granite #2043

wants to merge 6 commits into from

Conversation

gabe-l-hart
Copy link

@gabe-l-hart gabe-l-hart commented Oct 8, 2024

What does this PR do?

This PR adds support for models using IBM's GraniteForCausalLM architecture when converting to ONNX. The key changes are:

  • Allow users to opt into using transformers>=4.45 for onnx conversions No longer needed
  • Add "granite" to model configs and tasks
  • Add "granite" as a model_type that uses grouped attention

NOTE: I encountered an issue very similar to the one discussed in #1835. The root cause for me was the need to add "granite" to the list of models requiring Grouped Query Attention in modeling_decoder.py. I don't believe that is the root cause for #1835 since "llama" is already present there, but it is likely a similar issue showing up in the inference module using num_attention_heads instead of num_key_value_heads.

Rationale

This PR specifically addresses the "GraniteForCausalLM" architecture for IBM's forthcoming Granite family of models. The current ibm/PowerLM-3b model use this architecture and can be used as a placeholder for testing until the new models are released. The one exception is that the PowerLM model has num_attention_heads and num_key_value_heads set to match (no Grouped Query Attention) whereas the new models will use that (thus the need for the change to ensure GQA is used for "granite" at inference time).

Testing

When testing locally, I had the following dependency versions:

onnx==1.16.2
onnxruntime==1.19.2
torch==2.4.1
torchvision==0.19.1
transformers==4.45.2

To test the conversion, I did the following:

optimum-cli export onnx \
  --model $HOME/models/powerlm-3b \
  $HOME/models/powerlm-3b-onnx \
  --task text-generation-with-past

To evaluate the output side-by-side with the source model, I used the following script:

side_by_side.py
"""
Simple function to run and time pre and post optimized models
"""

# Standard
from datetime import timedelta
import argparse
import os
import time

# Third Party
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

def maybe_to_device(inputs: dict[str, torch.Tensor], device: str | None):
    """Send inputs to the device if desired"""
    if device:
        for k, v in inputs.items():
            inputs[k] = v.to(device)

def run_and_time(
    label: str,
    model_path: str,
    model_class: ORTModelForCausalLM | AutoModelForCausalLM,
    prompt: str,
    device: str | None,
    **kwargs,
):
    start_time = time.time()
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = model_class.from_pretrained(model_path, device_map=device)
    load_end_time = time.time()
    inputs = tokenizer(prompt, return_tensors="pt")

    #DEBUG
    breakpoint()

    tok_end_time = time.time()
    maybe_to_device(inputs, device)
    outputs = model.generate(**inputs, **kwargs)
    gen_end_time = time.time()
    res = tokenizer.decode(outputs[0])
    end_time = time.time()
    print(f"------ {label} ------")
    print(res)
    print(f"Total Time: {timedelta(seconds=end_time-start_time)}")
    print(f"Load Time: {timedelta(seconds=load_end_time-start_time)}")
    print(f"Generate Time: {timedelta(seconds=gen_end_time-tok_end_time)}")
    print()

# Defaults
home = os.getenv("HOME")
assert home, "Need $HOME!"
orig_model_path = f"{home}/models/PowerLM-3b"
onnx_model_path = f"{home}/models/PowerLM-3b-onnx-O4"
prompt = "Write a code to find the maximum value in a list of numbers."
device = "cuda" if torch.cuda.is_available() else None

def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--prompt", "-p", default=prompt)
    parser.add_argument("--raw-model", "-r", type=str, default=None)
    parser.add_argument("--onnx-model", "-o", type=str, default=None)
    parser.add_argument("--device", "-d", type=str, default=device)
    args = parser.parse_args()
    if args.raw_model:
        run_and_time("Transformers", args.raw_model, AutoModelForCausalLM, args.prompt, args.device, max_new_tokens=100)
    if args.onnx_model:
        run_and_time("ONNX", args.onnx_model, ORTModelForCausalLM, args.prompt, args.device, max_new_tokens=100)


if __name__ == "__main__":
    main()

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • (N/A) Did you make sure to update the documentation with your changes?
    • This is a model addition and there is not model-specific documentation that I can find
  • (N/A) Did you write any new necessary tests?
    • This is a model addition and there are not model-specific tests that I can find

Who can review?

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
…ttention

Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Branch: OnnxGranite

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
@IlyasMoutawwakil
Copy link
Member

can you add this architecture to the tests suite using a small tiny random model, and push it to the hub (you can create it with GraniteForCausalLM.from_config())

@xenova
Copy link
Contributor

xenova commented Oct 11, 2024

I've created a tiny-random model: https://huggingface.co/hf-internal-testing/tiny-random-GraniteForCausalLM.

Unfortunately, ONNX export currently fails

optimum-cli export onnx -m hf-internal-testing/tiny-random-GraniteForCausalLM /tmp/

with the following error:

See log
Framework not specified. Using pt to export the model.
Automatic task detection to text-generation-with-past (possible synonyms are: causal-lm-with-past).
Using the export variant default. Available variants are:
    - default: The default ONNX variant.

***** Exporting submodel 1/1: GraniteForCausalLM *****
Using framework PyTorch: 2.4.1+cu121
Overriding 1 configuration item(s)
	- use_cache -> True
We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
/usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py:447: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  or len(self.key_cache[layer_idx]) == 0  # the layer has no cache
/usr/local/lib/python3.10/dist-packages/transformers/models/granite/modeling_granite.py:982: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if sequence_length != 1:
/usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py:432: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  elif len(self.key_cache[layer_idx]) == 0:  # fills previously skipped layers; checking for tensor causes errors
Traceback (most recent call last):
  File "/usr/local/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/usr/local/lib/python3.10/dist-packages/optimum/commands/export/onnx.py", line 265, in run
    main_export(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/__main__.py", line 374, in main_export
    onnx_export_from_model(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 1188, in onnx_export_from_model
    _, onnx_outputs = export_models(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 782, in export_models
    export(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 887, in export
    export_output = export_pytorch(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 583, in export_pytorch
    onnx_export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 551, in export
    _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1648, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1174, in _model_to_graph
    graph = _optimize_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 714, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1997, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_helper.py", line 292, in wrapper
    return fn(g, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py", line 177, in scaled_dot_product_attention
    query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 93, in op
    return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 244, in _add_op
    inputs = [_const_if_tensor(graph_context, arg) for arg in args]
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 244, in <listcomp>
    inputs = [_const_if_tensor(graph_context, arg) for arg in args]
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 276, in _const_if_tensor
    return _add_op(graph_context, "onnx::Constant", value_z=arg)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 252, in _add_op
    node = _create_node(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 312, in _create_node
    _add_attribute(node, key, value, aten=aten)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py", line 363, in _add_attribute
    return getattr(node, f"{kind}_")(name, value)
TypeError: z_(): incompatible function arguments. The following argument types are supported:
    1. (self: torch._C.Node, arg0: str, arg1: torch.Tensor) -> torch._C.Node

Invoked with: %427 : Tensor = onnx::Constant(), scope: transformers.models.granite.modeling_granite.GraniteForCausalLM::/transformers.models.granite.modeling_granite.GraniteModel::model/transformers.models.granite.modeling_granite.GraniteDecoderLayer::layers.0/transformers.models.granite.modeling_granite.GraniteSdpaAttention::self_attn
, 'value', 1.0 
(Occurred when translating scaled_dot_product_attention).

Investigating why 🤔

@xenova
Copy link
Contributor

xenova commented Oct 11, 2024

Turns it it is a known (and fixed) bug: pytorch/pytorch#135615

Upgrading to torch nightly fixes it 👍

@IlyasMoutawwakil
Copy link
Member

Upgrading to torch nightly fixes it 👍

Oooh great ! it's the error I've been seeing when exporting clip with sdpa

@gabe-l-hart
Copy link
Author

Thanks for all the quick review! I realized I missed a lot of context on the issue itself, so will update with details of how I tested this locally.

@xenova
Copy link
Contributor

xenova commented Oct 12, 2024

and just to confirm, I tested the ONNX model with Transformers.js, and it matches the python version exactly 👍
This PR should be good to merge once we add the tiny random test (https://huggingface.co/hf-internal-testing/tiny-random-GraniteForCausalLM). Not exactly sure how to handle the minimum pytorch version though (cc @IlyasMoutawwakil)

@echarlaix
Copy link
Collaborator

and just to confirm, I tested the ONNX model with Transformers.js, and it matches the python version exactly 👍

Perfect, thanks a lot @xenova

This PR should be good to merge once we add the tiny random test (https://huggingface.co/hf-internal-testing/tiny-random-GraniteForCausalLM). Not exactly sure how to handle the minimum pytorch version though (cc @IlyasMoutawwakil)

For the minimum pytorch version, it can be set with MIN_TORCH_VERSION

MIN_TORCH_VERSION = version.parse("1.11")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants