Skip to content

Commit

Permalink
Determine FSDP/deepspeed settings on device select. (axolotl-ai-cloud…
Browse files Browse the repository at this point in the history
…#883)

* Determine FSDP/deepspeed settings on device select.

Without this, the OS env check for accelerate will fail.

* rename and move env setup call

* chore: lint

---------

Co-authored-by: Karl-Johan Alm <kalle@gmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
3 people committed Nov 29, 2023
1 parent 74a4856 commit f8e4520
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
Expand Down Expand Up @@ -296,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):

validate_config(cfg)

prepare_optim_env(cfg)

normalize_config(cfg)

setup_wandb_env_vars(cfg)
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,14 @@ def setup_fsdp_envs(cfg):
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap


def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
def prepare_optim_env(cfg):
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"


def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset
Expand Down

0 comments on commit f8e4520

Please sign in to comment.