Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tests for merging lora and validating the dtype #1512

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions src/axolotl/cli/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault


def do_cli(config: Path = Path("examples/"), **kwargs):
Expand All @@ -27,21 +28,26 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
flash_attention=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the above section already sets these properties, is it necessary to set it again below?

**kwargs,
)
cfg = modify_cfg_for_merge(parsed_cfg)

if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
do_merge_lora(cfg=cfg, cli_args=parsed_cli_args)


def modify_cfg_for_merge(cfg: DictDefault) -> DictDefault:
if not cfg.lora_model_dir and cfg.output_dir:
cfg.lora_model_dir = cfg.output_dir
if not Path(cfg.lora_model_dir).exists():
raise ValueError(
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
f"Target directory for merge: `{cfg.lora_model_dir}` does not exist."
)

parsed_cfg.load_in_4bit = False
parsed_cfg.load_in_8bit = False
parsed_cfg.flash_attention = False
parsed_cfg.deepspeed = None
parsed_cfg.fsdp = None
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None

do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
return cfg


if __name__ == "__main__":
Expand Down
77 changes: 70 additions & 7 deletions tests/e2e/test_lora_llama.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""
E2E tests for lora llama
"""

import json
import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.cli import do_merge_lora, load_datasets
from axolotl.cli.merge_lora import modify_cfg_for_merge
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
Expand Down Expand Up @@ -39,11 +42,6 @@ def test_lora(self, temp_dir):
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
Expand All @@ -57,6 +55,7 @@ def test_lora(self, temp_dir):
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
}
)
normalize_config(cfg)
Expand All @@ -65,3 +64,67 @@ def test_lora(self, temp_dir):

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

@with_temp_dir
def test_lora_merge(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, sometimes, this issue can occur for different model types. For ex, previous llama merge was fine, but mistral was not. Do we need to test this for other arch?

"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to train a model, maybe a tiny adapter can be uploaded to HF which we use for merge?

assert (Path(temp_dir) / "adapter_model.bin").exists()

cfg.lora_model_dir = cfg.output_dir
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None
Comment on lines +107 to +112
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be excluded as the modify_cfg_for_merge should've set it?

Suggested change
cfg.lora_model_dir = cfg.output_dir
cfg.load_in_4bit = False
cfg.load_in_8bit = False
cfg.flash_attention = False
cfg.deepspeed = None
cfg.fsdp = None


cfg = modify_cfg_for_merge(cfg)
cfg.merge_lora = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this setting inside the modify_cfg function as well.


cli_args = TrainerCliArgs(merge_lora=True)

do_merge_lora(cfg=cfg, cli_args=cli_args)
assert (Path(temp_dir) / "merged/pytorch_model.bin").exists()

with open(
Path(temp_dir) / "merged/config.json", "r", encoding="utf-8"
) as f_handle:
config = f_handle.read()
config = json.loads(config)
if is_torch_bf16_gpu_available():
assert config["torch_dtype"] == "bfloat16"
else:
assert config["torch_dtype"] == "float16"
Loading