From 0ead459239587287f1b9124d6bc3504735d3c895 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 22 Apr 2020 12:42:02 -0400 Subject: [PATCH] autocast --- pytorch_lightning/trainer/distrib_parts.py | 4 ++++ pytorch_lightning/trainer/trainer.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index f6613ce1ac4ba..7b79922d82a00 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -531,7 +531,9 @@ def dp_train(self, model): model.cuda(self.root_gpu) # hack forward to do autocast for the user + model_autocast_original_forward = model.forward if self.use_amp and self.use_native_amp: + # wrap the user's forward in autocast and give it back at the end model.forward = torch.cuda.amp.autocast()(model.forward) # TODO: remove in v0.8.0 @@ -558,6 +560,8 @@ def dp_train(self, model): self.run_pretrain_routine(model) + model.forward = model_autocast_original_forward + def horovod_train(self, model): # Horovod: initialize library hvd.init() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8f1c836ba83d2..8d5305f94cbda 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -489,6 +489,8 @@ def __init__( # AMP init # These are the only lines needed after v0.8.0 + # we wrap the user's forward with autocast and give it back at the end of fit + self.autocast_original_forward = None self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") if self.use_native_amp and self.precision == 16: self.scaler = torch.cuda.amp.GradScaler()