Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA support to HQQ Quantization #1618

Merged
merged 39 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e48b8ac
Add HQQ Lora
fahadh4ilyas Apr 3, 2024
b412495
fix error weight load
fahadh4ilyas Apr 3, 2024
f6595b4
Remove unused
fahadh4ilyas Apr 3, 2024
12f202a
Add quantized lora
fahadh4ilyas Apr 3, 2024
7abdd1b
fix make HQQLinear
fahadh4ilyas Apr 3, 2024
021114a
Fix dtype
fahadh4ilyas Apr 3, 2024
38e4a3a
Revert back quantize lora
fahadh4ilyas Apr 3, 2024
65435c6
Add prepare training for hqq quantization
fahadh4ilyas Apr 3, 2024
fb43017
Forget revert hqq
fahadh4ilyas Apr 3, 2024
fe2f25c
Remove warnings
fahadh4ilyas Apr 3, 2024
91d8d99
Other ways to check hqq quantization
fahadh4ilyas Apr 3, 2024
5eb8cc1
Add unit test for training
fahadh4ilyas Apr 3, 2024
59aa632
change bfloat16 to float16
fahadh4ilyas Apr 3, 2024
bc706a8
Fix load weight when applied dora
fahadh4ilyas Apr 5, 2024
a7ee47f
Move import hqq inside if clause
fahadh4ilyas Apr 5, 2024
e1e7675
Naming using CamelCase
fahadh4ilyas Apr 5, 2024
db11108
Remove unused function and fix naming convention
fahadh4ilyas Apr 5, 2024
19fc4c5
Pop offload_meta
fahadh4ilyas Apr 5, 2024
46511e0
Add use_dora params
fahadh4ilyas Apr 5, 2024
aec43b7
Remove confusing comments
fahadh4ilyas Apr 5, 2024
872d8cf
Additional test for checking output from HQQ
fahadh4ilyas Apr 5, 2024
97abfff
Add license notice
fahadh4ilyas Apr 5, 2024
0b757ed
Add parameter decorator
fahadh4ilyas Apr 5, 2024
11ca9a6
Redundant calling get_base_layer
fahadh4ilyas Apr 5, 2024
00be028
do make style
fahadh4ilyas Apr 5, 2024
1496650
Remove unused comments
fahadh4ilyas Apr 5, 2024
e533c73
Move dispatch_hqq out of if clause
fahadh4ilyas Apr 5, 2024
fabe859
make style all scripts
fahadh4ilyas Apr 5, 2024
caee0eb
Add comment for explanation
fahadh4ilyas Apr 23, 2024
a0328ff
Mention HQQ to docs
fahadh4ilyas Apr 23, 2024
0d41431
Merge branch 'main' into hqq-lora
fahadh4ilyas Apr 23, 2024
55d7f66
Add HQQ to Dockerfile
fahadh4ilyas Apr 23, 2024
0e0d3ac
Fix styling
fahadh4ilyas Apr 23, 2024
18aaa42
Merge branch 'main' into hqq-lora
fahadh4ilyas May 3, 2024
26f2cd6
Styling scripts
fahadh4ilyas May 3, 2024
4a73a3c
Comply with transformers HQQ integration
fahadh4ilyas May 3, 2024
5db3577
Test fully using transformers
fahadh4ilyas May 3, 2024
a818023
Add comments handling HQQ
fahadh4ilyas May 3, 2024
53f2701
Fix naming problem
fahadh4ilyas May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docker/peft-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ RUN source activate peft && \
RUN source activate peft && \
pip install aqlm[gpu]>=1.0.2

# Add HQQ for quantization testing
RUN source activate peft && \
pip install hqq

RUN source activate peft && \
pip freeze | grep transformers

Expand Down
28 changes: 28 additions & 0 deletions docs/source/developer_guides/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,34 @@ config = LoraConfig(
model = get_peft_model(model, config)
```

## HQQ quantization

The models that is quantized using Half-Quadratic Quantization of Large Machine Learning Models ([HQQ](https://mobiusml.github.io/hqq_blog/)) support LoRA adapter tuning. To tune the quantized model, you'll need to install the `hqq` library with: `pip install hqq`.

```py
from hqq.engine.hf import HQQModelForCausalLM

quantized_model = HQQModelForCausalLM.from_quantized(save_dir_or_hfhub, device='cuda')

peft_config = LoraConfig(...)

quantized_model = get_peft_model(quantized_model, peft_config)
```

Or using transformers version that is compatible with HQQ (e.g. by installing it from latest pypi or from source).

```python
from transformers import HqqConfig, AutoModelForCausalLM

quant_config = HqqConfig(nbits=4, group_size=64)

quantized_model = AutoModelForCausalLM.from_pretrained(save_dir_or_hfhub, device='cuda', quantization_config=quant_config)

peft_config = LoraConfig(...)

quantized_model = get_peft_model(quantized_model, peft_config)
```

## Next steps

If you're interested in learning more about quantization, the following may be helpful:
Expand Down
5 changes: 5 additions & 0 deletions src/peft/import_utils.py
fahadh4ilyas marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,8 @@ def is_auto_awq_available():
@lru_cache
def is_eetq_available():
return importlib.util.find_spec("eetq") is not None


@lru_cache
def is_hqq_available():
return importlib.util.find_spec("hqq") is not None
247 changes: 247 additions & 0 deletions src/peft/tuners/lora/hqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
fahadh4ilyas marked this conversation as resolved.
Show resolved Hide resolved

import copy
import warnings
from typing import Any, Optional

import torch

from peft.import_utils import is_hqq_available
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.other import transpose

from .layer import LoraLayer


if is_hqq_available():
from hqq.core.quantize import HQQLinear

class HqqLoraLinear(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)
self.fan_in_fan_out = False

self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights

Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return

for active_adapter in adapter_names:
if active_adapter not in self.lora_A.keys():
continue

layer = self.get_base_layer()
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
lora_data = self.get_delta_weight(active_adapter)

output = layer.dequantize()
if not self.use_dora[active_adapter]:
w_data = output + lora_data
else:
# handle dora
# since output already includes scaling, set it to 1 here
weight_norm = self._get_weight_norm(output, lora_data, scaling=1).detach()
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
w_data = dora_factor.view(-1, 1) * (output + lora_data)

if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
quant_config.pop("offload_meta", None)
new_hqq_layer.quantize(w_data, **quant_config)
self.base_layer = new_hqq_layer
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.lora_A.keys():
continue

lora_data = self.get_delta_weight(active_adapter)
layer = self.get_base_layer()
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
output = layer.dequantize()

if not self.use_dora[active_adapter]:
w_data = output - lora_data
else:
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm")
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
w_data = output.data / dora_factor.view(-1, 1) - lora_data

new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
quant_config.pop("offload_meta", None)
new_hqq_layer.quantize(w_data, **quant_config)
self.base_layer = new_hqq_layer

def get_delta_weight(self, adapter):
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
False,
)
* self.scaling[adapter]
)

def _mixed_batch_forward(
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)

unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])

for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
continue
if active_adapter not in self.lora_A.keys():
continue

lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)

# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
sub_batch = x[sub_batch_indices_list[i]]
output = lora_B(lora_A(dropout(sub_batch))) * scaling
if requires_conversion:
output = output.to(expected_dtype)
result[sub_batch_indices_list[i]] += output

return result

def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)

if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
else:
output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
if requires_conversion:
output = output.to(expected_dtype)

result = result + output

return result

def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep


def dispatch_hqq(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None

if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

if is_hqq_available() and isinstance(target_base_layer, HQQLinear):
new_module = HqqLoraLinear(target_base_layer, adapter_name, **kwargs)

return new_module
23 changes: 17 additions & 6 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
elif base_layer.__class__.__name__ == "EetqLinear":
# Eetq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear":
# HQQ layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")

Expand Down Expand Up @@ -191,9 +194,13 @@ def dora_init(self, adapter_name: str) -> None:

scaling = self.scaling[adapter_name]
with gather_params_ctx(self.get_base_layer().parameters()):
weight = self.get_base_layer().weight
quant_state = getattr(self.get_base_layer(), "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
base_layer = self.get_base_layer()
if hasattr(base_layer, "W_q"): # For handling HQQ quantized weight
weight = base_layer.dequantize()
else:
weight = base_layer.weight
quant_state = getattr(base_layer, "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds.
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1))
lora_weight = lora_weight.reshape(weight.shape)
Expand Down Expand Up @@ -223,9 +230,13 @@ def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter):
"""
lora_weight = lora_B.weight @ lora_A.weight
magnitude = self.lora_magnitude_vector[active_adapter]
weight = self.get_base_layer().weight
quant_state = getattr(self.get_base_layer(), "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
base_layer = self.get_base_layer()
if hasattr(base_layer, "W_q"):
fahadh4ilyas marked this conversation as resolved.
Show resolved Hide resolved
weight = base_layer.dequantize()
else:
weight = base_layer.weight
quant_state = getattr(base_layer, "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
weight = weight.to(x.dtype)
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
Expand Down
19 changes: 17 additions & 2 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .config import LoraConfig
from .eetq import dispatch_eetq
from .gptq import dispatch_gptq
from .hqq import dispatch_hqq
from .layer import Conv2d, LoraLayer, dispatch_default
from .tp_layer import dispatch_megatron

Expand Down Expand Up @@ -248,7 +249,13 @@ def _replace_module(self, parent, child_name, new_module, child):
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ("ranknum" in name):
weight = child.qweight if hasattr(child, "qweight") else child.weight
weight = (
child.qweight
if hasattr(child, "qweight")
else child.W_q
if hasattr(child, "W_q")
else child.weight
)
module.to(weight.device)

def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
Expand Down Expand Up @@ -290,7 +297,15 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
dispatchers.append(dispatch_bnb_4bit)

dispatchers.extend(
[dispatch_eetq, dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default]
[
dispatch_eetq,
dispatch_aqlm,
dispatch_awq,
dispatch_gptq,
dispatch_hqq,
dispatch_megatron,
dispatch_default,
]
)

new_module = None
Expand Down
Loading
Loading