Skip to content

Commit

Permalink
Persist embedding type of BART and mBART models after resize (hugging…
Browse files Browse the repository at this point in the history
…face#32242)

* fix: persist embedding type of MBartConditonalGeneration after resize

* fix: persist embedding type of BartConditonalGeneration after resize
  • Loading branch information
AbdiHaryadi authored and dataKim1201 committed Oct 7, 2024
1 parent 0bfcf0c commit ac60cce
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,7 +1431,8 @@ def __init__(self, config: BartConfig):
super().__init__(config)

padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)

self.encoder = BartEncoder(config, self.shared)
self.decoder = BartDecoder(config, self.shared)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,8 @@ def __init__(self, config: MBartConfig):
super().__init__(config)

padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = MBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)

self.encoder = MBartEncoder(config, self.shared)
self.decoder = MBartDecoder(config, self.shared)
Expand Down
12 changes: 12 additions & 0 deletions tests/models/bart/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,18 @@ def test_generate_fp16(self):
def test_load_save_without_tied_weights(self):
pass

def test_resize_embeddings_persists_embeddings_type(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()

config.scale_embedding = True
model = BartForConditionalGeneration(config)
old_type = type(model.model.decoder.embed_tokens)

model.resize_token_embeddings(new_num_tokens=config.vocab_size)

new_type = type(model.model.decoder.embed_tokens)
self.assertIs(old_type, new_type)


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down
12 changes: 12 additions & 0 deletions tests/models/mbart/test_modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,18 @@ def test_ensure_weights_are_shared(self):
def test_load_save_without_tied_weights(self):
pass

def test_resize_embeddings_persists_embeddings_type(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()

config.scale_embedding = True
model = MBartForConditionalGeneration(config)
old_type = type(model.model.decoder.embed_tokens)

model.resize_token_embeddings(new_num_tokens=config.vocab_size)

new_type = type(model.model.decoder.embed_tokens)
self.assertIs(old_type, new_type)


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down

0 comments on commit ac60cce

Please sign in to comment.