diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9496bbc844e5b..498a3ef0381a8 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -182,26 +182,20 @@ def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: i Example:: - def backward(self, use_amp, loss, optimizer): - if use_amp: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + def backward(self, trainer, loss, optimizer, optimizer_idx): + loss.backward() """ - if trainer.precision == 16: - # .backward is not special on 16-bit with TPUs - if trainer.on_tpu: - return + loss.backward() + + def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx): + if self.trainer.use_native_amp: + scaled_loss = self.trainer.scaler.scale(unscaled_loss) - if self.trainer.use_native_amp: - self.trainer.scaler.scale(loss).backward() - else: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() else: - loss.backward() + scaled_loss = amp.scale_loss(unscaled_loss, optimizer) + + return scaled_loss def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: """ diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ce805161ce47f..59a3d4122f89e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -609,6 +609,11 @@ def optimizer_closure(): # backward pass model_ref = self.get_model() with self.profiler.profile('model_backward'): + # scale loss for 16 bit + if self.precision == 16 and not self.on_tpu: + closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx) + + # do backward pass model_ref.backward(self, closure_loss, optimizer, opt_idx) # track metrics for callbacks