diff --git a/.drone.yml b/.drone.yml index 23f520c0b77e3..407ebd066cf9b 100644 --- a/.drone.yml +++ b/.drone.yml @@ -31,7 +31,7 @@ steps: - pip install pip -U - pip --version - nvidia-smi - # - bash ./tests/install_AMP.sh + #- bash ./tests/install_AMP.sh - apt-get update && apt-get install -y cmake - pip install -r requirements.txt --user -q - pip install -r ./tests/requirements-devel.txt --user -q diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 67b67ec1e8ee6..1a3f05be11c50 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -140,9 +140,15 @@ def backward(self, use_amp, loss, optimizer): """ if trainer.precision == 16: - # .backward is not special on 16-bit with TPUs - if not trainer.on_tpu: + if trainer.on_tpu: + return + + if self.trainer.use_native_amp: + self.trainer.scaler.scale(loss).backward() + + # TODO: remove in v0.8.0 + else: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bde43a6a0f8f6..ace22fd75cb0e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1157,9 +1157,22 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, if self.trainer.use_tpu and XLA_AVAILABLE: xm.optimizer_step(optimizer) elif isinstance(optimizer, torch.optim.LBFGS): + + # native amp + lbfgs is a no go right now + if self.use_amp and self.use_native_amp: + m = 'native PyTorch amp and lbfgs are not compatible. To request, please file' \ + 'a Github issue in PyTorch and tag @mcarilli' + raise MisconfigurationException(m) optimizer.step(second_order_closure) else: - optimizer.step() + if self.use_amp and self.use_native_amp: + self.trainer.scaler.step(optimizer) + else: + optimizer.step() + + # in native 16-bit we need to update scaler after optimizer step + if self.use_amp and self.use_native_amp: + self.trainer.scaler.update() # model hook self.on_before_zero_grad(optimizer) diff --git a/pytorch_lightning/trainer/auto_mix_precision.py b/pytorch_lightning/trainer/auto_mix_precision.py index 135cf83e288c8..2551b8a22dd0f 100644 --- a/pytorch_lightning/trainer/auto_mix_precision.py +++ b/pytorch_lightning/trainer/auto_mix_precision.py @@ -1,6 +1,8 @@ from abc import ABC +import torch from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import rank_zero_warn try: from apex import amp @@ -15,8 +17,28 @@ class TrainerAMPMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class precision: int + use_native_amp: bool def init_amp(self, use_amp): + # TODO: remove in v 0.8.0 + if self.use_native_amp: + rank_zero_warn("`amp_level` has been deprecated since v0.7.4 " + "(native amp does not require it)" + " and this argument will be removed in v0.8.0", DeprecationWarning) + + # Backward compatibility, TODO: remove in v0.9.0 + if use_amp is not None: + rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0" + " and this argument will be removed in v0.9.0", DeprecationWarning) + self.precision = 16 if use_amp else 32 + + assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' + + if use_amp and self.use_native_amp: + log.info('Using 16bit precision.') + return + + # TODO: remove all below for v0.8.0 if use_amp and not APEX_AVAILABLE: # pragma: no-cover raise ModuleNotFoundError(""" You set `use_amp=True` but do not have apex installed. @@ -31,4 +53,4 @@ def init_amp(self, use_amp): @property def use_amp(self) -> bool: - return self.precision == 16 and APEX_AVAILABLE + return self.precision == 16 diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index bfc85ee883f6e..736af5cad928f 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -151,6 +151,7 @@ class TrainerDDPMixin(ABC): amp_level: str use_tpu: bool default_root_dir: str + use_native_amp: bool @property @abstractmethod @@ -350,8 +351,8 @@ def ddp_train(self, process_idx, model): # AMP # run through amp wrapper before going to distributed DP - if self.use_amp: - # An example + # TODO: remove in v0.8.0 + if self.use_amp and not self.use_native_amp: model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 7ce61bbfb77e6..7b79922d82a00 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -394,6 +394,7 @@ class TrainerDPMixin(ABC): tpu_local_core_rank: int tpu_global_core_rank: int use_tpu: bool + use_native_amp: bool data_parallel_device_ids: ... logger: Union[LightningLoggerBase, bool] @@ -481,7 +482,8 @@ def single_gpu_train(self, model): # allow for lr schedulers as well self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - if self.use_amp: + # TODO: update for 0.8.0 + if self.use_amp and not self.use_native_amp: # An example model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers @@ -528,9 +530,16 @@ def dp_train(self, model): model.cuda(self.root_gpu) + # hack forward to do autocast for the user + model_autocast_original_forward = model.forward + if self.use_amp and self.use_native_amp: + # wrap the user's forward in autocast and give it back at the end + model.forward = torch.cuda.amp.autocast()(model.forward) + + # TODO: remove in v0.8.0 # check for this bug (amp + dp + !01 doesn't work) # https://github.com/NVIDIA/apex/issues/227 - if self.use_dp and self.use_amp: + if self.use_dp and self.use_amp and not self.use_native_amp: if self.amp_level == 'O2': raise MisconfigurationException( f'Amp level {self.amp_level} with DataParallel is not supported.' @@ -551,6 +560,8 @@ def dp_train(self, model): self.run_pretrain_routine(model) + model.forward = model_autocast_original_forward + def horovod_train(self, model): # Horovod: initialize library hvd.init() diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a996bd7a60d70..534c156ee7c6e 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -268,7 +268,11 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_ # ----------------- # RUN EVALUATION STEP # ----------------- - output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) + if self.use_amp and self.use_native_amp: + with torch.cuda.amp.autocast(): + output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) + else: + output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) # on dp / ddp2 might still want to do something with the batch parts if test_mode: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 20ef14ca3cb50..e92f3c092631e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -115,7 +115,6 @@ def __init__( print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0 weights_summary: Optional[str] = 'full', weights_save_path: Optional[str] = None, - amp_level: str = 'O1', num_sanity_val_steps: int = 5, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, @@ -124,6 +123,7 @@ def __init__( reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, + amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0 default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 @@ -487,20 +487,18 @@ def __init__( self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) - # 16 bit mixed precision training using apex + # AMP init + # These are the only lines needed after v0.8.0 + # we wrap the user's forward with autocast and give it back at the end of fit + self.autocast_original_forward = None + self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") + if self.use_native_amp and self.precision == 16: + self.scaler = torch.cuda.amp.GradScaler() + self.precision = precision + + # TODO: remove for v0.8.0 self.amp_level = amp_level self.precision = precision - - # Backward compatibility, TODO: remove in v0.9.0 - if use_amp is not None: - rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0" - " and this argument will be removed in v0.9.0", DeprecationWarning) - self.precision = 16 if use_amp else 32 - - assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' - - if self.precision == 16 and self.num_tpu_cores is None: - use_amp = True self.init_amp(use_amp) # Callback system diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 47448132df28a..0e9d00c6c4019 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -281,6 +281,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool): if on_gpu: model.cuda(self.root_gpu) + # restore amp scaling + if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: + self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) + # load training state (affects trainer only) self.restore_training_state(checkpoint) @@ -316,6 +320,10 @@ def dump_checkpoint(self): checkpoint['state_dict'] = model.state_dict() + # restore native amp scaling + if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: + checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() + if hasattr(model, "hparams"): is_namespace = isinstance(model.hparams, Namespace) checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams @@ -441,6 +449,10 @@ def hpc_load(self, folderpath, on_gpu): # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) + # restore amp scaling + if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: + self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) + if self.root_gpu is not None: model.cuda(self.root_gpu) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5b3d13c72b5f1..4b9e906d32f29 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -148,6 +148,7 @@ def training_step(self, batch, batch_idx): import numpy as np from torch.utils.data import DataLoader +import torch from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback @@ -588,8 +589,12 @@ def run_training_batch(self, batch, batch_idx): def optimizer_closure(): # forward pass with self.profiler.profile('model_forward'): - output_dict = self.training_forward( - split_batch, batch_idx, opt_idx, self.hiddens) + if self.use_amp and self.use_native_amp: + with torch.cuda.amp.autocast(): + output_dict = self.training_forward(split_batch, batch_idx, + opt_idx, self.hiddens) + else: + output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens) # format and reduce outputs accordingly processed_output = self.process_output(output_dict, train=True) @@ -645,6 +650,8 @@ def optimizer_closure(): self.track_grad_norm) # clip gradients + if self.use_amp and self.use_native_amp: + self.scaler.unscale_(optimizer) self.clip_gradients() # calls .step(), .zero_grad() diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 3364c9d305455..0d86d53b7bbc4 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -24,6 +24,7 @@ def get_model(self): """Warning: this is just empty shell for code implemented in other class.""" def clip_gradients(self): + # this code is a modification of torch.nn.utils.clip_grad_norm_ # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md if self.gradient_clip_val > 0: