Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama + Mistral] Add attention dropout #27315

Merged
merged 10 commits into from
Nov 13, 2023
5 changes: 4 additions & 1 deletion src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class LlamaConfig(PretrainedConfig):
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.

attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.

```python
>>> from transformers import LlamaModel, LlamaConfig
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -156,6 +158,7 @@ def __init__(
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

super().__init__(
pad_token_id=pad_token_id,
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
Expand All @@ -292,6 +293,7 @@ def __init__(self, config: LlamaConfig):
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
Expand Down Expand Up @@ -404,6 +406,7 @@ def forward(

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
Expand Down Expand Up @@ -489,10 +492,7 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

# TODO: llama does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout
dropout_rate = 0.0 if not self.training else self.attention_dropout

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/mistral/configuration_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class MistralConfig(PretrainedConfig):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention window size. If not specified, will default to `4096`.

attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.

```python
>>> from transformers import MistralModel, MistralConfig
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -139,6 +141,7 @@ def __init__(
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout

super().__init__(
pad_token_id=pad_token_id,
Expand Down
8 changes: 3 additions & 5 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(self, config: MistralConfig):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
Expand Down Expand Up @@ -284,6 +285,7 @@ def forward(

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
Expand Down Expand Up @@ -390,11 +392,7 @@ def forward(
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# TODO: Mistral does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout
dropout_rate = 0.0 if not self.training else self.attention_dropout

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
Expand Down
Loading