Skip to content

Commit

Permalink
YOLOv5 v4.0 Release (#1837)
Browse files Browse the repository at this point in the history
* Update C3 module

* Update C3 module

* Update C3 module

* Update C3 module

* update

* update

* update

* update

* update

* update

* update

* update

* update

* updates

* updates

* updates

* updates

* updates

* updates

* updates

* updates

* updates

* updates

* update

* update

* update

* update

* updates

* updates

* updates

* updates

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update datasets

* update

* update

* update

* update attempt_downlaod()

* merge

* merge

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* parameterize eps

* comments

* gs-multiple

* update

* max_nms implemented

* Create one_cycle() function

* update

* update

* update

* update

* update

* update

* update

* update study.png

* update study.png

* Update datasets.py
  • Loading branch information
glenn-jocher committed Jan 5, 2021
1 parent 0e341c5 commit 69be8e7
Show file tree
Hide file tree
Showing 23 changed files with 489 additions and 125 deletions.
35 changes: 20 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,40 @@

![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)

This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk.
This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices evolved over thousands of hours of training and evolution on anonymized client datasets. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk.

<img src="https://user-images.githubusercontent.com/26833433/90187293-6773ba00-dd6e-11ea-8f90-cd94afc0427f.png" width="1000">** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8.
<img src="https://user-images.githubusercontent.com/26833433/103594689-455e0e00-4eae-11eb-9cdf-7d753e2ceeeb.png" width="1000">** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8.

- **January 5, 2021**: [v4.0 release](https://github.com/ultralytics/yolov5/releases/tag/v4.0): nn.SiLU() activations, [Weights & Biases](https://wandb.ai/) logging, [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/) integration.
- **August 13, 2020**: [v3.0 release](https://github.com/ultralytics/yolov5/releases/tag/v3.0): nn.Hardswish() activations, data autodownload, native AMP.
- **July 23, 2020**: [v2.0 release](https://github.com/ultralytics/yolov5/releases/tag/v2.0): improved model definition, training and mAP.
- **June 22, 2020**: [PANet](https://arxiv.org/abs/1803.01534) updates: new heads, reduced parameters, improved speed and mAP [364fcfd](https://github.com/ultralytics/yolov5/commit/364fcfd7dba53f46edd4f04c037a039c0a287972).
- **June 19, 2020**: [FP16](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.half) as new default for smaller checkpoints and faster inference [d4c6674](https://github.com/ultralytics/yolov5/commit/d4c6674c98e19df4c40e33a777610a18d1961145).
- **June 9, 2020**: [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) updates: improved speed, size, and accuracy (credit to @WongKinYiu for CSP).
- **May 27, 2020**: Public release. YOLOv5 models are SOTA among all known YOLO implementations.


## Pretrained Checkpoints

| Model | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Speed<sub>GPU</sub> | FPS<sub>GPU</sub> || params | GFLOPS |
|---------- |------ |------ |------ | -------- | ------| ------ |------ | :------: |
| [YOLOv5s](https://github.com/ultralytics/yolov5/releases) | 37.0 | 37.0 | 56.2 | **2.4ms** | **416** || 7.5M | 17.5
| [YOLOv5m](https://github.com/ultralytics/yolov5/releases) | 44.3 | 44.3 | 63.2 | 3.4ms | 294 || 21.8M | 52.3
| [YOLOv5l](https://github.com/ultralytics/yolov5/releases) | 47.7 | 47.7 | 66.5 | 4.4ms | 227 || 47.8M | 117.2
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) | **49.2** | **49.2** | **67.7** | 6.9ms | 145 || 89.0M | 221.5
| | | | | | || |
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) + TTA|**50.8**| **50.8** | **68.9** | 25.5ms | 39 || 89.0M | 801.0
| Model | size | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Speed<sub>V100</sub> | FPS<sub>V100</sub> || params | GFLOPS |
|---------- |------ |------ |------ |------ | -------- | ------| ------ |------ | :------: |
| [YOLOv5s](https://github.com/ultralytics/yolov5/releases) |640 |36.8 |36.8 |55.6 |**2.2ms** |**455** ||7.3M |17.0
| [YOLOv5m](https://github.com/ultralytics/yolov5/releases) |640 |44.5 |44.5 |63.1 |2.9ms |345 ||21.4M |51.3
| [YOLOv5l](https://github.com/ultralytics/yolov5/releases) |640 |48.1 |48.1 |66.4 |3.8ms |264 ||47.0M |115.4
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) |640 |**50.1** |**50.1** |**68.7** |6.0ms |167 ||87.7M |218.8
| | | | | | | || |
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) + TTA |832 |**51.9** |**51.9** |**69.6** |24.9ms |40 ||87.7M |1005.3

<!---
| [YOLOv5l6](https://github.com/ultralytics/yolov5/releases) |640 |49.0 |49.0 |67.4 |4.1ms |244 ||77.2M |117.7
| [YOLOv5l6](https://github.com/ultralytics/yolov5/releases) |1280 |53.0 |53.0 |70.8 |12.3ms |81 ||77.2M |117.7
--->

** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results denote val2017 accuracy.
** All AP numbers are for single-model single-scale without ensemble or TTA. **Reproduce mAP** by `python test.py --data coco.yaml --img 640 --conf 0.001 --iou 0.65`
** Speed<sub>GPU</sub> averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) V100 instance, and includes image preprocessing, FP16 inference, postprocessing and NMS. NMS is 1-2ms/img. **Reproduce speed** by `python test.py --data coco.yaml --img 640 --conf 0.25 --iou 0.45`
** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation).
** Test Time Augmentation ([TTA](https://github.com/ultralytics/yolov5/issues/303)) runs at 3 image sizes. **Reproduce TTA** by `python test.py --data coco.yaml --img 832 --iou 0.65 --augment`


## Requirements

Python 3.8 or later with all [requirements.txt](https://github.com/ultralytics/yolov5/blob/master/requirements.txt) dependencies installed, including `torch>=1.7`. To install run:
Expand Down Expand Up @@ -106,21 +111,21 @@ import torch
from PIL import Image

# Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # for PIL/cv2/np inputs and NMS
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

# Images
img1 = Image.open('zidane.jpg')
img2 = Image.open('bus.jpg')
imgs = [img1, img2] # batched list of images

# Inference
prediction = model(imgs, size=640) # includes NMS
result = model(imgs)
```


## Training

Download [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) and run command below. Training times for YOLOv5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster). Use the largest `--batch-size` your GPU allows (batch sizes shown for 16 GB devices).
Run commands below to reproduce results on [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) dataset (dataset auto-downloads on first use). Training times for YOLOv5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster). Use the largest `--batch-size` your GPU allows (batch sizes shown for 16 GB devices).
```bash
$ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size 64
yolov5m 40
Expand Down
18 changes: 9 additions & 9 deletions data/coco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ test: ../coco/test-dev2017.txt # 20288 of 40670 images, submit to https://compe
nc: 80

# class names
names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush' ]

# Print classes
# with open('data/coco.yaml') as f:
Expand Down
18 changes: 9 additions & 9 deletions data/coco128.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ val: ../coco128/images/train2017/ # 128 images
nc: 80

# class names
names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush' ]
4 changes: 2 additions & 2 deletions data/voc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ val: ../VOC/images/val/ # 4952 images
nc: 20

# class names
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
names: [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ]
41 changes: 32 additions & 9 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

def forward(self, x):
return self.act(self.bn(self.conv(x)))
Expand Down Expand Up @@ -105,9 +105,39 @@ class Focus(nn.Module):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Focus, self).__init__()
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
# self.contract = Contract(gain=2)

def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
# return self.conv(self.contract(x))


class Contract(nn.Module):
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
def __init__(self, gain=2):
super().__init__()
self.gain = gain

def forward(self, x):
N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
s = self.gain
x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)


class Expand(nn.Module):
# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
def __init__(self, gain=2):
super().__init__()
self.gain = gain

def forward(self, x):
N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
s = self.gain
x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)


class Concat(nn.Module):
Expand Down Expand Up @@ -253,20 +283,13 @@ def tolist(self):
return x


class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
@staticmethod
def forward(x):
return x.view(x.size(0), -1)


class Classify(nn.Module):
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
super(Classify, self).__init__()
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
self.flat = Flatten()
self.flat = nn.Flatten()

def forward(self, x):
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
Expand Down
4 changes: 2 additions & 2 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def forward(self, x, augment=False):
for module in self:
y.append(module(x, augment)[0])
# y = torch.stack(y).max(0)[0] # max ensemble
# y = torch.cat(y, 1) # nms ensemble
y = torch.stack(y).mean(0) # mean ensemble
# y = torch.stack(y).mean(0) # mean ensemble
y = torch.cat(y, 1) # nms ensemble
return y, None # inference, train output


Expand Down
58 changes: 58 additions & 0 deletions models/hub/anchors.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Default YOLOv5 anchors for COCO data


# P5 -------------------------------------------------------------------------------------------------------------------
# P5-640:
anchors_p5_640:
- [ 10,13, 16,30, 33,23 ] # P3/8
- [ 30,61, 62,45, 59,119 ] # P4/16
- [ 116,90, 156,198, 373,326 ] # P5/32


# P6 -------------------------------------------------------------------------------------------------------------------
# P6-640: thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11, 21,19, 17,41, 43,32, 39,70, 86,64, 65,131, 134,130, 120,265, 282,180, 247,354, 512,387
anchors_p6_640:
- [ 9,11, 21,19, 17,41 ] # P3/8
- [ 43,32, 39,70, 86,64 ] # P4/16
- [ 65,131, 134,130, 120,265 ] # P5/32
- [ 282,180, 247,354, 512,387 ] # P6/64

# P6-1280: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27, 44,40, 38,94, 96,68, 86,152, 180,137, 140,301, 303,264, 238,542, 436,615, 739,380, 925,792
anchors_p6_1280:
- [ 19,27, 44,40, 38,94 ] # P3/8
- [ 96,68, 86,152, 180,137 ] # P4/16
- [ 140,301, 303,264, 238,542 ] # P5/32
- [ 436,615, 739,380, 925,792 ] # P6/64

# P6-1920: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41, 67,59, 57,141, 144,103, 129,227, 270,205, 209,452, 455,396, 358,812, 653,922, 1109,570, 1387,1187
anchors_p6_1920:
- [ 28,41, 67,59, 57,141 ] # P3/8
- [ 144,103, 129,227, 270,205 ] # P4/16
- [ 209,452, 455,396, 358,812 ] # P5/32
- [ 653,922, 1109,570, 1387,1187 ] # P6/64


# P7 -------------------------------------------------------------------------------------------------------------------
# P7-640: thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11, 13,30, 29,20, 30,46, 61,38, 39,92, 78,80, 146,66, 79,163, 149,150, 321,143, 157,303, 257,402, 359,290, 524,372
anchors_p7_640:
- [ 11,11, 13,30, 29,20 ] # P3/8
- [ 30,46, 61,38, 39,92 ] # P4/16
- [ 78,80, 146,66, 79,163 ] # P5/32
- [ 149,150, 321,143, 157,303 ] # P6/64
- [ 257,402, 359,290, 524,372 ] # P7/128

# P7-1280: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22, 54,36, 32,77, 70,83, 138,71, 75,173, 165,159, 148,334, 375,151, 334,317, 251,626, 499,474, 750,326, 534,814, 1079,818
anchors_p7_1280:
- [ 19,22, 54,36, 32,77 ] # P3/8
- [ 70,83, 138,71, 75,173 ] # P4/16
- [ 165,159, 148,334, 375,151 ] # P5/32
- [ 334,317, 251,626, 499,474 ] # P6/64
- [ 750,326, 534,814, 1079,818 ] # P7/128

# P7-1920: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34, 81,55, 47,115, 105,124, 207,107, 113,259, 247,238, 222,500, 563,227, 501,476, 376,939, 749,711, 1126,489, 801,1222, 1618,1227
anchors_p7_1920:
- [ 29,34, 81,55, 47,115 ] # P3/8
- [ 105,124, 207,107, 113,259 ] # P4/16
- [ 247,238, 222,500, 563,227 ] # P5/32
- [ 501,476, 376,939, 749,711 ] # P6/64
- [ 1126,489, 801,1222, 1618,1227 ] # P7/128
Loading

0 comments on commit 69be8e7

Please sign in to comment.