Skip to content

Commit

Permalink
Add new quant method (huggingface#32047)
Browse files Browse the repository at this point in the history
* Add new quant method

* update

* fix multi-device

* add test

* add offload

* style

* style

* add simple example

* initial doc

* docstring

* style again

* works ?

* better docs

* switch to non persistant

* remove print

* fix init

* code review
  • Loading branch information
SunMarc authored and MHRDYN7 committed Jul 23, 2024
1 parent c1c51c9 commit 1f4ee6f
Show file tree
Hide file tree
Showing 16 changed files with 770 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
title: EETQ
- local: quantization/hqq
title: HQQ
- local: quantization/fbgemm_fp8
title: FBGEMM_FP8
- local: quantization/optimum
title: Optimum
- local: quantization/contribute
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,8 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## HqqConfig

[[autodoc]] HqqConfig

## FbgemmFp8Config

[[autodoc]] FbgemmFp8Config

58 changes: 58 additions & 0 deletions docs/source/en/quantization/fbgemm_fp8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# FBGEMM FP8

With FBGEMM FP8 quantization method, you can quantize your model in FP8 (W8A8):
- the weights will be quantized in 8bit (FP8) per channel
- the activation will be quantized in 8bit (FP8) per token

It relies on the [FBGEMM](https://github.com/pytorch/FBGEMM) library which provides efficient low-precision general matrix multiplication for small batch sizes and support for accuracy-loss minimizing techniques such as row-wise quantization and outlier-aware quantization.

> [!TIP]
> You need a GPU with compute capability>=9 (e.g. H100)
Before you begin, make sure the following libraries are installed with their latest version:

```bash
pip install --upgrade accelerate fbgemm-gpu torch
```

If you are having issues with fbgemm-gpu and torch library, you might need to install the nighlty release. You can follow the instruction [here](https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries:~:text=found%20here.-,Install%20the%20FBGEMM_GPU%20Package,-Install%20through%20PyTorch)


```py
from transformers import FbgemmFp8Config, AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = FbgemmFp8Config()
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained".

```py
quant_path = "/path/to/save/quantized/model"
model.save_pretrained(quant_path)
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
```
1 change: 1 addition & 0 deletions docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@ Use the table below to help you decide which quantization method to use.
| [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |

2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,7 @@
"AwqConfig",
"BitsAndBytesConfig",
"EetqConfig",
"FbgemmFp8Config",
"GPTQConfig",
"HqqConfig",
"QuantoConfig",
Expand Down Expand Up @@ -5671,6 +5672,7 @@
AwqConfig,
BitsAndBytesConfig,
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
HqqConfig,
QuantoConfig,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
"ggml": [
"GGUF_CONFIG_MAPPING",
"GGUF_TENSOR_MAPPING",
Expand Down Expand Up @@ -126,6 +127,7 @@
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
from .ggml import (
GGUF_CONFIG_MAPPING,
GGUF_TENSOR_MAPPING,
Expand Down
161 changes: 161 additions & 0 deletions src/transformers/integrations/fbgemm_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging


if is_torch_available():
import torch
from torch import nn

if is_accelerate_available():
from accelerate import init_empty_weights

if is_fbgemm_gpu_available():
import fbgemm_gpu.experimental.gen_ai # noqa: F401

logger = logging.get_logger(__name__)


class FbgemmFp8Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)

if bias:
self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
else:
self.bias = None

def forward(self, x):
num_tokens = None
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)

# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
output = torch.ops.fbgemm.f8f8bf16_rowwise(
x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
)
output = output + self.bias if self.bias is not None else output
# Hacky for now, we have the output to the device of x
output = output.to(x.device)
del x_quantized, x_scale
return output


def _replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
if current_key_name is None:
current_key_name = []

for name, module in model.named_children():
current_key_name.append(name)

if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights(include_buffers=True):
in_features = module.in_features
out_features = module.out_features
model._modules[name] = FbgemmFp8Linear(
in_features,
out_features,
module.bias is not None,
)
has_been_replaced = True

# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# set non persistant buffer outside of init_empty_weights
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub], dtype=torch.float
)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
pre_quantized=pre_quantized,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced


def replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
):
"""
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
This will enable running your models using high performance fp8 kernel from FBGEMM library.
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `FP8Linear`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""

modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert

if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
)

if not has_been_replaced:
logger.warning(
"You are loading your model using FP8 quantization but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)

return model
12 changes: 9 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def _load_state_dict_into_meta_model(

# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
if dtype is not None and torch.is_floating_point(param) and param.dtype != torch.float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and any(
Expand All @@ -894,7 +894,6 @@ def _load_state_dict_into_meta_model(
old_param = getattr(old_param, split)
if old_param is None:
break

if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
Expand Down Expand Up @@ -3955,6 +3954,14 @@ def from_pretrained(
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
):
device_map_kwargs["force_hooks"] = True
if (
hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
device_map_kwargs["offload_buffers"] = True

if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)

Expand Down Expand Up @@ -4105,7 +4112,6 @@ def _fix_key(key):
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)

Expand Down
11 changes: 9 additions & 2 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AwqConfig,
BitsAndBytesConfig,
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
HqqConfig,
QuantizationConfigMixin,
Expand All @@ -31,6 +32,7 @@
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
Expand All @@ -45,6 +47,7 @@
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -56,6 +59,7 @@
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"hqq": HqqConfig,
"fbgemm_fp8": FbgemmFp8Config,
}


Expand Down Expand Up @@ -156,8 +160,11 @@ def merge_quantization_configs(
if isinstance(quantization_config, dict):
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)

if isinstance(quantization_config, (GPTQConfig, AwqConfig)) and quantization_config_from_args is not None:
# special case for GPTQ / AWQ config collision
if (
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config))
and quantization_config_from_args is not None
):
# special case for GPTQ / AWQ / FbgemmFp8 config collision
loading_attr_dict = quantization_config_from_args.get_loading_attributes()
for attr, val in loading_attr_dict.items():
setattr(quantization_config, attr, val)
Expand Down
Loading

0 comments on commit 1f4ee6f

Please sign in to comment.