From 7b60d49432dbb3c4dd4146f17605724a5ae23a44 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 12 May 2020 00:25:06 -0400 Subject: [PATCH] fixed native amp + ddp (#1788) * fixed native amp + ddp * fixed native amp + ddp --- pytorch_lightning/trainer/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cd3ee1c9d21db..e01f1dbb497d1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -510,8 +510,7 @@ def __init__( self.autocast_original_forward = None self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") self.precision = precision - if self.use_native_amp and self.precision == 16: - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = None # TODO: remove for v0.8.0 self.amp_level = amp_level @@ -858,6 +857,10 @@ def run_pretrain_routine(self, model: LightningModule): # set local properties on the model self.copy_trainer_model_properties(ref_model) + # init amp. Must be done here instead of __init__ to allow ddp to work + if self.use_native_amp and self.precision == 16: + self.scaler = torch.cuda.amp.GradScaler() + # log hyper-parameters if self.logger is not None: # save exp to get started