diff --git a/deepspeed/zero3_bf16.json b/deepspeed/zero3_bf16.json new file mode 100644 index 000000000..42d10b6bd --- /dev/null +++ b/deepspeed/zero3_bf16.json @@ -0,0 +1,39 @@ +{ + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 0, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 0, + "stage3_max_reuse_distance": 0, + "stage3_gather_16bit_weights_on_model_save": true + }, + "bf16": { + "enabled": true + }, + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 32, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "gradient_accumulation_steps": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 6e080e226..15df88f96 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -14,6 +14,15 @@ dataset_prepared_path: last_run_prepared val_set_size: 0.0 output_dir: ./qlora-out +## You can optionally freeze the entire model and unfreeze a subset of parameters +unfrozen_parameters: +# - lm_head.* +# - model.embed_tokens.* +# - model.layers.2[0-9]+.block_sparse_moe.gate.* +# - model.layers.2[0-9]+.block_sparse_moe.experts.* +# - model.layers.3[0-9]+.block_sparse_moe.gate.* +# - model.layers.3[0-9]+.block_sparse_moe.experts.* + adapter: qlora lora_model_dir: diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 1e6fbc320..81307b6b9 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -22,8 +22,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs): # pylint: disable=duplicate-code - print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) + print_axolotl_text_art() check_accelerate_default_config() check_user_token() parser = transformers.HfArgumentParser((TrainerCliArgs)) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 022d230cb..b65d1455f 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -18,6 +18,7 @@ from axolotl.logging_config import configure_logging from axolotl.monkeypatch import neft_embeddings from axolotl.utils.dict import DictDefault +from axolotl.utils.freeze import freeze_parameters_except from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer @@ -78,6 +79,9 @@ def train( ) resume_from_checkpoint = cfg.resume_from_checkpoint + if cfg.unfrozen_parameters: + freeze_parameters_except(model, cfg.unfrozen_parameters) + trainer = setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps ) diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py new file mode 100644 index 000000000..05beda1ca --- /dev/null +++ b/src/axolotl/utils/freeze.py @@ -0,0 +1,38 @@ +""" +module to freeze/unfreeze parameters by name +""" +import logging +import re + +from axolotl.utils.distributed import is_main_process + +LOG = logging.getLogger("axolotl.utils.freeze") + + +def freeze_parameters_except(model, regex_patterns): + """ + Freezes all layers of the given model except for the layers that match given regex patterns. + Periods in the patterns are treated as literal periods, not as wildcard characters. + + Parameters: + - model (nn.Module): The PyTorch model to be modified. + - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen. + + Returns: + None; the model is modified in place. + """ + # Escape periods and compile the regex patterns + compiled_patterns = [ + re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns + ] + + # First, freeze all parameters in the model + for param in model.parameters(): + param.requires_grad = False + + # Unfreeze layers that match the regex patterns + for name, param in model.named_parameters(): + if any(pattern.match(name) for pattern in compiled_patterns): + if is_main_process(): + LOG.debug(f"unfreezing {name}") + param.requires_grad = True diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 41a3582ea..251256262 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -21,6 +21,7 @@ PreTrainedModel, PreTrainedTokenizerBase, ) +from transformers.deepspeed import is_deepspeed_zero3_enabled from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN @@ -285,6 +286,9 @@ def load_model( model_kwargs["max_memory"] = cfg.max_memory model_kwargs["torch_dtype"] = cfg.torch_dtype + if is_deepspeed_zero3_enabled(): + del model_kwargs["device_map"] + if cfg.model_revision: model_kwargs["revision"] = cfg.model_revision if cfg.gptq: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 590861cc0..f046dd7be 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -276,6 +276,7 @@ def prepare_optim_env(cfg): setup_fsdp_envs(cfg) elif cfg.deepspeed: os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):