Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for native amp #1561

Merged
merged 17 commits into from
Apr 23, 2020
Merged

support for native amp #1561

merged 17 commits into from
Apr 23, 2020

Conversation

williamFalcon
Copy link
Contributor

@williamFalcon williamFalcon commented Apr 22, 2020

Fixes #1336
Fixes #1337

@mcarilli mind taking a look?

Issue 1

We have a slight issue with the DP API...

MyModel(nn.Module):
    ...
    @autocast()
    def forward(self, input):
       ...

# Alternatively
MyModel(nn.Module):
    ...
    def forward(self, input):
        with autocast():
            ...

@ethanwharris suggested a way around this which we have in the PR

original_fwd = model.forward
model.forward = autocast()(model.forward)

# train and stuff
# ...

model.forward = original_fwd

Issue 2

How do we save the state of the scaling factor to resume training?
@mcarilli

@mergify mergify bot requested a review from a team April 22, 2020 16:22
@Borda Borda added feature Is an improvement or enhancement priority: 0 High priority task labels Apr 22, 2020
@Borda Borda added this to the 0.7.4 milestone Apr 22, 2020
@Borda
Copy link
Member

Borda commented Apr 22, 2020

since which version is amp in pytorch native?

@williamFalcon
Copy link
Contributor Author

1.6. but we don’t need to explicitly check. we can test the properties as i did

@mergify mergify bot requested a review from a team April 22, 2020 17:40
@mcarilli
Copy link

mcarilli commented Apr 22, 2020

How do we save the state of the scaling factor to resume training?

saved_state = scaler.state_dict()
scaler.load_state_dict(saved_state)

@@ -138,11 +138,20 @@ def backward(self, use_amp, loss, optimizer):
else:
loss.backward()

.. note:: with PyTorch 1.6+ + precision=16 + multiple optimizers, set .backward(retrain_graph=True)
Copy link

@mcarilli mcarilli Apr 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this note.

The example is misleading, I guess. The retain_graph=True bit has nothing to do with Amp, it's only present because both losses interleave outputs from multiple models. Both backward passes use the same model graphs, so the first backward() must not tear them down. retain_graph=True would be necessary with or without Amp. That's unclear and maybe I should either change the example snippet so retain_graph=True is not needed, or add a comment clarifying that retain_graph=True there is not Amp-related.

return

if self.trainer.use_native_amp:
# don't forget to retain graph on backward with multiple optimizers

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mergify
Copy link
Contributor

mergify bot commented Apr 22, 2020

This pull request is now in conflict... :(

@williamFalcon
Copy link
Contributor Author

@Borda these tests are failing bc amp is not installed... did we remove amp?

@@ -281,6 +281,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
if on_gpu:
model.cuda(self.root_gpu)

# restore amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcarilli sanity check this loading?

Copy link

@mcarilli mcarilli Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good if you fix the saving https://github.com/PyTorchLightning/pytorch-lightning/pull/1561/files#r413418705

Like saving, loading should occur either at the very beginning of an iteration (before any training-related scaler calls for that iteration) or at the end of an iteration, after scaler.update(). It doesn't make a lot of sense to load state dicts at the end of an iteration, but if the saved state originated from a scaler.state_dict() call at the end of, say, iteration 1000 (i.e. after iteration 1000's call to scaler.update()), then it's ok to call load_state_dict at the beginning of iteration 1001 to resume.

@@ -316,6 +320,10 @@ def dump_checkpoint(self):

checkpoint['state_dict'] = model.state_dict()

# restore native amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcarilli sanity check this saving?

Copy link

@mcarilli mcarilli Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state_dict is a method, as for modules and optimizers, so checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() is what you want.
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict would stash the bound-method object itself :P

Copy link

@mcarilli mcarilli Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also you should make sure state_dict() is retrieved either at the very beginning of an iteration (before any scaler method calls) or at the very end (after scaler.update()), and that the model and optimizer state dicts are saved at that same spot.

I can't tell from these lines alone if the calling code occurs at a spot that obeys those criteria.

Copy link
Contributor Author

@williamFalcon williamFalcon Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought it was a property haha, but i guess it's consistent with the other state_dict() calls haha

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol i see. it's consistent with the rest

Copy link

@mcarilli mcarilli Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing to consider is that with torch.cuda.amp, it's permissible to

  • load a checkpoint from a model + optimizer not trained with Amp, and resume training with Amp enabled, or
  • load a checkpoint from a model + optimizer trained with Amp, and resume training without Amp.

I think your if criteria are flexible enough that both those cases can happen naturally with the appropriate user args but I'm not sure just from looking at it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this code works.

Case 1: Train with amp, load amp

works fine

case 2: Train amp, load and not use amp

in this case, lightning loads the amp state but amp is disabled so user doesn't use it at all

case 3: train regular, resume regular

works fine

case 4: train regular, resume with amp

in this case the checkpoint has no amp state and model starts normal but on amp.

@@ -316,6 +320,10 @@ def dump_checkpoint(self):

checkpoint['state_dict'] = model.state_dict()

# restore native amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()

@Borda
Copy link
Member

Borda commented Apr 23, 2020

@Borda these tests are failing bc amp is not installed... did we remove amp?

probably, unfortunately, it happened here with Horovoed #1529 (comment)
APEX was removed in 9257b37

@mergify
Copy link
Contributor

mergify bot commented Apr 23, 2020

This pull request is now in conflict... :(

.drone.yml Outdated Show resolved Hide resolved
@mergify mergify bot requested a review from a team April 23, 2020 18:14
@codecov
Copy link

codecov bot commented Apr 23, 2020

Codecov Report

Merging #1561 into master will decrease coverage by 0%.
The diff coverage is 55%.

@@          Coverage Diff           @@
##           master   #1561   +/-   ##
======================================
- Coverage      89%     88%   -0%     
======================================
  Files          68      68           
  Lines        3913    3955   +42     
======================================
+ Hits         3473    3496   +23     
- Misses        440     459   +19     

@williamFalcon williamFalcon merged commit 29ebe92 into master Apr 23, 2020
@Borda Borda deleted the apex branch April 23, 2020 20:41
kepler added a commit to kepler/pytorch-lightning that referenced this pull request May 11, 2020
@kepler kepler mentioned this pull request May 11, 2020
5 tasks
williamFalcon pushed a commit that referenced this pull request May 12, 2020
@Borda Borda modified the milestones: 0.7.4, v0.7.x Apr 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement priority: 0 High priority task
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Native Amp Support Native Amp Support
4 participants