Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow overriding of model_config parameters from the YML #853

Merged
merged 4 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,14 @@ is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:

# optional overrides to the base model configuration
model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float


# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
Expand Down Expand Up @@ -726,10 +734,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# LLaMA only
xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float

# Resume from a specific checkpoint dir
resume_from_checkpoint:
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ def validate_config(cfg):
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)

if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
57 changes: 21 additions & 36 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
AutoTokenizer,
BitsAndBytesConfig,
GPTQConfig,
LlamaConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
Expand All @@ -32,9 +31,14 @@
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True
return AutoConfig.from_pretrained(
model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)

return model_config


def load_tokenizer(cfg):
Expand All @@ -51,7 +55,7 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)

tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
Expand Down Expand Up @@ -110,7 +114,6 @@ def load_model(
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
model_config = load_model_config(cfg)

Expand Down Expand Up @@ -238,16 +241,9 @@ def load_model(
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM

config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
Expand Down Expand Up @@ -305,66 +301,55 @@ def load_model(
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=cfg.trust_remote_code or False,
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(config, "max_seq_len")
and config.max_seq_len
and cfg.sequence_len > config.max_seq_len
hasattr(model_config, "max_seq_len")
and model_config.max_seq_len
and cfg.sequence_len > model_config.max_seq_len
):
config.max_seq_len = cfg.sequence_len
model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(config, "max_sequence_length")
and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length
hasattr(model_config, "max_sequence_length")
and model_config.max_sequence_length
and cfg.sequence_len > model_config.max_sequence_length
):
config.max_sequence_length = cfg.sequence_len
model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
)
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
raise err

embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
Expand Down