Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed native amp + ddp #1788

Merged
merged 2 commits into from
May 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,7 @@ def __init__(
self.autocast_original_forward = None
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
self.precision = precision
if self.use_native_amp and self.precision == 16:
self.scaler = torch.cuda.amp.GradScaler()
self.scaler = None

# TODO: remove for v0.8.0
self.amp_level = amp_level
Expand Down Expand Up @@ -856,6 +855,10 @@ def run_pretrain_routine(self, model: LightningModule):
# set local properties on the model
self.copy_trainer_model_properties(ref_model)

# init amp. Must be done here instead of __init__ to allow ddp to work
if self.use_native_amp and self.precision == 16:
self.scaler = torch.cuda.amp.GradScaler()

# log hyper-parameters
if self.logger is not None:
# save exp to get started
Expand Down