Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the dtype of trainable params #1249

Closed
2 of 4 tasks
hiyouga opened this issue Dec 11, 2023 · 19 comments
Closed
2 of 4 tasks

About the dtype of trainable params #1249

hiyouga opened this issue Dec 11, 2023 · 19 comments

Comments

@hiyouga
Copy link
Contributor

hiyouga commented Dec 11, 2023

System Info

peft 0.7.0
transformers 4.34.0
torch 2.0.1

Who can help?

@pacman100 @younesbelkada @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, TaskType, LoraConfig, get_peft_model

tok = AutoTokenizer.from_pretrained("llama2-7b")
model = AutoModelForCausalLM.from_pretrained("llama2-7b", torch_dtype=torch.float16, low_cpu_mem_usage=True)

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    modules_to_save=["embed_tokens", "lm_head"]
)
peft_model = get_peft_model(model, lora_config)

for name, param in peft_model.named_parameters():
    print(name, param.dtype)
"""
base_model.model.model.layers.31.self_attn.q_proj.base_layer.weight torch.float16
base_model.model.model.layers.31.self_attn.q_proj.lora_A.default.weight torch.float16
base_model.model.model.layers.31.self_attn.q_proj.lora_B.default.weight torch.float16
base_model.model.model.layers.31.self_attn.k_proj.weight torch.float16
base_model.model.model.layers.31.self_attn.v_proj.base_layer.weight torch.float16
base_model.model.model.layers.31.self_attn.v_proj.lora_A.default.weight torch.float16
base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight torch.float16
...
base_model.model.lm_head.original_module.weight torch.float16
base_model.model.lm_head.modules_to_save.default.weight torch.float16
"""

Expected behavior

If we load the model with half-precision and use fp16 mixed precision training, it will throw "ValueError: Attempting to unscale FP16 gradients."

Should we manually cast them in float32?

@BenjaminBossan
Copy link
Member

Could you please provide the training code and the full error?

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 11, 2023

I used a modified version of https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py

Here is the training code: https://gist.github.com/hiyouga/361bc114960672115446050857895dbb

The major differences are the model loading dtype and the lora adapters:

L448 torch_dtype=torch.float16
L462 lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"],
        modules_to_save=["embed_tokens", "lm_head"]
    )
    model = get_peft_model(model, lora_config)

https://www.diffchecker.com/auUsf6ZO/

We ran this script by: CUDA_VISIBLE_DEVICES=0 python run_clm.py --model_name_or_path llama2-7b --low_cpu_mem_usage True --train_file wikipedia.json --block_size 512 --do_train --fp16 --output_dir test

Full error:

[INFO|trainer.py:593] 2023-12-11 19:41:42,620 >> Using auto half precision backend
[INFO|trainer.py:1723] 2023-12-11 19:41:42,780 >> ***** Running training *****
[INFO|trainer.py:1724] 2023-12-11 19:41:42,780 >>   Num examples = 963
[INFO|trainer.py:1725] 2023-12-11 19:41:42,780 >>   Num Epochs = 3
[INFO|trainer.py:1726] 2023-12-11 19:41:42,780 >>   Instantaneous batch size per device = 1
[INFO|trainer.py:1729] 2023-12-11 19:41:42,780 >>   Total train batch size (w. parallel, distributed & accumulation) = 1
[INFO|trainer.py:1730] 2023-12-11 19:41:42,780 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:1731] 2023-12-11 19:41:42,780 >>   Total optimization steps = 2,889
[INFO|trainer.py:1732] 2023-12-11 19:41:42,782 >>   Number of trainable parameters = 1,033,895,936
  0%|                                                                                                        | 0/2889 [00:00<?, ?it/s]Traceback (most recent call last):
  File "xx/training/run_clm.py", line 681, in <module>
    main()
  File "xx/training/run_clm.py", line 629, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "xxlib/python3.10/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "xxlib/python3.10/site-packages/transformers/trainer.py", line 1904, in _inner_training_loop
    self.accelerator.clip_grad_norm_(
  File "xxlib/python3.10/site-packages/accelerate/accelerator.py", line 2120, in clip_grad_norm_
    self.unscale_gradients()
  File "xxlib/python3.10/site-packages/accelerate/accelerator.py", line 2083, in unscale_gradients
    self.scaler.unscale_(opt)
  File "xxlib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 284, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
  File "xxlib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 212, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.

@BenjaminBossan
Copy link
Member

Thanks for providing the example. I could reproduce the error, but it was not related to PEFT. Using the normal script, without PEFT, and only with torch_dtype=torch.float16, causes the same issue. I'm not familiar with Trainer, but I assume this is not the correct way to handle mixed precision training. Please check the relevant documentation.

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 12, 2023

Thanks for replying. This problem is also not related to Trainer. Generally, we should make the trainable params in float32 in order to perform mixed precision training. The default dtype of PEFT adapters remains float16 if the base model was loaded in float16. So we cannot directly use these adapters in fp16 training (but we can use them in bf16 training).

weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)

Coincidently, if we use prepare_model_for_kbit_training to perform QLoRA, the adapter weights can be cast in float32 which solves the problem.

if not is_gptq_quantized:
# cast all non INT8 parameters to fp32
for param in model.parameters():
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)

@BenjaminBossan
Copy link
Member

Yes, you are correct. What I meant is that when using Trainer/accelerate, they should handle the dtypes, so explicitly requiring float16 should not be necessary.

Coincidently, if we use prepare_model_for_kbit_training to perform QLoRA, the adapter weights can be cast in float32 which solves the problem.

Yes, good point, this function is not only useful for QLoRA, but the name might suggest so.

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 12, 2023

I thought the trainer could not handle the model dtype for LoRA training. By default, the model is loaded with 32-bit precision, which consumes a high amount of GPU memory. (e.g. a 7B model requires 28GB GRAM)

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("llama2-7b", low_cpu_mem_usage=True, device_map="cuda")
print(model.dtype)
# torch.float32

However, during PEFT training, we want the model to be loaded directly in 16 bits or lower precision, in this case, it is necessary to explicitly specify the precision type in order to save more GPU memory. (e.g. a 7B model requires 16GB GRAM)

Reproduction

Training code: https://gist.github.com/hiyouga/5b139f3d4d41a6cc49382c9e79e177ea
Diff checker (with the official example): https://www.diffchecker.com/Tp1MqOPq/

Without torch_dtype=float16

CUDA_VISIBLE_DEVICES=0 python run_clm.py --model_name_or_path llama2-7b --low_cpu_mem_usage True --train_file wikipedia.json --block_size 128 --output_dir test --do_train --per_device_train_batch_size 1 --fp16

GRAM used: 33GB

With torch_dtype=float16

CUDA_VISIBLE_DEVICES=0 python run_clm.py --model_name_or_path llama2-7b --low_cpu_mem_usage True --train_file wikipedia.json --block_size 128 --output_dir test --do_train --per_device_train_batch_size 1 --fp16 --torch_dtype float16

GRAM used: 17GB

@KCFindstr
Copy link
Contributor

I was able to do float16 finetuning with peft==0.6.2. Are there any dependency changes that might lead to this issue?

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 13, 2023

I was able to do float16 finetuning with peft==0.6.2. Are there any dependency changes that might lead to this issue?

Yes, there is. In PEFT 0.6.2, they used self.weight to determine the dtype and device for the adapter weight. However, there isn't any weight after torch.nn.Module.__init__(), so this logic was invalid, and the adapter weight was not cast in the same dtype as the model and remained in float32.

https://github.com/huggingface/peft/blob/v0.6.2/src/peft/tuners/lora/layer.py#L63-L89

weight = getattr(self, "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)

Instead, in PEFT 0.7.0, they implemented self.get_base_layer() and used the weight of the base layer to achieve this. Consequently, the adapter weight was correctly cast in the same dtype as the model, resulting in 16-bit issues.

https://github.com/huggingface/peft/blob/v0.7.0/src/peft/tuners/lora/layer.py#L74-L103

weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)

@KCFindstr
Copy link
Contributor

@hiyouga Thanks for the explanation! So the float16 finetuning capability in peft 0.6.2 is actually a bug that got fixed in 0.7.0 🥲

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 13, 2023

A related issue: #1090
I closed it because of the above findings.

@BenjaminBossan
Copy link
Member

Could you try if applying this to the PEFT model would work after loading the model in 16bit:

def cast_lora_to_float(model):
    for name, mod in model.named_modules():
        if ("lora_" in name) and hasattr(mod, "weight"):
            mod.weight.data = mod.weight.data.float()
        if ("lora_" in name) and hasattr(mod, "bias") and (mod.bias is not None):
            mod.bias.data = mod.bias.data.float()

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 14, 2023

@BenjaminBossan It works when modules_to_save is None, but it cannot handle the params in modules_to_save

@BenjaminBossan
Copy link
Member

Yes, I forgot about it:

def cast_lora_to_float(model):
    for name, mod in model.named_modules():
        if ("lora_" in name) and hasattr(mod, "weight"):
            mod.weight.data = mod.weight.data.float()
        if ("lora_" in name) and hasattr(mod, "bias") and (mod.bias is not None):
            mod.bias.data = mod.bias.data.float()
        if ("modules_to_save" in name) and isinstance(mod, nn.Linear):
            mod.weight.data = mod.weight.data.float()
            if mod.bias is not None:
                mod.bias.data = mod.bias.data.float()

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 14, 2023

@BenjaminBossan It gives

base_model.model.model.embed_tokens.original_module.weight torch.float16
base_model.model.model.embed_tokens.modules_to_save.default.weight torch.float16
base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight torch.float16
base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight torch.float32
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight torch.float32
base_model.model.model.layers.0.self_attn.k_proj.weight torch.float16
base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight torch.float16
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight torch.float32
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight torch.float32

Because I use nn.Embedding as modules_to_save
It could also be a norm layer

@BenjaminBossan
Copy link
Member

Ah yes, please add an isinstance check to include the layer types you want to add to modules_to_save. Depending on the type, the attribute names could also be different from weight and bias but the idea should be the same.

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 14, 2023

I prefer to use

for param in filter(lambda p: p.requires_grad, model.parameters()):
    param.data = param.data.to(torch.float32)

@BenjaminBossan
Copy link
Member

That should work too. It is a bit more coarse-grained, so there could be model architectures where this modifies data that it shouldn't, but for most cases it should be fine.

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 20, 2023

A similar discussion here: huggingface/transformers#28142
It seems we couldn't use float16 model loading in PEFT training?

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Jan 9, 2024
Some users ran into the issue of trying to use a model loaded in float16
with mixed precision, e.g. these issues: huggingface#341, huggingface#1249. This PR documents
a workaround to solve the issue.

I also added tests that demonstrate the issue, as well as the
workaround.

Notes

This is not strictly a PEFT issue, but more a general error when using
AMP with float16. Still, since PEFT users encounter this sometimes, it
is useful to document it.

When we discussed this issue in the past, I think we concluded that it's
not as straightforward as PEFT automatically casting the weights to
float32, though I cannot remember anymore what the drawbacks were.

In any case, should we ever add an automatic solution for this in PEFT,
the added test should fail, which alerts us to the fact that we need to
update the documentation.
pacman100 pushed a commit that referenced this issue Jan 10, 2024
Some users ran into the issue of trying to use a model loaded in float16
with mixed precision, e.g. these issues: #341, #1249. This PR documents
a workaround to solve the issue.

I also added tests that demonstrate the issue, as well as the
workaround.

Notes

This is not strictly a PEFT issue, but more a general error when using
AMP with float16. Still, since PEFT users encounter this sometimes, it
is useful to document it.

When we discussed this issue in the past, I think we concluded that it's
not as straightforward as PEFT automatically casting the weights to
float32, though I cannot remember anymore what the drawbacks were.

In any case, should we ever add an automatic solution for this in PEFT,
the added test should fail, which alerts us to the fact that we need to
update the documentation.
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants