Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

let hf trainer handle torch compile #516

Merged
merged 5 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# where to save the finished model to
output_dir: ./completed-model

# whether to use torch.compile and which backend to use
torch_compile: # bool
torch_compile_backend: # Optional[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to open this "backend" config?

Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.

and reading the source of the torch compile, it seems to only have a condition for inductor.

Copy link
Collaborator

@tmm1 tmm1 Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure; there are other backends but maybe they're not useful?

torch._dynamo.list_backends()

['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tvm']


# training hyperparameters
gradient_accumulation_steps: 1
micro_batch_size: 2
Expand Down
4 changes: 0 additions & 4 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ def train(

model.config.use_cache = False

if torch.__version__ >= "2" and sys.platform != "win32":
tmm1 marked this conversation as resolved.
Show resolved Hide resolved
LOG.info("Compiling torch model")
model = torch.compile(model)

# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
Expand Down
7 changes: 7 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset

if cfg.torch_compile:
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
if cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = cfg.torch_compile_backend

# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
Expand Down