Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native Amp Support #1337

Closed
mcarilli opened this issue Apr 2, 2020 · 7 comments · Fixed by #1561
Closed

Native Amp Support #1337

mcarilli opened this issue Apr 2, 2020 · 7 comments · Fixed by #1561
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@mcarilli
Copy link

mcarilli commented Apr 2, 2020

Native automatic mixed precision support (torch.cuda.amp) is finally merged:
https://pytorch.org/docs/master/amp.html
https://pytorch.org/docs/master/notes/amp_examples.html
Apex Amp has many known pain points (extension builds, forward/backward compatibilty, DataParallel support, flaky checkpointing, i don’t even know if it can be hacked to handle double backward/gradient penalty, others…). torch.cuda.amp fixes all these, the interface is more flexible and intuitive, and the tighter integration brings more future performance optimizations into scope.

If you want to talk about adding torch.cuda.amp to Lightning, with an eye towards it becoming the true source of mixed precision and replacing Apex, message me on Pytorch slack anytime. I pinged you there as well, but I’m not sure if you monitor it habitually.

@mcarilli mcarilli added feature Is an improvement or enhancement help wanted Open to be worked on labels Apr 2, 2020
@mcarilli
Copy link
Author

mcarilli commented Apr 2, 2020

I think the torch.cuda.amp API is a much better fit for Lightning because its style is more functional (functional as in, it doesn't statefully alter anything outside itself). The necessary torch.cuda.amp calls could be contained entirely within trainer.fit() without any silent/weird effects elsewhere.

@williamFalcon
Copy link
Contributor

williamFalcon commented Apr 2, 2020

this is awesome. will definitely add! eta on the next pt release?
we can add forward compatibility.

@mcarilli does it still have the issues with saving/loading weights with the loss scaling factor?

@PyTorchLightning/core-contributors anyone interested in making this change?

one key consideration is saving/loading weights when amp scales the loss.

@mcarilli
Copy link
Author

mcarilli commented Apr 2, 2020

Yes, bitwise accurate saving/restoring is supported. You just need to call your GradScaler instance's state_dict() and load_state_dict() alongside the usual model/optimizer state_dict/load_state_dict calls.

@Borda Borda added this to Bugs (highest priority) in Key features - Roadmap v1.0 Apr 2, 2020
@Borda Borda moved this from Bugs (highest priority) to Todo (next release) in Key features - Roadmap v1.0 Apr 2, 2020
@williamFalcon williamFalcon modified the milestones: 0.7.3, 0.7.2 Apr 4, 2020
@williamFalcon
Copy link
Contributor

williamFalcon commented Apr 4, 2020

@mcarilli any chance you'd be interested in submitting the PR?
I might be able to get to it by early this week, but it'd be great to have in 0.7.2 which is coming early next week.

was going to add checks like:

if pytorch.__version__ >= 1.6:
   # new amp stuff
else:
   # old amp stuff

@williamFalcon williamFalcon added priority: 0 High priority task and removed priority: 0 High priority task labels Apr 4, 2020
@mcarilli
Copy link
Author

mcarilli commented Apr 4, 2020

hmm i don't know the lightning codebase at all, aside from the interface. It would take me longer than early next week to be sure I was making the right changes in the right places. The version is a more complex string though, so I'd use something like

version_ge_16 = False
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 6):
    version_ge_16 = True

@Borda
Copy link
Member

Borda commented Apr 4, 2020

not sure about the particular condition if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 6): but yes with parsing the version you can set this king of default env

@mcarilli
Copy link
Author

mcarilli commented Apr 4, 2020

Happy to review in-progress PRs though.

One key point is that torch.cuda.amp.autocast locally (in its context regions) enables Apex "O1"-like casting behavior. It casts inputs on the fly as they enter certain functions, it doesn't need to touch the model or the optimizer at all, nor does it need them to change. You shouldn't manually call .half() on the model or input data. (scaler.step(optimizer) only decides to call optimizer.step() or not, it doesn't change the optimizer in any stateful way).

Also that versioning condition is based on what works for us. The particular number (1.6+) is a decent criterion for native amp availability, the window of commits with torch.__version__ = 1.6.xyz that don't yet have autocast and GradScaler is small.

You could sidestep __version__ parsing entirely and check for full native amp support via

has_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")

Now that I mention it that's probably better.

@Borda Borda modified the milestones: 0.7.2, 0.7.3 Apr 8, 2020
Key features - Roadmap v1.0 automation moved this from Todo (next release) to Done Apr 23, 2020
@Borda Borda modified the milestones: 0.7.4, v0.7.x Apr 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
No open projects
Development

Successfully merging a pull request may close this issue.

3 participants