Skip to content

Commit

Permalink
change 4/5 - checkpoint_connector
Browse files Browse the repository at this point in the history
  • Loading branch information
tarepan committed Jan 2, 2021
1 parent 9c8acd7 commit e754e43
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, trainer):
# used to validate checkpointing logic
self.has_trained = False

def restore_weights(self, model: LightningModule):
def restore_weights(self, model: LightningModule) -> None:
"""
Attempt to restore a checkpoint (e.g. weights) in this priority:
1. from HPC weights
Expand Down Expand Up @@ -73,12 +73,18 @@ def restore_weights(self, model: LightningModule):
if self.trainer.on_gpu:
torch.cuda.empty_cache()

def restore(self, checkpoint_path: str, on_gpu: bool):
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
All restored states are listed in return value description of `dump_checkpoint`.
"""

# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
fs = get_filesystem(checkpoint_path)
if not fs.exists(checkpoint_path):
rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
return False

# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

Expand All @@ -94,6 +100,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# restore training state
self.restore_training_state(checkpoint)

rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
return True

def restore_model_state(self, model: LightningModule, checkpoint) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
Expand Down

0 comments on commit e754e43

Please sign in to comment.