Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ integration #25062

Merged
merged 50 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d963b97
GTPQ integration
SunMarc Jul 24, 2023
93f0d84
Add tests for gptq
SunMarc Jul 24, 2023
380baea
support for more quantization model
SunMarc Jul 25, 2023
810d537
fix style
SunMarc Jul 25, 2023
c3f5248
typo
SunMarc Jul 25, 2023
fc70ef4
fix method
SunMarc Jul 25, 2023
6a04bb8
Update src/transformers/modeling_utils.py
SunMarc Jul 25, 2023
271dab6
add dataclass and fix quantization_method
SunMarc Jul 25, 2023
992881e
fix doc
SunMarc Jul 26, 2023
3c2d940
Update tests/quantization/gptq/test_gptq.py
SunMarc Jul 26, 2023
9bbb336
Apply suggestions from code review
SunMarc Jul 26, 2023
0134c79
modify dataclass
SunMarc Jul 26, 2023
a2a7f5d
add gtpqconfig import
SunMarc Jul 26, 2023
70e1416
fix typo
SunMarc Jul 26, 2023
0e2014b
fix tests
SunMarc Jul 26, 2023
69e3c88
remove dataset as req arg
SunMarc Jul 26, 2023
cb46d75
remove tokenizer import
SunMarc Jul 26, 2023
9a3cafd
add offload cpu quantization test
SunMarc Jul 26, 2023
27e9b79
fix check dataset
SunMarc Jul 26, 2023
f47ecb4
modify dockerfile
SunMarc Jul 26, 2023
19d05d3
protect trainer
SunMarc Jul 26, 2023
76dffe2
style
SunMarc Jul 26, 2023
0f61037
test for config
SunMarc Jul 26, 2023
b0eccd5
add more log
SunMarc Jul 27, 2023
2e7a025
overwrite torch_dtype
SunMarc Jul 27, 2023
a07126a
draft doc
SunMarc Jul 27, 2023
c9d3f26
modify quantization_config docstring
SunMarc Jul 31, 2023
ecce1da
fix class name in docstring
SunMarc Jul 31, 2023
2226184
Apply suggestions from code review
SunMarc Jul 31, 2023
eff99cb
more warning
SunMarc Jul 31, 2023
159cf87
fix 8bit kwargs tests
SunMarc Jul 31, 2023
98db723
peft compatibility
SunMarc Jul 31, 2023
0144760
remove var
SunMarc Aug 1, 2023
fd8d70c
fix is_gptq_quantized
SunMarc Aug 1, 2023
0f96fb2
Merge branch 'main' into gptq_integration
SunMarc Aug 1, 2023
be19916
remove is_gptq_quantized
SunMarc Aug 2, 2023
9e8f487
Merge remote-tracking branch 'upstream/main' into gptq_integration
SunMarc Aug 2, 2023
4b4336e
fix wrap
SunMarc Aug 2, 2023
42d0049
Update src/transformers/modeling_utils.py
SunMarc Aug 8, 2023
a9658e2
Merge remote-tracking branch 'upstream/main' into gptq_integration
SunMarc Aug 8, 2023
62aa293
add exllama
SunMarc Aug 9, 2023
39137eb
skip test
SunMarc Aug 9, 2023
f23ce7e
Merge remote-tracking branch 'upstream/main' into gptq_integration
SunMarc Aug 9, 2023
0b0633b
overwrite float16
SunMarc Aug 9, 2023
c3c4a16
style
SunMarc Aug 9, 2023
a45b5b0
fix skip test
SunMarc Aug 9, 2023
69c8fce
Apply suggestions from code review
SunMarc Aug 10, 2023
bf98799
fix docsting formatting
SunMarc Aug 10, 2023
7adf9cb
add doc
SunMarc Aug 10, 2023
c93d1d0
better test
SunMarc Aug 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
download_url,
has_file,
is_accelerate_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_offline_mode,
is_optimum_available,
Expand All @@ -75,7 +76,7 @@
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled
from .utils.quantization_config import BitsAndBytesConfig
from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
from .utils.versions import require_version_core


Expand Down Expand Up @@ -2344,15 +2345,35 @@ def from_pretrained(
else:
model_kwargs = kwargs

# get the quantization method inside the config of the model if it exists
quantization_method = None
if hasattr(config, "quantization_config"):
if config.quantization_config.get("load_in_8bit", False):
quantization_method = QuantizationMethod.BITS_AND_BYTES
else:
quantization_method = config.quantization_config.get("quant_method", None)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

quantizer = None
if quantization_method == QuantizationMethod.GPTQ:
if not (is_optimum_available() and is_auto_gptq_available()):
raise ImportError(
"Loading GTPQ quantized model requires optimum library : `pip install optimum` and auto-gptq library 'pip install auto-gptq'"
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
)
else:
# Need to protect the import
from optimum.gptq import GPTQQuantizer
quantizer = GPTQQuantizer.from_dict(config.quantization_config)
torch_dtype = config.torch_dtype

if is_8bit_serializable and quantization_config is not None and load_in_8bit:
if hasattr(config, "quantization_config"):
if quantization_method == QuantizationMethod.BITS_AND_BYTES:
logger.warning(
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a"
" `quantization_config` attribute. The `quantization_config` attribute will be overwritten with the"
" one you passed to `from_pretrained`."
)
config.quantization_config = quantization_config
elif is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"):
elif is_8bit_serializable and not load_in_8bit and quantization_method == QuantizationMethod.BITS_AND_BYTES:
quantization_config = config.quantization_config
if isinstance(quantization_config, dict):
quantization_config = BitsAndBytesConfig.from_dict(quantization_config, return_unused_kwargs=False)
Expand Down Expand Up @@ -2382,7 +2403,9 @@ def from_pretrained(
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True

elif not is_8bit_serializable and not load_in_8bit and hasattr(config, "quantization_config"):
elif (
not is_8bit_serializable and not load_in_8bit and quantization_method == QuantizationMethod.BITS_AND_BYTES
):
logger.warning(
"Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct"
" `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with "
Expand Down Expand Up @@ -2767,6 +2790,8 @@ def from_pretrained(
"All non-linear modules will be loaded in full precision."
" If you want to load the other modules in other precision, please specify a `torch_dtype` attribute."
)
if quantization_method == QuantizationMethod.GPTQ:
model = quantizer.convert_model(model)

if isinstance(device_map, str):
special_dtypes = {}
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .utils import (
is_accelerate_available,
is_apex_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_bs4_available,
is_cython_available,
Expand Down Expand Up @@ -770,6 +771,13 @@ def require_optimum(test_case):
return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)


def require_auto_gptq(test_case):
"""
Decorator for auto_gptq dependency
"""
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)


def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
get_torch_version,
is_accelerate_available,
is_apex_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_bs4_available,
is_coloredlogs_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_onnx_available = _is_package_available("onnx")
_openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
_auto_gptq_available = _is_package_available("auto_gptq")
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
Expand Down Expand Up @@ -554,6 +555,10 @@ def is_optimum_available():
return _optimum_available


def is_auto_gptq_available():
return _auto_gptq_available


def is_optimum_neuron_available():
return _optimum_available and _is_package_available("optimum.neuron")

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Union

from packaging import version
Expand All @@ -33,6 +34,11 @@
logger = logging.get_logger(__name__)


class QuantizationMethod(Enum):
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gtpq"
SunMarc marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class BitsAndBytesConfig:
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand Down Expand Up @@ -97,6 +103,7 @@ def __init__(
bnb_4bit_use_double_quant=False,
**kwargs,
):
self.quant_method = QuantizationMethod.BITS_AND_BYTES
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
self.llm_int8_threshold = llm_int8_threshold
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
145 changes: 145 additions & 0 deletions tests/quantization/gptq/test_gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import unittest

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import (
is_torch_available,
require_accelerate,
require_auto_gptq,
require_optimum,
require_torch_gpu,
require_torch_multi_gpu,
slow,
)


if is_torch_available():
import torch


@slow
@require_optimum
@require_auto_gptq
@require_torch_gpu
class GTPQTest(unittest.TestCase):
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
model_name = "bigscience/bloom-560m"

input_text = "Hello my name is"
EXPECTED_OUTPUT = "Hello my name is John and I am a professional photographer. I"

# this seems a little small considering that we are doing 4bit quant but we have a small model and ww don't quantize the embeddings
EXPECTED_RELATIVE_DIFFERENCE = 1.664253062

bits = 4
group_size = 128
desc_act = False

dataset = [
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
]

device_map = None

# called only once for all test in this class
@classmethod
def setUpClass(cls):
from optimum.gptq import GPTQQuantizer

"""
Setup quantized model
"""
cls.model_fp16 = AutoModelForCausalLM.from_pretrained(
cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map
)
cls.mem_fp16 = cls.model_fp16.get_memory_footprint()

cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True)
cls.quantizer = GPTQQuantizer(bits=cls.bits, group_size=cls.group_size, desc_act=cls.desc_act)

cls.quantized_model = cls.quantizer.quantize_model(cls.model_fp16, cls.tokenizer, cls.dataset)

def test_memory_footprint(self):
r"""
A simple test to check if the model conversion has been done correctly by checking on the
memory footprint of the converted model
"""

mem_quantized = self.quantized_model.get_memory_footprint()

self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE)

def test_quantized_layers_class(self):
"""
Simple test to check if the model conversion has been done correctly by checking on
the class type of the linear layers of the converted models
"""
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear

QuantLinear = dynamically_import_QuantLinear(
use_triton=False, desc_act=self.desc_act, group_size=self.group_size
)
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)

def check_inference_correctness(self, model):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Given that we are operating on small numbers + the testing model is relatively small, we might not get
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
"""
# Check that inference pass works on the model
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

# Check the exactness of the results
output_parallel = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

# Get the generation
self.assertEqual(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_generate_quality(self):
"""
Simple test to check the quality of the model by comapring the the generated tokens with the expected tokens
"""
if self.device_map is None:
self.check_inference_correctness(self.quantized_model.to(0))
else:
self.check_inference_correctness(self.quantized_model)

def test_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights works
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.to("cpu").save_pretrained(tmpdirname)
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname).to(0)
self.check_inference_correctness(quantized_model_from_saved)

@require_accelerate
def test_serialization_big_model_inference(self):
"""
Test the serialization of the model and the loading of the quantized weights with big model inference
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.to("cpu").save_pretrained(tmpdirname)
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto")
self.check_inference_correctness(quantized_model_from_saved)


@require_accelerate
@require_torch_multi_gpu
class GTPQTestDeviceMap(GTPQTest):
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
device_map = "auto"
Loading