Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the quantization modification logic #2233

Merged
merged 11 commits into from
Apr 29, 2024

Conversation

dbogunowicz
Copy link
Contributor

@dbogunowicz dbogunowicz commented Apr 9, 2024

Feature Description

The sparseml.transformers.sparsification.modification package is a set of modifications that are applied to some of the transformer models, to make them compatible with our quantization flows.

This PR moves the act of modification away from SparseAutoModel, to the QuantizationModifier. This means we only modify the model when necessary - when we apply the quantization structure. This helps us to avoid clashes with the new behavior in transformers, where the models by default get initialized with SDPAttention.

Notable changes

  • moving the high-level modification logic to sparseml.modifiers.quantization. Keeping only the transformers-specific logic in the original directory
  • only the models that are being supported by the new, "post-refactor" quantization modifiers "participate" in the modification. If the user wishes to initialize one of the old models, the modification happens on the initialization of SparseAutoModel.

To my best knowledge, the failing tests are orthogonal to the contents of this PR.

Legacy PR description

Keeping the original PR message (analysis of the problem, that forced me to go down the final path present in the PR) for posteriority, as it contains a lot of useful context:

As reported by @Satrat, after upgrading the transformers version we did not see the expected training speedups during e.g. sparse fine-tuning process. It turned out this was caused by the modify_model(...) function during the initialization of the SparseAutoModelForCausalLM.

Let's explain what was happening using LLaMa as an example.

The model as of transformers==4.39.1 can be initialized with three types of attention:

  • LlamaSdpaAttention - the default in the current transformers version; it uses the CUDA optimized torch.nn.functional.scaled_dot_product_attention for quicker computation of attention
  • LlamaAttention - the previous default before the transformers upgrade; this is the attention type that uses torch.matmul method and thus is being modified by us through the modify_model(...) method.
  • LlamaFlashAttention - irrelevant in the context of this write-up.

The current, erroneous behavior, was the following:

  1. We were initializing the model with the default, SDPA-attention
  2. Because of the misuse of the isinstance() method, we were overriding this attention module's forward method, effectively replacing SDPA-attention's forward method with the original attention class' forward method.
  3. This is what was slowing down the training of the LlaMa model -> unknowingly, we were no longer using the fast CUDA-optimized attention but traditional, torch-based attention computation.

This PR hardens the modification logic - it uses a more restrictive type() instead of isinstance() to pick the correct attention type to modify. Now there is no difference in iterations per second when sparse fine tuning with or without the modify_model(...) function.

bfineran
bfineran previously approved these changes Apr 15, 2024
@dbogunowicz dbogunowicz changed the title [Fix] Remove hidden issue in modification repo that causes training slowdown Refactor the quantization modification logic Apr 16, 2024
Copy link
Contributor

@Satrat Satrat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but two main thoughts: can we document this environment variable somewhere with its intended usage? And could we add a test script to this PR that demonstrates the speed issue being fixed?

@bfineran bfineran merged commit 7cd2feb into main Apr 29, 2024
16 of 17 checks passed
@bfineran bfineran deleted the feature/damian/modifications_bug branch April 29, 2024 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants