Skip to content

Commit

Permalink
Template updates (#6914)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Sep 3, 2020
1 parent ea2c6f1 commit 722b580
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 29 deletions.
1 change: 0 additions & 1 deletion templates/adding_a_new_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ You can then finish the addition step by adding imports for your classes in the
- [ ] Add your configuration in `configuration_auto.py`.
- [ ] Add your PyTorch and TF 2.0 model respectively in `modeling_auto.py` and `modeling_tf_auto.py`.
- [ ] Add your tokenizer in `tokenization_auto.py`.
- [ ] Add your models and tokenizer to `pipeline.py`.
- [ ] Add a link to your conversion script in the main conversion utility (in `commands/convert.py`)
- [ ] Edit the PyTorch to TF 2.0 conversion script to add your model in the `convert_pytorch_checkpoint_to_tf2.py`
file.
Expand Down
30 changes: 17 additions & 13 deletions templates/adding_a_new_model/tests/test_modeling_tf_xxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import unittest

from transformers import XxxConfig, is_tf_available
from transformers.testing_utils import CACHE_DIR, require_tf, slow

from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow


if is_tf_available():
Expand Down Expand Up @@ -137,7 +137,7 @@ def prepare_config_and_inputs(self):

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

def create_and_check_xxx_model(
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFXxxModel(config=config)
Expand All @@ -154,15 +154,15 @@ def create_and_check_xxx_model(
)
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))

def create_and_check_xxx_for_masked_lm(
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFXxxForMaskedLM(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_xxx_for_sequence_classification(
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
Expand All @@ -171,7 +171,7 @@ def create_and_check_xxx_for_sequence_classification(
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_bert_for_multiple_choice(
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
Expand All @@ -187,7 +187,7 @@ def create_and_check_bert_for_multiple_choice(
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))

def create_and_check_xxx_for_token_classification(
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
Expand All @@ -196,7 +196,7 @@ def create_and_check_xxx_for_token_classification(
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))

def create_and_check_xxx_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFXxxForQuestionAnswering(config=config)
Expand Down Expand Up @@ -226,25 +226,29 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

def test_xxx_model(self):
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)

def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)

def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)

def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_for_sequence_classification(*config_and_inputs)
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)

def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_for_token_classification(*config_and_inputs)
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)

def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)

@slow
def test_model_from_pretrained(self):
Expand Down
51 changes: 36 additions & 15 deletions templates/adding_a_new_model/tests/test_modeling_xxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import unittest

from transformers import is_torch_available
from transformers.testing_utils import require_torch, require_torch_and_cuda, slow, torch_device

from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, require_torch_and_cuda, slow, torch_device


if is_torch_available():
Expand All @@ -29,6 +29,7 @@
AutoTokenizer,
XxxConfig,
XxxForMaskedLM,
XxxForMultipleChoice,
XxxForQuestionAnswering,
XxxForSequenceClassification,
XxxForTokenClassification,
Expand Down Expand Up @@ -126,7 +127,7 @@ def prepare_config_and_inputs(self):

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

def create_and_check_xxx_model(
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxModel(config=config)
Expand All @@ -138,18 +139,16 @@ def create_and_check_xxx_model(
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))

def create_and_check_xxx_for_masked_lm(
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxForMaskedLM(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_xxx_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxForQuestionAnswering(config=config)
Expand All @@ -165,7 +164,7 @@ def create_and_check_xxx_for_question_answering(
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))

def create_and_check_xxx_for_sequence_classification(
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
Expand All @@ -175,7 +174,7 @@ def create_and_check_xxx_for_sequence_classification(
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_xxx_for_token_classification(
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
Expand All @@ -185,6 +184,24 @@ def create_and_check_xxx_for_token_classification(
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))

def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = XxxForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
Expand Down Expand Up @@ -216,25 +233,29 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

def test_xxx_model(self):
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)

def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)

def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)

def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_for_sequence_classification(*config_and_inputs)
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)

def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xxx_for_token_classification(*config_and_inputs)
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)

def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_multiple_choice(*config_and_inputs)

@slow
def test_lm_outputs_same_as_reference_model(self):
Expand Down

0 comments on commit 722b580

Please sign in to comment.