Skip to content

Commit

Permalink
Add empty resume_from_checkpoint acceptance Lightning-AI#4366
Browse files Browse the repository at this point in the history
  • Loading branch information
tarepan committed Oct 27, 2020
1 parent 4106e2f commit bdca4b7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 16 deletions.
42 changes: 26 additions & 16 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def __init__(self, trainer):
# used to validate checkpointing logic
self.has_trained = False

def restore_weights(self, model: LightningModule):
def restore_weights(self, model: LightningModule) -> None:
"""
We attempt to restore weights in this order:
1. HPC weights.
2. if no HPC weights restore checkpoint_path weights
3. otherwise don't restore weights
Attempt to restore state from checkpoint in this priority:
1. HPC weights
2. `resume_from_checkpoint` file
3. don't restore
"""
# clear cache before restore
if self.trainer.on_gpu:
Expand All @@ -70,9 +70,8 @@ def restore_weights(self, model: LightningModule):
if self.trainer.on_gpu:
torch.cuda.empty_cache()

if not did_restore_hpc_weights:
if self.trainer.resume_from_checkpoint is not None:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
if (not did_restore_hpc_weights) and (self.trainer.resume_from_checkpoint is not None):
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)

# wait for all to catch up
self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')
Expand All @@ -81,21 +80,29 @@ def restore_weights(self, model: LightningModule):
if self.trainer.on_gpu:
torch.cuda.empty_cache()

def restore(self, checkpoint_path: str, on_gpu: bool):
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
"""
Restore training state from checkpoint.
Try to restore training state from checkpoint.
Also restores all training state like:
- epoch
- callbacks
- schedulers
- optimizer
Returns:
`True` if restored successfully else `False`
"""

# if on_gpu:
# checkpoint = torch.load(checkpoint_path)
# else:
# load on CPU first
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
# Try to load checkpoint from `checkpoint_path`. If failed, do not restore checkpoint.
try:
# if on_gpu:
# checkpoint = torch.load(checkpoint_path)
# else:
# load on CPU first
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
except Exception as _:
log.info(f'failed to load model from checkpoint:{checkpoint_path}')
return False

# load model state
model = self.trainer.get_model()
Expand All @@ -122,6 +129,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# load training state (affects trainer only)
self.restore_training_state(checkpoint)

log.info(f'Model restored from: {checkpoint_path}')
return True

def restore_training_state(self, checkpoint):
"""
Restore trainer state.
Expand Down Expand Up @@ -186,7 +196,7 @@ def restore_training_state(self, checkpoint):
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)

def restore_hpc_weights_if_needed(self, model: LightningModule):
def restore_hpc_weights_if_needed(self, model: LightningModule) -> bool:
"""If there is a set of hpc weights, use as signal to restore model."""
did_restore = False

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
This can be a URL.
If there isn't the checkpoint file at specified path, start training from scratch.
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
Expand Down
12 changes: 12 additions & 0 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ def test_resume_from_checkpoint(tmpdir):
trainer.fit(model)


def test_try_resume_from_non_existing_checkpoint(tmpdir):
""" Test that trying to resume from non-existing `resume_from_checkpoint` fail without error."""
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, logger=False, checkpoint_callback=checkpoint_callback)
# Generate checkpoint `last.ckpt` with template model
trainer.fit(model)
# `True` if resume/restore successfully else `False`
assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu)
assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_running_test_pretrained_model_distrib_dp(tmpdir):
"""Verify `test()` on pretrained model."""
Expand Down

0 comments on commit bdca4b7

Please sign in to comment.