Skip to content

Commit

Permalink
Added optimizer_idx to backward call (#733)
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK authored and williamFalcon committed Jan 24, 2020
1 parent a804755 commit 946aef6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,13 @@ def on_after_backward(self):
"""
pass

def backward(self, use_amp, loss, optimizer):
def backward(self, use_amp, loss, optimizer, optimizer_idx):
"""Override backward with your own implementation if you need to
:param use_amp: Whether amp was requested or not
:param loss: Loss is already scaled by accumulated grads
:param optimizer: Current optimizer being used
:param optimizer_idx: Index of the current optimizer being used
:return:
Called to perform backward step.
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,13 +490,14 @@ def optimizer_closure():

# backward pass
model_ref = self.get_model()
model_ref.backward(self.use_amp, closure_loss, optimizer)
model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx)

# track metrics for callbacks
all_callback_metrics.append(callback_metrics)

# track progress bar metrics
self.add_tqdm_metrics(progress_bar_metrics)
self.add_tqdm_metrics(progress_bar_metrics)
all_log_metrics.append(log_metrics)

# insert after step hook
Expand Down

0 comments on commit 946aef6

Please sign in to comment.