Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Deepspeed loading #950

Merged
merged 5 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions deepspeed/zero3_bf16.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 0,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
hamelsmu marked this conversation as resolved.
Show resolved Hide resolved
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
2 changes: 1 addition & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer

Expand Down Expand Up @@ -78,6 +79,9 @@ def train(
)
resume_from_checkpoint = cfg.resume_from_checkpoint

if cfg.unfrozen_parameters:
freeze_parameters_except(model, cfg.unfrozen_parameters)

trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
Expand Down
32 changes: 32 additions & 0 deletions src/axolotl/utils/freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
module to freeze/unfreeze parameters by name
"""
import re


def freeze_parameters_except(model, regex_patterns):
"""
Freezes all layers of the given model except for the layers that match given regex patterns.
Periods in the patterns are treated as literal periods, not as wildcard characters.

Parameters:
- model (nn.Module): The PyTorch model to be modified.
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.

Returns:
None; the model is modified in place.
"""
# Escape periods and compile the regex patterns
compiled_patterns = [
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
]

# First, freeze all parameters in the model
for param in model.parameters():
param.requires_grad = False

# Unfreeze layers that match the regex patterns
for name, child in model.named_children():
if any(pattern.match(name) for pattern in compiled_patterns):
for param in child.parameters():
param.requires_grad = True
4 changes: 4 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.deepspeed import is_deepspeed_zero3_enabled

from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
Expand Down Expand Up @@ -285,6 +286,9 @@ def load_model(
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype

if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you have to delete this key?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepspeed maps the weights on its own and doesn't want device_map set

https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2858-L2861


if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def prepare_optim_env(cfg):
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed


def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
Expand Down