-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LoRA support to HQQ Quantization #1618
Conversation
Thanks a lot for adding this PR. Would you be so kind to add some tests? It should be enough to copy+paste some existing tests and slightly adjust them for HQQ, see e.g. this one. |
@BenjaminBossan is this okay? Sorry if my code is kinda a mess. |
Looks great, thanks, not messy at all -- or at least, not messier than the rest of the PEFT code ;-) I'll give this a test tomorrow and will be back with a proper review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @fahadh4ilyas for this contribution. It already looks quite good and feature complete and from my testing, it also seems to generally work. There are still a few corners that need improvement. Please check my comments. On top of that, I have some more general comments:
Code style
Could you please always run make style
on your changes before committing?
DoRA
Thanks for also taking DoRA into consideration. However, currently, it fails because DoRA searches for self.get_base_layer().weight
. We can add a check there to search for W_q
instead. Below is the diff that worked for me:
modified src/peft/tuners/lora/layer.py
@@ -183,9 +183,11 @@ class LoraLayer(BaseTunerLayer):
lora_A = self.lora_A[adapter_name]
lora_B = self.lora_B[adapter_name]
scaling = self.scaling[adapter_name]
- with gather_params_ctx(self.get_base_layer()):
- weight = self.get_base_layer().weight
- quant_state = getattr(self.get_base_layer(), "state", None)
+ base_layer = self.get_base_layer()
+ with gather_params_ctx(base_layer):
+ # W_q is for HQQ
+ weight = base_layer.W_q if hasattr(base_layer, "W_q") else base_layer.weight
+ quant_state = getattr(base_layer, "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds.
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1))
@@ -212,7 +214,9 @@ class LoraLayer(BaseTunerLayer):
"""
lora_weight = lora_B.weight @ lora_A.weight
magnitude = self.lora_magnitude_vector[active_adapter]
- weight = self.get_base_layer().weight
+ base_layer = self.get_base_layer()
+ # W_q is for HQQ
+ weight = base_layer.W_q if hasattr(base_layer, "W_q") else base_layer.weight
quant_state = getattr(self.get_base_layer(), "state", None)
weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb
weight = weight.to(x.dtype)
Additionally, right now we only dequantize bnb weights for DoRA, but this needs to be extended to check for HQQ weights, check the function dequantize_bnb_weight
in utils/integrations.py
. Of course, if we add HQQ there, the function should also be renamed to a more generic name. This new function could then also be used inside the unmerge
method.
Calculating the number of params
Do we need to update get_nb_trainable_parameters
? I don't know how HQQ packs the weights, but I get very different results when I print this for a normal model vs an HQQ model. For instance, using the unit test you added:
>>> # LoRA normal model
>>> model.print_trainable_parameters()
trainable params: 2,252,800 || all params: 1,102,301,184 || trainable%: 0.20437245579516677
>>> # HQQ LoRA model
>>> model.print_trainable_parameters()
trainable params: 2,252,800 || all params: 617,859,072 || trainable%: 0.36461389046335796
Docs
Let's add an entry to the docs here.
Docker
To ensure that HQQ is run in our GPU tests, we need to install it in the corresponding Dockerfile. Could you please add it here? For me, for some reason, my PyTorch version was downgraded when I installed HQQ, but maybe that's just an issue with my personal environment. Did you observe the same thing?
Additional testing
I wrote an additional test to check for merging/unmerging/unloading/saving+loading. Could you please add it?
def test_hqq_lora_model_outputs(self):
# check that the outputs generated by HQQ with LoRA are similar to those without HQQ
from hqq.engine.hf import HQQModelForCausalLM
from hqq.core.quantize import BaseQuantizeConfig
device = 'cuda'
compute_dtype = torch.float16
# first load the model without HQQ
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map=device,
torch_dtype=compute_dtype,
)
config = LoraConfig(
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
init_lora_weights=False,
)
torch.manual_seed(0)
model = get_peft_model(model, config).eval()
inputs = self.tokenizer("The meaning of unit tests is", return_tensors="pt").to(model.device)
with torch.inference_mode():
output_normal = model(**inputs).logits
assert torch.isfinite(output_normal).all()
del model
gc.collect()
torch.cuda.empty_cache()
# now load with HQQ
model = HQQModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map=device,
torch_dtype=compute_dtype,
)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64)
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device)
torch.manual_seed(0)
model = get_peft_model(model, config).eval()
with torch.inference_mode():
output_hqq = model(**inputs).logits
# check that outputs of HQQ are highly correlated; there are outliers, so don't check for equality
cc_matrix = torch.corrcoef(torch.stack((output_normal.flatten(), output_hqq.flatten())))
assert cc_matrix.min() > 0.97
# check that outputs are the same after merging
cc_matrix = torch.corrcoef(torch.stack((output_normal.flatten(), output_hqq.flatten())))
assert cc_matrix.min() > 0.97
# check outputs are the same after unmerging
model.unmerge_adapter()
with torch.inference_mode():
output_unmerged = model(**inputs).logits
cc_matrix = torch.corrcoef(torch.stack((output_normal.flatten(), output_unmerged.flatten())))
assert cc_matrix.min() > 0.97
# check that the results are the same after saving and loading
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
del model
gc.collect()
torch.cuda.empty_cache()
model = HQQModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map=device,
torch_dtype=compute_dtype,
)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64)
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device)
model = PeftModel.from_pretrained(model, tmp_dir)
with torch.inference_mode():
output_loaded = model(**inputs).logits
# for loading, we expect high precision, so check for equality and not just correlation
atol, rtol = 1e-6, 1e-6
assert torch.allclose(output_hqq, output_loaded, atol=atol, rtol=rtol)
# check that outputs are the same after merge_and_unload
model = model.merge_and_unload()
with torch.inference_mode():
output_merged_unloaded = model(**inputs).logits
cc_matrix = torch.corrcoef(torch.stack((output_normal.flatten(), output_merged_unloaded.flatten())))
assert cc_matrix.min() > 0.97
Quantization error
While working on the additional tests, I plotted the LoRA outputs of the normal model vs HQQ model used in the test and they seem to be highly correlated, but there are a big deviations for individual data points
>>> torch.testing.assert_close(output_normal, output_hqq)
*** AssertionError: Tensor-likes are not close!
Mismatched elements: 223632 / 224000 (99.8%)
Greatest absolute difference: 3.56396484375 at index (0, 3, 1141) (up to 1e-05 allowed)
Greatest relative difference: 41549.37890625 at index (0, 0, 15916) (up to 1.3e-06 allowed)
Probably this is fine, just wanting to check if this is in line with expectations.
@BenjaminBossan I'm done fixing it. Could you please test it again? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the additional changes. This looks pretty good already, there are only a few small things, please check my comments.
Furthermore, could you please ensure that make style
was run on your code? Also make sure that the ruff version is 0.2.2.
Finally, it would be great to have an entry in the docs for this new quantization method.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@fahadh4ilyas Could you please run |
@BenjaminBossan it's done. I think I have to get used to run |
You may want to set up a pre-commit hook as explained here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding HQQ. From my point of view, this looks good to be merged. Waiting for @younesbelkada's review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work ! looks very clean overall ! I just have minor comments !
1- can you update the tests so that it'll leverage AutoModelForCausalLM
once huggingface/transformers#29637 is merged ?
2- Can you also add few lines in the documentation mentioning HQQ
3- Could you add hqq
in our docker image we use for testing: https://github.com/huggingface/peft/blob/main/docker/peft-gpu/Dockerfile
Thanks a lot !
Any progress so far, amazing work for the hqq support we want to have a try this feature. |
With huggingface/transformers#29637 being merged, let's not forget about this PR. There are some merge conflicts due to the EETQ PR, but they should be easy to fix. |
@BenjaminBossan and @younesbelkada please check the current update. Is it okay if the test checking the compatibility of transformers with HQQ first? |
Thanks so much @fahadh4ilyas !
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work thank you ! I just left one comment about tests,
I also made: huggingface/transformers#30632 that we can merge after this PR gets merged
Awesome work @fahadh4ilyas 🙏! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good thankks ! One last comment left, can you make sure the tests you designed pass with transformers built on main?
Sorry for my stupid mistake. Here is the revision. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, thanks !
No description provided.