From 791dadb51c7da5641a4841eb8a5f319bbc24982b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 6 Dec 2020 14:58:33 +0100 Subject: [PATCH] Pycocotools best.pt after COCO train (#1616) * Pycocotools best.pt after COCO train * cleanup --- models/hub/yolov3-tiny.yaml | 41 +++++++++++++++++++++++++++++ models/hub/yolov3.yaml | 51 +++++++++++++++++++++++++++++++++++++ test.py | 5 ++-- train.py | 33 ++++++++++++++++-------- utils/google_utils.py | 2 +- 5 files changed, 117 insertions(+), 15 deletions(-) create mode 100644 models/hub/yolov3-tiny.yaml create mode 100644 models/hub/yolov3.yaml diff --git a/models/hub/yolov3-tiny.yaml b/models/hub/yolov3-tiny.yaml new file mode 100644 index 000000000000..85f9fbd498d0 --- /dev/null +++ b/models/hub/yolov3-tiny.yaml @@ -0,0 +1,41 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,14, 23,27, 37,58] # P4/16 + - [81,82, 135,169, 344,319] # P5/32 + +# YOLOv3-tiny backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [16, 3, 1]], # 0 + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2 + [-1, 1, Conv, [32, 3, 1]], + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4 + [-1, 1, Conv, [64, 3, 1]], + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8 + [-1, 1, Conv, [128, 3, 1]], + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16 + [-1, 1, Conv, [256, 3, 1]], + [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32 + [-1, 1, Conv, [512, 3, 1]], + [-1, 1, nn.ZeroPad2d, [0, 1, 0, 1]], # 11 + [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12 + ] + +# YOLOv3-tiny head +head: + [[-1, 1, Conv, [1024, 3, 1]], + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large) + + [-2, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P4 + [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium) + + [[19, 15], 1, Detect, [nc, anchors]], # Detect(P4, P5) + ] diff --git a/models/hub/yolov3.yaml b/models/hub/yolov3.yaml new file mode 100644 index 000000000000..f2e761355469 --- /dev/null +++ b/models/hub/yolov3.yaml @@ -0,0 +1,51 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# darknet53 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [32, 3, 1]], # 0 + [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 + [-1, 1, Bottleneck, [64]], + [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 + [-1, 2, Bottleneck, [128]], + [-1, 1, Conv, [256, 3, 2]], # 5-P3/8 + [-1, 8, Bottleneck, [256]], + [-1, 1, Conv, [512, 3, 2]], # 7-P4/16 + [-1, 8, Bottleneck, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 + [-1, 4, Bottleneck, [1024]], # 10 + ] + +# YOLOv3 head +head: + [[-1, 1, Bottleneck, [1024, False]], + [-1, 1, Conv, [512, [1, 1]]], + [-1, 1, Conv, [1024, 3, 1]], + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large) + + [-2, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P4 + [-1, 1, Bottleneck, [512, False]], + [-1, 1, Bottleneck, [512, False]], + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium) + + [-2, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P3 + [-1, 1, Bottleneck, [256, False]], + [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small) + + [[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/test.py b/test.py index d1c231d4e201..cf22218e40f2 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,4 @@ import argparse -import glob import json import os from pathlib import Path @@ -246,7 +245,7 @@ def test(data, # Save JSON if save_json and len(jdict): w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights - anno_json = glob.glob('../coco/annotations/instances_val*.json')[0] # annotations json + anno_json = '../coco/annotations/instances_val2017.json' # annotations json pred_json = str(save_dir / f"{w}_predictions.json") # predictions json print('\nEvaluating pycocotools mAP... saving %s...' % pred_json) with open(pred_json, 'w') as f: @@ -266,7 +265,7 @@ def test(data, eval.summarize() map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) except Exception as e: - print('ERROR: pycocotools unable to run: %s' % e) + print(f'pycocotools unable to run: {e}') # Return results if not training: diff --git a/train.py b/train.py index 353a199368ca..b88e44c16160 100644 --- a/train.py +++ b/train.py @@ -22,6 +22,7 @@ from tqdm import tqdm import test # import test.py to get mAP after each epoch +from models.experimental import attempt_load from models.yolo import Model from utils.autoanchor import check_anchors from utils.datasets import create_dataloader @@ -193,9 +194,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Process 0 if rank in [-1, 0]: ema.updates = start_epoch * nb // accumulate # set EMA updates - testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, + testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, - rank=-1, world_size=opt.world_size, workers=opt.workers)[0] # testloader + rank=-1, world_size=opt.world_size, workers=opt.workers, pad=0.5)[0] if not opt.resume: labels = np.concatenate(dataset.labels, 0) @@ -385,15 +386,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if rank in [-1, 0]: # Strip optimizers - n = opt.name if opt.name.isnumeric() else '' - fresults, flast, fbest = save_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt' - for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]): - if f1.exists(): - os.rename(f1, f2) # rename - if str(f2).endswith('.pt'): # is *.pt - strip_optimizer(f2) # strip optimizer - os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload - # Finish + for f in [last, best]: + if f.exists(): # is *.pt + strip_optimizer(f) # strip optimizer + os.system('gsutil cp %s gs://%s/weights' % (f, opt.bucket)) if opt.bucket else None # upload + + # Plots if plots: plot_results(save_dir=save_dir) # save as results.png if wandb: @@ -401,6 +399,19 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files if (save_dir / f).exists()]}) logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) + + # Test best.pt + if opt.data.endswith('coco.yaml') and nc == 80: # if COCO + results, _, _ = test.test(opt.data, + batch_size=total_batch_size, + imgsz=imgsz_test, + model=attempt_load(best if best.exists() else last, device).half(), + single_cls=opt.single_cls, + dataloader=testloader, + save_dir=save_dir, + save_json=True, # use pycocotools + plots=False) + else: dist.destroy_process_group() diff --git a/utils/google_utils.py b/utils/google_utils.py index a311efd1bb7b..11d1a6ca784b 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -17,7 +17,7 @@ def gsutil_getsize(url=''): def attempt_download(weights): # Attempt to download pretrained weights if not found locally - weights = weights.strip().replace("'", '') + weights = str(weights).strip().replace("'", '') file = Path(weights).name.lower() msg = weights + ' missing, try downloading from https://github.com/ultralytics/yolov5/releases/'