Skip to content

Commit

Permalink
Fix FA2 integration (huggingface#28142)
Browse files Browse the repository at this point in the history
* fix fa2

* fix FA2 for popular models

* improve warning and add Younes as co-author

Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix the warning

* Add Tip

* typo fix

* nit

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
3 people authored and staghado committed Jan 15, 2024
1 parent 816ed2b commit 2810abc
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/model_doc/llama2.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ come in several checkpoints they each contain a part of each weight of the model

- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.

- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type.


## Resources

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,9 +1419,9 @@ def _check_and_enable_flash_attn_2(
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
)
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to"
" unexpected behaviour."
logger.warning(
"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. "
"No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator."
)

# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,8 @@ def forward(
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.query_key_value.weight.dtype

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def forward(
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.c_attn.weight.dtype

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@ def forward(
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.q_proj.weight.dtype

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ def forward(
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.q_proj.weight.dtype

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def forward(
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.q_proj.weight.dtype

Expand Down

0 comments on commit 2810abc

Please sign in to comment.