Skip to content

Commit

Permalink
added state saving
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Apr 23, 2020
1 parent afb6801 commit fa87d1d
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
if on_gpu:
model.cuda(self.root_gpu)

# restore amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])

# load training state (affects trainer only)
self.restore_training_state(checkpoint)

Expand Down Expand Up @@ -316,6 +320,10 @@ def dump_checkpoint(self):

checkpoint['state_dict'] = model.state_dict()

# restore native amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict

if hasattr(model, "hparams"):
is_namespace = isinstance(model.hparams, Namespace)
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams
Expand Down Expand Up @@ -441,6 +449,10 @@ def hpc_load(self, folderpath, on_gpu):
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

# restore amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])

if self.root_gpu is not None:
model.cuda(self.root_gpu)

Expand Down

0 comments on commit fa87d1d

Please sign in to comment.