Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

let hf trainer handle torch compile #516

Merged
merged 5 commits into from
Sep 13, 2023
Merged

let hf trainer handle torch compile #516

merged 5 commits into from
Sep 13, 2023

Conversation

winglian
Copy link
Collaborator

No description provided.

src/axolotl/utils/trainer.py Outdated Show resolved Hide resolved
@winglian winglian added the hold don't merge this yet label Aug 31, 2023
@@ -526,6 +526,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# where to save the finished model to
output_dir: ./completed-model

# whether to use torch.compile and which backend to use
torch_compile: # bool
torch_compile_backend: # Optional[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to open this "backend" config?

Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.

and reading the source of the torch compile, it seems to only have a condition for inductor.

Copy link
Collaborator

@tmm1 tmm1 Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure; there are other backends but maybe they're not useful?

torch._dynamo.list_backends()

['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tvm']

@NanoCode012
Copy link
Collaborator

It seems like there may be small benefits to leaving it to default on. Is there any case having it off is better?

@tmm1
Copy link
Collaborator

tmm1 commented Sep 1, 2023

Is there any case having it off is better?

it cannot default on because it doesn't work. if you try to train it errors out.

@@ -579,6 +580,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset

if cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tmm1 lmk if you think this is good enough for now

@winglian winglian merged commit a4e1bb6 into main Sep 13, 2023
6 checks passed
@winglian winglian deleted the torch-compile branch September 13, 2023 15:42
mkeoliya pushed a commit to mkeoliya/axolotl that referenced this pull request Dec 15, 2023
* let hf trainer handle torch compile

* remove torch compile checks, include option for backend

* suppress torch errors to get further

* require min torch version of 2.1.0 for torch compile to work

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
hold don't merge this yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants