Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HQQ quantization support #29637

Merged
merged 78 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
bbc68fe
update HQQ transformers integration
mobicham Apr 24, 2024
2a1f224
Merge branch 'huggingface:main' into stable
mobicham Apr 24, 2024
e1e5df6
push import_utils.py
mobicham Apr 24, 2024
0192b03
add force_hooks check in modeling_utils.py
mobicham Apr 24, 2024
823de37
fix | with Optional
mobicham Apr 24, 2024
08d7b8e
force bias as param
mobicham Apr 24, 2024
e1fa6c9
check bias is Tensor
mobicham Apr 24, 2024
6e854ca
force forward for multi-gpu
mobicham Apr 24, 2024
2b9f271
review fixes pass
mobicham Apr 25, 2024
5bb9ca2
remove torch grad()
mobicham Apr 25, 2024
392e7c5
if any key in linear_tags fix
mobicham Apr 25, 2024
20f9ad5
add cpu/disk check
mobicham Apr 25, 2024
3a5679a
isinstance return
mobicham Apr 25, 2024
7a1bbca
add multigpu test + refactor tests
mobicham Apr 25, 2024
65b2887
clean hqq_utils imports in hqq.py
mobicham Apr 25, 2024
bba74cd
clean hqq_utils imports in quantizer_hqq.py
mobicham Apr 25, 2024
de88c2a
delete hqq_utils.py
mobicham Apr 25, 2024
651a586
Delete src/transformers/utils/hqq_utils.py
mobicham Apr 25, 2024
d07ea85
ruff init
mobicham Apr 25, 2024
dedf69e
remove torch.float16 from __init__ in test
mobicham Apr 25, 2024
0edf8a4
refactor test
mobicham Apr 25, 2024
c7ec123
isinstance -> type in quantizer_hqq.py
mobicham Apr 26, 2024
5283ac2
cpu/disk device_map check in quantizer_hqq.py
mobicham Apr 29, 2024
15daeb4
remove type(module) nn.linear check in quantizer_hqq.py
mobicham Apr 29, 2024
bc4bc73
add BaseQuantizeConfig import inside HqqConfig init
mobicham Apr 29, 2024
b54e87b
remove hqq import in hqq.py
mobicham Apr 29, 2024
0f9698a
remove accelerate import from test_hqq.py
mobicham Apr 29, 2024
d31837f
quant config.py doc update
mobicham Apr 29, 2024
b8f792c
add hqqconfig to main_classes doc
mobicham Apr 29, 2024
8b84cb1
Merge branch 'huggingface:main' into stable
mobicham Apr 29, 2024
9a061e5
make style
mobicham Apr 29, 2024
8612282
__init__ fix
mobicham Apr 29, 2024
b786793
ruff __init__
mobicham Apr 29, 2024
e7ba717
skip_modules list
mobicham Apr 29, 2024
3a38f21
hqqconfig format fix
mobicham Apr 29, 2024
9eee213
hqqconfig doc fix
mobicham Apr 29, 2024
03cc8e6
hqqconfig doc fix
mobicham Apr 29, 2024
96bd141
hqqconfig doc fix
mobicham Apr 29, 2024
713d226
hqqconfig doc fix
mobicham Apr 29, 2024
dad9a60
hqqconfig doc fix
mobicham Apr 29, 2024
67c0985
hqqconfig doc fix
mobicham Apr 29, 2024
94c393a
hqqconfig doc fix
mobicham Apr 29, 2024
35fc9f5
hqqconfig doc fix
mobicham Apr 29, 2024
06f6497
hqqconfig doc fix
mobicham Apr 29, 2024
25fde9c
test_hqq.py remove mistral comment
mobicham Apr 30, 2024
ee50516
remove self.using_multi_gpu is False
mobicham Apr 30, 2024
01d798a
torch_dtype default val set and logger.info
mobicham Apr 30, 2024
a909ca8
hqq.py isinstance fix
mobicham May 2, 2024
c466c89
remove torch=None
mobicham May 2, 2024
d522fed
torch_device test_hqq
mobicham May 2, 2024
a09e90f
rename test_hqq
mobicham May 2, 2024
5bdf40f
MODEL_ID in test_hqq
mobicham May 2, 2024
e693d47
quantizer_hqq setattr fix
mobicham May 2, 2024
f5cabe5
quantizer_hqq typo fix
mobicham May 2, 2024
5ede086
imports quantizer_hqq.py
mobicham May 2, 2024
c86000b
isinstance quantizer_hqq
mobicham May 2, 2024
7d3e083
hqq_layer.bias reformat quantizer_hqq
mobicham May 2, 2024
082dfea
Step 2 as comment in quantizer_hqq
mobicham May 2, 2024
667f1ad
prepare_for_hqq_linear() comment
mobicham May 2, 2024
e0cd784
keep_in_fp32_modules fix
mobicham May 2, 2024
5d3b504
HqqHfQuantizer reformat
mobicham May 2, 2024
cc1961c
quantization.md hqqconfig
mobicham May 2, 2024
9aa9e15
quantization.md model example reformat
mobicham May 2, 2024
9273e21
quantization.md # space
mobicham May 2, 2024
f29e7a4
quantization.md space })
mobicham May 2, 2024
5168852
quantization.md space })
mobicham May 2, 2024
0dfe080
quantization_config fix doc
mobicham May 2, 2024
2934052
axis value check in quantization_config
mobicham May 2, 2024
bc7cf4e
format
mobicham May 2, 2024
d33f944
dynamic config explanation
mobicham May 2, 2024
3522f0a
quant config method in quantization.md
mobicham May 2, 2024
cc14c21
remove shard-level progress
mobicham May 2, 2024
1e81036
.cuda fix modeling_utils
mobicham May 2, 2024
ca07f5a
test_hqq fixes
mobicham May 2, 2024
4cc776e
Merge branch 'huggingface:main' into stable
mobicham May 2, 2024
3d777ed
make fix-copies
mobicham May 2, 2024
b808858
Merge branch 'huggingface:main' into stable
mobicham May 2, 2024
5e71139
Merge branch 'huggingface:main' into stable
mobicham May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/transformers-quantization-latest-gpu/Dockerfile
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
# Add aqlm for quantization testing
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2

# Add hqq for quantization testing
RUN python3 -m pip install --no-cache-dir hqq

# Add autoawq for quantization testing
# >=v0.2.3 needed for compatibility with torch 2.2.1
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/quantization.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## HfQuantizer

[[autodoc]] quantizers.base.HfQuantizer

## HqqConfig

[[autodoc]] HqqConfig
40 changes: 40 additions & 0 deletions docs/source/en/quantization.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,43 @@ The speed and throughput of fused and unfused modules were also tested with the
<figcaption class="mt-2 text-center text-sm text-gray-500">generate throughput/batch size</figcaption>
</div>
</div>

## HQQ
Half-Quadratic Quantization (HQQ) implements on-the-fly quantization via fast robust optimization. It doesn't require calibration data and can be used to quantize any model.
Please refer to the <a href="https://github.com/mobiusml/hqq/">official package</a> for more details.

For installation, we recommend you use the following approach to get the latest version and build its corresponding CUDA kernels:
```
pip install hqq
```

To quantize a model, you need to create an ```HqqConfig``` as follows:
mobicham marked this conversation as resolved.
Show resolved Hide resolved
``` Python
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

#Linear layers will use the same quantization config
mobicham marked this conversation as resolved.
Show resolved Hide resolved
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default

#Each type of linear layer (referred to as linear tag) will use different quantization parameters
mobicham marked this conversation as resolved.
Show resolved Hide resolved
mobicham marked this conversation as resolved.
Show resolved Hide resolved
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
quant_config = HqqConfig(dynamic_config={
'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,
'self_attn.v_proj':q4_config,
'self_attn.o_proj':q4_config,

'mlp.gate_proj':q3_config,
'mlp.up_proj' :q3_config,
'mlp.down_proj':q3_config,
})
mobicham marked this conversation as resolved.
Show resolved Hide resolved
```

Then you simply quantize the model as follows
``` Python
model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda", quantization_config=quant_config)
mobicham marked this conversation as resolved.
Show resolved Hide resolved
```
### Optimized Runtime
HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training.
For faster inference, HQQ supports 4-bit fused kernels (TorchAO and Marlin), reaching up to 200 tokens/sec on a single 4090.
For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend
Empty file modified docs/source/en/quicktour.md
100644 → 100755
mobicham marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,7 @@
"BitsAndBytesConfig",
"EetqConfig",
"GPTQConfig",
"HqqConfig",
"QuantoConfig",
],
}
Expand Down Expand Up @@ -6097,6 +6098,7 @@
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
HqqConfig,
QuantoConfig,
)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
"hqq": ["prepare_for_hqq_linear"],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
"AzureMLCallback",
Expand Down Expand Up @@ -113,6 +114,7 @@
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
from .hqq import prepare_for_hqq_linear
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
AzureMLCallback,
Expand Down
123 changes: 123 additions & 0 deletions src/transformers/integrations/hqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"HQQ (Half-Quadratic Quantization) integration file"

from ..utils import is_hqq_available, is_torch_available, logging


if is_torch_available():
import torch
else:
torch = None
mobicham marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.get_logger(__name__)


# Name all modules inside the model
def autoname_modules(model):
for name, module in model.named_modules():
module.name = name


# Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj
def name_to_linear_tag(name):
return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))])


# Get all linear tags available
def get_linear_tags(model):
if is_hqq_available():
from hqq.core.quantize import HQQLinear

linear_tags = set()
for name, module in model.named_modules():
if type(module) in [torch.nn.Linear, HQQLinear]:
mobicham marked this conversation as resolved.
Show resolved Hide resolved
linear_tags.add(name_to_linear_tag(name))
return list(linear_tags)


def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None):
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if isinstance(module, torch.nn.Linear):
# Get linear tag
linear_tag = name_to_linear_tag(module.name)

# We put the module quant_config into the nn.Linear layer so we can access it later in quantizer_hqq.create_quantized_param()
if linear_tag in patch_params:
if patch_params[linear_tag] is not None:
model._modules[name].quant_config = patch_params[linear_tag]
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)

has_been_replaced = True

if len(list(module.children())) > 0:
_, has_been_replaced = _prepare_for_hqq_linear(
module,
patch_params=patch_params,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)

return model, has_been_replaced


def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_convert=None, has_been_replaced=False):
"""
Prepares nn.Linear layers for HQQ quantization.
Since each layer type can have separate quantization parameters, we need to do the following:
1- tag each module with its neme via autoname_modules()
2- Extract linear_tags (e.g. ['self_attn.q_proj', ...])
3- Map quantization parameters as a dictionary linear_tag -> quant_params as HQQLinear exepects it, this is referred to as patch_params
"""

modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert

# Add name to module
autoname_modules(model)
mobicham marked this conversation as resolved.
Show resolved Hide resolved

# Get linear tags. This allows us to use different quant params to different layer types
linear_tags = get_linear_tags(model)

# Convert quantization_config to layer-wise config
skip_modules = quantization_config.skip_modules
quant_config = quantization_config.to_dict()
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))

if any(key in linear_tags for key in quant_config.keys()):
# If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None)
patch_params = {key: None for key in linear_tags}
patch_params.update(quant_config)
else:
# Same quant_config for all layers
patch_params = {k: quant_config for k in linear_tags}

model, has_been_replaced = _prepare_for_hqq_linear(
model, patch_params=patch_params, has_been_replaced=has_been_replaced
)

# We store quantization config as linear_tag -> hqq quant config
model.config.quantization_config = patch_params

if not has_been_replaced:
logger.warning("No linear modules were found in your model for quantization.")

return model
Empty file modified src/transformers/integrations/integration_utils.py
100644 → 100755
mobicham marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
20 changes: 19 additions & 1 deletion src/transformers/modeling_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm as tqdm_lib

from .activations import get_activation
from .configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -808,7 +809,13 @@ def _load_state_dict_into_meta_model(
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

for param_name, param in state_dict.items():
# Show shard-level progress. Useful to monitor quantization progress
quant_show_progress = False
mobicham marked this conversation as resolved.
Show resolved Hide resolved
if hf_quantizer is not None:
if hasattr(hf_quantizer, "show_progress"):
quant_show_progress = hf_quantizer.show_progress
mobicham marked this conversation as resolved.
Show resolved Hide resolved
mobicham marked this conversation as resolved.
Show resolved Hide resolved

for param_name, param in tqdm_lib(state_dict.items(), disable=not quant_show_progress):
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
continue
Expand Down Expand Up @@ -2656,6 +2663,8 @@ def get_memory_footprint(self, return_buffers=True):

@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
mobicham marked this conversation as resolved.
Show resolved Hide resolved
# Checks if the model has been loaded in 8-bit
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
Expand All @@ -2667,6 +2676,8 @@ def cuda(self, *args, **kwargs):

@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
Expand Down Expand Up @@ -3736,6 +3747,13 @@ def from_pretrained(
}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
# For HQQ method we force-set the hooks for single GPU envs
if (
"force_hooks" in inspect.signature(dispatch_model).parameters
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
):
device_map_kwargs["force_hooks"] = True
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)

Expand Down
Empty file modified src/transformers/quantizers/__init__.py
100644 → 100755
Empty file.
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
HqqConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
Expand All @@ -31,6 +32,7 @@
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HQQHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer


Expand All @@ -42,6 +44,7 @@
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"hqq": HQQHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -52,6 +55,7 @@
"gptq": GPTQConfig,
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"hqq": HqqConfig,
}


Expand Down
Loading
Loading