From 946aef6216924a1016e1d273ee122d1f1addae41 Mon Sep 17 00:00:00 2001 From: Anand Krishnamoorthy Date: Sat, 25 Jan 2020 08:03:07 +0900 Subject: [PATCH] Added optimizer_idx to backward call (#733) --- pytorch_lightning/core/hooks.py | 3 ++- pytorch_lightning/trainer/training_loop.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index dc5a502c5f89a..c89b82e00a101 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -124,12 +124,13 @@ def on_after_backward(self): """ pass - def backward(self, use_amp, loss, optimizer): + def backward(self, use_amp, loss, optimizer, optimizer_idx): """Override backward with your own implementation if you need to :param use_amp: Whether amp was requested or not :param loss: Loss is already scaled by accumulated grads :param optimizer: Current optimizer being used + :param optimizer_idx: Index of the current optimizer being used :return: Called to perform backward step. diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 484983d7b62be..43f0cdce2ee13 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -490,13 +490,14 @@ def optimizer_closure(): # backward pass model_ref = self.get_model() - model_ref.backward(self.use_amp, closure_loss, optimizer) + model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx) # track metrics for callbacks all_callback_metrics.append(callback_metrics) # track progress bar metrics self.add_tqdm_metrics(progress_bar_metrics) + self.add_tqdm_metrics(progress_bar_metrics) all_log_metrics.append(log_metrics) # insert after step hook