Skip to content

Commit

Permalink
be more robust about checking embedding modules for lora finetunes (#…
Browse files Browse the repository at this point in the history
…1074) [skip ci]

* be more robust about checking embedding modules for lora finetunes

* update dynamic error message
  • Loading branch information
winglian committed Jan 10, 2024
1 parent ead34c5 commit 0f10080
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 30 deletions.
18 changes: 4 additions & 14 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def normalize_config(cfg):


def validate_config(cfg):
"""
This is a "pre-validation" step that handles the yaml configuration before we have any
information about the model architecture
"""
if is_torch_bf16_gpu_available():
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
Expand Down Expand Up @@ -443,20 +447,6 @@ def validate_config(cfg):
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
raise ValueError("neftune_noise_alpha must be > 0.0")

if (
cfg.adapter
and cfg.tokens
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
)
)
):
raise ValueError(
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
)

if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/utils/lora_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
helpers for lora embeddings
"""


def get_linear_embedding_layers(model_type):
"""
returns the linear embedding layers needed for loras, dependent on the model arch
"""
if model_type == "phi-msft":
return ["embd", "lm_head.linear"]
return ["lm_head", "embed_tokens"]
34 changes: 27 additions & 7 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import math
import os
from typing import Any, Optional, Tuple # noqa: F401
from typing import Any, Optional, Tuple, Union # noqa: F401

import addict
import bitsandbytes as bnb
Expand All @@ -28,12 +28,16 @@
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
from axolotl.utils.lora_embeddings import get_linear_embedding_layers

LOG = logging.getLogger("axolotl")


def check_model_config(cfg: DictDefault, model_config: AutoConfig):
quant_config_exists = hasattr(model_config, "quantization_config")
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
Expand All @@ -52,6 +56,20 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)

lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
if (
cfg.adapter
and cfg.tokens
and (
not cfg.lora_modules_to_save
or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)
)
):
lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save))
raise ValueError(
f"`lora_modules_to_save` not properly set when adding new tokens. Please include {lora_modules_to_save} in `lora_modules_to_save`."
)


def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
Expand Down Expand Up @@ -139,6 +157,7 @@ def load_tokenizer(cfg):
setattr(tokenizer, attr_name, "<|endoftext|>")

if cfg.special_tokens:
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in cfg.special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
Expand All @@ -149,14 +168,15 @@ def load_tokenizer(cfg):
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save
for x in ["embed_tokens", "lm_head"]
x in cfg.lora_modules_to_save for x in lora_modules_to_save
)
)
and (model_config.model_type in ["llama", "mistral", "mixtral"])
):
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
f"Please set lora_modules_to_save to {lora_modules_to_save} when using an adapter and changing the special tokens."
)

tokenizer.add_special_tokens(
Expand Down
70 changes: 61 additions & 9 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars


class ValidationTest(unittest.TestCase):
class BaseValidation(unittest.TestCase):
"""
Test the validation module
Base validation module to setup the log capture
"""

_caplog: Optional[pytest.LogCaptureFixture] = None
Expand All @@ -24,6 +25,12 @@ class ValidationTest(unittest.TestCase):
def inject_fixtures(self, caplog):
self._caplog = caplog


class ValidationTest(BaseValidation):
"""
Test the validation module
"""

def test_load_4bit_deprecate(self):
cfg = DictDefault(
{
Expand Down Expand Up @@ -687,16 +694,23 @@ def test_warmup_step_no_conflict(self):

validate_config(cfg)

def test_add_tokens_adapter(self):

class ValidationCheckModelConfig(BaseValidation):
"""
Test the validation for the config when the model config is available
"""

def test_llama_add_tokens_adapter(self):
cfg = DictDefault(
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
)
model_config = DictDefault({"model_type": "llama"})

with pytest.raises(
ValueError,
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
validate_config(cfg)
check_model_config(cfg, model_config)

cfg = DictDefault(
{
Expand All @@ -709,9 +723,9 @@ def test_add_tokens_adapter(self):

with pytest.raises(
ValueError,
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
validate_config(cfg)
check_model_config(cfg, model_config)

cfg = DictDefault(
{
Expand All @@ -722,10 +736,48 @@ def test_add_tokens_adapter(self):
}
)

validate_config(cfg)
check_model_config(cfg, model_config)

def test_phi2_add_tokens_adapter(self):
cfg = DictDefault(
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
)
model_config = DictDefault({"model_type": "phi-msft"})

with pytest.raises(
ValueError,
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
check_model_config(cfg, model_config)

cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embed_tokens", "lm_head"],
}
)

with pytest.raises(
ValueError,
match=r".*`lora_modules_to_save` not properly set when adding new tokens*",
):
check_model_config(cfg, model_config)

cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embd", "lm_head.linear"],
}
)

check_model_config(cfg, model_config)


class ValidationWandbTest(ValidationTest):
class ValidationWandbTest(BaseValidation):
"""
Validation test for wandb
"""
Expand Down

0 comments on commit 0f10080

Please sign in to comment.