From 6be31c3f598c9a1f4d99f3ec08fd886369df2511 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 19 Jun 2020 14:48:17 -0400 Subject: [PATCH 1/4] move backward --- pytorch_lightning/core/hooks.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9496bbc844e5b..2d19aeb6d2dfd 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -182,26 +182,11 @@ def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: i Example:: - def backward(self, use_amp, loss, optimizer): - if use_amp: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + def backward(self, trainer, loss, optimizer, optimizer_idx): + loss.backward() """ - if trainer.precision == 16: - # .backward is not special on 16-bit with TPUs - if trainer.on_tpu: - return - - if self.trainer.use_native_amp: - self.trainer.scaler.scale(loss).backward() - else: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: """ From 2d8e31ec9c6651f575b0ced077c3e0a9775e4a47 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 19 Jun 2020 14:52:40 -0400 Subject: [PATCH 2/4] refactor backward to remove 16 bit from user override --- pytorch_lightning/trainer/training_loop.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ce805161ce47f..59a3d4122f89e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -609,6 +609,11 @@ def optimizer_closure(): # backward pass model_ref = self.get_model() with self.profiler.profile('model_backward'): + # scale loss for 16 bit + if self.precision == 16 and not self.on_tpu: + closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx) + + # do backward pass model_ref.backward(self, closure_loss, optimizer, opt_idx) # track metrics for callbacks From 9bf54da07aa31da97706440e0915914ab4328c81 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 19 Jun 2020 14:54:57 -0400 Subject: [PATCH 3/4] refactor backward to remove 16 bit from user override --- pytorch_lightning/core/hooks.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 2d19aeb6d2dfd..2d49c82bb7d1a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -188,6 +188,16 @@ def backward(self, trainer, loss, optimizer, optimizer_idx): """ loss.backward() + def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx): + if self.trainer.use_native_amp: + scaled_loss = self.trainer.scaler.scale(unscaled_loss) + + else: + # TODO: remove in v0.8.0 + scaled_loss = amp.scale_loss(unscaled_loss, optimizer) + + return scaled_loss + def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors From f66b67f3e76446a6253bfa6d8e2d04c96396da37 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 19 Jun 2020 14:55:49 -0400 Subject: [PATCH 4/4] Update pytorch_lightning/core/hooks.py Co-authored-by: Jirka Borovec --- pytorch_lightning/core/hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 2d49c82bb7d1a..498a3ef0381a8 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -193,7 +193,6 @@ def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx): scaled_loss = self.trainer.scaler.scale(unscaled_loss) else: - # TODO: remove in v0.8.0 scaled_loss = amp.scale_loss(unscaled_loss, optimizer) return scaled_loss