Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Resume from the latest checkpoint automatically. #61

Merged
merged 6 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mmrazor/apis/mmcls/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mmrazor.core.hooks import DistSamplerSeedHook
from mmrazor.core.optimizer import build_optimizers
from mmrazor.datasets.utils import split_dataset
from mmrazor.utils import find_latest_checkpoint


def set_random_seed(seed, deterministic=False):
Expand Down Expand Up @@ -190,6 +191,12 @@ def train_model(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
7 changes: 7 additions & 0 deletions mmrazor/apis/mmdet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
from mmrazor.core.hooks import DistSamplerSeedHook
from mmrazor.core.optimizer import build_optimizers
from mmrazor.utils import find_latest_checkpoint


def set_random_seed(seed, deterministic=False):
Expand Down Expand Up @@ -181,6 +182,12 @@ def train_detector(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
7 changes: 7 additions & 0 deletions mmrazor/apis/mmseg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mmrazor.core.distributed_wrapper import DistributedDataParallelWrapper
from mmrazor.core.optimizer import build_optimizers
from mmrazor.utils import find_latest_checkpoint


def set_random_seed(seed, deterministic=False):
Expand Down Expand Up @@ -137,6 +138,12 @@ def train_segmentor(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

resume_from = None
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down
3 changes: 2 additions & 1 deletion mmrazor/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .misc import find_latest_checkpoint
from .setup_env import setup_multi_processes

__all__ = ['setup_multi_processes']
__all__ = ['find_latest_checkpoint', 'setup_multi_processes']
38 changes: 38 additions & 0 deletions mmrazor/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import warnings


def find_latest_checkpoint(path, suffix='pth'):
"""Find the latest checkpoint from the working directory.

Args:
path(str): The path to find checkpoints.
suffix(str): File extension. Defaults to pth.

Returns:
latest_path(str | None): File path of the latest checkpoint.

References:
.. [1] https://github.com/microsoft/SoftTeacher
/blob/main/ssod/utils/patch.py
"""
if not osp.exists(path):
warnings.warn('The path of checkpoints does not exist.')
return None
if osp.exists(osp.join(path, f'latest.{suffix}')):
return osp.join(path, f'latest.{suffix}')

checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
if len(checkpoints) == 0:
warnings.warn('There are no checkpoints in the path.')
return None
latest = -1
latest_path = None
for checkpoint in checkpoints:
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
if count > latest:
latest = count
latest_path = checkpoint
return latest_path
43 changes: 43 additions & 0 deletions tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

from mmrazor.utils import find_latest_checkpoint


def test_find_latest_checkpoint():
with tempfile.TemporaryDirectory() as tmpdir:
path = tmpdir
latest = find_latest_checkpoint(path)
# There are no checkpoints in the path.
assert latest is None

path = tmpdir + '/none'
latest = find_latest_checkpoint(path)
# The path does not exist.
assert latest is None

with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/latest.pth', 'w') as f:
f.write('latest')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'latest.pth')

with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/iter_4000.pth', 'w') as f:
f.write('iter_4000')
with open(tmpdir + '/iter_8000.pth', 'w') as f:
f.write('iter_8000')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'iter_8000.pth')

with tempfile.TemporaryDirectory() as tmpdir:
with open(tmpdir + '/epoch_1.pth', 'w') as f:
f.write('epoch_1')
with open(tmpdir + '/epoch_2.pth', 'w') as f:
f.write('epoch_2')
path = tmpdir
latest = find_latest_checkpoint(path)
assert latest == osp.join(tmpdir, 'epoch_2.pth')
5 changes: 5 additions & 0 deletions tools/mmcls/train_mmcls.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def parse_args():
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
Expand Down Expand Up @@ -101,6 +105,7 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
Expand Down
5 changes: 5 additions & 0 deletions tools/mmdet/train_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def parse_args():
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
Expand Down Expand Up @@ -112,6 +116,7 @@ def main():
osp.splitext(osp.basename(args.config))[0])
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
Expand Down
5 changes: 5 additions & 0 deletions tools/mmseg/train_mmseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def parse_args():
'--load-from', help='the checkpoint file to load weights from')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically')
parser.add_argument(
'--no-validate',
action='store_true',
Expand Down Expand Up @@ -114,6 +118,7 @@ def main():
cfg.load_from = args.load_from
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.auto_resume = args.auto_resume
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
Expand Down