Skip to content

Commit

Permalink
autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Apr 22, 2020
1 parent 2763a48 commit 0ead459
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,9 @@ def dp_train(self, model):
model.cuda(self.root_gpu)

# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
if self.use_amp and self.use_native_amp:
# wrap the user's forward in autocast and give it back at the end
model.forward = torch.cuda.amp.autocast()(model.forward)

# TODO: remove in v0.8.0
Expand All @@ -558,6 +560,8 @@ def dp_train(self, model):

self.run_pretrain_routine(model)

model.forward = model_autocast_original_forward

def horovod_train(self, model):
# Horovod: initialize library
hvd.init()
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ def __init__(

# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.autocast_original_forward = None
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if self.use_native_amp and self.precision == 16:
self.scaler = torch.cuda.amp.GradScaler()
Expand Down

0 comments on commit 0ead459

Please sign in to comment.