Skip to content

Commit

Permalink
let hf trainer handle torch compile (axolotl-ai-cloud#516)
Browse files Browse the repository at this point in the history
* let hf trainer handle torch compile

* remove torch compile checks, include option for backend

* suppress torch errors to get further

* require min torch version of 2.1.0 for torch compile to work

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
  • Loading branch information
winglian and tmm1 committed Sep 13, 2023
1 parent b4f7b10 commit 6a8d36e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# where to save the finished model to
output_dir: ./completed-model

# whether to use torch.compile and which backend to use
torch_compile: # bool
torch_compile_backend: # Optional[str]

# training hyperparameters
gradient_accumulation_steps: 1
micro_batch_size: 2
Expand Down
4 changes: 0 additions & 4 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ def train(

model.config.use_cache = False

if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)

# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
Expand Down
16 changes: 16 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional, Union

import numpy as np
import torch
import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled
Expand Down Expand Up @@ -604,6 +605,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better

if cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
else:
import torch._dynamo # pylint: disable=redefined-outer-name

torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
if cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = cfg.torch_compile_backend

# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
Expand Down

0 comments on commit 6a8d36e

Please sign in to comment.