Skip to content

Commit

Permalink
fix yolo and text_encoder tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed May 3, 2024
1 parent 1215075 commit 62f8830
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ def get_pretrained_model_and_inputs(self):
return model, inputs

def get_vision_text_model(self, vision_config, text_config):
# Run in eager mode if we are in slow mode
if _run_slow_tests:
vision_config._attn_implementation = "eager"

vision_model = ViTModel(vision_config).eval()
text_model = BertModel(text_config).eval()
return vision_model, text_model
Expand Down Expand Up @@ -397,6 +401,10 @@ def check_vision_text_output_attention(
)

def get_vision_text_model(self, vision_config, text_config):
# Run in eager mode if we are in slow mode
if _run_slow_tests:
vision_config._attn_implementation = "eager"

vision_model = DeiTModel(vision_config).eval()
text_model = RobertaModel(text_config).eval()
return vision_model, text_model
Expand Down
1 change: 1 addition & 0 deletions tests/models/yolos/test_modeling_yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def test_model(self):
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
config._attn_implementation = "eager"

# in YOLOS, the seq_len is different
seq_len = self.model_tester.expected_seq_len
Expand Down

0 comments on commit 62f8830

Please sign in to comment.