diff --git a/data/hyp.finetune.yaml b/data/hyp.finetune.yaml index fe9cd55019f7..1b84cff95c2c 100644 --- a/data/hyp.finetune.yaml +++ b/data/hyp.finetune.yaml @@ -15,7 +15,7 @@ weight_decay: 0.00036 warmup_epochs: 2.0 warmup_momentum: 0.5 warmup_bias_lr: 0.05 -giou: 0.0296 +box: 0.0296 cls: 0.243 cls_pw: 0.631 obj: 0.301 diff --git a/data/hyp.scratch.yaml b/data/hyp.scratch.yaml index 9f53e86dd3ab..43354316c095 100644 --- a/data/hyp.scratch.yaml +++ b/data/hyp.scratch.yaml @@ -10,7 +10,7 @@ weight_decay: 0.0005 # optimizer weight decay 5e-4 warmup_epochs: 3.0 # warmup epochs (fractions ok) warmup_momentum: 0.8 # warmup initial momentum warmup_bias_lr: 0.1 # warmup initial bias lr -giou: 0.05 # box loss gain +box: 0.05 # box loss gain cls: 0.5 # cls loss gain cls_pw: 1.0 # cls BCELoss positive_weight obj: 1.0 # obj loss gain (scale with pixels) diff --git a/sotabench.py b/sotabench.py index daef5168b213..96ea6bffcbb0 100644 --- a/sotabench.py +++ b/sotabench.py @@ -113,7 +113,7 @@ def test(data, # Compute loss if training: # if model has loss hyperparameters - loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls + loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls # Run NMS t = time_synchronized() diff --git a/test.py b/test.py index e0bb7726f7d1..9e79a769f884 100644 --- a/test.py +++ b/test.py @@ -106,7 +106,7 @@ def test(data, # Compute loss if training: # if model has loss hyperparameters - loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls + loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls # Run NMS t = time_synchronized() diff --git a/train.py b/train.py index 4060a5701a8b..bbb69c7e2b53 100644 --- a/train.py +++ b/train.py @@ -195,7 +195,7 @@ def train(hyp, opt, device, tb_writer=None): hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model - model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) + model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model.names = names @@ -204,7 +204,7 @@ def train(hyp, opt, device, tb_writer=None): nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training maps = np.zeros(nc) # mAP per class - results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' + results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move scaler = amp.GradScaler(enabled=cuda) logger.info('Image sizes %g train, %g test\nUsing %g dataloader workers\nLogging results to %s\n' @@ -234,7 +234,7 @@ def train(hyp, opt, device, tb_writer=None): if rank != -1: dataloader.sampler.set_epoch(epoch) pbar = enumerate(dataloader) - logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size')) + logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size')) if rank in [-1, 0]: pbar = tqdm(pbar, total=nb) # progress bar optimizer.zero_grad() @@ -245,7 +245,7 @@ def train(hyp, opt, device, tb_writer=None): # Warmup if ni <= nw: xi = [0, nw] # x interp - # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) + # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 @@ -319,21 +319,21 @@ def train(hyp, opt, device, tb_writer=None): # Write with open(results_file, 'a') as f: - f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls) + f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) if len(opt.name) and opt.bucket: os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) # Tensorboard if tb_writer: - tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss + tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', - 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss + 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss 'x/lr0', 'x/lr1', 'x/lr2'] # params for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): tb_writer.add_scalar(tag, x, epoch) # Update best mAP - fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1] + fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] if fi > best_fitness: best_fitness = fi @@ -463,7 +463,7 @@ def train(hyp, opt, device, tb_writer=None): 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok) 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr - 'giou': (1, 0.02, 0.2), # GIoU loss gain + 'box': (1, 0.02, 0.2), # box loss gain 'cls': (1, 0.2, 4.0), # cls loss gain 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels) diff --git a/utils/general.py b/utils/general.py index da016631e589..2530b10efb11 100755 --- a/utils/general.py +++ b/utils/general.py @@ -509,11 +509,11 @@ def compute_loss(p, targets, model): # predictions, targets, model pxy = ps[:, :2].sigmoid() * 2. - 0.5 pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box - giou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # giou(prediction, target) - lbox += (1.0 - giou).mean() # giou loss + iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target) + lbox += (1.0 - iou).mean() # iou loss # Objectness - tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio + tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio # Classification if model.nc > 1: # cls loss (only if multiple classes) @@ -528,7 +528,7 @@ def compute_loss(p, targets, model): # predictions, targets, model lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss s = 3 / np # output count scaling - lbox *= h['giou'] * s + lbox *= h['box'] * s lobj *= h['obj'] * s * (1.4 if np == 4 else 1.) lcls *= h['cls'] * s bs = tobj.shape[0] # batch size @@ -1234,7 +1234,7 @@ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general im def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay() # Plot training 'results*.txt', overlaying train and val losses s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends - t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles + t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')): results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T n = results.shape[1] # number of rows @@ -1254,13 +1254,13 @@ def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_ fig.savefig(f.replace('.txt', '.png'), dpi=200) -def plot_results(start=0, stop=0, bucket='', id=(), labels=(), - save_dir=''): # from utils.general import *; plot_results() +def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): + # from utils.general import *; plot_results() # Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training fig, ax = plt.subplots(2, 5, figsize=(12, 6)) ax = ax.ravel() - s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall', - 'val GIoU', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] + s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', + 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] if bucket: # os.system('rm -rf storage.googleapis.com') # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]