diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cd3ee1c9d21db..e01f1dbb497d1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -510,8 +510,7 @@ def __init__( self.autocast_original_forward = None self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") self.precision = precision - if self.use_native_amp and self.precision == 16: - self.scaler = torch.cuda.amp.GradScaler() + self.scaler = None # TODO: remove for v0.8.0 self.amp_level = amp_level @@ -858,6 +857,10 @@ def run_pretrain_routine(self, model: LightningModule): # set local properties on the model self.copy_trainer_model_properties(ref_model) + # init amp. Must be done here instead of __init__ to allow ddp to work + if self.use_native_amp and self.precision == 16: + self.scaler = torch.cuda.amp.GradScaler() + # log hyper-parameters if self.logger is not None: # save exp to get started