Skip to content

Commit

Permalink
Fix dpr<>bart config for RAG (huggingface#8808)
Browse files Browse the repository at this point in the history
* correct dpr test and bert pos fault

* fix dpr bert config problem

* fix layoutlm

* add config to dpr as well
  • Loading branch information
patrickvonplaten authored and stas00 committed Dec 2, 2020
1 parent fc3d92a commit 707bfa6
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
Expand Down Expand Up @@ -268,7 +268,7 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pruned_heads = set()

self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/dpr/configuration_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ class DPRConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
projection_dim (:obj:`int`, `optional`, defaults to 0):
Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
projection is done.
Expand All @@ -93,6 +100,7 @@ def __init__(
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
projection_dim: int = 0,
**kwargs
):
Expand All @@ -112,3 +120,4 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.projection_dim = projection_dim
self.position_embedding_type = position_embedding_type
4 changes: 2 additions & 2 deletions src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

# End copy
self.padding_idx = config.pad_token_id
Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
25 changes: 12 additions & 13 deletions tests/test_modeling_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
if is_torch_available():
import torch

from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
from transformers.models.dpr.modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down Expand Up @@ -104,7 +104,8 @@ def prepare_config_and_inputs(self):
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)

config = BertConfig(
config = DPRConfig(
projection_dim=self.projection_dim,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
Expand All @@ -115,14 +116,12 @@ def prepare_config_and_inputs(self):
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
)
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

def create_and_check_dpr_context_encoder(
def create_and_check_context_encoder(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DPRContextEncoder(config=config)
Expand All @@ -133,7 +132,7 @@ def create_and_check_dpr_context_encoder(
result = model(input_ids)
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))

def create_and_check_dpr_question_encoder(
def create_and_check_question_encoder(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DPRQuestionEncoder(config=config)
Expand All @@ -144,7 +143,7 @@ def create_and_check_dpr_question_encoder(
result = model(input_ids)
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))

def create_and_check_dpr_reader(
def create_and_check_reader(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DPRReader(config=config)
Expand Down Expand Up @@ -199,17 +198,17 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

def test_dpr_context_encoder_model(self):
def test_context_encoder_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_context_encoder(*config_and_inputs)
self.model_tester.create_and_check_context_encoder(*config_and_inputs)

def test_dpr_question_encoder_model(self):
def test_question_encoder_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_question_encoder(*config_and_inputs)
self.model_tester.create_and_check_question_encoder(*config_and_inputs)

def test_dpr_reader_model(self):
def test_reader_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_reader(*config_and_inputs)
self.model_tester.create_and_check_reader(*config_and_inputs)

@slow
def test_model_from_pretrained(self):
Expand Down

0 comments on commit 707bfa6

Please sign in to comment.