From 69be8e738f45e7908c1270eedd350f98c0c7bfa4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 4 Jan 2021 19:54:09 -0800 Subject: [PATCH] YOLOv5 v4.0 Release (#1837) * Update C3 module * Update C3 module * Update C3 module * Update C3 module * update * update * update * update * update * update * update * update * update * updates * updates * updates * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * updates * updates * updates * updates * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update datasets * update * update * update * update attempt_downlaod() * merge * merge * update * update * update * update * update * update * update * update * update * update * parameterize eps * comments * gs-multiple * update * max_nms implemented * Create one_cycle() function * update * update * update * update * update * update * update * update study.png * update study.png * Update datasets.py --- README.md | 35 +++++++----- data/coco.yaml | 18 +++--- data/coco128.yaml | 18 +++--- data/voc.yaml | 4 +- models/common.py | 41 +++++++++++--- models/experimental.py | 4 +- models/hub/anchors.yaml | 58 +++++++++++++++++++ models/hub/yolov5-p2.yaml | 54 ++++++++++++++++++ models/hub/yolov5-p6.yaml | 56 +++++++++++++++++++ models/hub/yolov5-p7.yaml | 67 ++++++++++++++++++++++ models/yolo.py | 18 +++--- models/yolov5l.yaml | 16 +++--- models/yolov5m.yaml | 16 +++--- models/yolov5s.yaml | 16 +++--- models/yolov5x.yaml | 16 +++--- train.py | 14 +++-- utils/autoanchor.py | 1 + utils/datasets.py | 113 ++++++++++++++++++++++++++++++++++++-- utils/general.py | 11 ++-- utils/google_utils.py | 25 +++------ utils/loss.py | 4 +- utils/plots.py | 4 +- utils/torch_utils.py | 5 +- 23 files changed, 489 insertions(+), 125 deletions(-) create mode 100644 models/hub/anchors.yaml create mode 100644 models/hub/yolov5-p2.yaml create mode 100644 models/hub/yolov5-p6.yaml create mode 100644 models/hub/yolov5-p7.yaml diff --git a/README.md b/README.md index f3d1ba3445e6..a3b5b00f4d7e 100755 --- a/README.md +++ b/README.md @@ -4,28 +4,32 @@ ![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg) -This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk. +This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices evolved over thousands of hours of training and evolution on anonymized client datasets. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk. -** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8. +** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8. +- **January 5, 2021**: [v4.0 release](https://github.com/ultralytics/yolov5/releases/tag/v4.0): nn.SiLU() activations, [Weights & Biases](https://wandb.ai/) logging, [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/) integration. - **August 13, 2020**: [v3.0 release](https://github.com/ultralytics/yolov5/releases/tag/v3.0): nn.Hardswish() activations, data autodownload, native AMP. - **July 23, 2020**: [v2.0 release](https://github.com/ultralytics/yolov5/releases/tag/v2.0): improved model definition, training and mAP. - **June 22, 2020**: [PANet](https://arxiv.org/abs/1803.01534) updates: new heads, reduced parameters, improved speed and mAP [364fcfd](https://github.com/ultralytics/yolov5/commit/364fcfd7dba53f46edd4f04c037a039c0a287972). - **June 19, 2020**: [FP16](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.half) as new default for smaller checkpoints and faster inference [d4c6674](https://github.com/ultralytics/yolov5/commit/d4c6674c98e19df4c40e33a777610a18d1961145). -- **June 9, 2020**: [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) updates: improved speed, size, and accuracy (credit to @WongKinYiu for CSP). -- **May 27, 2020**: Public release. YOLOv5 models are SOTA among all known YOLO implementations. ## Pretrained Checkpoints -| Model | APval | APtest | AP50 | SpeedGPU | FPSGPU || params | GFLOPS | -|---------- |------ |------ |------ | -------- | ------| ------ |------ | :------: | -| [YOLOv5s](https://github.com/ultralytics/yolov5/releases) | 37.0 | 37.0 | 56.2 | **2.4ms** | **416** || 7.5M | 17.5 -| [YOLOv5m](https://github.com/ultralytics/yolov5/releases) | 44.3 | 44.3 | 63.2 | 3.4ms | 294 || 21.8M | 52.3 -| [YOLOv5l](https://github.com/ultralytics/yolov5/releases) | 47.7 | 47.7 | 66.5 | 4.4ms | 227 || 47.8M | 117.2 -| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) | **49.2** | **49.2** | **67.7** | 6.9ms | 145 || 89.0M | 221.5 -| | | | | | || | -| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) + TTA|**50.8**| **50.8** | **68.9** | 25.5ms | 39 || 89.0M | 801.0 +| Model | size | APval | APtest | AP50 | SpeedV100 | FPSV100 || params | GFLOPS | +|---------- |------ |------ |------ |------ | -------- | ------| ------ |------ | :------: | +| [YOLOv5s](https://github.com/ultralytics/yolov5/releases) |640 |36.8 |36.8 |55.6 |**2.2ms** |**455** ||7.3M |17.0 +| [YOLOv5m](https://github.com/ultralytics/yolov5/releases) |640 |44.5 |44.5 |63.1 |2.9ms |345 ||21.4M |51.3 +| [YOLOv5l](https://github.com/ultralytics/yolov5/releases) |640 |48.1 |48.1 |66.4 |3.8ms |264 ||47.0M |115.4 +| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) |640 |**50.1** |**50.1** |**68.7** |6.0ms |167 ||87.7M |218.8 +| | | | | | | || | +| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) + TTA |832 |**51.9** |**51.9** |**69.6** |24.9ms |40 ||87.7M |1005.3 + + ** APtest denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results denote val2017 accuracy. ** All AP numbers are for single-model single-scale without ensemble or TTA. **Reproduce mAP** by `python test.py --data coco.yaml --img 640 --conf 0.001 --iou 0.65` @@ -33,6 +37,7 @@ This repository represents Ultralytics open-source research into future object d ** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation). ** Test Time Augmentation ([TTA](https://github.com/ultralytics/yolov5/issues/303)) runs at 3 image sizes. **Reproduce TTA** by `python test.py --data coco.yaml --img 832 --iou 0.65 --augment` + ## Requirements Python 3.8 or later with all [requirements.txt](https://github.com/ultralytics/yolov5/blob/master/requirements.txt) dependencies installed, including `torch>=1.7`. To install run: @@ -106,7 +111,7 @@ import torch from PIL import Image # Model -model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # for PIL/cv2/np inputs and NMS +model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # Images img1 = Image.open('zidane.jpg') @@ -114,13 +119,13 @@ img2 = Image.open('bus.jpg') imgs = [img1, img2] # batched list of images # Inference -prediction = model(imgs, size=640) # includes NMS +result = model(imgs) ``` ## Training -Download [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) and run command below. Training times for YOLOv5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster). Use the largest `--batch-size` your GPU allows (batch sizes shown for 16 GB devices). +Run commands below to reproduce results on [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) dataset (dataset auto-downloads on first use). Training times for YOLOv5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster). Use the largest `--batch-size` your GPU allows (batch sizes shown for 16 GB devices). ```bash $ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size 64 yolov5m 40 diff --git a/data/coco.yaml b/data/coco.yaml index 09f3a7890373..b9da2bf5919b 100644 --- a/data/coco.yaml +++ b/data/coco.yaml @@ -18,15 +18,15 @@ test: ../coco/test-dev2017.txt # 20288 of 40670 images, submit to https://compe nc: 80 # class names -names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', - 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', - 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', - 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', - 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush'] +names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush' ] # Print classes # with open('data/coco.yaml') as f: diff --git a/data/coco128.yaml b/data/coco128.yaml index 12e1d799718d..c41bccf2b8d5 100644 --- a/data/coco128.yaml +++ b/data/coco128.yaml @@ -17,12 +17,12 @@ val: ../coco128/images/train2017/ # 128 images nc: 80 # class names -names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', - 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', - 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', - 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', - 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush'] +names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', + 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush' ] diff --git a/data/voc.yaml b/data/voc.yaml index b93f6a0e5717..851a9e0b060c 100644 --- a/data/voc.yaml +++ b/data/voc.yaml @@ -17,5 +17,5 @@ val: ../VOC/images/val/ # 4952 images nc: 20 # class names -names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', - 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] +names: [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', + 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ] diff --git a/models/common.py b/models/common.py index 17b9f017085e..a25707fdd387 100644 --- a/models/common.py +++ b/models/common.py @@ -30,7 +30,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k super(Conv, self).__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) - self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) def forward(self, x): return self.act(self.bn(self.conv(x))) @@ -105,9 +105,39 @@ class Focus(nn.Module): def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups super(Focus, self).__init__() self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + # self.contract = Contract(gain=2) def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + # return self.conv(self.contract(x)) + + +class Contract(nn.Module): + # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) + def __init__(self, gain=2): + super().__init__() + self.gain = gain + + def forward(self, x): + N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain' + s = self.gain + x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40) + return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40) + + +class Expand(nn.Module): + # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160) + def __init__(self, gain=2): + super().__init__() + self.gain = gain + + def forward(self, x): + N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain' + s = self.gain + x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2) + return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160) class Concat(nn.Module): @@ -253,20 +283,13 @@ def tolist(self): return x -class Flatten(nn.Module): - # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions - @staticmethod - def forward(x): - return x.view(x.size(0), -1) - - class Classify(nn.Module): # Classification head, i.e. x(b,c1,20,20) to x(b,c2) def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups super(Classify, self).__init__() self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1) - self.flat = Flatten() + self.flat = nn.Flatten() def forward(self, x): z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list diff --git a/models/experimental.py b/models/experimental.py index 136e86d7f8e6..2dbbf7fa32f2 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -105,8 +105,8 @@ def forward(self, x, augment=False): for module in self: y.append(module(x, augment)[0]) # y = torch.stack(y).max(0)[0] # max ensemble - # y = torch.cat(y, 1) # nms ensemble - y = torch.stack(y).mean(0) # mean ensemble + # y = torch.stack(y).mean(0) # mean ensemble + y = torch.cat(y, 1) # nms ensemble return y, None # inference, train output diff --git a/models/hub/anchors.yaml b/models/hub/anchors.yaml new file mode 100644 index 000000000000..a07a4dc72387 --- /dev/null +++ b/models/hub/anchors.yaml @@ -0,0 +1,58 @@ +# Default YOLOv5 anchors for COCO data + + +# P5 ------------------------------------------------------------------------------------------------------------------- +# P5-640: +anchors_p5_640: + - [ 10,13, 16,30, 33,23 ] # P3/8 + - [ 30,61, 62,45, 59,119 ] # P4/16 + - [ 116,90, 156,198, 373,326 ] # P5/32 + + +# P6 ------------------------------------------------------------------------------------------------------------------- +# P6-640: thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11, 21,19, 17,41, 43,32, 39,70, 86,64, 65,131, 134,130, 120,265, 282,180, 247,354, 512,387 +anchors_p6_640: + - [ 9,11, 21,19, 17,41 ] # P3/8 + - [ 43,32, 39,70, 86,64 ] # P4/16 + - [ 65,131, 134,130, 120,265 ] # P5/32 + - [ 282,180, 247,354, 512,387 ] # P6/64 + +# P6-1280: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27, 44,40, 38,94, 96,68, 86,152, 180,137, 140,301, 303,264, 238,542, 436,615, 739,380, 925,792 +anchors_p6_1280: + - [ 19,27, 44,40, 38,94 ] # P3/8 + - [ 96,68, 86,152, 180,137 ] # P4/16 + - [ 140,301, 303,264, 238,542 ] # P5/32 + - [ 436,615, 739,380, 925,792 ] # P6/64 + +# P6-1920: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41, 67,59, 57,141, 144,103, 129,227, 270,205, 209,452, 455,396, 358,812, 653,922, 1109,570, 1387,1187 +anchors_p6_1920: + - [ 28,41, 67,59, 57,141 ] # P3/8 + - [ 144,103, 129,227, 270,205 ] # P4/16 + - [ 209,452, 455,396, 358,812 ] # P5/32 + - [ 653,922, 1109,570, 1387,1187 ] # P6/64 + + +# P7 ------------------------------------------------------------------------------------------------------------------- +# P7-640: thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11, 13,30, 29,20, 30,46, 61,38, 39,92, 78,80, 146,66, 79,163, 149,150, 321,143, 157,303, 257,402, 359,290, 524,372 +anchors_p7_640: + - [ 11,11, 13,30, 29,20 ] # P3/8 + - [ 30,46, 61,38, 39,92 ] # P4/16 + - [ 78,80, 146,66, 79,163 ] # P5/32 + - [ 149,150, 321,143, 157,303 ] # P6/64 + - [ 257,402, 359,290, 524,372 ] # P7/128 + +# P7-1280: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22, 54,36, 32,77, 70,83, 138,71, 75,173, 165,159, 148,334, 375,151, 334,317, 251,626, 499,474, 750,326, 534,814, 1079,818 +anchors_p7_1280: + - [ 19,22, 54,36, 32,77 ] # P3/8 + - [ 70,83, 138,71, 75,173 ] # P4/16 + - [ 165,159, 148,334, 375,151 ] # P5/32 + - [ 334,317, 251,626, 499,474 ] # P6/64 + - [ 750,326, 534,814, 1079,818 ] # P7/128 + +# P7-1920: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34, 81,55, 47,115, 105,124, 207,107, 113,259, 247,238, 222,500, 563,227, 501,476, 376,939, 749,711, 1126,489, 801,1222, 1618,1227 +anchors_p7_1920: + - [ 29,34, 81,55, 47,115 ] # P3/8 + - [ 105,124, 207,107, 113,259 ] # P4/16 + - [ 247,238, 222,500, 563,227 ] # P5/32 + - [ 501,476, 376,939, 749,711 ] # P6/64 + - [ 1126,489, 801,1222, 1618,1227 ] # P7/128 diff --git a/models/hub/yolov5-p2.yaml b/models/hub/yolov5-p2.yaml new file mode 100644 index 000000000000..0633a90fd065 --- /dev/null +++ b/models/hub/yolov5-p2.yaml @@ -0,0 +1,54 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: 3 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 7-P5/32 + [ -1, 1, SPP, [ 1024, [ 5, 9, 13 ] ] ], + [ -1, 3, C3, [ 1024, False ] ], # 9 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 13 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 17 (P3/8-small) + + [ -1, 1, Conv, [ 128, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 2 ], 1, Concat, [ 1 ] ], # cat backbone P2 + [ -1, 1, C3, [ 128, False ] ], # 21 (P2/4-xsmall) + + [ -1, 1, Conv, [ 128, 3, 2 ] ], + [ [ -1, 18 ], 1, Concat, [ 1 ] ], # cat head P3 + [ -1, 3, C3, [ 256, False ] ], # 24 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 14 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 27 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 10 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 1024, False ] ], # 30 (P5/32-large) + + [ [ 24, 27, 30 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5) + ] diff --git a/models/hub/yolov5-p6.yaml b/models/hub/yolov5-p6.yaml new file mode 100644 index 000000000000..3728a118f090 --- /dev/null +++ b/models/hub/yolov5-p6.yaml @@ -0,0 +1,56 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: 3 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 + [ -1, 3, C3, [ 768 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 + [ -1, 1, SPP, [ 1024, [ 3, 5, 7 ] ] ], + [ -1, 3, C3, [ 1024, False ] ], # 11 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 768, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 + [ -1, 3, C3, [ 768, False ] ], # 15 + + [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 19 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 23 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 20 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 26 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 16 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 768, False ] ], # 29 (P5/32-large) + + [ -1, 1, Conv, [ 768, 3, 2 ] ], + [ [ -1, 12 ], 1, Concat, [ 1 ] ], # cat head P6 + [ -1, 3, C3, [ 1024, False ] ], # 32 (P5/64-xlarge) + + [ [ 23, 26, 29, 32 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6) + ] diff --git a/models/hub/yolov5-p7.yaml b/models/hub/yolov5-p7.yaml new file mode 100644 index 000000000000..ca8f8492ce0e --- /dev/null +++ b/models/hub/yolov5-p7.yaml @@ -0,0 +1,67 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: 3 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 + [ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 + [ -1, 3, C3, [ 128 ] ], + [ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 + [ -1, 9, C3, [ 256 ] ], + [ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 + [ -1, 9, C3, [ 512 ] ], + [ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 + [ -1, 3, C3, [ 768 ] ], + [ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 + [ -1, 3, C3, [ 1024 ] ], + [ -1, 1, Conv, [ 1280, 3, 2 ] ], # 11-P7/128 + [ -1, 1, SPP, [ 1280, [ 3, 5 ] ] ], + [ -1, 3, C3, [ 1280, False ] ], # 13 + ] + +# YOLOv5 head +head: + [ [ -1, 1, Conv, [ 1024, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 10 ], 1, Concat, [ 1 ] ], # cat backbone P6 + [ -1, 3, C3, [ 1024, False ] ], # 17 + + [ -1, 1, Conv, [ 768, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 + [ -1, 3, C3, [ 768, False ] ], # 21 + + [ -1, 1, Conv, [ 512, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 + [ -1, 3, C3, [ 512, False ] ], # 25 + + [ -1, 1, Conv, [ 256, 1, 1 ] ], + [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], + [ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 + [ -1, 3, C3, [ 256, False ] ], # 29 (P3/8-small) + + [ -1, 1, Conv, [ 256, 3, 2 ] ], + [ [ -1, 26 ], 1, Concat, [ 1 ] ], # cat head P4 + [ -1, 3, C3, [ 512, False ] ], # 32 (P4/16-medium) + + [ -1, 1, Conv, [ 512, 3, 2 ] ], + [ [ -1, 22 ], 1, Concat, [ 1 ] ], # cat head P5 + [ -1, 3, C3, [ 768, False ] ], # 35 (P5/32-large) + + [ -1, 1, Conv, [ 768, 3, 2 ] ], + [ [ -1, 18 ], 1, Concat, [ 1 ] ], # cat head P6 + [ -1, 3, C3, [ 1024, False ] ], # 38 (P6/64-xlarge) + + [ -1, 1, Conv, [ 1024, 3, 2 ] ], + [ [ -1, 14 ], 1, Concat, [ 1 ] ], # cat head P7 + [ -1, 3, C3, [ 1280, False ] ], # 41 (P7/128-xxlarge) + + [ [ 29, 32, 35, 38, 41 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6, P7) + ] diff --git a/models/yolo.py b/models/yolo.py index b75e939468a6..5dc8b57f4d98 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -1,17 +1,13 @@ import argparse import logging -import math import sys from copy import deepcopy from pathlib import Path -import torch -import torch.nn as nn - sys.path.append('./') # to run '$ python *.py' files in subdirectories logger = logging.getLogger(__name__) -from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, C3, Concat, NMS, autoShape +from models.common import * from models.experimental import MixConv2d, CrossConv from utils.autoanchor import check_anchor_order from utils.general import make_divisible, check_file, set_logging @@ -89,7 +85,7 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, # Build strides, anchors m = self.model[-1] # Detect() if isinstance(m, Detect): - s = 128 # 2x min stride + s = 256 # 2x min stride m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward m.anchors /= m.stride.view(-1, 1, 1) check_anchor_order(m) @@ -109,7 +105,7 @@ def forward(self, x, augment=False, profile=False): f = [None, 3, None] # flips (2-ud, 3-lr) y = [] # outputs for si, fi in zip(s, f): - xi = scale_img(x.flip(fi) if fi else x, si) + xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) yi = self.forward_once(xi)[0] # forward # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save yi[..., :4] /= si # de-scale @@ -242,13 +238,17 @@ def parse_model(d, ch): # model_dict, input_channels(3) elif m is nn.BatchNorm2d: args = [ch[f]] elif m is Concat: - c2 = sum([ch[-1 if x == -1 else x + 1] for x in f]) + c2 = sum([ch[x if x < 0 else x + 1] for x in f]) elif m is Detect: args.append([ch[x + 1] for x in f]) if isinstance(args[1], int): # number of anchors args[1] = [list(range(args[1] * 2))] * len(f) + elif m is Contract: + c2 = ch[f if f < 0 else f + 1] * args[0] ** 2 + elif m is Expand: + c2 = ch[f if f < 0 else f + 1] // args[0] ** 2 else: - c2 = ch[f] + c2 = ch[f if f < 0 else f + 1] m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module t = str(m)[8:-2].replace('__main__.', '') # module type diff --git a/models/yolov5l.yaml b/models/yolov5l.yaml index 13095541703f..71ebf86e5791 100644 --- a/models/yolov5l.yaml +++ b/models/yolov5l.yaml @@ -14,14 +14,14 @@ backbone: # [from, number, module, args] [[-1, 1, Focus, [64, 3]], # 0-P1/2 [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, BottleneckCSP, [128]], + [-1, 3, C3, [128]], [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 9, BottleneckCSP, [256]], + [-1, 9, C3, [256]], [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, BottleneckCSP, [512]], + [-1, 9, C3, [512]], [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], - [-1, 3, BottleneckCSP, [1024, False]], # 9 + [-1, 3, C3, [1024, False]], # 9 ] # YOLOv5 head @@ -29,20 +29,20 @@ head: [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, BottleneckCSP, [512, False]], # 13 + [-1, 3, C3, [512, False]], # 13 [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) + [-1, 3, C3, [256, False]], # 17 (P3/8-small) [-1, 1, Conv, [256, 3, 2]], [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) [-1, 1, Conv, [512, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] diff --git a/models/yolov5m.yaml b/models/yolov5m.yaml index eb50a713f2ff..3c749c916246 100644 --- a/models/yolov5m.yaml +++ b/models/yolov5m.yaml @@ -14,14 +14,14 @@ backbone: # [from, number, module, args] [[-1, 1, Focus, [64, 3]], # 0-P1/2 [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, BottleneckCSP, [128]], + [-1, 3, C3, [128]], [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 9, BottleneckCSP, [256]], + [-1, 9, C3, [256]], [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, BottleneckCSP, [512]], + [-1, 9, C3, [512]], [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], - [-1, 3, BottleneckCSP, [1024, False]], # 9 + [-1, 3, C3, [1024, False]], # 9 ] # YOLOv5 head @@ -29,20 +29,20 @@ head: [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, BottleneckCSP, [512, False]], # 13 + [-1, 3, C3, [512, False]], # 13 [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) + [-1, 3, C3, [256, False]], # 17 (P3/8-small) [-1, 1, Conv, [256, 3, 2]], [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) [-1, 1, Conv, [512, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] diff --git a/models/yolov5s.yaml b/models/yolov5s.yaml index 2bec4529294e..aca669d60d8b 100644 --- a/models/yolov5s.yaml +++ b/models/yolov5s.yaml @@ -14,14 +14,14 @@ backbone: # [from, number, module, args] [[-1, 1, Focus, [64, 3]], # 0-P1/2 [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, BottleneckCSP, [128]], + [-1, 3, C3, [128]], [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 9, BottleneckCSP, [256]], + [-1, 9, C3, [256]], [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, BottleneckCSP, [512]], + [-1, 9, C3, [512]], [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], - [-1, 3, BottleneckCSP, [1024, False]], # 9 + [-1, 3, C3, [1024, False]], # 9 ] # YOLOv5 head @@ -29,20 +29,20 @@ head: [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, BottleneckCSP, [512, False]], # 13 + [-1, 3, C3, [512, False]], # 13 [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) + [-1, 3, C3, [256, False]], # 17 (P3/8-small) [-1, 1, Conv, [256, 3, 2]], [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) [-1, 1, Conv, [512, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] diff --git a/models/yolov5x.yaml b/models/yolov5x.yaml index 9676402040ce..d3babdf7baf0 100644 --- a/models/yolov5x.yaml +++ b/models/yolov5x.yaml @@ -14,14 +14,14 @@ backbone: # [from, number, module, args] [[-1, 1, Focus, [64, 3]], # 0-P1/2 [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, BottleneckCSP, [128]], + [-1, 3, C3, [128]], [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 9, BottleneckCSP, [256]], + [-1, 9, C3, [256]], [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, BottleneckCSP, [512]], + [-1, 9, C3, [512]], [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 1, SPP, [1024, [5, 9, 13]]], - [-1, 3, BottleneckCSP, [1024, False]], # 9 + [-1, 3, C3, [1024, False]], # 9 ] # YOLOv5 head @@ -29,20 +29,20 @@ head: [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, BottleneckCSP, [512, False]], # 13 + [-1, 3, C3, [512, False]], # 13 [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) + [-1, 3, C3, [256, False]], # 17 (P3/8-small) [-1, 1, Conv, [256, 3, 2]], [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) [-1, 1, Conv, [512, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] diff --git a/train.py b/train.py index d0f3bdcc1ab1..d61ca4b3ac2a 100644 --- a/train.py +++ b/train.py @@ -104,6 +104,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): nbs = 64 # nominal batch size accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay + logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") pg0, pg1, pg2 = [], [], [] # optimizer parameter groups for k, v in model.named_modules(): @@ -164,7 +165,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): del ckpt, state_dict # Image sizes - gs = int(max(model.stride)) # grid size (max stride) + gs = int(model.stride.max()) # grid size (max stride) + nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples # DP mode @@ -187,7 +189,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, world_size=opt.world_size, workers=opt.workers, - image_weights=opt.image_weights) + image_weights=opt.image_weights, quad=opt.quad) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class nb = len(dataloader) # number of batches assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) @@ -214,7 +216,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # Model parameters - hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset + hyp['cls'] *= nc / 80. # scale hyp['cls'] to class count + hyp['obj'] *= imgsz ** 2 / 640. ** 2 * 3. / nl # scale hyp['obj'] to image size and output layers model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) @@ -290,6 +293,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size if rank != -1: loss *= opt.world_size # gradient averaged between devices in DDP mode + if opt.quad: + loss *= 4. # Backward scaler.scale(loss).backward() @@ -458,10 +463,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): parser.add_argument('--project', default='runs/train', help='save to project/name') parser.add_argument('--name', default='exp', help='save to project/name') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--quad', action='store_true', help='quad dataloader') opt = parser.parse_args() # Set DDP variables - opt.total_batch_size = opt.batch_size opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 set_logging(opt.global_rank) @@ -486,6 +491,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run # DDP mode + opt.total_batch_size = opt.batch_size device = select_device(opt.device, batch_size=opt.batch_size) if opt.local_rank != -1: assert torch.cuda.device_count() > opt.local_rank diff --git a/utils/autoanchor.py b/utils/autoanchor.py index 0c33dcbc3495..badefc154138 100644 --- a/utils/autoanchor.py +++ b/utils/autoanchor.py @@ -110,6 +110,7 @@ def print_results(k): print('WARNING: Extremely small objects found. ' '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0))) wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels + # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 # Kmeans calculation print('Running kmeans for %g anchors on %g points...' % (n, len(wh))) diff --git a/utils/datasets.py b/utils/datasets.py index e2b139fa6371..9001832aadec 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -15,6 +15,7 @@ import cv2 import numpy as np import torch +import torch.nn.functional as F from PIL import Image, ExifTags from torch.utils.data import Dataset from tqdm import tqdm @@ -55,7 +56,7 @@ def exif_size(img): def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, - rank=-1, world_size=1, workers=8, image_weights=False): + rank=-1, world_size=1, workers=8, image_weights=False, quad=False): # Make sure only the first process in DDP process the dataset first, and the following others can use the cache with torch_distributed_zero_first(rank): dataset = LoadImagesAndLabels(path, imgsz, batch_size, @@ -79,7 +80,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa num_workers=nw, sampler=sampler, pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn) + collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) return dataloader, dataset @@ -578,6 +579,32 @@ def collate_fn(batch): l[:, 0] = i # add target image index for build_targets() return torch.stack(img, 0), torch.cat(label, 0), path, shapes + @staticmethod + def collate_fn4(batch): + img, label, path, shapes = zip(*batch) # transposed + n = len(shapes) // 4 + img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n] + + ho = torch.tensor([[0., 0, 0, 1, 0, 0]]) + wo = torch.tensor([[0., 0, 1, 0, 0, 0]]) + s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale + for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW + i *= 4 + if random.random() < 0.5: + im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[ + 0].type(img[i].type()) + l = label[i] + else: + im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) + l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s + img4.append(im) + label4.append(l) + + for i, l in enumerate(label4): + l[:, 0] = i # add target image index for build_targets() + + return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4 + # Ancillary functions -------------------------------------------------------------------------------------------------- def load_image(self, index): @@ -617,7 +644,7 @@ def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5): def load_mosaic(self, index): - # loads images in a mosaic + # loads images in a 4-mosaic labels4 = [] s = self.img_size @@ -674,6 +701,80 @@ def load_mosaic(self, index): return img4, labels4 +def load_mosaic9(self, index): + # loads images in a 9-mosaic + + labels9 = [] + s = self.img_size + indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(8)] # 8 additional image indices + for i, index in enumerate(indices): + # Load image + img, _, (h, w) = load_image(self, index) + + # place img in img9 + if i == 0: # center + img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + h0, w0 = h, w + c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates + elif i == 1: # top + c = s, s - h, s + w, s + elif i == 2: # top right + c = s + wp, s - h, s + wp + w, s + elif i == 3: # right + c = s + w0, s, s + w0 + w, s + h + elif i == 4: # bottom right + c = s + w0, s + hp, s + w0 + w, s + hp + h + elif i == 5: # bottom + c = s + w0 - w, s + h0, s + w0, s + h0 + h + elif i == 6: # bottom left + c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h + elif i == 7: # left + c = s - w, s + h0 - h, s, s + h0 + elif i == 8: # top left + c = s - w, s + h0 - hp - h, s, s + h0 - hp + + padx, pady = c[:2] + x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords + + # Labels + x = self.labels[index] + labels = x.copy() + if x.size > 0: # Normalized xywh to pixel xyxy format + labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padx + labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + pady + labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padx + labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + pady + labels9.append(labels) + + # Image + img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax] + hp, wp = h, w # height, width previous + + # Offset + yc, xc = [int(random.uniform(0, s)) for x in self.mosaic_border] # mosaic center x, y + img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s] + + # Concat/clip labels + if len(labels9): + labels9 = np.concatenate(labels9, 0) + labels9[:, [1, 3]] -= xc + labels9[:, [2, 4]] -= yc + + np.clip(labels9[:, 1:], 0, 2 * s, out=labels9[:, 1:]) # use with random_perspective + # img9, labels9 = replicate(img9, labels9) # replicate + + # Augment + img9, labels9 = random_perspective(img9, labels9, + degrees=self.hyp['degrees'], + translate=self.hyp['translate'], + scale=self.hyp['scale'], + shear=self.hyp['shear'], + perspective=self.hyp['perspective'], + border=self.mosaic_border) # border to remove + + return img9, labels9 + + def replicate(img, labels): # Replicate labels h, w = img.shape[:2] @@ -811,12 +912,12 @@ def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shea return img, targets -def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n) +def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio w1, h1 = box1[2] - box1[0], box1[3] - box1[1] w2, h2 = box2[2] - box2[0], box2[3] - box2[1] - ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio - return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates + ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio + return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates def cutout(image, labels): diff --git a/utils/general.py b/utils/general.py index fceefb0a4a97..797587b13978 100755 --- a/utils/general.py +++ b/utils/general.py @@ -281,6 +281,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non # Settings min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height max_det = 300 # maximum number of detections per image + max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() time_limit = 10.0 # seconds to quit after redundant = True # require redundant detections multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) @@ -328,13 +329,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non # if not torch.isfinite(x).all(): # x = x[torch.isfinite(x).all(1)] - # If none remain process next image + # Check shape n = x.shape[0] # number of boxes - if not n: + if not n: # no boxes continue - - # Sort by confidence - # x = x[x[:, 4].argsort(descending=True)] + elif n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence # Batched NMS c = x[:, 5:6] * (0 if agnostic else max_wh) # classes @@ -352,6 +352,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non output[xi] = x[i] if (time.time() - t) > time_limit: + print(f'WARNING: NMS time limit {time_limit}s exceeded') break # time limit exceeded return output diff --git a/utils/google_utils.py b/utils/google_utils.py index f97349d4594f..242270c1b033 100644 --- a/utils/google_utils.py +++ b/utils/google_utils.py @@ -6,6 +6,7 @@ import time from pathlib import Path +import requests import torch @@ -21,21 +22,14 @@ def attempt_download(weights): file = Path(weights).name.lower() msg = weights + ' missing, try downloading from https://github.com/ultralytics/yolov5/releases/' - models = ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt'] # available models - redundant = False # offer second download option - - if file in models and not os.path.isfile(weights): - # Google Drive - # d = {'yolov5s.pt': '1R5T6rIyy3lLwgFXNms8whc-387H0tMQO', - # 'yolov5m.pt': '1vobuEExpWQVpXExsJ2w-Mbf3HJjWkQJr', - # 'yolov5l.pt': '1hrlqD1Wdei7UT4OgT785BEk1JwnSvNEV', - # 'yolov5x.pt': '1mM8aZJlWTxOg7BZJvNUMrTnA2AbeCVzS'} - # r = gdrive_download(id=d[file], name=weights) if file in d else 1 - # if r == 0 and os.path.exists(weights) and os.path.getsize(weights) > 1E6: # check - # return + response = requests.get('https://github.com/gitapi/repos/ultralytics/yolov5/releases/latest').json() # github api + assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] + redundant = False # second download option + if file in assets and not os.path.isfile(weights): try: # GitHub - url = 'https://github.com/ultralytics/yolov5/releases/download/v3.1/' + file + tag = response['tag_name'] # i.e. 'v1.0' + url = f'https://github.com/ultralytics/yolov5/releases/download/{tag}/{file}' print('Downloading %s to %s...' % (url, weights)) torch.hub.download_url_to_file(url, weights) assert os.path.exists(weights) and os.path.getsize(weights) > 1E6 # check @@ -53,10 +47,9 @@ def attempt_download(weights): return -def gdrive_download(id='1uH2BylpFxHKEGXKL6wJJlsgMU2YEjxuc', name='tmp.zip'): - # Downloads a file from Google Drive. from utils.google_utils import *; gdrive_download() +def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', name='tmp.zip'): + # Downloads a file from Google Drive. from yolov5.utils.google_utils import *; gdrive_download() t = time.time() - print('Downloading https://drive.google.com/uc?export=download&id=%s as %s... ' % (id, name), end='') os.remove(name) if os.path.exists(name) else None # remove existing os.remove('cookie') if os.path.exists('cookie') else None diff --git a/utils/loss.py b/utils/loss.py index 2049516904ff..46051f2eae49 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -106,7 +106,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # Losses nt = 0 # number of targets no = len(p) # number of outputs - balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6 + balance = [4.0, 1.0, 0.3, 0.1, 0.03] # P3-P7 for i, pi in enumerate(p): # layer index, layer predictions b, a, gj, gi = indices[i] # image, anchor, gridy, gridx tobj = torch.zeros_like(pi[..., 0], device=device) # target obj @@ -140,7 +140,7 @@ def compute_loss(p, targets, model): # predictions, targets, model s = 3 / no # output count scaling lbox *= h['box'] * s - lobj *= h['obj'] * s * (1.4 if no == 4 else 1.) + lobj *= h['obj'] lcls *= h['cls'] * s bs = tobj.shape[0] # batch size diff --git a/utils/plots.py b/utils/plots.py index 6b4e30147b77..c883ea24f253 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -223,7 +223,7 @@ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() plt.savefig('targets.jpg', dpi=200) -def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt() +def plot_study_txt(path='study/', x=None): # from utils.plots import *; plot_study_txt() # Plot study.txt generated by test.py fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True) ax = ax.ravel() @@ -246,7 +246,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx ax2.grid() ax2.set_xlim(0, 30) - ax2.set_ylim(28, 50) + ax2.set_ylim(29, 51) ax2.set_yticks(np.arange(30, 55, 5)) ax2.set_xlabel('GPU Speed (ms/img)') ax2.set_ylabel('COCO AP val') diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 69a31213c543..75bcb7f9834f 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -225,8 +225,8 @@ def load_classifier(name='resnet101', n=2): return model -def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio - # scales img(bs,3,y,x) by ratio +def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) + # scales img(bs,3,y,x) by ratio constrained to gs-multiple if ratio == 1.0: return img else: @@ -234,7 +234,6 @@ def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio s = (int(h * ratio), int(w * ratio)) # new size img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize if not same_shape: # pad/crop img - gs = 32 # (pixels) grid size h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)] return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean