From 9aad69d85635a8a65e1f0ee995516c0f8183c0f3 Mon Sep 17 00:00:00 2001 From: Frederik Diehl <3789648+fdiehl@users.noreply.github.com> Date: Mon, 20 Jan 2020 20:51:44 +0100 Subject: [PATCH] Added atomic checkpoint creation (#689) * Added atomic checkpoint creation * Added documentation for _atomic_checkpoint --- pytorch_lightning/trainer/training_io.py | 25 ++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index b569a6c1499cf..6ea819ba1691c 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -257,17 +257,34 @@ def term_handler(self, signum, frame): # -------------------- # MODEL SAVE CHECKPOINT # -------------------- + def _atomic_save(self, checkpoint, filepath): + """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. + + This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once + saving is finished. + + Args: + checkpoint (object): The object to save. + Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` + accepts. + filepath (str|pathlib.Path): The path to which the checkpoint will be saved. + This points to the file that the checkpoint will be stored in. + """ + tmp_path = str(filepath) + ".part" + torch.save(checkpoint, tmp_path) + os.replace(tmp_path, filepath) + def save_checkpoint(self, filepath): checkpoint = self.dump_checkpoint() # do the actual save try: - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) def restore(self, checkpoint_path, on_gpu): @@ -415,12 +432,12 @@ def hpc_save(self, folderpath, logger): # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) return filepath