Skip to content

Commit

Permalink
Merge pull request #802 from OptimalScale/yizhenjia-qwen-support
Browse files Browse the repository at this point in the history
Merge LoRA and base model
  • Loading branch information
research4pan authored Apr 30, 2024
2 parents 2c317d2 + b28fc04 commit ffc527e
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 3 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ cd data && ./download.sh alpaca && cd -
> --output_model_path output_models/finetuned_llama2_7b_lora \
>```
> </details>
>
> <details><summary>Merge LoRA Weight</summary>
>
>Merge LoRA weight and the base model into one using:
>```sh
>./scripts/run_merge_lora.sh \
> --model_name_or_path Qwen/Qwen1.5-1.8B \
> --lora_model_path output_models/lora \
> --output_model_path output_models/lora_merged \
>```
></details>
### Inference
After finetuning, you can run the following command to chat with the model.
Expand Down
28 changes: 27 additions & 1 deletion examples/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,30 @@

@dataclass
class MergeLoraArguments:
device: str = field(
default='cpu',
metadata={
"help": "device to merge model on",
},
)
ds_config: str = field(
default='configs/ds_config_eval.json',
metadata={
"help": "deepspeed config file path",
},
)
output_model_path: Optional[str] = field(
default=None,
metadata={
"help": "output merged full model path"
},
)
local_rank: Optional[int] = field(
default=-1,
metadata={
"help": "local rank for deepspeed",
},
)


def main():
Expand All @@ -37,9 +55,17 @@ def main():
model_args, merge_lora_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, merge_lora_args = parser.parse_args_into_dataclasses()

if merge_lora_args.device == 'gpu':
raise NotImplementedError('Merging LoRA weight using GPU not supported yet. Please use cpu.')

model_args.use_lora = True
model = AutoModel.get_model(model_args, tune_strategy='none')
model = AutoModel.get_model(
model_args,
tune_strategy='none',
device=merge_lora_args.device,
ds_config=merge_lora_args.ds_config
)
model.merge_lora_weights()
model.save(merge_lora_args.output_model_path, save_full_model=True)

Expand Down
55 changes: 55 additions & 0 deletions scripts/run_merge_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/bin/bash

# Parses arguments
model_name_or_path=gpt2
lora_model_path=output_models/lora
output_model_path=output_models/merge_lora
device=cpu

# if gpu
deepspeed_args="--master_port=11000"

while [[ $# -ge 1 ]]; do
key="$1"
case ${key} in
--model_name_or_path)
model_name_or_path="$2"
shift
;;
--lora_model_path)
lora_model_path="$2"
shift
;;
--output_model_path)
output_model_path="$2"
shift
;;
--device)
device="$2"
shift
;;
--deepspeed_args)
deepspeed_args="$2"
shift
;;
*)
echo "error: unknown option \"${key}\"" 1>&2
exit 1
esac
shift
done


if [ ${device} == "cpu" ]; then
python examples/merge_lora.py \
--model_name_or_path ${model_name_or_path} \
--lora_model_path ${lora_model_path} \
--output_model_path ${output_model_path} \
--device ${device} \
--ds_config configs/ds_config_eval.json
elif [ ${device} == "gpu" ]; then
echo "Error: Merging LoRA weights using gpu not supported yet. Please use cpu."
else
echo "Error: Unknown device \"${device}\"" 1>&2
exit 1
fi
11 changes: 9 additions & 2 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
logger.debug(f"torch_dtype on init: {torch_dtype}")

config_kwargs = {
"cache_dir": model_args.cache_dir,
Expand Down Expand Up @@ -327,7 +328,7 @@ def __init__(
if peft_model_id is not None:
self.backend_model = PeftModel.from_pretrained(
self.backend_model,
peft_model_id,
peft_model_id,
)
self.tokenizer.padding_side = "left"
else:
Expand Down Expand Up @@ -849,7 +850,13 @@ def save(self, dir, save_full_model=False, *args, **kwargs):
"""
self.get_tokenizer().save_pretrained(dir)
if save_full_model and self.model_args.use_lora:
self.backend_model_full.save_pretrained(dir)
save_dtype = (
torch.float16
if self.model_args.torch_dtype in ["auto", None]
else getattr(torch, self.model_args.torch_dtype)
)
self.backend_model_full.to(dtype=save_dtype).save_pretrained(dir)
logger.warning(f"Save full model with dtype: {save_dtype}")
else:
self.get_backend_model().save_pretrained(dir)

Expand Down

0 comments on commit ffc527e

Please sign in to comment.