Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dpr<>bart config for RAG #8808

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
Expand Down Expand Up @@ -268,7 +268,7 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pruned_heads = set()

self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment that mentions backward compatibility ?


def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/dpr/configuration_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ class DPRConfig(PretrainedConfig):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
projection_dim (:obj:`int`, `optional`, defaults to 0):
Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
projection is done.
Expand All @@ -93,6 +100,7 @@ def __init__(
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
projection_dim: int = 0,
**kwargs
):
Expand All @@ -112,3 +120,4 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.projection_dim = projection_dim
self.position_embedding_type = position_embedding_type
4 changes: 2 additions & 2 deletions src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, config):

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

# End copy
self.padding_idx = config.pad_token_id
Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(self, config):
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
Expand Down
25 changes: 12 additions & 13 deletions tests/test_modeling_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
if is_torch_available():
import torch

from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
from transformers.models.dpr.modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down Expand Up @@ -104,7 +104,8 @@ def prepare_config_and_inputs(self):
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)

config = BertConfig(
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Nov 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dangerous when new params are added to Bert, it'll lead to silent errors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch

config = DPRConfig(
projection_dim=self.projection_dim,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
Expand All @@ -115,14 +116,12 @@ def prepare_config_and_inputs(self):
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
)
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

def create_and_check_dpr_context_encoder(
def create_and_check_context_encoder(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DPRContextEncoder(config=config)
Expand All @@ -133,7 +132,7 @@ def create_and_check_dpr_context_encoder(
result = model(input_ids)
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))

def create_and_check_dpr_question_encoder(
def create_and_check_question_encoder(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DPRQuestionEncoder(config=config)
Expand All @@ -144,7 +143,7 @@ def create_and_check_dpr_question_encoder(
result = model(input_ids)
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))

def create_and_check_dpr_reader(
def create_and_check_reader(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = DPRReader(config=config)
Expand Down Expand Up @@ -199,17 +198,17 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

def test_dpr_context_encoder_model(self):
def test_context_encoder_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_context_encoder(*config_and_inputs)
self.model_tester.create_and_check_context_encoder(*config_and_inputs)

def test_dpr_question_encoder_model(self):
def test_question_encoder_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_question_encoder(*config_and_inputs)
self.model_tester.create_and_check_question_encoder(*config_and_inputs)

def test_dpr_reader_model(self):
def test_reader_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dpr_reader(*config_and_inputs)
self.model_tester.create_and_check_reader(*config_and_inputs)

@slow
def test_model_from_pretrained(self):
Expand Down