Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

let hf trainer handle torch compile #516

Merged
merged 5 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# where to save the finished model to
output_dir: ./completed-model

# whether to use torch.compile and which backend to use
torch_compile: # bool
torch_compile_backend: # Optional[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to open this "backend" config?

Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.

and reading the source of the torch compile, it seems to only have a condition for inductor.

Copy link
Collaborator

@tmm1 tmm1 Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure; there are other backends but maybe they're not useful?

torch._dynamo.list_backends()

['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tvm']


# training hyperparameters
gradient_accumulation_steps: 1
micro_batch_size: 2
Expand Down
4 changes: 0 additions & 4 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ def train(

model.config.use_cache = False

if torch.__version__ >= "2" and sys.platform != "win32":
tmm1 marked this conversation as resolved.
Show resolved Hide resolved
LOG.info("Compiling torch model")
model = torch.compile(model)

# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
Expand Down
16 changes: 16 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional, Union

import numpy as np
import torch
import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled
Expand Down Expand Up @@ -579,6 +580,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset

if cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tmm1 lmk if you think this is good enough for now

LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
else:
import torch._dynamo # pylint: disable=redefined-outer-name

torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
if cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = cfg.torch_compile_backend

# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
Expand Down