From 1f4ee6f274468ba80839fcc77f884976b362e688 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Mon, 22 Jul 2024 20:21:59 +0200 Subject: [PATCH] Add new quant method (#32047) * Add new quant method * update * fix multi-device * add test * add offload * style * style * add simple example * initial doc * docstring * style again * works ? * better docs * switch to non persistant * remove print * fix init * code review --- docs/source/en/_toctree.yml | 2 + docs/source/en/main_classes/quantization.md | 5 + docs/source/en/quantization/fbgemm_fp8.md | 58 ++++ docs/source/en/quantization/overview.md | 1 + src/transformers/__init__.py | 2 + src/transformers/integrations/__init__.py | 2 + src/transformers/integrations/fbgemm_fp8.py | 161 +++++++++++ src/transformers/modeling_utils.py | 12 +- src/transformers/quantizers/auto.py | 11 +- .../quantizers/quantizer_fbgemm_fp8.py | 205 +++++++++++++ src/transformers/testing_utils.py | 8 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 + src/transformers/utils/quantization_config.py | 32 +++ tests/quantization/fbgemm_fp8/__init__.py | 0 .../fbgemm_fp8/test_fbgemm_fp8.py | 270 ++++++++++++++++++ 16 files changed, 770 insertions(+), 5 deletions(-) create mode 100644 docs/source/en/quantization/fbgemm_fp8.md create mode 100644 src/transformers/integrations/fbgemm_fp8.py create mode 100644 src/transformers/quantizers/quantizer_fbgemm_fp8.py create mode 100644 tests/quantization/fbgemm_fp8/__init__.py create mode 100644 tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 430670aa4364e6..740bb4b0719c61 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -157,6 +157,8 @@ title: EETQ - local: quantization/hqq title: HQQ + - local: quantization/fbgemm_fp8 + title: FBGEMM_FP8 - local: quantization/optimum title: Optimum - local: quantization/contribute diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index f1e2acdcfe4809..fc5808415cbe5f 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -56,3 +56,8 @@ Learn how to quantize models in the [Quantization](../quantization) guide. ## HqqConfig [[autodoc]] HqqConfig + +## FbgemmFp8Config + +[[autodoc]] FbgemmFp8Config + diff --git a/docs/source/en/quantization/fbgemm_fp8.md b/docs/source/en/quantization/fbgemm_fp8.md new file mode 100644 index 00000000000000..4df194d31be7ca --- /dev/null +++ b/docs/source/en/quantization/fbgemm_fp8.md @@ -0,0 +1,58 @@ + + +# FBGEMM FP8 + +With FBGEMM FP8 quantization method, you can quantize your model in FP8 (W8A8): +- the weights will be quantized in 8bit (FP8) per channel +- the activation will be quantized in 8bit (FP8) per token + +It relies on the [FBGEMM](https://github.com/pytorch/FBGEMM) library which provides efficient low-precision general matrix multiplication for small batch sizes and support for accuracy-loss minimizing techniques such as row-wise quantization and outlier-aware quantization. + +> [!TIP] +> You need a GPU with compute capability>=9 (e.g. H100) + +Before you begin, make sure the following libraries are installed with their latest version: + +```bash +pip install --upgrade accelerate fbgemm-gpu torch +``` + +If you are having issues with fbgemm-gpu and torch library, you might need to install the nighlty release. You can follow the instruction [here](https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries:~:text=found%20here.-,Install%20the%20FBGEMM_GPU%20Package,-Install%20through%20PyTorch) + + +```py +from transformers import FbgemmFp8Config, AutoModelForCausalLM, AutoTokenizer + +model_name = "meta-llama/Meta-Llama-3-8B" +quantization_config = FbgemmFp8Config() +quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config) + +tokenizer = AutoTokenizer.from_pretrained(model_name) +input_text = "What are we having for dinner?" +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +output = quantized_model.generate(**input_ids, max_new_tokens=10) +print(tokenizer.decode(output[0], skip_special_tokens=True)) +``` + +A quantized model can be saved via "saved_pretrained" and be reused again via the "from_pretrained". + +```py +quant_path = "/path/to/save/quantized/model" +model.save_pretrained(quant_path) +model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto") +``` \ No newline at end of file diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 6cd13fc894633b..99fc669e49f448 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -55,4 +55,5 @@ Use the table below to help you decide which quantization method to use. | [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ | | [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ | | [Quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/quanto | +| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a6d173a0ce4048..bfcf95311ac86b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -934,6 +934,7 @@ "AwqConfig", "BitsAndBytesConfig", "EetqConfig", + "FbgemmFp8Config", "GPTQConfig", "HqqConfig", "QuantoConfig", @@ -5671,6 +5672,7 @@ AwqConfig, BitsAndBytesConfig, EetqConfig, + FbgemmFp8Config, GPTQConfig, HqqConfig, QuantoConfig, diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 9b838bd1608490..4c756a23ae0aa4 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -45,6 +45,7 @@ "unset_hf_deepspeed_config", ], "eetq": ["replace_with_eetq_linear"], + "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"], "ggml": [ "GGUF_CONFIG_MAPPING", "GGUF_TENSOR_MAPPING", @@ -126,6 +127,7 @@ unset_hf_deepspeed_config, ) from .eetq import replace_with_eetq_linear + from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear from .ggml import ( GGUF_CONFIG_MAPPING, GGUF_TENSOR_MAPPING, diff --git a/src/transformers/integrations/fbgemm_fp8.py b/src/transformers/integrations/fbgemm_fp8.py new file mode 100644 index 00000000000000..a0f5b2b76089b9 --- /dev/null +++ b/src/transformers/integrations/fbgemm_fp8.py @@ -0,0 +1,161 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging + + +if is_torch_available(): + import torch + from torch import nn + +if is_accelerate_available(): + from accelerate import init_empty_weights + +if is_fbgemm_gpu_available(): + import fbgemm_gpu.experimental.gen_ai # noqa: F401 + +logger = logging.get_logger(__name__) + + +class FbgemmFp8Linear(torch.nn.Module): + def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn)) + self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype)) + self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False) + + if bias: + self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype)) + else: + self.bias = None + + def forward(self, x): + num_tokens = None + # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. + # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45 + x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub + ) + # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works + # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device) + + # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight + output = torch.ops.fbgemm.f8f8bf16_rowwise( + x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True + ) + output = output + self.bias if self.bias is not None else output + # Hacky for now, we have the output to the device of x + output = output.to(x.device) + del x_quantized, x_scale + return output + + +def _replace_with_fbgemm_fp8_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + pre_quantized=False, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + if current_key_name is None: + current_key_name = [] + + for name, module in model.named_children(): + current_key_name.append(name) + + if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(include_buffers=True): + in_features = module.in_features + out_features = module.out_features + model._modules[name] = FbgemmFp8Linear( + in_features, + out_features, + module.bias is not None, + ) + has_been_replaced = True + + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + # set non persistant buffer outside of init_empty_weights + model._modules[name].input_scale_ub = torch.tensor( + [quantization_config.activation_scale_ub], dtype=torch.float + ) + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_fbgemm_fp8_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + pre_quantized=pre_quantized, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_fbgemm_fp8_linear( + model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False +): + """ + A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules. + This will enable running your models using high performance fp8 kernel from FBGEMM library. + + The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should + be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no + CPU/GPU memory is required to run this function. Each weight will be quantized along the channel. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): + Names of the modules to not convert in `FP8Linear`. In practice we keep the `lm_head` in full precision + for numerical stability reasons. + current_key_name (`List[`str`]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or + `disk`). + """ + + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + + if quantization_config.modules_to_not_convert is not None: + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) + modules_to_not_convert = list(set(modules_to_not_convert)) + model, has_been_replaced = _replace_with_fbgemm_fp8_linear( + model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model using FP8 quantization but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a2cea6dcdc2483..a20b7d941fbfe6 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -868,7 +868,7 @@ def _load_state_dict_into_meta_model( # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params # in int/uint/bool and not cast them. - if dtype is not None and torch.is_floating_point(param): + if dtype is not None and torch.is_floating_point(param) and param.dtype != torch.float8_e4m3fn: if ( keep_in_fp32_modules is not None and any( @@ -894,7 +894,6 @@ def _load_state_dict_into_meta_model( old_param = getattr(old_param, split) if old_param is None: break - if old_param is not None: if dtype is None: param = param.to(old_param.dtype) @@ -3955,6 +3954,14 @@ def from_pretrained( and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ ): device_map_kwargs["force_hooks"] = True + if ( + hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8 + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + device_map_kwargs["offload_buffers"] = True + if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): dispatch_model(model, **device_map_kwargs) @@ -4105,7 +4112,6 @@ def _fix_key(key): if cls._keys_to_ignore_on_load_unexpected is not None: for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - if hf_quantizer is not None: missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 2c65afa77e282c..40aa86fc37c733 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -20,6 +20,7 @@ AwqConfig, BitsAndBytesConfig, EetqConfig, + FbgemmFp8Config, GPTQConfig, HqqConfig, QuantizationConfigMixin, @@ -31,6 +32,7 @@ from .quantizer_bnb_4bit import Bnb4BitHfQuantizer from .quantizer_bnb_8bit import Bnb8BitHfQuantizer from .quantizer_eetq import EetqHfQuantizer +from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer from .quantizer_gptq import GptqHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer @@ -45,6 +47,7 @@ "quanto": QuantoHfQuantizer, "eetq": EetqHfQuantizer, "hqq": HqqHfQuantizer, + "fbgemm_fp8": FbgemmFp8HfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -56,6 +59,7 @@ "aqlm": AqlmConfig, "quanto": QuantoConfig, "hqq": HqqConfig, + "fbgemm_fp8": FbgemmFp8Config, } @@ -156,8 +160,11 @@ def merge_quantization_configs( if isinstance(quantization_config, dict): quantization_config = AutoQuantizationConfig.from_dict(quantization_config) - if isinstance(quantization_config, (GPTQConfig, AwqConfig)) and quantization_config_from_args is not None: - # special case for GPTQ / AWQ config collision + if ( + isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config)) + and quantization_config_from_args is not None + ): + # special case for GPTQ / AWQ / FbgemmFp8 config collision loading_attr_dict = quantization_config_from_args.get_loading_attributes() for attr, val in loading_attr_dict.items(): setattr(quantization_config, attr, val) diff --git a/src/transformers/quantizers/quantizer_fbgemm_fp8.py b/src/transformers/quantizers/quantizer_fbgemm_fp8.py new file mode 100644 index 00000000000000..6591a56fce7840 --- /dev/null +++ b/src/transformers/quantizers/quantizer_fbgemm_fp8.py @@ -0,0 +1,205 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from packaging import version + +from .base import HfQuantizer + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging +from .quantizers_utils import get_module_from_name + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +class FbgemmFp8HfQuantizer(HfQuantizer): + """ + FP8 quantization using fbgemm kernels + """ + + requires_parameters_quantization = True + requires_calibration = False + + required_packages = ["fbgemm-gpu", "accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + + def validate_environment(self, *args, **kwargs): + if not is_torch_available() or version.parse(importlib.metadata.version("torch")) < version.parse("2.1.0"): + raise ImportError( + "Using fbgemm fp8 quantization requires torch > 2.1.0" + "Please install the latest version of torch ( pip install --upgrade torch )" + ) + if not is_fbgemm_gpu_available(): + raise ImportError( + "Using fbgemm fp8 quantization requires fbgemm-gpu library" + "Please install the latest version of fbgemm-gpu library by following : https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries" + ) + + if not is_accelerate_available("0.32.2"): + raise ImportError( + "Loading an FP8 quantized model requires accelerate > 0.32.1 (`pip install --upgrade accelerate`)" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("Using FP8 quantized models with fbgemm kernels requires a GPU") + + compute_capability = torch.cuda.get_device_capability() + major, minor = compute_capability + if major < 9: + raise ValueError( + "FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)" + ) + + device_map = kwargs.get("device_map", None) + if device_map is None: + logger.warning_once( + "You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set " + "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. " + ) + elif device_map is not None: + if ( + not self.pre_quantized + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to load an FP8 model with a device_map that contains a CPU or disk device." + "This is not supported when the model is quantized on the fly. " + "Please use a quantized checkpoint or remove the CPU or disk device from the device_map." + ) + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + torch_dtype = torch.bfloat16 + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.bloat16` due to " + "requirements of `fbgemm-gpu` to enable model loading in fp8. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.bfloat16 to remove this warning.", + torch_dtype, + ) + elif torch_dtype == torch.float16: + raise ValueError( + "You cannot use FP8 with torch_dtype=torch.float16." + "We recommend you passing torch_dtype=torch.bfloat16" + ) + return torch_dtype + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + from ..integrations import FbgemmFp8Linear + + module, tensor_name = get_module_from_name(model, param_name) + + if isinstance(module, FbgemmFp8Linear): + if self.pre_quantized or tensor_name == "bias": + if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: + raise ValueError("Expect quantized weights but got an unquantized weight") + return False + else: + if tensor_name == "weight_scale": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + return True + return False + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + """ + Quantizes weights into weight and weight_scale + """ + new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value) + + module, tensor_name = get_module_from_name(model, param_name) + module._buffers[tensor_name] = new_value.to(target_device) + # to have the right output shape -> (out_features, 1) + module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device) + + if unexpected_keys is not None and param_name in unexpected_keys: + unexpected_keys.remove(param_name) + del param_name + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + from ..integrations import get_keys_to_not_convert, replace_with_fbgemm_fp8_linear + + self.modules_to_not_convert = get_keys_to_not_convert(model) + + if self.quantization_config.modules_to_not_convert is not None: + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) + + model = replace_with_fbgemm_fp8_linear( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + pre_quantized=self.pre_quantized, + ) + + model.config.quantization_config = self.quantization_config + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + from ..integrations import FbgemmFp8Linear + + not_missing_keys = [] + for name, module in model.named_modules(): + if isinstance(module, FbgemmFp8Linear): + for missing in missing_keys: + if ( + (name in missing or name in f"{prefix}.{missing}") + and not missing.endswith(".weight") + and not missing.endswith(".bias") + ): + not_missing_keys.append(missing) + return [k for k in missing_keys if k not in not_missing_keys] + + @property + def is_serializable(self): + return True + + @property + def is_trainable(self) -> bool: + return False diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 60ff7815a971ae..edfc9519963bee 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -68,6 +68,7 @@ is_eetq_available, is_essentia_available, is_faiss_available, + is_fbgemm_gpu_available, is_flash_attn_2_available, is_flax_available, is_fsdp_available, @@ -1116,6 +1117,13 @@ def require_quanto(test_case): return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case) +def require_fbgemm_gpu(test_case): + """ + Decorator for fbgemm_gpu dependency + """ + return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case) + + def require_phonemizer(test_case): """ Decorator marking a test that requires phonemizer diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 351ab0cf11ffba..efe473a6cdeda2 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -127,6 +127,7 @@ is_eetq_available, is_essentia_available, is_faiss_available, + is_fbgemm_gpu_available, is_flash_attn_2_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index bd14dd8cd7530c..f81b9d3dba41bd 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -98,6 +98,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _av_available = importlib.util.find_spec("av") is not None _bitsandbytes_available = _is_package_available("bitsandbytes") _eetq_available = _is_package_available("eetq") +_fbgemm_gpu_available = _is_package_available("fbgemm_gpu") _galore_torch_available = _is_package_available("galore_torch") _lomo_available = _is_package_available("lomo_optim") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. @@ -888,6 +889,10 @@ def is_eetq_available(): return _eetq_available +def is_fbgemm_gpu_available(): + return _fbgemm_gpu_available + + def is_levenshtein_available(): return _levenshtein_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 506c4db447c7aa..5de8307c3bd79b 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -42,6 +42,7 @@ class QuantizationMethod(str, Enum): QUANTO = "quanto" EETQ = "eetq" HQQ = "hqq" + FBGEMM_FP8 = "fbgemm_fp8" class AWQLinearVersion(str, Enum): @@ -1047,3 +1048,34 @@ def post_init(self): accepted_weights = ["int8"] if self.weights not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}") + + +@dataclass +class FbgemmFp8Config(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using fbgemm fp8 quantization. + + Args: + activation_scale_ub (`float`, *optional*, defaults to 1200.0): + The activation scale upper bound. This is used when quantizing the input activation. + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + """ + + def __init__( + self, + activation_scale_ub: float = 1200.0, + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.FBGEMM_FP8 + self.activation_scale_ub = activation_scale_ub + self.modules_to_not_convert = modules_to_not_convert + + def get_loading_attributes(self): + attibutes_dict = copy.deepcopy(self.__dict__) + loading_attibutes = ["activation_scale_ub"] + loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} + return loading_attibutes_dict diff --git a/tests/quantization/fbgemm_fp8/__init__.py b/tests/quantization/fbgemm_fp8/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py b/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py new file mode 100644 index 00000000000000..61a1eecba8d3df --- /dev/null +++ b/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FbgemmFp8Config, OPTForCausalLM +from transformers.testing_utils import ( + require_accelerate, + require_fbgemm_gpu, + require_read_token, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) +from transformers.utils import is_accelerate_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +@require_torch_gpu +class FbgemmFp8ConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object + """ + quantization_config = FbgemmFp8Config() + config_to_dict = quantization_config.to_dict() + + for key in config_to_dict: + self.assertEqual(getattr(quantization_config, key), config_to_dict[key]) + + def test_from_dict(self): + """ + Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict + """ + dict = {"modules_to_not_convert": ["lm_head.weight"], "quant_method": "fbgemm_fp8"} + quantization_config = FbgemmFp8Config.from_dict(dict) + + self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert) + self.assertEqual(dict["quant_method"], quantization_config.quant_method) + + +@slow +@require_torch_gpu +@require_fbgemm_gpu +@require_accelerate +@require_read_token +class FbgemmFp8Test(unittest.TestCase): + model_name = "meta-llama/Meta-Llama-3-8B" + + input_text = "What are we having for dinner?" + max_new_tokens = 9 + + EXPECTED_OUTPUT = "What are we having for dinner?\nI'm having a steak and a salad" + + device_map = "cuda" + + offload_device_map = { + "model.embed_tokens": 0, + "model.layers.0": 0, + "model.layers.1": 0, + "model.layers.2": 0, + "model.layers.3": 0, + "model.layers.4": 0, + "model.layers.5": 0, + "model.layers.6": 0, + "model.layers.7": 0, + "model.layers.8": 0, + "model.layers.9": 0, + "model.layers.10": 0, + "model.layers.11": 0, + "model.layers.12": 0, + "model.layers.13": 0, + "model.layers.14": 0, + "model.layers.15": 0, + "model.layers.16": "cpu", + "model.layers.17": "cpu", + "model.layers.18": "cpu", + "model.layers.19": "cpu", + "model.layers.20": "disk", + "model.layers.21": "disk", + "model.layers.22": "disk", + "model.layers.23": "disk", + "model.layers.24": "disk", + "model.layers.25": "disk", + "model.layers.26": "disk", + "model.layers.27": "disk", + "model.layers.28": "disk", + "model.layers.29": "disk", + "model.layers.30": "disk", + "model.layers.31": "disk", + "model.norm": "disk", + "lm_head": "disk", + } + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + quantization_config = FbgemmFp8Config() + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, device_map=cls.device_map, quantization_config=quantization_config + ) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_quantized_model_conversion(self): + """ + Simple test that checks if the quantized model has been converted properly + """ + + from transformers.integrations import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear + + model_id = "facebook/opt-350m" + config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5") + quantization_config = FbgemmFp8Config() + + with init_empty_weights(): + model = OPTForCausalLM(config) + + nb_linears = 0 + for module in model.modules(): + if isinstance(module, torch.nn.Linear): + nb_linears += 1 + + model = replace_with_fbgemm_fp8_linear(model, quantization_config=quantization_config) + nb_fbgemm_linear = 0 + for module in model.modules(): + if isinstance(module, FbgemmFp8Linear): + nb_fbgemm_linear += 1 + + self.assertEqual(nb_linears - 1, nb_fbgemm_linear) + + with init_empty_weights(): + model = OPTForCausalLM(config) + quantization_config = FbgemmFp8Config(modules_to_not_convert=["fc1"]) + model = replace_with_fbgemm_fp8_linear(model, quantization_config=quantization_config) + nb_fbgemm_linear = 0 + for module in model.modules(): + if isinstance(module, FbgemmFp8Linear): + nb_fbgemm_linear += 1 + + self.assertEqual(nb_linears - 25, nb_fbgemm_linear) + + def test_quantized_model(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_save_pretrained(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_change_loading_attributes(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + quantization_config = FbgemmFp8Config(activation_scale_ub=1000.0) + + model = AutoModelForCausalLM.from_pretrained( + tmpdirname, device_map=self.device_map, quantization_config=quantization_config + ) + + self.assertEqual(model.model.layers[1].mlp.down_proj.input_scale_ub.item(), 1000.0) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_quantized_model_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + quantization_config = FbgemmFp8Config() + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, device_map="auto", quantization_config=quantization_config + ) + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_quantized_model_offload(self): + """ + Simple test that checks if the quantized model returns an error when loading with cpu/disk offloaded + """ + quantization_config = FbgemmFp8Config() + + with self.assertRaisesRegex( + ValueError, "You are attempting to load an FP8 model with a device_map that contains a CPU or disk device." + ): + AutoModelForCausalLM.from_pretrained( + self.model_name, device_map=self.offload_device_map, quantization_config=quantization_config + ) + + def test_save_pretrained_offload(self): + """ + Simple test that checks if the saved quantized model is working properly cpu/disk offload + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map) + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_save_pretrained_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto") + self.assertTrue(set(model.hf_device_map.values()) == {0, 1}) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)