Skip to content

Commit

Permalink
New apex amp API (#1465)
Browse files Browse the repository at this point in the history
* use new apex amp API
* make apex opt_level as option
  • Loading branch information
francoishernandez authored and vince62s committed Jun 13, 2019
1 parent e156cce commit aaa220b
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 25 deletions.
2 changes: 0 additions & 2 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ def fix_key(s):

model.generator = generator
model.to(device)
if model_opt.model_dtype == 'fp16':
model.half()

return model

Expand Down
4 changes: 4 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def model_opts(parser):
group.add('--loss_scale', '-loss_scale', type=float, default=0,
help="For FP16 training, the static loss scale to use. If not "
"set, the loss scale is dynamically computed.")
group.add('--apex_opt_level', '-apex_opt_level', type=str, default="O2",
choices=["O0", "O1", "O2", "O3"],
help="For FP16 training, the opt_level to use."
"See https://nvidia.github.io/apex/amp.html#opt-levels.")


def preprocess_opts(parser):
Expand Down
3 changes: 1 addition & 2 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,7 @@ def validate(self, valid_iter, moving_average=None):
valid_model = deepcopy(self.model)
for avg, param in zip(self.moving_average,
valid_model.parameters()):
param.data = avg.data.half() if self.model_dtype == "fp16" \
else avg.data
param.data = avg.data
else:
valid_model = self.model

Expand Down
30 changes: 13 additions & 17 deletions onmt/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from copy import copy
from math import sqrt

from onmt.utils.misc import fn_args


def build_torch_optimizer(model, opt):
"""Builds the PyTorch optimizer.
Expand Down Expand Up @@ -87,17 +85,14 @@ def build_torch_optimizer(model, opt):

if opt.model_dtype == 'fp16':
import apex
static_loss_scale = opt.loss_scale
dynamic_loss_scale = opt.loss_scale == 0
# TODO: clean this up when APEX unify its optimizer API.
if opt.optim.startswith('fused'):
namespace = apex.optimizers # Faster wrapper.
else:
namespace = apex.fp16_utils
optimizer = namespace.FP16_Optimizer(
loss_scale = "dynamic" if opt.loss_scale == 0 else opt.loss_scale
model, optimizer = apex.amp.initialize(
[model, model.generator],
optimizer,
static_loss_scale=static_loss_scale,
dynamic_loss_scale=dynamic_loss_scale)
opt_level=opt.apex_opt_level,
loss_scale=loss_scale,
keep_batchnorm_fp32=False if opt.optim == "fusedadam" else None)

return optimizer


Expand Down Expand Up @@ -317,10 +312,9 @@ def backward(self, loss):
"""Wrapper for backward pass. Some optimizer requires ownership of the
backward pass."""
if self._with_fp16_wrapper:
kwargs = {}
if "update_master_grads" in fn_args(self._optimizer.backward):
kwargs["update_master_grads"] = True
self._optimizer.backward(loss, **kwargs)
import apex
with apex.amp.scale_loss(loss, self._optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()

Expand All @@ -336,7 +330,9 @@ def step(self):
self._optimizer.update_master_grads()
if hasattr(self._optimizer, "clip_master_grads") and \
self._max_grad_norm > 0:
self._optimizer.clip_master_grads(self._max_grad_norm)
import apex
torch.nn.utils.glip_grad_norm_(
apex.amp.master_params(self), self._max_grad_norm)
for group in self._optimizer.param_groups:
group['lr'] = learning_rate
if not self._with_fp16_wrapper and self._max_grad_norm > 0:
Expand Down
4 changes: 0 additions & 4 deletions onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def validate_model_opts(cls, model_opt):
if model_opt.model_type != "text":
raise AssertionError(
"--share_embeddings requires --model_type text.")
if model_opt.model_dtype == "fp16":
logger.warning(
"FP16 is experimental, the generated checkpoints may "
"be incompatible with a future version")

@classmethod
def ckpt_model_opts(cls, ckpt_opt):
Expand Down

0 comments on commit aaa220b

Please sign in to comment.