Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SeamlessM4Tv2ConformerEncoder does not behaves as expected if gradient checkpointing is enabled #31028

Closed
2 of 4 tasks
anferico opened this issue May 25, 2024 · 8 comments · Fixed by #31945
Closed
2 of 4 tasks
Labels

Comments

@anferico
Copy link
Contributor

anferico commented May 25, 2024

System Info

  • transformers version: 4.42.0.dev0
  • Platform: Linux-5.4.0-172-generic-x86_64-with-glibc2.17
  • Python version: 3.8.19
  • Huggingface_hub version: 0.23.1
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): 2.13.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.7.0 (cpu)
  • Jax version: 0.4.13
  • JaxLib version: 0.4.13
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker @sanchit

Proposed fix:

class SeamlessM4Tv2ConformerEncoder(...):
    [...]
    def forward(...):
        [...]
            if not skip_the_layer or deepspeed_zero3_is_enabled:
                # under deepspeed zero3 all gpus must run in sync
                if self.gradient_checkpointing and self.training:
                    layer_outputs = self._gradient_checkpointing_func(
                        layer.__call__,
                        hidden_states,
                        attention_mask,
                        output_attentions,    # <---------- Add this parameter
                        conv_attention_mask,  # <---------- Add this parameter         
                    )
                else:
                    layer_outputs = layer(
                        hidden_states,
                        attention_mask=attention_mask,
                        output_attentions=output_attentions,
                        conv_attention_mask=conv_attention_mask,
                    )
                hidden_states = layer_outputs[0]
        [...]

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Train a model that has transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2.SeamlessM4Tv2ConformerEncoder as a submodule
  2. (enable gradient checkpointing while training)
  3. When calling SeamlessM4Tv2ConformerEncoder.forward(), pass output_attentions=True and return_dict=True. For example:
    encoder: SeamlessM4Tv2ConformerEncoder = ...
    output = encoder(..., output_attentions=True, return_dict=True)

Expected behavior

output.attentions is a tuple of not-None tensors, one per encoder layer. Instead, the actual behavior is that output.attentions = (None, None, ..., None).

@amyeroberts
Copy link
Collaborator

@anferico Thanks for raising! I think you pinged the wrong Sanchit - cc @sanchit-gandhi

@ArthurZucker
Copy link
Collaborator

cc @ylacombe as well! 🤗

@ylacombe
Copy link
Contributor

ylacombe commented Jun 5, 2024

Hey @anferico, nice catch, would you like to open a PR to fix this?

Note that Seamless training is not supported yet in transformers though

@anferico
Copy link
Contributor Author

anferico commented Jun 5, 2024

@ylacombe sure, I'll open a PR 👍 No worries about the support for training, as I actually have a use case where I just take the speech encoder out of SeamlessM4Tv2 and employ it in a larger model architecture

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@anferico
Copy link
Contributor Author

anferico commented Jul 5, 2024

Will open a PR soon, sorry for the delay

@ylacombe
Copy link
Contributor

ylacombe commented Jul 8, 2024

No worries @anferico, don't hesitate to ping me once it's done!

@anferico
Copy link
Contributor Author

@ylacombe PR opened (#31945)! Besides, I also wanted to point your attention to another issue (#31946) I opened regarding the speech encoder of SeamlessM4Tv2. It would be great if you could check it out 🙏🏼

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
4 participants