diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 1c5e0f7e1042f..8dd14a327dae3 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -520,7 +520,9 @@ def dp_train(self, model): model.cuda(self.root_gpu) # hack forward to do autocast for the user + model_autocast_original_forward = model.forward if self.use_amp and self.use_native_amp: + # wrap the user's forward in autocast and give it back at the end model.forward = torch.cuda.amp.autocast()(model.forward) # TODO: remove in v0.8.0 @@ -547,6 +549,9 @@ def dp_train(self, model): self.run_pretrain_routine(model) + # when training completes give back the forward + model.forward = model_autocast_original_forward + def normalize_parse_gpu_string_input(s): if isinstance(s, str): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6d96485ef3b4c..e375cb1953a56 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -489,6 +489,8 @@ def __init__( # AMP init # These are the only lines needed after v0.8.0 + # we wrap the user's forward with autocast and give it back at the end of fit + self.autocast_original_forward = None self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") if self.use_native_amp and self.precision == 16: self.scaler = torch.cuda.amp.GradScaler()