From d4147515739cf85bc934a587e2a00660d5de4865 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 13 Dec 2023 09:04:57 -0500 Subject: [PATCH 1/5] add check for zero3 --- src/axolotl/utils/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 41a3582ea..251256262 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -21,6 +21,7 @@ PreTrainedModel, PreTrainedTokenizerBase, ) +from transformers.deepspeed import is_deepspeed_zero3_enabled from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN @@ -285,6 +286,9 @@ def load_model( model_kwargs["max_memory"] = cfg.max_memory model_kwargs["torch_dtype"] = cfg.torch_dtype + if is_deepspeed_zero3_enabled(): + del model_kwargs["device_map"] + if cfg.model_revision: model_kwargs["revision"] = cfg.model_revision if cfg.gptq: From 7d2ec9b6be6d5f90855da103e35a8c56aa0e8634 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 13 Dec 2023 09:21:16 -0500 Subject: [PATCH 2/5] freeze parameters --- src/axolotl/train.py | 4 ++++ src/axolotl/utils/freeze.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 src/axolotl/utils/freeze.py diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 022d230cb..b65d1455f 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -18,6 +18,7 @@ from axolotl.logging_config import configure_logging from axolotl.monkeypatch import neft_embeddings from axolotl.utils.dict import DictDefault +from axolotl.utils.freeze import freeze_parameters_except from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer @@ -78,6 +79,9 @@ def train( ) resume_from_checkpoint = cfg.resume_from_checkpoint + if cfg.unfrozen_parameters: + freeze_parameters_except(model, cfg.unfrozen_parameters) + trainer = setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps ) diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py new file mode 100644 index 000000000..fca4b9dd6 --- /dev/null +++ b/src/axolotl/utils/freeze.py @@ -0,0 +1,32 @@ +""" +module to freeze/unfreeze parameters by name +""" +import re + + +def freeze_parameters_except(model, regex_patterns): + """ + Freezes all layers of the given model except for the layers that match given regex patterns. + Periods in the patterns are treated as literal periods, not as wildcard characters. + + Parameters: + - model (nn.Module): The PyTorch model to be modified. + - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen. + + Returns: + None; the model is modified in place. + """ + # Escape periods and compile the regex patterns + compiled_patterns = [ + re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns + ] + + # First, freeze all parameters in the model + for param in model.parameters(): + param.requires_grad = False + + # Unfreeze layers that match the regex patterns + for name, child in model.named_children(): + if any(pattern.match(name) for pattern in compiled_patterns): + for param in child.parameters(): + param.requires_grad = True From ad9538ab9495e742b61ceb2f8871f8e3269edb8a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 13 Dec 2023 11:44:19 -0500 Subject: [PATCH 3/5] fixes for deepspeed loading --- deepspeed/zero3_bf16.json | 39 ++++++++++++++++++++++++++++++++++++ src/axolotl/cli/train.py | 2 +- src/axolotl/utils/trainer.py | 1 + 3 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 deepspeed/zero3_bf16.json diff --git a/deepspeed/zero3_bf16.json b/deepspeed/zero3_bf16.json new file mode 100644 index 000000000..42d10b6bd --- /dev/null +++ b/deepspeed/zero3_bf16.json @@ -0,0 +1,39 @@ +{ + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 0, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 0, + "stage3_max_reuse_distance": 0, + "stage3_gather_16bit_weights_on_model_save": true + }, + "bf16": { + "enabled": true + }, + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 32, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "gradient_accumulation_steps": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 1e6fbc320..81307b6b9 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -22,8 +22,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs): # pylint: disable=duplicate-code - print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) + print_axolotl_text_art() check_accelerate_default_config() check_user_token() parser = transformers.HfArgumentParser((TrainerCliArgs)) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 590861cc0..f046dd7be 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -276,6 +276,7 @@ def prepare_optim_env(cfg): setup_fsdp_envs(cfg) elif cfg.deepspeed: os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): From 355c1e3d8ac00c8e8d8c2ceaecb770e82e0160fe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 13 Dec 2023 13:23:46 -0500 Subject: [PATCH 4/5] fix model parameter check --- src/axolotl/utils/freeze.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index fca4b9dd6..0cb21512b 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -26,7 +26,6 @@ def freeze_parameters_except(model, regex_patterns): param.requires_grad = False # Unfreeze layers that match the regex patterns - for name, child in model.named_children(): + for name, param in model.named_parameters(): if any(pattern.match(name) for pattern in compiled_patterns): - for param in child.parameters(): - param.requires_grad = True + param.requires_grad = True From ec12e91df46fa49697bfbdf37713f95634d0b091 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 13 Dec 2023 13:37:58 -0500 Subject: [PATCH 5/5] unfrozen parameters in example mixtral and logging when unfreezing --- examples/mistral/mixtral.yml | 9 +++++++++ src/axolotl/utils/freeze.py | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 6e080e226..15df88f96 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -14,6 +14,15 @@ dataset_prepared_path: last_run_prepared val_set_size: 0.0 output_dir: ./qlora-out +## You can optionally freeze the entire model and unfreeze a subset of parameters +unfrozen_parameters: +# - lm_head.* +# - model.embed_tokens.* +# - model.layers.2[0-9]+.block_sparse_moe.gate.* +# - model.layers.2[0-9]+.block_sparse_moe.experts.* +# - model.layers.3[0-9]+.block_sparse_moe.gate.* +# - model.layers.3[0-9]+.block_sparse_moe.experts.* + adapter: qlora lora_model_dir: diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index 0cb21512b..05beda1ca 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -1,8 +1,13 @@ """ module to freeze/unfreeze parameters by name """ +import logging import re +from axolotl.utils.distributed import is_main_process + +LOG = logging.getLogger("axolotl.utils.freeze") + def freeze_parameters_except(model, regex_patterns): """ @@ -28,4 +33,6 @@ def freeze_parameters_except(model, regex_patterns): # Unfreeze layers that match the regex patterns for name, param in model.named_parameters(): if any(pattern.match(name) for pattern in compiled_patterns): + if is_main_process(): + LOG.debug(f"unfreezing {name}") param.requires_grad = True