Skip to content

Commit

Permalink
Single-source training update (ultralytics#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Aug 9, 2020
1 parent f7c40c9 commit c96b169
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def train(hyp, opt, device, tb_writer=None):
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
parser.add_argument('--hyp', type=str, default='data/hyp.finetune.yaml', help='hyperparameters path')
parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
Expand All @@ -396,16 +396,17 @@ def train(hyp, opt, device, tb_writer=None):
opt = parser.parse_args()

# Resume
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights:
print(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights
if opt.resume:
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights:
print(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights
if opt.local_rank == -1 or ("RANK" in os.environ and os.environ["RANK"] == "0"):
check_git_status()

opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
assert len(opt.hyp), '--hyp must be specified'

opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = select_device(opt.device, batch_size=opt.batch_size)
Expand Down

0 comments on commit c96b169

Please sign in to comment.