From e62b404a0a5a39bca26fb9a270fd99804e53dffe Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 11 May 2020 22:22:05 -0400 Subject: [PATCH 1/2] fixed native amp + ddp --- pytorch_lightning/trainer/trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aa0565d9d83a2..b472c267834f3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -509,7 +509,7 @@ def __init__( self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") self.precision = precision if self.use_native_amp and self.precision == 16: - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = None # TODO: remove for v0.8.0 self.amp_level = amp_level @@ -856,6 +856,10 @@ def run_pretrain_routine(self, model: LightningModule): # set local properties on the model self.copy_trainer_model_properties(ref_model) + # init amp. Must be done here instead of __init__ to allow ddp to work + if self.use_native_amp and self.precision == 16: + self.scaler = torch.cuda.amp.GradScaler() + # log hyper-parameters if self.logger is not None: # save exp to get started From e31423dac18043588a075edd2c61f1c4aa9c546b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 11 May 2020 22:23:15 -0400 Subject: [PATCH 2/2] fixed native amp + ddp --- pytorch_lightning/trainer/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b472c267834f3..d859626ecb23b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -508,8 +508,7 @@ def __init__( self.autocast_original_forward = None self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") self.precision = precision - if self.use_native_amp and self.precision == 16: - self.scaler = None + self.scaler = None # TODO: remove for v0.8.0 self.amp_level = amp_level