Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPT2] Add SDPA support #31172

Merged
merged 16 commits into from
Jun 19, 2024
58 changes: 58 additions & 0 deletions docs/source/en/model_doc/gpt2.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,64 @@ Below is an expected speedup diagram that compares pure inference time between t
<img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg">
</div>


## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04) using `float16` with
[gpt2-large](https://huggingface.co/openai-community/gpt2-large), we saw the
following speedups during training and inference.

### Training
| Batch size | Seq len | Time per batch (Eager - s) | Time per batch (SDPA - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) |
|-----------:|--------:|----------------------------:|--------------------------:|------------:|--------------------:|-------------------:|------------------:|
| 1 | 128 | 0.039 | 0.032 | 23.042 | 3482.32 | 3494.62 | -0.352 |
| 1 | 256 | 0.073 | 0.059 | 25.15 | 3546.66 | 3552.6 | -0.167 |
| 1 | 512 | 0.155 | 0.118 | 30.96 | 4230.1 | 3665.59 | 15.4 |
| 1 | 1024 | 0.316 | 0.209 | 50.839 | 8682.26 | 4881.09 | 77.875 |
| 2 | 128 | 0.07 | 0.06 | 15.324 | 3557.8 | 3545.91 | 0.335 |
| 2 | 256 | 0.143 | 0.122 | 16.53 | 3901.5 | 3657.68 | 6.666 |
| 2 | 512 | 0.267 | 0.213 | 25.626 | 7062.21 | 4876.47 | 44.822 |
| 2 | 1024 | OOM | 0.404 | / | OOM | 8096.35 | SDPA does not OOM |
| 4 | 128 | 0.134 | 0.128 | 4.412 | 3675.79 | 3648.72 | 0.742 |
| 4 | 256 | 0.243 | 0.217 | 12.292 | 6129.76 | 4871.12 | 25.839 |
| 4 | 512 | 0.494 | 0.406 | 21.687 | 12466.6 | 8102.64 | 53.858 |
| 4 | 1024 | OOM | 0.795 | / | OOM | 14568.2 | SDPA does not OOM |

### Inference
| Batch size | Seq len | Per token latency Eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem Eager (MB) | Mem SDPA (MB) | Mem saved (%) |
|-----------:|--------:|-----------------------------:|----------------------------:|------------:|---------------:|--------------:|--------------:|
| 1 | 128 | 7.991 | 6.968 | 14.681 | 1685.2 | 1701.32 | -0.947 |
| 1 | 256 | 8.462 | 7.199 | 17.536 | 1745.49 | 1770.78 | -1.428 |
| 1 | 512 | 8.68 | 7.853 | 10.529 | 1907.69 | 1921.29 | -0.708 |
| 1 | 768 | 9.101 | 8.365 | 8.791 | 2032.93 | 2068.12 | -1.701 |
| 2 | 128 | 9.169 | 9.001 | 1.861 | 1803.84 | 1811.4 | -0.418 |
| 2 | 256 | 9.907 | 9.78 | 1.294 | 1907.72 | 1921.44 | -0.714 |
| 2 | 512 | 11.519 | 11.644 | -1.071 | 2176.86 | 2197.75 | -0.951 |
| 2 | 768 | 13.022 | 13.407 | -2.873 | 2464.3 | 2491.06 | -1.074 |
| 4 | 128 | 10.097 | 9.831 | 2.709 | 1942.25 | 1985.13 | -2.16 |
| 4 | 256 | 11.599 | 11.398 | 1.764 | 2177.28 | 2197.86 | -0.937 |
| 4 | 512 | 14.653 | 14.45 | 1.411 | 2753.16 | 2772.57 | -0.7 |
| 4 | 768 | 17.846 | 17.617 | 1.299 | 3327.04 | 3343.97 | -0.506 |




## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
Expand Down
140 changes: 130 additions & 10 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand All @@ -43,6 +45,7 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
get_torch_version,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
Expand Down Expand Up @@ -558,6 +561,113 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)


class GPT2SdpaAttention(GPT2Attention):
"""
GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
to adapt to the SDPA API.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")

def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

bsz, q_len, _ = hidden_states.size()

# Initial attention projections
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

# Optional kv caching
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

present = None
if use_cache is True:
present = (key, value)

# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
)

# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.embed_dim)

# Final projection
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

return attn_output, present, None


class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
Expand All @@ -575,10 +685,7 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states


GPT2_ATTENTION_CLASSES = {
"eager": GPT2Attention,
"flash_attention_2": GPT2FlashAttention2,
}
GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}


class GPT2Block(nn.Module):
Expand Down Expand Up @@ -674,6 +781,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPT2Block"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down Expand Up @@ -1022,11 +1130,24 @@ def forward(
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)

if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds

# Attention mask.
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good!

if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
attention_mask=attention_mask,
input_shape=(batch_size, input_shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
Expand All @@ -1050,7 +1171,11 @@ def forward(
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
if self._attn_implementation != "flash_attention_2":
if _use_sdpa:
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
)
elif not self._attn_implementation == "flash_attention_2":
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
Expand All @@ -1061,11 +1186,6 @@ def forward(
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)

if inputs_embeds is None:
vasqu marked this conversation as resolved.
Show resolved Hide resolved
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds

if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
Expand Down
3 changes: 2 additions & 1 deletion tests/models/gpt2/test_modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,8 @@ def test_gpt2_sample_max_time(self):
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
Comment on lines -835 to +836
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only failing test without this modification. I'm not sure if this is how it is intended. Maybe assertGreater should be assertLess instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this change is totally fine, now SDPA is used so the generation is faster


@slow
def test_contrastive_search_gpt2(self):
Expand Down
Loading