From a2d048c20a757778e8e0882aa5e3c258ad65efab Mon Sep 17 00:00:00 2001 From: Frederik Diehl Date: Wed, 15 Jan 2020 16:18:01 +0100 Subject: [PATCH] Added atomic checkpoint creation --- pytorch_lightning/trainer/training_io.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 7410a52371b30d..18f4cea6f6cd58 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -255,17 +255,22 @@ def term_handler(self, signum, frame): # -------------------- # MODEL SAVE CHECKPOINT # -------------------- + def _atomic_save(self, checkpoint, filepath): + tmp_path = str(filepath) + ".part" + torch.save(checkpoint, tmp_path) + os.replace(tmp_path, filepath) + def save_checkpoint(self, filepath): checkpoint = self.dump_checkpoint() # do the actual save try: - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) def restore(self, checkpoint_path, on_gpu): @@ -413,12 +418,12 @@ def hpc_save(self, folderpath, logger): # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] - torch.save(checkpoint, filepath) + self._atomic_save(checkpoint, filepath) return filepath