Skip to content

Commit

Permalink
Optional layers (#8961)
Browse files Browse the repository at this point in the history
* Apply on BERT and ALBERT

* Update TF Bart

* Add input processing to TF BART

* Add input processing for TF CTRL

* Add input processing to TF Distilbert

* Add input processing to TF DPR

* Add input processing to TF Electra

* Add deprecated arguments

* Add input processing to TF XLM

* remove unused imports

* Add input processing to TF Funnel

* Add input processing to TF GPT2

* Add input processing to TF Longformer

* Add input processing to TF Lxmert

* Apply style

* Add input processing to TF Mobilebert

* Add input processing to TF GPT

* Add input processing to TF Roberta

* Add input processing to TF T5

* Add input processing to TF TransfoXL

* Apply style

* Rebase on master

* Fix wrong model name

* Fix BART

* Apply style

* Put the deprecated warnings in the input processing function

* Remove the unused imports

* Raise an error when len(kwargs)>0

* test ModelOutput instead of TFBaseModelOutput

* Address Patrick's comments

* Address Patrick's comments

* Add boolean processing for the inputs

* Take into account the optional layers

* Add missing/unexpected weights in the other models

* Apply style

* rename parameters

* Apply style

* Remove useless

* Remove useless

* Remove useless

* Update num parameters

* Fix tests

* Address Patrick's comment

* Remove useless attribute
  • Loading branch information
jplu committed Dec 8, 2020
1 parent 9d7d000 commit bf7f79c
Show file tree
Hide file tree
Showing 17 changed files with 195 additions and 98 deletions.
48 changes: 32 additions & 16 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,18 +481,22 @@ def call(self, hidden_states):
class TFAlbertMainLayer(tf.keras.layers.Layer):
config_class = AlbertConfig

def __init__(self, config, **kwargs):
def __init__(self, config, add_pooling_layer=True, **kwargs):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.config = config

self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
self.encoder = TFAlbertTransformer(config, name="encoder")
self.pooler = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="pooler",
self.pooler = (
tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="pooler",
)
if add_pooling_layer
else None
)

def get_input_embeddings(self):
Expand Down Expand Up @@ -601,7 +605,7 @@ def call(
)

sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output[:, 0])
pooled_output = self.pooler(sequence_output[:, 0]) if self.pooler is not None else None

if not inputs["return_dict"]:
return (
Expand Down Expand Up @@ -807,6 +811,9 @@ def call(
ALBERT_START_DOCSTRING,
)
class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
Expand Down Expand Up @@ -914,13 +921,13 @@ def call(self, pooled_output, training: bool):

@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):

_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

self.albert = TFAlbertMainLayer(config, name="albert")
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")

def get_output_embeddings(self):
Expand Down Expand Up @@ -1007,6 +1014,10 @@ def call(
ALBERT_START_DOCSTRING,
)
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"predictions"]
_keys_to_ignore_on_load_missing = [r"dropout"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
Expand Down Expand Up @@ -1099,14 +1110,15 @@ def call(
ALBERT_START_DOCSTRING,
)
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):

_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
_keys_to_ignore_on_load_missing = [r"dropout"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels

self.albert = TFAlbertMainLayer(config, name="albert")
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
Expand Down Expand Up @@ -1193,14 +1205,14 @@ def call(
ALBERT_START_DOCSTRING,
)
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):

_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels

self.albert = TFAlbertMainLayer(config, name="albert")
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
Expand Down Expand Up @@ -1301,6 +1313,10 @@ def call(
ALBERT_START_DOCSTRING,
)
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
_keys_to_ignore_on_load_missing = [r"dropout"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand Down
6 changes: 2 additions & 4 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,8 @@ def call(self, input_ids, use_cache=False):
)
@keras_serializable
class TFBartModel(TFPretrainedBartModel):
base_model_prefix = "model"

def __init__(self, config: BartConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
Expand Down Expand Up @@ -1033,10 +1035,6 @@ def get_output_embeddings(self):
BART_START_DOCSTRING,
)
class TFBartForConditionalGeneration(TFPretrainedBartModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
]
_keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight",
Expand Down
71 changes: 52 additions & 19 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def call(self, pooled_output):
class TFBertMainLayer(tf.keras.layers.Layer):
config_class = BertConfig

def __init__(self, config, **kwargs):
def __init__(self, config, add_pooling_layer=True, **kwargs):
super().__init__(**kwargs)

self.config = config
Expand All @@ -558,7 +558,7 @@ def __init__(self, config, **kwargs):
self.return_dict = config.use_return_dict
self.embeddings = TFBertEmbeddings(config, name="embeddings")
self.encoder = TFBertEncoder(config, name="encoder")
self.pooler = TFBertPooler(config, name="pooler")
self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None

def get_input_embeddings(self):
return self.embeddings
Expand Down Expand Up @@ -663,7 +663,7 @@ def call(
)

sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not inputs["return_dict"]:
return (
Expand Down Expand Up @@ -880,6 +880,9 @@ def call(
BERT_START_DOCSTRING,
)
class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"cls.predictions.decoder.weight"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand Down Expand Up @@ -976,9 +979,13 @@ def call(

@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"cls.seq_relationship",
r"cls.predictions.decoder.weight",
r"nsp___cls",
]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
Expand All @@ -989,7 +996,7 @@ def __init__(self, config, *inputs, **kwargs):
"bi-directional self-attention."
)

self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

def get_output_embeddings(self):
Expand Down Expand Up @@ -1068,17 +1075,21 @@ def call(


class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"cls.seq_relationship",
r"cls.predictions.decoder.weight",
r"nsp___cls",
]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

if not config.is_decoder:
logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")

self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

def get_output_embeddings(self):
Expand Down Expand Up @@ -1165,6 +1176,9 @@ def call(
BERT_START_DOCSTRING,
)
class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"cls.predictions"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand Down Expand Up @@ -1262,6 +1276,10 @@ def call(
BERT_START_DOCSTRING,
)
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
_keys_to_ignore_on_load_missing = [r"dropout"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand Down Expand Up @@ -1353,6 +1371,10 @@ def call(
BERT_START_DOCSTRING,
)
class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
_keys_to_ignore_on_load_missing = [r"dropout"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand Down Expand Up @@ -1477,15 +1499,21 @@ def call(
BERT_START_DOCSTRING,
)
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"mlm___cls",
r"nsp___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
_keys_to_ignore_on_load_missing = [r"dropout"]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
Expand Down Expand Up @@ -1571,15 +1599,20 @@ def call(
BERT_START_DOCSTRING,
)
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"mlm___cls",
r"nsp___cls",
r"cls.predictions",
r"cls.seq_relationship",
]

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name="bert")
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/electra/modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ class TFElectraPreTrainedModel(TFPreTrainedModel):

config_class = ElectraConfig
base_model_prefix = "electra"
# When the model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"]
_keys_to_ignore_on_load_missing = [r"dropout"]


@keras_serializable
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/funnel/modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,6 @@ def call(
return_dict=inputs["return_dict"],
training=inputs["training"],
)

last_hidden_state = outputs[0]
pooled_output = last_hidden_state[:, 0]
logits = self.classifier(pooled_output, training=inputs["training"])
Expand Down Expand Up @@ -1735,7 +1734,6 @@ def call(
training=training,
kwargs_call=kwargs,
)
return_dict = inputs["return_dict"] if inputs["return_dict"] is not None else self.funnel.return_dict
outputs = self.funnel(
inputs["input_ids"],
inputs["attention_mask"],
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gpt2/modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):

config_class = GPT2Config
base_model_prefix = "transformer"
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"h.\d+.attn.bias"]


@dataclass
Expand Down
Loading

0 comments on commit bf7f79c

Please sign in to comment.