Skip to content

Commit

Permalink
autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Apr 22, 2020
1 parent 801d27d commit 2763a48
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,10 @@ def dp_train(self, model):

model.cuda(self.root_gpu)

# hack forward to do autocast for the user
if self.use_amp and self.use_native_amp:
model.forward = torch.cuda.amp.autocast()(model.forward)

# TODO: remove in v0.8.0
# check for this bug (amp + dp + !01 doesn't work)
# https://github.com/NVIDIA/apex/issues/227
Expand Down

0 comments on commit 2763a48

Please sign in to comment.