Skip to content

Commit

Permalink
Fix LARC with mixed precision (#793)
Browse files Browse the repository at this point in the history
The LARC optimizer wraps an underlying optimizer and then needs to be passed
to amp.initialize for mixed precision. There were 3 different crashes happening
in this situation, fix all of them and add a unit test.

I don't know if the 'LARC' in sys.modules check ever worked. In my setup, the
entry in sys.modules is 'apex.parallel.LARC'. Checking if the variable is
defined seems more reliable though.
  • Loading branch information
vreis authored Apr 22, 2020
1 parent 55716d8 commit 2ec84eb
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 2 deletions.
2 changes: 1 addition & 1 deletion apex/amp/_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
from .amp import init as amp_init

optimizers_was_list = False
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in sys.modules and isinstance(optimizers, LARC)):
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
optimizers = [optimizers]
elif optimizers is None:
optimizers = []
Expand Down
2 changes: 1 addition & 1 deletion apex/amp/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def scale_loss(loss,
yield loss
return

if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in sys.modules and isinstance(optimizers, LARC)):
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
optimizers = [optimizers]

loss_scaler = _amp_state.loss_scalers[loss_id]
Expand Down
4 changes: 4 additions & 0 deletions apex/parallel/LARC.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __getstate__(self):
def __setstate__(self, state):
self.optim.__setstate__(state)

@property
def state(self):
return self.optim.state

def __repr__(self):
return self.optim.__repr__()

Expand Down
53 changes: 53 additions & 0 deletions tests/L0/run_amp/test_larc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest

import torch
from torch import nn
from torch.nn import Parameter

from apex import amp
from apex.parallel.LARC import LARC
from utils import common_init


class MyModel(torch.nn.Module):
def __init__(self, unique):
super(MyModel, self).__init__()
self.weight0 = Parameter(
unique + torch.arange(2, device="cuda", dtype=torch.float32)
)

def forward(self, input):
return (input * self.weight0).sum()


class TestLARC(unittest.TestCase):
def setUp(self):
self.x = torch.ones((2), device="cuda", dtype=torch.float32)
common_init(self)

def tearDown(self):
pass

def test_larc_mixed_precision(self):
for opt_level in ["O0", "O1", "O2", "O3"]:
model = MyModel(1)

optimizer = LARC(
torch.optim.SGD(
[{"params": model.parameters(), "lr": 0.25}], momentum=0.125
)
)

model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level, verbosity=0
)

optimizer.zero_grad()
loss = model(self.x)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()


if __name__ == "__main__":
unittest.main()

0 comments on commit 2ec84eb

Please sign in to comment.