Skip to content

Commit

Permalink
Self-contained checkpoint --resume (#8839)
Browse files Browse the repository at this point in the history
* Single checkpoint resume

* Update train.py

* Add hyp

* Add hyp

* Add hyp

* FIX

* avoid resume on url data

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* avoid resume on url data

* avoid resume on url data

* Update

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] authored Aug 3, 2022
1 parent 4d8d84b commit a75a110
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
27 changes: 18 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download
from utils.downloads import attempt_download, is_url
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
Expand Down Expand Up @@ -77,6 +77,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
with open(hyp, errors='ignore') as f:
hyp = yaml.safe_load(f) # load hyps dict
LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
opt.hyp = hyp.copy() # for saving hyps to checkpoints

# Save run settings
if not evolve:
Expand Down Expand Up @@ -377,6 +378,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
'opt': vars(opt),
'date': datetime.now().isoformat()}

# Save last, best and delete
Expand Down Expand Up @@ -472,8 +474,7 @@ def parse_opt(known=False):
parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')

opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt
return parser.parse_known_args()[0] if known else parser.parse_args()


def main(opt, callbacks=Callbacks()):
Expand All @@ -484,12 +485,20 @@ def main(opt, callbacks=Callbacks()):
check_requirements(exclude=['thop'])

# Resume
if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
if opt.resume and not (check_wandb_resume(opt) or opt.evolve): # resume an interrupted run
last = Path(opt.resume if isinstance(opt.resume, str) else get_latest_run()) # specified or most recent last.pt
assert last.is_file(), f'ERROR: --resume checkpoint {last} does not exist'
opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
opt_data = opt.data # original dataset
if opt_yaml.is_file():
with open(opt_yaml, errors='ignore') as f:
d = yaml.safe_load(f)
else:
d = torch.load(last, map_location='cpu')['opt']
opt = argparse.Namespace(**d) # replace
opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
if is_url(opt.data):
opt.data = str(opt_data) # avoid HUB resume auth timeout
else:
opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
Expand Down
10 changes: 6 additions & 4 deletions utils/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import torch


def is_url(url):
def is_url(url, check_online=True):
# Check if online file exists
try:
r = urllib.request.urlopen(url) # response
return r.getcode() == 200
except urllib.request.HTTPError:
url = str(url)
result = urllib.parse.urlparse(url)
assert all([result.scheme, result.netloc, result.path]) # check if is url
return (urllib.request.urlopen(url).getcode() == 200) if check_online else True # check if exists online
except (AssertionError, urllib.request.HTTPError):
return False


Expand Down
5 changes: 3 additions & 2 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,9 @@ def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, re
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
ema.updates = ckpt['updates']
if resume:
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
LOGGER.info(f'Resuming training from {weights} for {epochs - start_epoch} more epochs to {epochs} total epochs')
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
if epochs < start_epoch:
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
epochs += ckpt['epoch'] # finetune additional epochs
Expand Down

0 comments on commit a75a110

Please sign in to comment.