Skip to content

Commit

Permalink
[Mistral] Add Flash Attention-2 support for mistral (#26464)
Browse files Browse the repository at this point in the history
* add FA-2 support for mistral

* fixup

* add sliding windows

* fixing few nits

* v1 slicing cache - logits do not match

* add comment

* fix bugs

* more mem efficient

* add warning once

* add warning once

* oops

* fixup

* more comments

* copy

* add safety checker

* fixup

* Update src/transformers/models/mistral/modeling_mistral.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* copied from

* up

* raise when padding side is right

* fixup

* add doc + few minor changes

* fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
younesbelkada and ArthurZucker committed Oct 3, 2023
1 parent 1a2e966 commit ae9a344
Show file tree
Hide file tree
Showing 5 changed files with 435 additions and 5 deletions.
45 changes: 45 additions & 0 deletions docs/source/en/model_doc/mistral.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,51 @@ tokenizer = LlamaTokenizer.from_pretrained("/output/path")
model = MistralForCausalLM.from_pretrained("/output/path")
```

## Combining Mistral and Flash Attention 2

First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.

```bash
pip install -U flash-attn --no-build-isolation
```

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)

To load and run a model using Flash Attention 2, refer to the snippet below:

```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

>>> prompt = "My favourite condiment is"

>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)

>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"The expected outupt"
```

### Expected speedups

Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mistral-7B-v0.1` checkpoint and the Flash Attention 2 version of the model.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/mistral-7b-inference-large-seqlen.png">
</div>

### Sliding window Attention

The current implementation supports the sliding window attention mechanism and memory efficient cache management.
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).

The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.

## The Mistral Team

Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, L茅lio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timoth茅e Lacroix, William El Sayed.
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Make sure to follow the installation guide on the repository mentioned above to
We natively support Flash Attention 2 for the following models:

- Llama
- Mistral
- Falcon

You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
Expand Down
Loading

0 comments on commit ae9a344

Please sign in to comment.