Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exllama kernels support for AWQ models #28634

Merged
merged 16 commits into from
Mar 5, 2024

Conversation

IlyasMoutawwakil
Copy link
Member

What does this PR do?

Following casper-hansen/AutoAWQ#313
ExllamaV2 offers up to 2x speedup compared to GEMM, while also compatible with AMD ROCm.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc and @younesbelkada

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making exllama kernel compatible with AWQ models ! This will make AWQ so much faster ! I've left a few minor comments.

src/transformers/integrations/awq.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
IlyasMoutawwakil and others added 2 commits January 24, 2024 12:45
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
@IlyasMoutawwakil
Copy link
Member Author

I guess all points are addressed.
@casper-hansen when is 0.1.9 planned ?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding the ex-llama v2 support ! 🔥
Let's add autoawq==0.1.9 in the Dockerfile:

RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl
- @casperhansen could you confirm the 0.1.9 is planned sometime soon?
Do you know also if this feature is supported on NVIDIA T4 GPUs? If that's the case can you add a simple generation test in: https://github.com/huggingface/transformers/blob/main/tests/quantization/autoawq/test_awq.py
Thanks !

@casper-hansen
Copy link

casper-hansen commented Jan 26, 2024

@younesbelkada The next release will be 0.2.0 🤗. For T4 support, I have not tested it. If AutoGPTQ supports T4 with ExLlama v1 and v2 kernels, AutoAWQ should too as the kernels are the same.

EDIT: To answer the timeline question. There is no set-in-stone plan for the next release. PRs to be merged before release include AMD support, Marlin support, Qwen2 support, and hopefully PEFT support. I expect this could be done in <1-2 weeks.

@younesbelkada
Copy link
Contributor

Awesome! Per my understanding ex-llama + AutoGPTQ should be supported on T4 so it should be all good !
Let me know whenever you have some progress for the PEFT support so that I'll dive in to add AWQ + PEFT support directly in PEFT

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IlyasMoutawwakil - #26610 being merged would you be happy to transfer the logic inside transformers/src/quantizers/quantizer_awq.py's post-processing method?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much @IlyasMoutawwakil ! I left one suggestion - what do you think?


# default values for exllamav2 from
# https://github.com/AutoGPTQ/AutoGPTQ/blob/6ba14f17ef73c161c2c4707cbf0b41e569a9c6dd/auto_gptq/nn_modules/qlinear/qlinear_exllamav2.py#L171
model = exllamav2_post_init(model, max_input_len=2048, max_batch_size=8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't we make max_input_len configurable through AwqConfig - wdyt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marc suggested we leave it as is for now #28634 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh okay ! I think it would makes sense to directly expose a exllama_config I think - wdyt @SunMarc ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would make more sense to expose it in a exllama_config !

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in another PR right ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I think it should be better to add it now and not leave the main branch with hardcoded config values, it shouldn't be super complex as you can just copy over the existing logic in GptqConfig right?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks already really nice thanks to the integration refactor!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, clean ! Let's merge this PR right after the next release of autoawq
@casper-hansen do you have any ETA for the next release?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again !

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that @casper-hansen already made the release of awq with exllama kernel. Can you check that everything works fine with the latest release @IlyasMoutawwakil ? Then, we are good to merge !

@IlyasMoutawwakil
Copy link
Member Author

@SunMarc on it!

@IlyasMoutawwakil
Copy link
Member Author

IlyasMoutawwakil commented Feb 23, 2024

works on rocm5.6 with torch 2.2

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig

quantization_config = AwqConfig(version="exllama")
model = AutoModelForCausalLM.from_pretrained(
    "TheBloke/Mistral-7B-Instruct-v0.1-AWQ",
    quantization_config=quantization_config,
    device_map="auto" or torch.device("cuda"),
)

input_ids = torch.randint(0, 100, (1, 128), dtype=torch.long, device="cuda")
output = model(input_ids)
print(output.logits)

tokenizer = AutoTokenizer.from_pretrained("TheBloke/Mistral-7B-Instruct-v0.1-AWQ")
input_ids = tokenizer.encode("How to make a cake", return_tensors="pt").to(model.device)
output = model.generate(input_ids, do_sample=True, max_length=50, pad_token_id=50256)
print(tokenizer.decode(output[0], skip_special_tokens=True))

The device_map is mandatory since exllamav2_post_init scratch space tensors allocation needs to check which qweights are on which cuda devices. I believe this is a requirement in GPTQ as well.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks LGTM with one nit !

docker/transformers-all-latest-gpu/Dockerfile Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding exllama support 🔥

@ArthurZucker ArthurZucker merged commit 4fc708f into huggingface:main Mar 5, 2024
19 of 21 checks passed
damithsenanayake pushed a commit to damithsenanayake/transformers that referenced this pull request Mar 7, 2024
* added exllama kernels support for awq models

* doc

* style

* Update src/transformers/modeling_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* refactor

* moved exllama post init to after device dispatching

* bump autoawq version

* added exllama test

* style

* configurable exllama kernels

* copy exllama_config from gptq

* moved exllama version check to post init

* moved to quantization dockerfile

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
itazap pushed a commit that referenced this pull request May 14, 2024
* added exllama kernels support for awq models

* doc

* style

* Update src/transformers/modeling_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* refactor

* moved exllama post init to after device dispatching

* bump autoawq version

* added exllama test

* style

* configurable exllama kernels

* copy exllama_config from gptq

* moved exllama version check to post init

* moved to quantization dockerfile

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants