From 1bc11868ebbb8bca0b0c6156fb100e1d5c4c7888 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 15 Nov 2023 23:47:08 -0500 Subject: [PATCH] allow overriding of model_config parameters from the YML (#853) * allow overriding of model_config parameters from the YML * remove old logging, update readme * move the updating of model config to the load_model_config function * add warning for deprecated rope_scaling in the root of the YML config --- README.md | 12 +++++--- src/axolotl/utils/config.py | 3 ++ src/axolotl/utils/models.py | 57 ++++++++++++++----------------------- 3 files changed, 32 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index c3480c640..5024d88c9 100644 --- a/README.md +++ b/README.md @@ -489,6 +489,14 @@ is_llama_derived_model: # Please note that if you set this to true, `padding_side` will be set to "left" by default is_mistral_derived_model: +# optional overrides to the base model configuration +model_config: + # RoPE Scaling https://github.com/huggingface/transformers/pull/24653 + rope_scaling: + type: # linear | dynamic + factor: # float + + # Whether you are training a 4-bit GPTQ quantized model gptq: true gptq_groupsize: 128 # group size @@ -756,10 +764,6 @@ landmark_attention: # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # LLaMA only xpos_rope: -# RoPE Scaling https://github.com/huggingface/transformers/pull/24653 -rope_scaling: - type: # linear | dynamic - factor: # float # Resume from a specific checkpoint dir resume_from_checkpoint: diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 81660ae65..d2db92a63 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -369,6 +369,9 @@ def validate_config(cfg): "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." ) + if cfg.rope_scaling: + LOG.warning("`rope_scaling` should now be be a key under `model_config`") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8848e9503..f90d003ac 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -17,7 +17,6 @@ AutoTokenizer, BitsAndBytesConfig, GPTQConfig, - LlamaConfig, PreTrainedModel, PreTrainedTokenizerBase, ) @@ -32,9 +31,14 @@ def load_model_config(cfg): model_config_name = cfg.base_model_config or cfg.base_model trust_remote_code = cfg.trust_remote_code is True - return AutoConfig.from_pretrained( + model_config = AutoConfig.from_pretrained( model_config_name, trust_remote_code=trust_remote_code ) + if cfg.model_config: + for key, val in cfg.model_config.items(): + setattr(model_config, key, val) + + return model_config def load_tokenizer(cfg): @@ -51,7 +55,7 @@ def load_tokenizer(cfg): if cfg.tokenizer_type: tokenizer_cls = getattr(transformers, cfg.tokenizer_type) - tokenizer_config = cfg.tokenizer_config or cfg.base_model_config + tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model tokenizer = tokenizer_cls.from_pretrained( tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, @@ -110,7 +114,6 @@ def load_model( Load a model for a given configuration and tokenizer. """ base_model = cfg.base_model - base_model_config = cfg.base_model_config model_type = cfg.model_type model_config = load_model_config(cfg) @@ -238,16 +241,9 @@ def load_model( if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: from transformers import LlamaForCausalLM - config_kwargs = {} - if cfg.rope_scaling: - config_kwargs["rope_scaling"] = cfg.rope_scaling - config = LlamaConfig.from_pretrained( - base_model_config, - **config_kwargs, - ) model = LlamaForCausalLM.from_pretrained( base_model, - config=config, + config=model_config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, **model_kwargs, @@ -305,66 +301,55 @@ def load_model( if cfg.gptq: model = AutoModelForCausalLM.from_pretrained( base_model, + config=model_config, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) else: model = getattr(transformers, model_type).from_pretrained( base_model, + config=model_config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) else: - config = AutoConfig.from_pretrained( - base_model, - trust_remote_code=cfg.trust_remote_code or False, - ) # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # when training starts if ( - hasattr(config, "max_seq_len") - and config.max_seq_len - and cfg.sequence_len > config.max_seq_len + hasattr(model_config, "max_seq_len") + and model_config.max_seq_len + and cfg.sequence_len > model_config.max_seq_len ): - config.max_seq_len = cfg.sequence_len + model_config.max_seq_len = cfg.sequence_len LOG.warning(f"increasing context length to {cfg.sequence_len}") elif ( - hasattr(config, "max_sequence_length") - and config.max_sequence_length - and cfg.sequence_len > config.max_sequence_length + hasattr(model_config, "max_sequence_length") + and model_config.max_sequence_length + and cfg.sequence_len > model_config.max_sequence_length ): - config.max_sequence_length = cfg.sequence_len + model_config.max_sequence_length = cfg.sequence_len LOG.warning(f"increasing context length to {cfg.sequence_len}") if cfg.gptq: model = AutoModelForCausalLM.from_pretrained( base_model, - config=config, + config=model_config, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) else: model = AutoModelForCausalLM.from_pretrained( base_model, - config=config, + config=model_config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) except Exception as err: # pylint: disable=broad-exception-caught - LOG.error( - "Exception raised attempting to load model, retrying with AutoModelForCausalLM" - ) LOG.exception(err) - model = AutoModelForCausalLM.from_pretrained( - base_model, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, - ) + raise err embeddings_len = ( math.ceil(len(tokenizer) / 32) * 32