Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
lyaronskaya authored and Sebastien Ehrhardt committed Apr 27, 2024
1 parent d324cd0 commit 6468319
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
13 changes: 10 additions & 3 deletions src/transformers/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,16 @@ def forward(
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, head_mask,
self.attention_probs_dropout_prob if self.training else 0.0, is_causal=False, scale=None)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
Expand Down
1 change: 0 additions & 1 deletion tests/models/vit/test_modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
require_torch,
require_torch_accelerator,
require_torch_fp16,
require_torch_sdpa,
require_vision,
slow,
torch_device,
Expand Down

0 comments on commit 6468319

Please sign in to comment.