Skip to content

Commit

Permalink
autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Apr 22, 2020
1 parent f0d2c78 commit 0a20aee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,9 @@ def dp_train(self, model):
model.cuda(self.root_gpu)

# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
if self.use_amp and self.use_native_amp:
# wrap the user's forward in autocast and give it back at the end
model.forward = torch.cuda.amp.autocast()(model.forward)

# TODO: remove in v0.8.0
Expand All @@ -547,6 +549,9 @@ def dp_train(self, model):

self.run_pretrain_routine(model)

# when training completes give back the forward
model.forward = model_autocast_original_forward


def normalize_parse_gpu_string_input(s):
if isinstance(s, str):
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ def __init__(

# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.autocast_original_forward = None
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if self.use_native_amp and self.precision == 16:
self.scaler = torch.cuda.amp.GradScaler()
Expand Down

0 comments on commit 0a20aee

Please sign in to comment.