Skip to content

Commit

Permalink
[refactor results 1] - refactor backward (#2276)
Browse files Browse the repository at this point in the history
* move backward

* refactor backward to remove 16 bit from user override

* refactor backward to remove 16 bit from user override

* Update pytorch_lightning/core/hooks.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
williamFalcon and Borda committed Jun 19, 2020
1 parent e780072 commit 8d51279
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
26 changes: 10 additions & 16 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,26 +182,20 @@ def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: i
Example::
def backward(self, use_amp, loss, optimizer):
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
"""
if trainer.precision == 16:
# .backward is not special on 16-bit with TPUs
if trainer.on_tpu:
return
loss.backward()

def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx):
if self.trainer.use_native_amp:
scaled_loss = self.trainer.scaler.scale(unscaled_loss)

if self.trainer.use_native_amp:
self.trainer.scaler.scale(loss).backward()
else:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
scaled_loss = amp.scale_loss(unscaled_loss, optimizer)

return scaled_loss

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,11 @@ def optimizer_closure():
# backward pass
model_ref = self.get_model()
with self.profiler.profile('model_backward'):
# scale loss for 16 bit
if self.precision == 16 and not self.on_tpu:
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx)

# do backward pass
model_ref.backward(self, closure_loss, optimizer, opt_idx)

# track metrics for callbacks
Expand Down

0 comments on commit 8d51279

Please sign in to comment.