Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save apex scaler states #2828

Merged
merged 1 commit into from
Aug 5, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@
else:
XLA_AVAILABLE = True

try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True

try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
Expand Down Expand Up @@ -317,6 +324,8 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# restore amp scaling
if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])

# load training state (affects trainer only)
self.restore_training_state(checkpoint)
Expand Down Expand Up @@ -368,6 +377,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
# save native amp scaling
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
elif self.use_amp and not NATIVE_AMP_AVALAIBLE:
checkpoint['amp_scaling_state'] = amp.state_dict()

# add the module_arguments and state_dict from the model
model = self.get_model()
Expand Down Expand Up @@ -523,6 +534,8 @@ def hpc_load(self, folderpath, on_gpu):
# restore amp scaling
if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])

if self.root_gpu is not None:
model.cuda(self.root_gpu)
Expand Down