Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-existing resume_from_checkpoint acceptance for auto-resubmit #4402

Merged
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
43f989d
Add empty resume_from_checkpoint acceptance #4366
tarepan Oct 27, 2020
21b60b7
Fix general error catch with focused file check
tarepan Oct 29, 2020
1b29438
Add fsspec HTTP extras
tarepan Oct 29, 2020
697ebc3
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Oct 31, 2020
fb418b8
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Nov 2, 2020
0d9d763
Fix potential too much logging in DDP
tarepan Nov 3, 2020
742ce48
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Nov 3, 2020
a128112
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Nov 5, 2020
238c3e4
Add PR changelog
tarepan Nov 5, 2020
80500c7
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Nov 6, 2020
02be71c
Add well-written argument explanation
tarepan Nov 9, 2020
6db62b2
Fix DDP-compatible restore logging
tarepan Nov 9, 2020
c1186e4
Fix utility import pathes
tarepan Nov 9, 2020
b6a3cd1
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Nov 9, 2020
4bfc6ee
Refactor load step commentaries
tarepan Nov 9, 2020
913ab97
Refactor hpc ckpt suffix acquisition
tarepan Nov 9, 2020
41d2e32
Refactor restore/hpc_load match
tarepan Nov 9, 2020
e0e17b8
Refactor hpc load trial
tarepan Nov 9, 2020
7fbba16
Refactor checkpoint dir check
tarepan Nov 9, 2020
6710e6a
Refactor unneeded function nest
tarepan Nov 9, 2020
882ec2e
Refactor nested If
tarepan Nov 9, 2020
ececdea
Refactor duplicated cache clear
tarepan Nov 9, 2020
5f47685
Refactor attempt flow with if/elif
tarepan Nov 9, 2020
676b4ab
Fix pip8
tarepan Nov 9, 2020
cd2481a
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Nov 12, 2020
3a63c90
Merge branch 'master' into refactor/load
tchaton Nov 16, 2020
30f4f7d
Refactor hook commentary
tarepan Nov 16, 2020
9fb14ac
Fix pep8
tarepan Nov 16, 2020
afcf339
Refactor hpc load checkpoint path acquisition
tarepan Nov 16, 2020
585c761
Fix pip8
tarepan Nov 16, 2020
d76ab46
Fix typo
tarepan Nov 18, 2020
936a186
Fix typo
tarepan Nov 18, 2020
1d3cf0b
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Nov 18, 2020
b994660
Merge branch 'master' into refactor/load
tarepan Nov 27, 2020
a633327
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Nov 27, 2020
0cfba08
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Nov 30, 2020
b466cf6
Merge branch 'master' into refactor/load
tarepan Nov 30, 2020
deeee09
Merge branch 'master' into feature/4366_non_existing_checkpoint
Borda Nov 30, 2020
557f104
Merge remote-tracking branch 'upstream/master' into feature/4366_non_…
tarepan Dec 2, 2020
b26fc83
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tarepan Dec 2, 2020
f7a65f1
Merge branch 'master' into feature/4366_non_existing_checkpoint
SeanNaren Dec 5, 2020
37f8392
Fix doc
tarepan Dec 8, 2020
b5f980e
Refactor None Union type with Optional
tarepan Dec 8, 2020
104017e
Merge branch 'master' into refactor/load
Borda Dec 8, 2020
e6fbc54
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tarepan Dec 8, 2020
7134708
Merge branch 'master' into feature/4366_non_existing_checkpoint
SeanNaren Dec 9, 2020
faa6d96
Merge branch 'master' into feature/4366_non_existing_checkpoint
SeanNaren Dec 9, 2020
eb0a716
Merge branch 'master' into feature/4366_non_existing_checkpoint
SeanNaren Dec 9, 2020
8e01141
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Dec 10, 2020
1e692ed
Merge branch 'master' into feature/4366_non_existing_checkpoint
s-rog Dec 12, 2020
4887cd5
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tarepan Dec 13, 2020
9fcb140
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Dec 13, 2020
79da4ff
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tarepan Dec 17, 2020
1ac9f6a
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Dec 17, 2020
7a7ec4b
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Dec 20, 2020
5c031f4
Merge branch 'master' into feature/4366_non_existing_checkpoint
Borda Dec 23, 2020
292662f
Merge branch 'master' into feature/4366_non_existing_checkpoint
s-rog Dec 28, 2020
70ea89c
Merge branch 'master' into feature/4366_non_existing_checkpoint
tarepan Jan 2, 2021
77bf8c1
Fix build-doc CI failure debuged in #5329
tarepan Jan 2, 2021
383e40e
Fix fsspec import during build-doc #5329
tarepan Jan 2, 2021
f6eb95a
Fix test epoch
tarepan Jan 2, 2021
743fe31
Fix test with latest test models
tarepan Jan 2, 2021
1444238
.
Borda Jan 4, 2021
b7bdd64
Merge remote-tracking branch 'upstream/master' into feature/4366_non_…
tarepan Jan 4, 2021
9f43e53
Refactor argument doc of resume_from_checkpoint
tarepan Jan 4, 2021
6093989
Fix package extras strip for sphinx
tarepan Jan 4, 2021
ca6e21e
Fix unnesessary dependency for docs
tarepan Jan 4, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added printing of total num of params, trainable and non-trainable params in ModelSummary ([#4521](https://github.com/PyTorchLightning/pytorch-lightning/pull/4521))


- Added `resume_from_checkpoint` accept non-existing file path ([#4402](https://github.com/PyTorchLightning/pytorch-lightning/pull/4402))


### Changed

- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))

- WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648))


### Deprecated

- Deprecated `prefix` argument in `ModelCheckpoint` ([#4765](https://github.com/PyTorchLightning/pytorch-lightning/pull/4765))
Expand Down
25 changes: 18 additions & 7 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
Expand All @@ -41,7 +41,7 @@ def __init__(self, trainer):
# used to validate checkpointing logic
self.has_trained = False

def restore_weights(self, model: LightningModule):
def restore_weights(self, model: LightningModule) -> None:
"""
Attempt to restore a checkpoint (e.g. weights) in this priority:
1. from HPC weights
Expand All @@ -59,9 +59,8 @@ def restore_weights(self, model: LightningModule):
if self.trainer.on_gpu:
torch.cuda.empty_cache()

if not did_restore_hpc_weights:
if self.trainer.resume_from_checkpoint is not None:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
if (not did_restore_hpc_weights) and (self.trainer.resume_from_checkpoint is not None):
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)

# wait for all to catch up
self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')
Expand All @@ -70,7 +69,7 @@ def restore_weights(self, model: LightningModule):
if self.trainer.on_gpu:
torch.cuda.empty_cache()

def restore(self, checkpoint_path: str, on_gpu: bool):
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
"""
Load model/training states from the checkpoint file through file-read and state-restore.
Also restores all training state like:
Expand All @@ -79,8 +78,17 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
- schedulers
- optimizer
In detail, check return value description of `dump_checkpoint`

Returns:
`True` if restored successfully else `False`
"""

# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
fs = get_filesystem(checkpoint_path)
if not fs.exists(checkpoint_path):
rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
return False
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

# if on_gpu:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are those comments there ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea, but separate from the PR it seems

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warn exists as a result of this discussion > #4402 (comment)
There is no strong reason why this warn is in restore.
Warns could be in another place (checking bool return value of restore enable warn switching.)

# checkpoint = torch.load(checkpoint_path)
# else:
Expand All @@ -101,6 +109,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# restore training state
self.restore_training_state(checkpoint)

rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
return True

def restore_model_state(self, model: LightningModule, checkpoint) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
Expand Down Expand Up @@ -187,7 +198,7 @@ def restore_training_state(self, checkpoint):
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)

def restore_hpc_weights_if_needed(self, model: LightningModule):
def restore_hpc_weights_if_needed(self, model: LightningModule) -> bool:
"""If there is a set of hpc weights, use as signal to restore model."""
did_restore = False

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def __init__(

resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
This can be a URL.
If there is no checkpoint file at the specified path, start training from scratch.

sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement >=5.1
tqdm>=4.41.0
fsspec>=0.8.0
fsspec[http]>=0.8.1
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
tensorboard>=2.2.0
12 changes: 12 additions & 0 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def test_model_properties_resume_from_checkpoint(tmpdir):
trainer.fit(model)


def test_try_resume_from_non_existing_checkpoint(tmpdir):
""" Test that trying to resume from non-existing `resume_from_checkpoint` fail without error."""
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, logger=False, checkpoint_callback=checkpoint_callback)
tarepan marked this conversation as resolved.
Show resolved Hide resolved
# Generate checkpoint `last.ckpt` with template model
trainer.fit(model)
# `True` if resume/restore successfully else `False`
assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu)
assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)


class CaptureCallbacksBeforeTraining(Callback):
callbacks = []

Expand Down