Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fused modules #747

Merged
merged 23 commits into from
Oct 21, 2023
Merged

Implement fused modules #747

merged 23 commits into from
Oct 21, 2023

Conversation

casper-hansen
Copy link
Collaborator

@casper-hansen casper-hansen commented Oct 19, 2023

There are two common ways to fuse layers in Llama/Mistral type of models. Speed and memory are measured on RTX 3090 with TinyLlama 1B.

  • MLP: Merge gate_proj and up_proj together
    • Handle splitting SwiGLU weights into original module post-training
  • Attention: Merge q, k, v together
    • Handle replacing attention module better, currently only works with sample packing off
    • Fix the loss being much higher

All fusing of layers must happen AFTER the model is loaded in order to load the pretrained weights into the fused modules.

TinyLlama 1.1B - A6000

Conclusion: Fusing MLP can save roughly 27% memory in cache. Fusing attention seems to do nothing for the speed but increases memory with about 1GB.

  1. None fused (Main):
    • Memory: 6.186GB (+15.724GB cache, +0.781GB misc)
    • Speed: 0.53 seconds per step
  2. MLP fused (PR):
    • Memory: 6.188GB (+11.443GB cache, +0.781GB misc)
    • Speed: 0.51 seconds per step
  3. MLP + Attention fused (PR):
    • Memory: 6.618GB (+12.195GB cache, +0.781GB misc)
    • Speed: 0.50 seconds per step

Llama-2-7B - A100

Conclusion: Saves enough memory to load using adamw_torch.

  1. None fused (main):

    • adamw_torch Memory: 37.732GB (+39.764GB cache, +1.366GB misc)
    • adamw_torch_fused Memory: 37.732GB (+14.506GB cache, +1.366GB misc)
    • adamw_bnb_8bit Memory: 25.393GB (+14.494GB cache, +1.366GB misc)
  2. MLP fused (PR):

    • adamw_torch Memory: 37.732GB (+38.813GB cache, +1.366GB misc)
    • adamw_torch_fused Memory: 37.732GB (+14.647GB cache, +1.366GB misc)
    • adamw_bnb_8bit Memory: 25.269GB (+14.137GB cache, +1.366GB misc)
  3. MLP + Attention fused (PR):

    • adamw_bnb_8bit Memory: 31.332GB (+13.752GB cache, +1.366GB misc)
    • adamw_torch Memory: OOM

QLoRA

Currently, it is not compatible with QLoRA. But there is potential to do so. In bitsandbytes, you can import the 4-bit and 8-bit linears and use them instead of nn.Linear.

https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/nn/modules.py#L258

@casper-hansen casper-hansen marked this pull request as ready for review October 20, 2023 18:55
@winglian winglian merged commit 15d3a65 into axolotl-ai-cloud:main Oct 21, 2023
4 checks passed
mkeoliya pushed a commit to mkeoliya/axolotl that referenced this pull request Dec 15, 2023
* MLP: Memory saving

* Remove RMSNorm restrictions

* Map packed weights to original

* FusedAttention module

* Simplify code

* Move fused modules

* Fix critical typo

* Split inplace

* Add FFT config

* Add validation of fused arguments

* Add fused arguments to config

* Update docs

* Fix validation logic

* Add fused modules to flash attn

* Only fuse during training

* Remove timing

* Formatting

* Formatting

* Formatting

* chore: lint

* chore: lint

* add e2e tests for fused llama

* no lora for tests

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants