Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embedding class is replaced when calling resize_token_embeddings #31835

Closed
1 of 4 tasks
TWagner2 opened this issue Jul 8, 2024 · 11 comments · Fixed by #32242
Closed
1 of 4 tasks

Embedding class is replaced when calling resize_token_embeddings #31835

TWagner2 opened this issue Jul 8, 2024 · 11 comments · Fixed by #32242

Comments

@TWagner2
Copy link

TWagner2 commented Jul 8, 2024

System Info

  • transformers version: 4.42.3
  • Platform: Linux-5.10.217-205.860.amzn2.x86_64-x86_64-with-glibc2.26
  • Python version: 3.11.9
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1+cu121 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When calling the resize_token_embedding method of MBart, the embeddings of the model are converted from MBartScaledWordEmbeddings to just regular nn.Embeddings. This changes the model output because the scaling is lost. Below is a minimal script to demonstrate the issue. In particular, this means that the following series of steps leads to unexpected behavior:

  1. Load model, and resize embeddings for a training task
  2. Train model
  3. Evaluate, results look fine because we both trained and evaluated without scaling in the embeddings
  4. Save the model using save_pretrained and load it again for production. This will implicitly replace the embedding class with MBartScaledWordEmbeddings again.
  5. Run the model. Output is now incorrect because scaling is now included, i.e. the model architecture changed.

I believe other models should also be affected, compare this PR: #30410

from transformers.models.mbart import MBartConfig, MBartForCausalLM
import torch

example_input = torch.tensor([1024])

config = MBartConfig(vocab_size=50265, scale_embedding=True)
model = MBartForCausalLM(config)
# This will be an MBartScaledWordEmbedding
print(model.model.decoder.embed_tokens)
old_type = type(model.model.decoder.embed_tokens)
old_output = model.model.decoder.embed_tokens(example_input)
# This will be a standard nn.Embedding
model.resize_token_embeddings(new_num_tokens=50266)
model.resize_token_embeddings(new_num_tokens=50265)
print(model.model.decoder.embed_tokens)
new_type = type(model.model.decoder.embed_tokens)
new_output = model.model.decoder.embed_tokens(example_input)
print(f"Embedding type remained the same: {old_type is new_type}") # False
print(f"Output is the same: {torch.equal(old_output, new_output)}") # False

Expected behavior

The embedding does not change and keeps the scaling.

@zucchini-nlp
Copy link
Member

Thanks for reporting the issue!

For context to whoever might encounter this issue: MBartScaledWordEmbeddings was introduced because of backwards compatibility, as we expected users to pass already scaled embeddings in earlier versions (see comment). Imo the prev behavior was a bit misleading but we still needed to keep it.

Solution can be to overwrite resize_token_embeddings for these models to account for a custom Embedding module, which will be less breaking. I will add it to my todo list :)

@bayllama
Copy link
Contributor

@zucchini-nlp @ArthurZucker
Can I take this issue up and raise a PR?

@zucchini-nlp
Copy link
Member

@bayllama Sure. Looking forward to your PR. If you have any questions or need assistance, feel free to ask 😄

@TWagner2
Copy link
Author

Thanks a lot for looking at this issue @zucchini-nlp, and thanks @bayllama for taking this up!

Solution can be to overwrite resize_token_embeddings for these models to account for a custom Embedding module, which will be less breaking.

If I may make a suggestion, in my opinion this would ideally be solved by making the resize_token__embeddings method in the base class agnostic to the exact type of embedding, instead of overwriting the method. Otherwise, one runs the risk of re-introducing this bug whenever a new model is added.

@zucchini-nlp
Copy link
Member

@TWagner2 yes, would be great if you can make the resize_token_embeddings class agnostic and keep backwards compatibility!

@bayllama
Copy link
Contributor

Thanks @zucchini-nlp and @TWagner2 for the inputs. Will try to make the function class agnostic.

@bayllama
Copy link
Contributor

Hi @zucchini-nlp and @TWagner2 ,

I took my first stab here in trying to make the function resize_token_embeddings agnostic of the function class and makes sure it always returns the same Custom Embedding Class in the output. I made a commit onto my fork here below with the changes,

https://github.com/bayllama/transformers/commit/eebabe17909de773aa63e2c57d62974fe7c527be

Description of the changes:

  1. Added **kwargs to the parameters in the init function of the Custom Embedding Class. This allows us the ability to pass the super class parameters to it if needed.
  2. Made the padding_idx and Optional[int] and set the default to 0. The reason for this change is while creating an instance of this class in modelling_utils.py we don't want to set a padding_idx.
  3. In the modelling_utils.py file, made a change to _get_resized_embeddings where we capture the type of the input class and make sure the output is of the same type.

Please take a look at the commit url above and let me know if I am going in the right direction. Also, if this change is fine I may need to make similar changes to 12 more modeling_{model_name}.py files and make a PR.

Thanks!

@zucchini-nlp
Copy link
Member

@bayllama can you open a PR with the current changes pls? If we're happy with them, then you can modify other models same way

@bayllama
Copy link
Contributor

@zucchini-nlp Have made the PR #31979

Will need to make changes to the other models that have a custom embedding class also to make sure all tests pass. Let me know

@AbdiHaryadi
Copy link
Contributor

AbdiHaryadi commented Jul 26, 2024

I also experienced this bug, but with different model. With Transformers version 4.43.2, the issue has been solved for MBartForCausalLM (like the previous example), but not for MBartForConditionalGeneration. Here is the slightly changed code:

from transformers.models.mbart import MBartConfig, MBartForConditionalGeneration
import torch

example_input = torch.tensor([1024])

config = MBartConfig(vocab_size=50265, scale_embedding=True)
model = MBartForConditionalGeneration(config)
# This will be an MBartScaledWordEmbedding
print(model.model.decoder.embed_tokens)
old_type = type(model.model.decoder.embed_tokens)
old_output = model.model.decoder.embed_tokens(example_input)
# This will be a standard nn.Embedding
# model.resize_token_embeddings(new_num_tokens=50266) # This line is optional.
model.resize_token_embeddings(new_num_tokens=50265)
print(model.model.decoder.embed_tokens)
new_type = type(model.model.decoder.embed_tokens)
new_output = model.model.decoder.embed_tokens(example_input)
print(f"Embedding type remained the same: {old_type is new_type}") # False
print(f"Output is the same: {torch.equal(old_output, new_output)}") # False

I've tried to analyze the bug. What I found is the method of MBartModel.get_input_embeddings returns self.shared which has torch.nn.Embedding type, shown in this code:

class MBartModel(MBartPreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: MBartConfig):
        super().__init__(config)

        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)

        self.encoder = MBartEncoder(config, self.shared)
        self.decoder = MBartDecoder(config, self.shared)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.shared

    ...

However, that type is inconsistent with MBartEncoder.embed_tokens which has MBartScaledWordEmbedding type, shown in this code: (I haven't analyzed MBartDecoder yet.)

class MBartEncoder(MBartPreTrainedModel):
    ...
    def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
        ...
        embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        self.embed_tokens = MBartScaledWordEmbedding(
            config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
        )
        ...

And that seems a problem because MBartModel.get_input_embeddings only returning the standard embedding type, not the actual embedding which is used by MBartEncoder. Furthermore, the method of resize_token_embeddings uses get_input_embeddings and set_input_embeddings as shown in PretrainedModel._resize_token_embeddings implementation:

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
    ...

    def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
        old_embeddings = self.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
        ...
        self.set_input_embeddings(new_embeddings)

        ...

    ...

But, the persisted embeddings in MBartEncoder.embed_tokens is replaced by self.shared after resizing because of MBartModel.set_input_embeddings method, shown in this code:

class MBartModel(MBartPreTrainedModel):
    ...

    def set_input_embeddings(self, value):
        self.shared = value
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

    ...

I think, to solve this problem, instead of using torch.nn.Embedding initialization for self.shared in MBartModel, let's just use MBartScaledWordEmbedding. The implementation is exactly same as the initialization in MBartEncoder.__init__.

Can I create another pull request for this solution, or is there any issue for this?

@zucchini-nlp
Copy link
Member

@AbdiHaryadi I see, the embedding layer in MBartModel has to be the Scaled Embedding also. Yes, feel free to make a PR and tag me when it's ready for review :)

I also found these models need to be aligned with the shared embedding, if you can add it to the PR: BartModel, SeamlessM4TForTextToText, SeamlessM4TForTextToSpeech, SeamlessM4TModel and same for seamless-v2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants