diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aa0565d9d83a2..d859626ecb23b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -508,8 +508,7 @@ def __init__( self.autocast_original_forward = None self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") self.precision = precision - if self.use_native_amp and self.precision == 16: - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = None # TODO: remove for v0.8.0 self.amp_level = amp_level @@ -856,6 +855,10 @@ def run_pretrain_routine(self, model: LightningModule): # set local properties on the model self.copy_trainer_model_properties(ref_model) + # init amp. Must be done here instead of __init__ to allow ddp to work + if self.use_native_amp and self.precision == 16: + self.scaler = torch.cuda.amp.GradScaler() + # log hyper-parameters if self.logger is not None: # save exp to get started