-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix LARC with mixed precision (#793)
The LARC optimizer wraps an underlying optimizer and then needs to be passed to amp.initialize for mixed precision. There were 3 different crashes happening in this situation, fix all of them and add a unit test. I don't know if the 'LARC' in sys.modules check ever worked. In my setup, the entry in sys.modules is 'apex.parallel.LARC'. Checking if the variable is defined seems more reliable though.
- Loading branch information
Showing
4 changed files
with
59 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import unittest | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import Parameter | ||
|
||
from apex import amp | ||
from apex.parallel.LARC import LARC | ||
from utils import common_init | ||
|
||
|
||
class MyModel(torch.nn.Module): | ||
def __init__(self, unique): | ||
super(MyModel, self).__init__() | ||
self.weight0 = Parameter( | ||
unique + torch.arange(2, device="cuda", dtype=torch.float32) | ||
) | ||
|
||
def forward(self, input): | ||
return (input * self.weight0).sum() | ||
|
||
|
||
class TestLARC(unittest.TestCase): | ||
def setUp(self): | ||
self.x = torch.ones((2), device="cuda", dtype=torch.float32) | ||
common_init(self) | ||
|
||
def tearDown(self): | ||
pass | ||
|
||
def test_larc_mixed_precision(self): | ||
for opt_level in ["O0", "O1", "O2", "O3"]: | ||
model = MyModel(1) | ||
|
||
optimizer = LARC( | ||
torch.optim.SGD( | ||
[{"params": model.parameters(), "lr": 0.25}], momentum=0.125 | ||
) | ||
) | ||
|
||
model, optimizer = amp.initialize( | ||
model, optimizer, opt_level=opt_level, verbosity=0 | ||
) | ||
|
||
optimizer.zero_grad() | ||
loss = model(self.x) | ||
with amp.scale_loss(loss, optimizer) as scaled_loss: | ||
scaled_loss.backward() | ||
optimizer.step() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |