Skip to content

Commit

Permalink
Added atomic checkpoint creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Frederik Diehl committed Jan 20, 2020
1 parent de2ccc0 commit b52365a
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,22 @@ def term_handler(self, signum, frame):
# --------------------
# MODEL SAVE CHECKPOINT
# --------------------
def _atomic_save(self, checkpoint, filepath):
tmp_path = str(filepath) + ".part"
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)

def save_checkpoint(self, filepath):
checkpoint = self.dump_checkpoint()

# do the actual save
try:
torch.save(checkpoint, filepath)
self._atomic_save(checkpoint, filepath)
except AttributeError:
if 'hparams' in checkpoint:
del checkpoint['hparams']

torch.save(checkpoint, filepath)
self._atomic_save(checkpoint, filepath)

def restore(self, checkpoint_path, on_gpu):
# if on_gpu:
Expand Down Expand Up @@ -412,12 +417,12 @@ def hpc_save(self, folderpath, logger):
# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
try:
torch.save(checkpoint, filepath)
self._atomic_save(checkpoint, filepath)
except AttributeError:
if 'hparams' in checkpoint:
del checkpoint['hparams']

torch.save(checkpoint, filepath)
self._atomic_save(checkpoint, filepath)

return filepath

Expand Down

0 comments on commit b52365a

Please sign in to comment.