Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bettertransformer] Transformers 4.41.0 (torch.SDPA-Bert) breaks bettertransformers Bert, but works in Transformers 4.40.2 #1902

Closed
4 tasks
michaelfeil opened this issue Jun 9, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@michaelfeil
Copy link

michaelfeil commented Jun 9, 2024

System Info

python 3.11 
Windows / WSL2
poetry venv

Who can help?

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

Installing torch=2.3.1 and transformers=4.41.0 (or transformers=4.40.2 for fix).

from optimum.bettertransformer import BetterTransformer
from transformers import AutoModel, AutoTokenizer
import torch
model_name = "michaelfeil/bge-small-en-v1.5"

model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


model = BetterTransformer.transform(model) # COMMENT OUT TO MAKE IT WORK.

model = model.cuda()

for num_sentences in [1, 10, 100]:
    sentences = [f"This is sentence number {i * 10}" for i in range(num_sentences)]
    inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(model.device)
    outputs = model(**inputs)
    print(outputs.last_hidden_state.shape)

Output:

The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.
torch.Size([1, 7, 384])
torch.Size([10, 7, 384])
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/michael/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/michael/infinity/bert.py", line 16, in <module>
    outputs = model(**inputs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 1137, in forward
    encoder_outputs = self.encoder(
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py", line 690, in forward
    layer_outputs = layer_module(
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/michael/infinity/libs/infinity_emb/.venv/lib/python3.10/site-packages/optimum/bettertransformer/models/encoder_models.py", line 300, in forward
    attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
RuntimeError: shape '[100, 8]' is invalid for input of size 6400

Solution

Installing dependencies from lock file
Package operations: 0 installs, 1 update, 0 removals
  • Downgrading transformers (4.41.2 -> 4.40.2)

Works:

torch.Size([1, 7, 384])
torch.Size([10, 7, 384])
torch.Size([100, 8, 384])

Expected behavior

Bettertransformer is still 1.5x faster than torch.sdpa -> stuck with pinning huggingface transformers with <4.40.2 for now.

@michaelfeil michaelfeil added the bug Something isn't working label Jun 9, 2024
@michaelfeil michaelfeil changed the title [Bettertransformer] Transformers 4.41 (torch.SDPA-Bert) breaks Bettertransformer Bert, but works in 4.40.2` [Bettertransformer] Transformers ==4.41.0 (torch.SDPA-Bert) breaks bettertransformers Bert, but works in Transformers ==4.40.2 Jun 9, 2024
@michaelfeil michaelfeil changed the title [Bettertransformer] Transformers ==4.41.0 (torch.SDPA-Bert) breaks bettertransformers Bert, but works in Transformers ==4.40.2 [Bettertransformer] Transformers 4.41.0 (torch.SDPA-Bert) breaks bettertransformers Bert, but works in Transformers 4.40.2 Jun 9, 2024
@michaelfeil
Copy link
Author

@hackyon
Copy link

hackyon commented Jun 10, 2024

Yea, it looks like BetterTransformer might be expecting a different shape for the attention mask.

Can you try to use "eager" attention implementation with BetterTransformer to see if it fixes things?
model = AutoModel.from_pretrained(model_name, attn_implementation="eager")

@michaelfeil
Copy link
Author

Thanks for the fast response.

Eager works, but its a breaking change if you dont add it!
Do you think there is an idea to patch bettertransformers!

@hackyon
Copy link

hackyon commented Jun 11, 2024

Yea, unfortunately I think there might be cause to put BetterTransformer optimizations directly into Transformer, and deprecate BetterTransformer support for BERT. This means adding BERT here:

if hf_config.model_type in ["falcon", "gpt_bigcode", "llama", "whisper"]:

It might be better you for to just skip that BetterTransformer conversion.

You mentioned that BetterTransfomer is still 1.5x faster, where did you get that metric?

@michaelfeil
Copy link
Author

michaelfeil commented Jun 11, 2024

@hackyon Might be unusual, but should give a pretty good throughput idea end-to-end.

https://github.com/michaelfeil/infinity/blob/d325d9ad66ecc8732620c92f11d9119efb1f1afa/libs/infinity_emb/Makefile#L51C1-L56C30

pip install infinity_emb[all]==0.0.42

git clone github.com/michaelfeil/infinity
~/infinity/libs/infinity_emb$ make benchmark_embed 

Bettertransformer (eager -> torch._transformer_encoder_fwd)

infinity_emb v2 --model-id BAAI/bge-large-en-v1.5 -> Results: 35-37 sentences per second. (over 2 runs)
infinity_emb v2 --model-id BAAI/bge-small-en-v1.5 -> Results: 263-266 sentences per second. (over 2 runs)

SDPA and w/o Bettertransformer

infinity_emb v2 --model-id BAAI/bge-large-en-v1.5 --no-bettertransformer -> Results 32-32 sentences per second (2 runs)
infinity_emb v2 --model-id BAAI/bge-large-en-v1.5 --no-bettertransformer -> Results 188-196 sentences per second (2 runs)

Result:

Please don't remove the option to use Bettertransformers! I do rely on the patch in BetterTransformers with Bert.
But, regardless thanks for your PR in transformers - it might save the world more energy that you will consume in your personal lifetime (gpu hours excluded), no kidding.

@fxmarty
Copy link
Contributor

fxmarty commented Jun 24, 2024

Hi @michaelfeil, I believe what you are benefiting is the nested tensor support in BetterTransformer, that allows speedups in batched inferences cases due to not using padding. This is not integrated in Transformers.

Indeed, having breaking changes is not very ideal, although SDPA is now supported in Transformers, the above is not.

@michaelfeil
Copy link
Author

benefiting is the nested tensor support in BetterTransformer - yes, and support for torch._transformer_encoder_layer_fwd. Closing this issue, as I can just overwrite the attn_implementation if it's going to be replaced by bettertransformers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants