Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for torch 2.0 #2172

Merged
merged 57 commits into from
Apr 27, 2023
Merged

Add support for torch 2.0 #2172

merged 57 commits into from
Apr 27, 2023

Conversation

dakinggg
Copy link
Contributor

@dakinggg dakinggg commented Apr 26, 2023

What does this PR do?

This PR upgrades the torch pin to support torch 2.0. It includes related fixes, mostly resulting from the use of use_orig_params=True with FSDP, which is necessary to support compile.

Changes:

  • Different way of loading FSDP state dicts for torch 2
  • _LRScheduler -> LRScheduler
  • using summon_full_params to get HF generate to work with FSDP
  • fixing the way we compute optimizer metrics to allow for the fact that not all ranks have all params
  • assorted minor fixes and test fixes

Manual tests:

  • sharded multinode autoresume

Screen Shot 2023-04-26 at 7 37 10 PM

  • local autoresume

Screen Shot 2023-04-26 at 7 36 33 PM

  • full resume

Screen Shot 2023-04-26 at 7 35 59 PM

Screen Shot 2023-04-27 at 12 49 25 AM

What issue(s) does this change relate to?

Closes CO-2029
Closes #2147

Before submitting

  • Have you read the contributor guidelines?
  • Was this change discussed/approved in a GitHub issue first? It is much more likely to be merged if so.
  • Did you update any related docs and document your change?
  • Did you update any related tests and add any new tests related to your change? (see testing)
  • Did you run the tests locally to make sure they pass?
  • Did you run pre-commit on your change? (see the pre-commit section of prerequisites)

@dakinggg dakinggg marked this pull request as ready for review April 27, 2023 02:38
@dakinggg dakinggg requested review from a team as code owners April 27, 2023 02:38
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a gap in the resumption tests?

Mostly LGTM / minor comments. will do one more pass after comments are resolved before approval because PR is massive

.github/workflows/pr-cpu.yaml Outdated Show resolved Hide resolved
.github/workflows/pr-cpu.yaml Outdated Show resolved Hide resolved
.github/workflows/pr-gpu.yaml Show resolved Hide resolved
composer/callbacks/optimizer_monitor.py Show resolved Hide resolved
composer/core/state.py Outdated Show resolved Hide resolved
tests/algorithms/test_gradient_clipping.py Show resolved Hide resolved
tests/callbacks/test_optimizer_monitor.py Outdated Show resolved Hide resolved
tests/common/models.py Outdated Show resolved Hide resolved
tests/trainer/test_sharded_checkpoint.py Outdated Show resolved Hide resolved
tests/trainer/test_sharded_checkpoint.py Show resolved Hide resolved
@dakinggg
Copy link
Contributor Author

dakinggg commented Apr 27, 2023

@mvpatel2000 the resumptions with a gap are because run 1 went for 10 batches, and then run 2 was run with autoresume true and an increased max duration, rather than deleting some checkpoints or something. Lmk if that makes sense.

@dakinggg dakinggg requested a review from dskhudia as a code owner April 27, 2023 07:12
dakinggg and others added 13 commits April 27, 2023 01:02
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Serialize and load torchmetrics through state_dict() and load_state_dict() instead of pickle
Copy link
Contributor

@karan6181 karan6181 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments. Overall looks good. I liked the detailed comments that you have added at multiple places to get better understanding of a code.

Minor nit: More descriptive PR header name ?

composer/core/state.py Outdated Show resolved Hide resolved
composer/core/types.py Show resolved Hide resolved
composer/utils/auto_log_hparams.py Show resolved Hide resolved
@dakinggg dakinggg changed the title Torch2 Add support for torch 2.0 Apr 27, 2023
Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let's get more approvals though before merging. Only outstanding is adding GPU daily tests

Copy link
Contributor

@karan6181 karan6181 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Copy link
Contributor

@nik-mosaic nik-mosaic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add helpful error messages to the ONNX export method if a user runs into an error and is running PyTorch 2.0? We could suggest they try downgrading PyTorch versions if one of their model operators is not supported.

This is not a blocking suggestion --- we can merge without this.

@dakinggg
Copy link
Contributor Author

I think the message you get from ONNX directly is about as clear as it gets...since we don't know which operator they're having trouble with and how that correspond to opset version and torch version. but open to another suggestion

@dakinggg dakinggg merged commit 6180ef0 into dev Apr 27, 2023
@dakinggg dakinggg deleted the torch2branch2 branch April 27, 2023 22:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Pytorch 2.0 support
5 participants