Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

be more robust about checking embedding modules for lora finetunes #1074

Merged
merged 2 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def normalize_config(cfg):


def validate_config(cfg):
"""
This is a "pre-validation" step that handles the yaml configuration before we have any
information about the model architecture
"""
if is_torch_bf16_gpu_available():
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
Expand Down Expand Up @@ -443,20 +447,6 @@ def validate_config(cfg):
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
raise ValueError("neftune_noise_alpha must be > 0.0")

if (
cfg.adapter
and cfg.tokens
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
)
)
):
raise ValueError(
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
)

if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/utils/lora_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
helpers for lora embeddings
"""


def get_linear_embedding_layers(model_type):
"""
returns the linear embedding layers needed for loras, dependent on the model arch
"""
if model_type == "phi-msft":
return ["embd", "lm_head.linear"]
return ["lm_head", "embed_tokens"]
34 changes: 27 additions & 7 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import math
import os
from typing import Any, Optional, Tuple # noqa: F401
from typing import Any, Optional, Tuple, Union # noqa: F401

import addict
import bitsandbytes as bnb
Expand All @@ -28,12 +28,16 @@
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
from axolotl.utils.lora_embeddings import get_linear_embedding_layers

LOG = logging.getLogger("axolotl")


def check_model_config(cfg: DictDefault, model_config: AutoConfig):
quant_config_exists = hasattr(model_config, "quantization_config")
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
Expand All @@ -52,6 +56,20 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)

lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
if (
cfg.adapter
and cfg.tokens
and (
not cfg.lora_modules_to_save
or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)
)
):
lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save))
winglian marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"`lora_modules_to_save` not properly set when adding new tokens. Please include {lora_modules_to_save} in `lora_modules_to_save`."
)


def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
Expand Down Expand Up @@ -139,6 +157,7 @@ def load_tokenizer(cfg):
setattr(tokenizer, attr_name, "<|endoftext|>")

if cfg.special_tokens:
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in cfg.special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
Expand All @@ -149,14 +168,15 @@ def load_tokenizer(cfg):
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save
for x in ["embed_tokens", "lm_head"]
x in cfg.lora_modules_to_save for x in lora_modules_to_save
)
)
and (model_config.model_type in ["llama", "mistral", "mixtral"])
):
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
f"Please set lora_modules_to_save to {lora_modules_to_save} when using an adapter and changing the special tokens."
)

tokenizer.add_special_tokens(
Expand Down
70 changes: 61 additions & 9 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars


class ValidationTest(unittest.TestCase):
class BaseValidation(unittest.TestCase):
"""
Test the validation module
Base validation module to setup the log capture
"""

_caplog: Optional[pytest.LogCaptureFixture] = None
Expand All @@ -24,6 +25,12 @@ class ValidationTest(unittest.TestCase):
def inject_fixtures(self, caplog):
self._caplog = caplog


class ValidationTest(BaseValidation):
"""
Test the validation module
"""

def test_load_4bit_deprecate(self):
cfg = DictDefault(
{
Expand Down Expand Up @@ -687,16 +694,23 @@ def test_warmup_step_no_conflict(self):

validate_config(cfg)

def test_add_tokens_adapter(self):

class ValidationCheckModelConfig(BaseValidation):
"""
Test the validation for the config when the model config is available
"""

def test_llama_add_tokens_adapter(self):
cfg = DictDefault(
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
)
model_config = DictDefault({"model_type": "llama"})

with pytest.raises(
ValueError,
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
validate_config(cfg)
check_model_config(cfg, model_config)

cfg = DictDefault(
{
Expand All @@ -709,9 +723,9 @@ def test_add_tokens_adapter(self):

with pytest.raises(
ValueError,
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
validate_config(cfg)
check_model_config(cfg, model_config)

cfg = DictDefault(
{
Expand All @@ -722,10 +736,48 @@ def test_add_tokens_adapter(self):
}
)

validate_config(cfg)
check_model_config(cfg, model_config)

def test_phi2_add_tokens_adapter(self):
cfg = DictDefault(
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
)
model_config = DictDefault({"model_type": "phi-msft"})
NanoCode012 marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(
ValueError,
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
check_model_config(cfg, model_config)

cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embed_tokens", "lm_head"],
}
)

with pytest.raises(
ValueError,
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
check_model_config(cfg, model_config)

cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embd", "lm_head.linear"],
}
)

check_model_config(cfg, model_config)


class ValidationWandbTest(ValidationTest):
class ValidationWandbTest(BaseValidation):
"""
Validation test for wandb
"""
Expand Down