Skip to content

Commit

Permalink
Added atomic checkpoint creation (#689)
Browse files Browse the repository at this point in the history
* Added atomic checkpoint creation

* Added documentation for _atomic_checkpoint
  • Loading branch information
fgerzer authored and williamFalcon committed Jan 20, 2020
1 parent 06242c2 commit 9aad69d
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,34 @@ def term_handler(self, signum, frame):
# --------------------
# MODEL SAVE CHECKPOINT
# --------------------
def _atomic_save(self, checkpoint, filepath):
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once
saving is finished.
Args:
checkpoint (object): The object to save.
Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
accepts.
filepath (str|pathlib.Path): The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
"""
tmp_path = str(filepath) + ".part"
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)

def save_checkpoint(self, filepath):
checkpoint = self.dump_checkpoint()

# do the actual save
try:
torch.save(checkpoint, filepath)
self._atomic_save(checkpoint, filepath)
except AttributeError:
if 'hparams' in checkpoint:
del checkpoint['hparams']

torch.save(checkpoint, filepath)
self._atomic_save(checkpoint, filepath)

def restore(self, checkpoint_path, on_gpu):

Expand Down Expand Up @@ -415,12 +432,12 @@ def hpc_save(self, folderpath, logger):
# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
try:
torch.save(checkpoint, filepath)
self._atomic_save(checkpoint, filepath)
except AttributeError:
if 'hparams' in checkpoint:
del checkpoint['hparams']

torch.save(checkpoint, filepath)
self._atomic_save(checkpoint, filepath)

return filepath

Expand Down

0 comments on commit 9aad69d

Please sign in to comment.