Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor results 1] - refactor backward #2276

Merged
merged 4 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,26 +182,21 @@ def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: i

Example::

def backward(self, use_amp, loss, optimizer):
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()

"""
if trainer.precision == 16:
# .backward is not special on 16-bit with TPUs
if trainer.on_tpu:
return
loss.backward()

def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx):
if self.trainer.use_native_amp:
scaled_loss = self.trainer.scaler.scale(unscaled_loss)

if self.trainer.use_native_amp:
self.trainer.scaler.scale(loss).backward()
else:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# TODO: remove in v0.8.0
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
scaled_loss = amp.scale_loss(unscaled_loss, optimizer)

return scaled_loss

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,11 @@ def optimizer_closure():
# backward pass
model_ref = self.get_model()
with self.profiler.profile('model_backward'):
# scale loss for 16 bit
if self.precision == 16 and not self.on_tpu:
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx)

# do backward pass
model_ref.backward(self, closure_loss, optimizer, opt_idx)

# track metrics for callbacks
Expand Down