Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for native amp #1561

Merged
merged 17 commits into from
Apr 23, 2020
2 changes: 1 addition & 1 deletion .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ steps:
- pip install pip -U
- pip --version
- nvidia-smi
# - bash ./tests/install_AMP.sh
- bash ./tests/install_AMP.sh
Borda marked this conversation as resolved.
Show resolved Hide resolved
- apt-get update && apt-get install -y cmake
- pip install -r requirements.txt --user -q
- pip install -r ./tests/requirements-devel.txt --user -q
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,15 @@ def backward(self, use_amp, loss, optimizer):

"""
if trainer.precision == 16:

# .backward is not special on 16-bit with TPUs
if not trainer.on_tpu:
if trainer.on_tpu:
return

if self.trainer.use_native_amp:
self.trainer.scaler.scale(loss).backward()
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved

# TODO: remove in v0.8.0
else:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,22 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,
if self.trainer.use_tpu and XLA_AVAILABLE:
xm.optimizer_step(optimizer)
elif isinstance(optimizer, torch.optim.LBFGS):

# native amp + lbfgs is a no go right now
if self.use_amp and self.use_native_amp:
m = 'native PyTorch amp and lbfgs are not compatible. To request, please file' \
'a Github issue in PyTorch and tag @mcarilli'
raise MisconfigurationException(m)
optimizer.step(second_order_closure)
else:
optimizer.step()
if self.use_amp and self.use_native_amp:
self.trainer.scaler.step(optimizer)
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
else:
optimizer.step()

# in native 16-bit we need to update scaler after optimizer step
if self.use_amp and self.use_native_amp:
self.trainer.scaler.update()

# model hook
self.on_before_zero_grad(optimizer)
Expand Down
24 changes: 23 additions & 1 deletion pytorch_lightning/trainer/auto_mix_precision.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn

try:
from apex import amp
Expand All @@ -15,8 +17,28 @@ class TrainerAMPMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
precision: int
use_native_amp: bool

def init_amp(self, use_amp):
# TODO: remove in v 0.8.0
if self.use_native_amp:
rank_zero_warn("`amp_level` has been deprecated since v0.7.4 "
"(native amp does not require it)"
" and this argument will be removed in v0.8.0", DeprecationWarning)

# Backward compatibility, TODO: remove in v0.9.0
if use_amp is not None:
rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0"
" and this argument will be removed in v0.9.0", DeprecationWarning)
self.precision = 16 if use_amp else 32

assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

if use_amp and self.use_native_amp:
log.info('Using 16bit precision.')
return

# TODO: remove all below for v0.8.0
if use_amp and not APEX_AVAILABLE: # pragma: no-cover
raise ModuleNotFoundError("""
You set `use_amp=True` but do not have apex installed.
Expand All @@ -31,4 +53,4 @@ def init_amp(self, use_amp):

@property
def use_amp(self) -> bool:
return self.precision == 16 and APEX_AVAILABLE
return self.precision == 16
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class TrainerDDPMixin(ABC):
amp_level: str
use_tpu: bool
default_root_dir: str
use_native_amp: bool

@property
@abstractmethod
Expand Down Expand Up @@ -350,8 +351,8 @@ def ddp_train(self, process_idx, model):

# AMP
# run through amp wrapper before going to distributed DP
if self.use_amp:
# An example
# TODO: remove in v0.8.0
if self.use_amp and not self.use_native_amp:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers

Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ class TrainerDPMixin(ABC):
tpu_local_core_rank: int
tpu_global_core_rank: int
use_tpu: bool
use_native_amp: bool
data_parallel_device_ids: ...
logger: Union[LightningLoggerBase, bool]

Expand Down Expand Up @@ -481,7 +482,8 @@ def single_gpu_train(self, model):
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

if self.use_amp:
# TODO: update for 0.8.0
if self.use_amp and not self.use_native_amp:
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
Expand Down Expand Up @@ -528,9 +530,16 @@ def dp_train(self, model):

model.cuda(self.root_gpu)

# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
if self.use_amp and self.use_native_amp:
# wrap the user's forward in autocast and give it back at the end
model.forward = torch.cuda.amp.autocast()(model.forward)
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved

# TODO: remove in v0.8.0
# check for this bug (amp + dp + !01 doesn't work)
# https://github.com/NVIDIA/apex/issues/227
if self.use_dp and self.use_amp:
if self.use_dp and self.use_amp and not self.use_native_amp:
if self.amp_level == 'O2':
raise MisconfigurationException(
f'Amp level {self.amp_level} with DataParallel is not supported.'
Expand All @@ -551,6 +560,8 @@ def dp_train(self, model):

self.run_pretrain_routine(model)

model.forward = model_autocast_original_forward
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

def horovod_train(self, model):
# Horovod: initialize library
hvd.init()
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
# -----------------
# RUN EVALUATION STEP
# -----------------
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
if self.use_amp and self.use_native_amp:
with torch.cuda.amp.autocast():
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
else:
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)

# on dp / ddp2 might still want to do something with the batch parts
if test_mode:
Expand Down
24 changes: 11 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def __init__(
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
weights_summary: Optional[str] = 'full',
weights_save_path: Optional[str] = None,
amp_level: str = 'O1',
num_sanity_val_steps: int = 5,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[str] = None,
Expand All @@ -124,6 +123,7 @@ def __init__(
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0
default_save_path=None, # backward compatible, todo: remove in v0.8.0
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -487,20 +487,18 @@ def __init__(
self.determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)

# 16 bit mixed precision training using apex
# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.autocast_original_forward = None
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
Borda marked this conversation as resolved.
Show resolved Hide resolved
if self.use_native_amp and self.precision == 16:
self.scaler = torch.cuda.amp.GradScaler()
self.precision = precision

# TODO: remove for v0.8.0
self.amp_level = amp_level
self.precision = precision

# Backward compatibility, TODO: remove in v0.9.0
if use_amp is not None:
rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0"
" and this argument will be removed in v0.9.0", DeprecationWarning)
self.precision = 16 if use_amp else 32

assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

if self.precision == 16 and self.num_tpu_cores is None:
use_amp = True
self.init_amp(use_amp)

# Callback system
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
if on_gpu:
model.cuda(self.root_gpu)

# restore amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcarilli sanity check this loading?

Copy link

@mcarilli mcarilli Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good if you fix the saving https://github.com/PyTorchLightning/pytorch-lightning/pull/1561/files#r413418705

Like saving, loading should occur either at the very beginning of an iteration (before any training-related scaler calls for that iteration) or at the end of an iteration, after scaler.update(). It doesn't make a lot of sense to load state dicts at the end of an iteration, but if the saved state originated from a scaler.state_dict() call at the end of, say, iteration 1000 (i.e. after iteration 1000's call to scaler.update()), then it's ok to call load_state_dict at the beginning of iteration 1001 to resume.

self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])

# load training state (affects trainer only)
self.restore_training_state(checkpoint)

Expand Down Expand Up @@ -316,6 +320,10 @@ def dump_checkpoint(self):

checkpoint['state_dict'] = model.state_dict()

# restore native amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcarilli sanity check this saving?

Copy link

@mcarilli mcarilli Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state_dict is a method, as for modules and optimizers, so checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() is what you want.
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict would stash the bound-method object itself :P

Copy link

@mcarilli mcarilli Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also you should make sure state_dict() is retrieved either at the very beginning of an iteration (before any scaler method calls) or at the very end (after scaler.update()), and that the model and optimizer state dicts are saved at that same spot.

I can't tell from these lines alone if the calling code occurs at a spot that obeys those criteria.

Copy link
Contributor Author

@williamFalcon williamFalcon Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought it was a property haha, but i guess it's consistent with the other state_dict() calls haha

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol i see. it's consistent with the rest

Copy link

@mcarilli mcarilli Apr 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing to consider is that with torch.cuda.amp, it's permissible to

  • load a checkpoint from a model + optimizer not trained with Amp, and resume training with Amp enabled, or
  • load a checkpoint from a model + optimizer trained with Amp, and resume training without Amp.

I think your if criteria are flexible enough that both those cases can happen naturally with the appropriate user args but I'm not sure just from looking at it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this code works.

Case 1: Train with amp, load amp

works fine

case 2: Train amp, load and not use amp

in this case, lightning loads the amp state but amp is disabled so user doesn't use it at all

case 3: train regular, resume regular

works fine

case 4: train regular, resume with amp

in this case the checkpoint has no amp state and model starts normal but on amp.

checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()

if hasattr(model, "hparams"):
is_namespace = isinstance(model.hparams, Namespace)
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams
Expand Down Expand Up @@ -441,6 +449,10 @@ def hpc_load(self, folderpath, on_gpu):
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

# restore amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])

if self.root_gpu is not None:
model.cuda(self.root_gpu)

Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def training_step(self, batch, batch_idx):

import numpy as np
from torch.utils.data import DataLoader
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
Expand Down Expand Up @@ -588,8 +589,12 @@ def run_training_batch(self, batch, batch_idx):
def optimizer_closure():
# forward pass
with self.profiler.profile('model_forward'):
output_dict = self.training_forward(
split_batch, batch_idx, opt_idx, self.hiddens)
if self.use_amp and self.use_native_amp:
with torch.cuda.amp.autocast():
output_dict = self.training_forward(split_batch, batch_idx,
opt_idx, self.hiddens)
else:
output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens)

# format and reduce outputs accordingly
processed_output = self.process_output(output_dict, train=True)
Expand Down Expand Up @@ -645,6 +650,8 @@ def optimizer_closure():
self.track_grad_norm)

# clip gradients
if self.use_amp and self.use_native_amp:
self.scaler.unscale_(optimizer)
self.clip_gradients()

# calls .step(), .zero_grad()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_model(self):
"""Warning: this is just empty shell for code implemented in other class."""

def clip_gradients(self):

# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if self.gradient_clip_val > 0:
Expand Down