diff --git a/train.py b/train.py index c7bd9e6bf3da..39b996c34a38 100644 --- a/train.py +++ b/train.py @@ -372,7 +372,7 @@ def train(hyp, opt, device, tb_writer=None): parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path') parser.add_argument('--cfg', type=str, default='', help='model.yaml path') parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') - parser.add_argument('--hyp', type=str, default='data/hyp.finetune.yaml', help='hyperparameters path') + parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml') parser.add_argument('--epochs', type=int, default=300) parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') @@ -396,16 +396,17 @@ def train(hyp, opt, device, tb_writer=None): opt = parser.parse_args() # Resume - last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run - if last and not opt.weights: - print(f'Resuming training from {last}') - opt.weights = last if opt.resume and not opt.weights else opt.weights + if opt.resume: + last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run + if last and not opt.weights: + print(f'Resuming training from {last}') + opt.weights = last if opt.resume and not opt.weights else opt.weights if opt.local_rank == -1 or ("RANK" in os.environ and os.environ["RANK"] == "0"): check_git_status() + opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml') opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' - assert len(opt.hyp), '--hyp must be specified' opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) device = select_device(opt.device, batch_size=opt.batch_size)