Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPTQ Quantization #1216

Merged
merged 57 commits into from
Aug 10, 2023
Merged

Add GPTQ Quantization #1216

merged 57 commits into from
Aug 10, 2023

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Jul 20, 2023

What does this PR do?

This PR adds the possibility to perform GTPQ quantization. I tried to be as generic as possible to support any kind of models ( whether it is Transformer models or not). The backend relies on auto_gptq library where we use GTPQ class and QuantLinear class. With this API you can do conversion, saving and loading quantized weights. We have a dependence on accelerate to save the weights and load the quantized weights efficiently without allocating more memory than needed.

Quantize model:

model_name = "bigscience/bloom-1b7"
from transformers import AutoModelForCausalLM, AutoTokenizer
from optimum.gptq import GPTQQuantizer, load_quantized_model
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype = torch.float16, device_map = "auto")
quantizer = GPTQQuantizer(bits=4, dataset="c4", group_size=128, desc_act=False)
quantized_model = quantizer.quantize_model(quantized_model, tokenizer)

Save model

save_folder = 'bloom-1b7-quantized_optimum'
quantizer.save(model, save_folder)

Convert model and Load quantized weights

from accelerate import init_empty_weights
with init_empty_weights():
    empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
empty_model.tie_weights()
model_from_saved = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto")

PS: In a follow-up PR, I will use this API to integrate GTPQ quantization into transformers, so that we take into account GPTQ quantized model, just like what we did for bitsandbytes

TODO:

  • Doc
  • CI
  • cpu offload (for now, if some modules are offloaded to the cpu (using device_map), the quantization will not work )
  • exllama integration
  • script to calculate the perplexity, maybe in a new PR ...

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this work and integration! Left some initial comments before the review of @fxmarty with respect to potential transformers integration later on

optimum/gptq/__init__.py Outdated Show resolved Hide resolved
optimum/gptq/data.py Show resolved Hide resolved
optimum/gptq/data.py Outdated Show resolved Hide resolved
optimum/gptq/quantizer.py Outdated Show resolved Hide resolved
optimum/gptq/quantizer.py Outdated Show resolved Hide resolved
optimum/gptq/quantizer.py Outdated Show resolved Hide resolved
optimum/gptq/quantizer.py Outdated Show resolved Hide resolved
optimum/gptq/constants.py Show resolved Hide resolved
Comment on lines +218 to +219
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed to quantize model.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need a GPU to quantize the model? I don't have the details right out of my head now, but is it a strong requirement?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not considering the actual inference process which I understand requires a GPU, maybe quantization/calibration can be done on CPU too?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is a strong requirement. I need to test but you can have a look at fasterquant method from GPTQ class. Looks like it uses cuda to speed up some operations (cholesky, cholesky_inverse, matmul) and I can see that he calls torch.cuda.synchronize().

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is in good shape, thank you for working on it!

I still think that using exllama kernel would be very interesting (for ref https://github.com/fxmarty/q4f16-gemm-gemv-benchmark & huggingface/text-generation-inference#553 (comment), second one is with Triton, that I see is not used here), but integrating with AutoGPTQ is already a nice step (making sure that cuda-old is used when possible).

Is there a hard requirement on accelerate? If so, why?

I think it would be good to add a documentation as well in this PR, and CI.

It is also not clear to me what should go in accelerate and optimum, if this PR puts a hard requirement on accelerate. Why not put it in accelerate directly? Why bitsandbytes in accelerate & this in optimum? I think it can be quite confusing to users.

optimum/gptq/data.py Show resolved Hide resolved
optimum/gptq/data.py Show resolved Hide resolved
optimum/gptq/data.py Outdated Show resolved Hide resolved
optimum/gptq/data.py Outdated Show resolved Hide resolved
optimum/gptq/data.py Outdated Show resolved Hide resolved
optimum/gptq/quantizer.py Outdated Show resolved Hide resolved
optimum/gptq/quantizer.py Show resolved Hide resolved
optimum/gptq/quantizer.py Outdated Show resolved Hide resolved
optimum/gptq/quantizer.py Show resolved Hide resolved
optimum/gptq/quantizer.py Outdated Show resolved Hide resolved
SunMarc and others added 10 commits July 21, 2023 09:47
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work @SunMarc ! All this looks great! Left few nits but the overall structure looks really nice. Thanks a lot for working on this!
Let's wait for @fxmarty's review for merging this PR :D

.github/workflows/test_gptq.yml Outdated Show resolved Hide resolved
SunMarc and others added 3 commits July 31, 2023 10:29
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I left a few questions/style comments

SunMarc and others added 2 commits August 1, 2023 09:56
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
@SunMarc SunMarc requested a review from fxmarty August 1, 2023 14:46
Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for adding it!

Just left taste comments / answers to the threads above. I believe this breaks with transformers 4.31 now, but maybe that's fine

@fxmarty
Copy link
Contributor

fxmarty commented Aug 3, 2023

Can you run make style?

@SunMarc
Copy link
Member Author

SunMarc commented Aug 8, 2023

I believe this breaks with transformers 4.31 now, but maybe that's fine

I've changed it so that we install from source. I will change it back after the release of transformers.

@SunMarc SunMarc requested a review from fxmarty August 9, 2023 21:46
Copy link
Contributor

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice PR @SunMarc 🔥 🚀
Just commenting the title of the new doc section.
Also, could you add a bullet point for GPTQ in this section please?

docs/source/_toctree.yml Outdated Show resolved Hide resolved
@SunMarc SunMarc merged commit 9f2943e into huggingface:main Aug 10, 2023
60 of 67 checks passed
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants