Skip to content

Commit

Permalink
save apex scaler states (#2828)
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed Aug 5, 2020
1 parent 6034d5e commit bef27c5
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@
else:
XLA_AVAILABLE = True

try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True

try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
Expand Down Expand Up @@ -317,6 +324,8 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# restore amp scaling
if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])

# load training state (affects trainer only)
self.restore_training_state(checkpoint)
Expand Down Expand Up @@ -368,6 +377,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
# save native amp scaling
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
elif self.use_amp and not NATIVE_AMP_AVALAIBLE:
checkpoint['amp_scaling_state'] = amp.state_dict()

# add the module_arguments and state_dict from the model
model = self.get_model()
Expand Down Expand Up @@ -523,6 +534,8 @@ def hpc_load(self, folderpath, on_gpu):
# restore amp scaling
if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])

if self.root_gpu is not None:
model.cuda(self.root_gpu)
Expand Down

0 comments on commit bef27c5

Please sign in to comment.