From 0df3b970cf755ae27c4099324c5ccb3c1a1884a3 Mon Sep 17 00:00:00 2001 From: AyushExel Date: Wed, 7 Oct 2020 20:03:46 +0530 Subject: [PATCH 1/7] Add wandb metric logging and bounding box debugging --- test.py | 32 ++++++++++++++++++++++++++++++-- train.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/test.py b/test.py index 3bcbbd86f442..1f83d882a013 100644 --- a/test.py +++ b/test.py @@ -32,7 +32,12 @@ def test(data, dataloader=None, save_dir='', merge=False, - save_txt=False): + save_txt=False, + bbox_debug=0): + # Import wandb if logging is enabled + if bbox_debug > 0: + import wandb + # Initialize/load model and set device training = model is not None if training: # called by train.py @@ -88,7 +93,7 @@ def test(data, s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0. loss = torch.zeros(3, device=device) - jdict, stats, ap, ap_class = [], [], [], [] + jdict, stats, ap, ap_class, wandb_image_log = [], [], [], [], [] for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): img = img.to(device, non_blocking=True) img = img.half() if half else img.float() # uint8 to fp16/32 @@ -135,6 +140,25 @@ def test(data, with open(str(out / Path(paths[si]).stem) + '.txt', 'a') as f: f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format + # Log images with bounding boxes + if len(wandb_image_log) < bbox_debug: + x = pred.clone() + bbox_data = [{ + "position": { + "minX": float(xyxy[0]), + "minY": float(xyxy[1]), + "maxX": float(xyxy[2]), + "maxY": float(xyxy[3]) + }, + "class_id": int(cls), + "scores": { + "class_score": float(conf) + }, + "domain":"pixel" + } for *xyxy, conf, cls in x] + im = wandb.Image(img[si], boxes={"predictions": {"box_data":bbox_data}}) + wandb_image_log.append(im) + # Clip boxes to image bounds clip_coords(pred, (height, width)) @@ -192,6 +216,10 @@ def test(data, f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i) plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions + # Log the images to W&B + if len(wandb_image_log) > 0: + wandb.log({"outputs":wandb_image_log}) + # Compute statistics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy if len(stats) and stats[0].any(): diff --git a/train.py b/train.py index c6fa31107a02..4f3178ce75c8 100644 --- a/train.py +++ b/train.py @@ -32,8 +32,28 @@ logger = logging.getLogger(__name__) +try: + import wandb + wandb_disabled = os.environ['WANDB_DISABLED'] if 'WANDB_DISABLED' in os.environ else None + if wandb_disabled is True: + wandb_log = False + else: + wandb_log = True + print("Automatic Weights & Biases logging enabled, to disable set os.environ['WANDB_DISABLED'] = 'true'") +except ImportError: + wandb_log = False + print("wandb is not installed. Install wandb using 'pip install wandb' to track your experiments and enable bounding box debugging") + def train(hyp, opt, device, tb_writer=None): + if wandb_log and not opt.resume: + name = opt.name if opt.name != '' else 'yoloV5' + run = wandb.init(project=name, config=opt) + # Do not log bounding box images if wandb is not initialized + if not wandb_log: + print("Setting num_bbox to 0") + opt.num_bbox = 0 + logger.info(f'Hyperparameters {hyp}') log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory wdir = log_dir / 'weights' # weights directory @@ -79,6 +99,20 @@ def train(hyp, opt, device, tb_writer=None): else: model = Model(opt.cfg, ch=3, nc=nc).to(device) # create + # Resume Logging in the same W&B run + if wandb_log and opt.resume: + if 'wandb_id' in ckpt: + try: + run = wandb.init(id=ckpt['wandb_id'],resume='must') + print('Resuming wandb logging') + except KeyError: + print('wandb run cannot be resumed, creating a new run') + + if wandb.run is None: + name = opt.name if opt.name != '' else 'yoloV5' + run = wandb.init(project=name, config=opt) + print('wandb logging enabled') + # Freeze freeze = ['', ] # parameter names to freeze (full or partial) if any(freeze): @@ -317,7 +351,8 @@ def train(hyp, opt, device, tb_writer=None): model=ema.ema, single_cls=opt.single_cls, dataloader=testloader, - save_dir=log_dir) + save_dir=log_dir, + bbox_debug=opt.num_bbox) # Write with open(results_file, 'a') as f: @@ -325,15 +360,20 @@ def train(hyp, opt, device, tb_writer=None): if len(opt.name) and opt.bucket: os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) + tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss + 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', + 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss + 'x/lr0', 'x/lr1', 'x/lr2'] # params # Tensorboard if tb_writer: - tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss - 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', - 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss - 'x/lr0', 'x/lr1', 'x/lr2'] # params for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): tb_writer.add_scalar(tag, x, epoch) + # W&B logging + if wandb_log: + for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): + wandb.log({tag:x}) + # Update best mAP fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1] if fi > best_fitness: @@ -347,7 +387,8 @@ def train(hyp, opt, device, tb_writer=None): 'best_fitness': best_fitness, 'training_results': f.read(), 'model': ema.ema, - 'optimizer': None if final_epoch else optimizer.state_dict()} + 'optimizer': None if final_epoch else optimizer.state_dict(), + 'wandb_id': run.id if wandb_log else None} # Save last, best and delete torch.save(ckpt, last) @@ -404,6 +445,7 @@ def train(hyp, opt, device, tb_writer=None): parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') parser.add_argument('--logdir', type=str, default='runs/', help='logging directory') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') + parser.add_argument('--num-bbox', type=int, default=50, help='maximum number of images logged to W&B for bounding box debugging') opt = parser.parse_args() # Set DDP variables From 9ef29eb017a65e92ac52385e8e501e52124de14c Mon Sep 17 00:00:00 2001 From: AyushExel Date: Thu, 8 Oct 2020 12:05:11 +0530 Subject: [PATCH 2/7] Improve formatting, readability --- test.py | 8 +++++--- train.py | 33 +++++++++++++-------------------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/test.py b/test.py index 1f83d882a013..3b9f8813bac0 100644 --- a/test.py +++ b/test.py @@ -33,10 +33,12 @@ def test(data, save_dir='', merge=False, save_txt=False, - bbox_debug=0): + num_predictions=0): # Import wandb if logging is enabled - if bbox_debug > 0: + if num_predictions > 0: import wandb + if num_predictions > 100: + num_predictions = 100 # Initialize/load model and set device training = model is not None @@ -141,7 +143,7 @@ def test(data, f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format # Log images with bounding boxes - if len(wandb_image_log) < bbox_debug: + if len(wandb_image_log) < num_predictions: x = pred.clone() bbox_data = [{ "position": { diff --git a/train.py b/train.py index 4f3178ce75c8..6cfbda476c85 100644 --- a/train.py +++ b/train.py @@ -15,7 +15,7 @@ import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data import yaml -from torch.cuda import amp +#from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm @@ -34,8 +34,8 @@ try: import wandb - wandb_disabled = os.environ['WANDB_DISABLED'] if 'WANDB_DISABLED' in os.environ else None - if wandb_disabled is True: + wandb_disabled = os.environ.get('WANDB_DISABLED') + if wandb_disabled == 'true': wandb_log = False else: wandb_log = True @@ -46,13 +46,13 @@ def train(hyp, opt, device, tb_writer=None): - if wandb_log and not opt.resume: - name = opt.name if opt.name != '' else 'yoloV5' - run = wandb.init(project=name, config=opt) + if wandb_log and not opt.resume and wandb.run is None: + project = opt.name if opt.name != '' else 'yoloV5' + run = wandb.init(project=project, config=opt) # Do not log bounding box images if wandb is not initialized if not wandb_log: - print("Setting num_bbox to 0") - opt.num_bbox = 0 + print("Setting num_predictions to 0") + opt.num_predictions = 0 logger.info(f'Hyperparameters {hyp}') log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory @@ -101,17 +101,9 @@ def train(hyp, opt, device, tb_writer=None): # Resume Logging in the same W&B run if wandb_log and opt.resume: - if 'wandb_id' in ckpt: - try: - run = wandb.init(id=ckpt['wandb_id'],resume='must') - print('Resuming wandb logging') - except KeyError: - print('wandb run cannot be resumed, creating a new run') - if wandb.run is None: - name = opt.name if opt.name != '' else 'yoloV5' - run = wandb.init(project=name, config=opt) - print('wandb logging enabled') + project = opt.name if opt.name != '' else 'yoloV5' + run = wandb.init(id=ckpt['wandb_id'], resume="allow", project=project, config=opt) # Freeze freeze = ['', ] # parameter names to freeze (full or partial) @@ -352,7 +344,7 @@ def train(hyp, opt, device, tb_writer=None): single_cls=opt.single_cls, dataloader=testloader, save_dir=log_dir, - bbox_debug=opt.num_bbox) + num_predictions=opt.num_predictions) # Write with open(results_file, 'a') as f: @@ -445,7 +437,8 @@ def train(hyp, opt, device, tb_writer=None): parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') parser.add_argument('--logdir', type=str, default='runs/', help='logging directory') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') - parser.add_argument('--num-bbox', type=int, default=50, help='maximum number of images logged to W&B for bounding box debugging') + parser.add_argument('--num-predictions', type=int, default=50, help='number of images logged to W&B for bounding box debugging. Maximum limit is 100') + opt = parser.parse_args() # Set DDP variables From ff7b1c99be38a012438d6053429440e361735b5b Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 8 Oct 2020 08:01:49 +0000 Subject: [PATCH 3/7] Remove mutliple path for init, improve formatting --- train.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/train.py b/train.py index 6cfbda476c85..8f9df42ccd6a 100644 --- a/train.py +++ b/train.py @@ -15,7 +15,7 @@ import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data import yaml -#from torch.cuda import amp +from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm @@ -46,14 +46,6 @@ def train(hyp, opt, device, tb_writer=None): - if wandb_log and not opt.resume and wandb.run is None: - project = opt.name if opt.name != '' else 'yoloV5' - run = wandb.init(project=project, config=opt) - # Do not log bounding box images if wandb is not initialized - if not wandb_log: - print("Setting num_predictions to 0") - opt.num_predictions = 0 - logger.info(f'Hyperparameters {hyp}') log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory wdir = log_dir / 'weights' # weights directory @@ -98,12 +90,16 @@ def train(hyp, opt, device, tb_writer=None): logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report else: model = Model(opt.cfg, ch=3, nc=nc).to(device) # create - - # Resume Logging in the same W&B run - if wandb_log and opt.resume: - if wandb.run is None: - project = opt.name if opt.name != '' else 'yoloV5' - run = wandb.init(id=ckpt['wandb_id'], resume="allow", project=project, config=opt) + + # Initialize wandb + if wandb_log and wandb.run is None: + project = opt.name if opt.name != '' else 'yoloV5' + id = ckpt.get('wandb_id') if 'ckpt' in locals() else None + run = wandb.init(id=id, resume="allow", project=project, config=opt) + # Do not log bounding box images if wandb is not initialized + if not wandb_log: + print("Setting num_predictions to 0") + opt.num_predictions = 0 # Freeze freeze = ['', ] # parameter names to freeze (full or partial) From 7d187215c01b4f05ab3bab4754f5ae54336bbd5a Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Fri, 9 Oct 2020 14:35:34 +0000 Subject: [PATCH 4/7] Fix typo --- test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test.py b/test.py index 7430fb67d8c4..021c1819adce 100644 --- a/test.py +++ b/test.py @@ -32,7 +32,7 @@ def test(data, dataloader=None, save_dir=Path(''), # for saving images save_txt=False, # for auto-labelling - plots=True + plots=True, num_predictions=0): # Import wandb if logging is enabled if num_predictions > 0: From 1a09bc2e1325844b15389b6ec2d0ac78e1bf6d85 Mon Sep 17 00:00:00 2001 From: AyushExel Date: Wed, 28 Oct 2020 22:17:12 +0530 Subject: [PATCH 5/7] Fix argument conflicts --- test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test.py b/test.py index 021c1819adce..93d90dfd4be9 100644 --- a/test.py +++ b/test.py @@ -32,6 +32,7 @@ def test(data, dataloader=None, save_dir=Path(''), # for saving images save_txt=False, # for auto-labelling + save_conf=False, plots=True, num_predictions=0): # Import wandb if logging is enabled From 572a62756a75175cd7f23b9456d3ea2040be93fd Mon Sep 17 00:00:00 2001 From: AyushExel Date: Wed, 28 Oct 2020 22:27:02 +0530 Subject: [PATCH 6/7] Remove conflicts --- test.py | 49 ++++++++++++++++++----------- train.py | 96 ++++++++++++++++++++++++++++++++++---------------------- 2 files changed, 89 insertions(+), 56 deletions(-) diff --git a/test.py b/test.py index 93d90dfd4be9..d094c8a35270 100644 --- a/test.py +++ b/test.py @@ -50,15 +50,17 @@ def test(data, set_logging() device = select_device(opt.device, batch_size=batch_size) save_txt = opt.save_txt # save *.txt labels - if save_txt: - out = Path('inference/output') - if os.path.exists(out): - shutil.rmtree(out) # delete output folder - os.makedirs(out) # make new output folder # Remove previous - for f in glob.glob(str(save_dir / 'test_batch*.jpg')): - os.remove(f) + if os.path.exists(save_dir): + shutil.rmtree(save_dir) # delete dir + os.makedirs(save_dir) # make new dir + + if save_txt: + out = save_dir / 'autolabels' + if os.path.exists(out): + shutil.rmtree(out) # delete dir + os.makedirs(out) # make new dir # Load model model = attempt_load(weights, map_location=device) # load FP32 model @@ -114,7 +116,7 @@ def test(data, # Compute loss if training: # if model has loss hyperparameters - loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls + loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls # Run NMS t = time_synchronized() @@ -140,11 +142,12 @@ def test(data, x[:, :4] = scale_coords(img[si].shape[1:], x[:, :4], shapes[si][0], shapes[si][1]) # to original for *xyxy, conf, cls in x: xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh + line = (cls, conf, *xywh) if save_conf else (cls, *xywh) # label format with open(str(out / Path(paths[si]).stem) + '.txt', 'a') as f: - f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format + f.write(('%g ' * len(line) + '\n') % line) # Log images with bounding boxes - if len(wandb_image_log) < num_predictions: + if len(wandb_image_log) < bbox_debug: x = pred.clone() bbox_data = [{ "position": { @@ -214,9 +217,9 @@ def test(data, # Plot images if plots and batch_i < 1: - f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename + f = save_dir / f'test_batch{batch_i}_gt.jpg' # filename plot_images(img, targets, paths, str(f), names) # ground truth - f = save_dir / ('test_batch%g_pred.jpg' % batch_i) + f = save_dir / f'test_batch{batch_i}_pred.jpg' plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions # Log the images to W&B @@ -249,11 +252,11 @@ def test(data, # Save JSON if save_json and len(jdict): - f = 'detections_val2017_%s_results.json' % \ - (weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename - print('\nCOCO mAP with pycocotools... saving %s...' % f) - with open(f, 'w') as file: - json.dump(jdict, file) + w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights + file = save_dir / f"detections_val2017_{w}_results.json" # predicted annotations file + print('\nCOCO mAP with pycocotools... saving %s...' % file) + with open(file, 'w') as f: + json.dump(jdict, f) try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb from pycocotools.coco import COCO @@ -261,7 +264,7 @@ def test(data, imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]) # initialize COCO ground truth api - cocoDt = cocoGt.loadRes(f) # initialize COCO pred api + cocoDt = cocoGt.loadRes(str(file)) # initialize COCO pred api cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') cocoEval.params.imgIds = imgIds # image IDs to evaluate cocoEval.evaluate() @@ -294,6 +297,8 @@ def test(data, parser.add_argument('--augment', action='store_true', help='augmented inference') parser.add_argument('--verbose', action='store_true', help='report mAP by class') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') + parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') + parser.add_argument('--save-dir', type=str, default='runs/test', help='directory to save results') opt = parser.parse_args() opt.save_json |= opt.data.endswith('coco.yaml') opt.data = check_file(opt.data) # check file @@ -309,7 +314,13 @@ def test(data, opt.save_json, opt.single_cls, opt.augment, - opt.verbose) + opt.verbose, + save_dir=Path(opt.save_dir), + save_txt=opt.save_txt, + save_conf=opt.save_conf, + ) + + print('Results saved to %s' % opt.save_dir) elif opt.task == 'study': # run over a range of settings and save/plot for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']: diff --git a/train.py b/train.py index c82619fb7c90..de48236a337b 100644 --- a/train.py +++ b/train.py @@ -1,12 +1,13 @@ import argparse import logging -import math import os import random import shutil import time from pathlib import Path +from warnings import warn +import math import numpy as np import torch.distributed as dist import torch.nn.functional as F @@ -33,8 +34,8 @@ try: import wandb - wandb_disabled = os.environ.get('WANDB_DISABLED') - if wandb_disabled == 'true': + wandb_disabled = os.environ['WANDB_DISABLED'] if 'WANDB_DISABLED' in os.environ else None + if wandb_disabled is True: wandb_log = False else: wandb_log = True @@ -45,6 +46,14 @@ def train(hyp, opt, device, tb_writer=None): + if wandb_log and not opt.resume: + name = opt.name if opt.name != '' else 'yoloV5' + run = wandb.init(project=name, config=opt) + # Do not log bounding box images if wandb is not initialized + if not wandb_log: + print("Setting num_bbox to 0") + opt.num_bbox = 0 + logger.info(f'Hyperparameters {hyp}') log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory wdir = log_dir / 'weights' # weights directory @@ -89,16 +98,20 @@ def train(hyp, opt, device, tb_writer=None): logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report else: model = Model(opt.cfg, ch=3, nc=nc).to(device) # create - - # Initialize wandb - if wandb_log and wandb.run is None: - project = opt.name if opt.name != '' else 'yoloV5' - id = ckpt.get('wandb_id') if 'ckpt' in locals() else None - run = wandb.init(id=id, resume="allow", project=project, config=opt) - # Do not log bounding box images if wandb is not initialized - if not wandb_log: - print("Setting num_predictions to 0") - opt.num_predictions = 0 + + # Resume Logging in the same W&B run + if wandb_log and opt.resume: + if 'wandb_id' in ckpt: + try: + run = wandb.init(id=ckpt['wandb_id'],resume='must') + print('Resuming wandb logging') + except KeyError: + print('wandb run cannot be resumed, creating a new run') + + if wandb.run is None: + name = opt.name if opt.name != '' else 'yoloV5' + run = wandb.init(project=name, config=opt) + print('wandb logging enabled') # Freeze freeze = ['', ] # parameter names to freeze (full or partial) @@ -217,7 +230,7 @@ def train(hyp, opt, device, tb_writer=None): hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model - model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) + model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model.names = names @@ -226,10 +239,11 @@ def train(hyp, opt, device, tb_writer=None): nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training maps = np.zeros(nc) # mAP per class - results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' + results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move scaler = amp.GradScaler(enabled=cuda) - logger.info('Image sizes %g train, %g test\nUsing %g dataloader workers\nLogging results to %s\n' + logger.info('Image sizes %g train, %g test\n' + 'Using %g dataloader workers\nLogging results to %s\n' 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, log_dir, epochs)) for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ model.train() @@ -256,7 +270,7 @@ def train(hyp, opt, device, tb_writer=None): if rank != -1: dataloader.sampler.set_epoch(epoch) pbar = enumerate(dataloader) - logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size')) + logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size')) if rank in [-1, 0]: pbar = tqdm(pbar, total=nb) # progress bar optimizer.zero_grad() @@ -267,7 +281,7 @@ def train(hyp, opt, device, tb_writer=None): # Warmup if ni <= nw: xi = [0, nw] # x interp - # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) + # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 @@ -311,11 +325,11 @@ def train(hyp, opt, device, tb_writer=None): # Plot if ni < 3: - f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename + f = str(log_dir / f'train_batch{ni}.jpg') # filename result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) - if tb_writer and result is not None: - tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) - # tb_writer.add_graph(model, imgs) # add model to tensorboard + # if tb_writer and result is not None: + # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) + # tb_writer.add_graph(model, imgs) # add model to tensorboard # end batch ------------------------------------------------------------------------------------------------ @@ -337,12 +351,13 @@ def train(hyp, opt, device, tb_writer=None): single_cls=opt.single_cls, dataloader=testloader, save_dir=log_dir, - plots=epoch == 0 or final_epoch,# plot first and last - num_predictions=opt.num_predictions) + plots=epoch == 0 or final_epoch, + num_predictions=opt.num_predictions + ) # Write with open(results_file, 'a') as f: - f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls) + f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) if len(opt.name) and opt.bucket: os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) @@ -352,6 +367,10 @@ def train(hyp, opt, device, tb_writer=None): 'x/lr0', 'x/lr1', 'x/lr2'] # params # Tensorboard if tb_writer: + tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss + 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', + 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss + 'x/lr0', 'x/lr1', 'x/lr2'] # params for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): tb_writer.add_scalar(tag, x, epoch) @@ -361,7 +380,7 @@ def train(hyp, opt, device, tb_writer=None): wandb.log({tag:x}) # Update best mAP - fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1] + fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] if fi > best_fitness: best_fitness = fi @@ -422,7 +441,7 @@ def train(hyp, opt, device, tb_writer=None): parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') - parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') + parser.add_argument('--name', default='', help='renames experiment folder exp{N} to exp{N}_{name} if supplied') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') @@ -431,8 +450,7 @@ def train(hyp, opt, device, tb_writer=None): parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') parser.add_argument('--logdir', type=str, default='runs/', help='logging directory') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') - parser.add_argument('--num-predictions', type=int, default=50, help='number of images logged to W&B for bounding box debugging. Maximum limit is 100') - + parser.add_argument('--num-bbox', type=int, default=50, help='maximum number of images logged to W&B for bounding box debugging') opt = parser.parse_args() # Set DDP variables @@ -460,9 +478,8 @@ def train(hyp, opt, device, tb_writer=None): opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1 - device = select_device(opt.device, batch_size=opt.batch_size) - # DDP mode + device = select_device(opt.device, batch_size=opt.batch_size) if opt.local_rank != -1: assert torch.cuda.device_count() > opt.local_rank torch.cuda.set_device(opt.local_rank) @@ -471,15 +488,20 @@ def train(hyp, opt, device, tb_writer=None): assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' opt.batch_size = opt.total_batch_size // opt.world_size - logger.info(opt) + # Hyperparameters with open(opt.hyp) as f: hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps + if 'box' not in hyp: + warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' % + (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120')) + hyp['box'] = hyp.pop('giou') # Train + logger.info(opt) if not opt.evolve: tb_writer = None if opt.global_rank in [-1, 0]: - logger.info('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir) + logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.logdir}", view at http://localhost:6006/') tb_writer = SummaryWriter(log_dir=log_dir) # runs/exp0 train(hyp, opt, device, tb_writer) @@ -494,7 +516,7 @@ def train(hyp, opt, device, tb_writer=None): 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok) 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr - 'giou': (1, 0.02, 0.2), # GIoU loss gain + 'box': (1, 0.02, 0.2), # box loss gain 'cls': (1, 0.2, 4.0), # cls loss gain 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels) @@ -519,7 +541,7 @@ def train(hyp, opt, device, tb_writer=None): assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' opt.notest, opt.nosave = True, True # only test/save final epoch # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices - yaml_file = Path('runs/evolve/hyp_evolved.yaml') # save best result here + yaml_file = Path(opt.logdir) / 'evolve' / 'hyp_evolved.yaml' # save best result here if opt.bucket: os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists @@ -563,5 +585,5 @@ def train(hyp, opt, device, tb_writer=None): # Plot results plot_evolution(yaml_file) - print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these ' - 'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file)) + print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n' + f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}') From 597dead83278f01986eda18e6d4b276d563dfb2d Mon Sep 17 00:00:00 2001 From: AyushExel Date: Wed, 28 Oct 2020 22:47:50 +0530 Subject: [PATCH 7/7] Fix typo --- test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test.py b/test.py index d094c8a35270..f931ed1cc7e1 100644 --- a/test.py +++ b/test.py @@ -147,7 +147,7 @@ def test(data, f.write(('%g ' * len(line) + '\n') % line) # Log images with bounding boxes - if len(wandb_image_log) < bbox_debug: + if len(wandb_image_log) < num_predictions: x = pred.clone() bbox_data = [{ "position": {