Skip to content

Commit

Permalink
misc fixes to add gptq tests (axolotl-ai-cloud#621)
Browse files Browse the repository at this point in the history
* misc fixes to add gptq tests

* set bf16 needed for fa2
  • Loading branch information
winglian committed Sep 22, 2023
1 parent 038ee66 commit 89c9297
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 21 deletions.
6 changes: 5 additions & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def deco(func):
def wrapper(*args, **kwargs):
device = kwargs.get("device", args[0] if args else None)

if not torch.cuda.is_available() or device == "auto" or device == "cpu":
if (
not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):
return default_value

return func(*args, **kwargs)
Expand Down
35 changes: 23 additions & 12 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import transformers
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
Expand Down Expand Up @@ -309,16 +310,26 @@ def load_model(
):
config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
Expand Down Expand Up @@ -466,10 +477,10 @@ def load_llama_adapter(model, cfg):


def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.eval_steps
and cfg.save_steps % cfg.eval_steps == 0
)
or False,
Expand Down
58 changes: 56 additions & 2 deletions tests/e2e/test_lora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import tempfile
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
Expand All @@ -24,6 +25,7 @@ class TestLoraLlama(unittest.TestCase):

def test_lora(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
Expand Down Expand Up @@ -51,7 +53,7 @@ def test_lora(self):
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
Expand All @@ -62,9 +64,11 @@ def test_lora(self):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()

def test_lora_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
Expand Down Expand Up @@ -94,7 +98,7 @@ def test_lora_packing(self):
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
Expand All @@ -105,3 +109,53 @@ def test_lora_packing(self):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()

def test_lora_gptq(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
"gptq_disable_exllama": True,
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"save_steps": 0.5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
14 changes: 8 additions & 6 deletions tests/e2e/test_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def test_ft(self):
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sequence_len": 512,
"sample_packing": False,
"load_in_8bit": True,
"load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
Expand All @@ -55,8 +55,9 @@ def test_ft(self):
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"bf16": True,
}
)
normalize_config(cfg)
Expand All @@ -74,9 +75,9 @@ def test_ft_packed(self):
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sequence_len": 512,
"sample_packing": True,
"load_in_8bit": True,
"load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
Expand All @@ -98,8 +99,9 @@ def test_ft_packed(self):
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"bf16": True,
}
)
normalize_config(cfg)
Expand Down

0 comments on commit 89c9297

Please sign in to comment.