Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc fixes to add gptq tests #621

Merged
merged 2 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def deco(func):
def wrapper(*args, **kwargs):
device = kwargs.get("device", args[0] if args else None)

if not torch.cuda.is_available() or device == "auto" or device == "cpu":
if (
not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):
return default_value

return func(*args, **kwargs)
Expand Down
35 changes: 23 additions & 12 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import transformers
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
Expand Down Expand Up @@ -309,16 +310,26 @@ def load_model(
):
config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
Expand Down Expand Up @@ -466,10 +477,10 @@ def load_llama_adapter(model, cfg):


def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.eval_steps
and cfg.save_steps % cfg.eval_steps == 0
)
or False,
Expand Down
58 changes: 56 additions & 2 deletions tests/e2e/test_lora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import tempfile
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
Expand All @@ -24,6 +25,7 @@ class TestLoraLlama(unittest.TestCase):

def test_lora(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
Expand Down Expand Up @@ -51,7 +53,7 @@ def test_lora(self):
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
Expand All @@ -62,9 +64,11 @@ def test_lora(self):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()

def test_lora_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
Expand Down Expand Up @@ -94,7 +98,7 @@ def test_lora_packing(self):
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
Expand All @@ -105,3 +109,53 @@ def test_lora_packing(self):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()

def test_lora_gptq(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
"gptq_disable_exllama": True,
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"save_steps": 0.5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
14 changes: 8 additions & 6 deletions tests/e2e/test_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def test_ft(self):
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sequence_len": 512,
"sample_packing": False,
"load_in_8bit": True,
"load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
Expand All @@ -55,8 +55,9 @@ def test_ft(self):
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"bf16": True,
}
)
normalize_config(cfg)
Expand All @@ -74,9 +75,9 @@ def test_ft_packed(self):
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sequence_len": 512,
"sample_packing": True,
"load_in_8bit": True,
"load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
Expand All @@ -98,8 +99,9 @@ def test_ft_packed(self):
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"bf16": True,
}
)
normalize_config(cfg)
Expand Down