From 5331384cb61e4e94a449e6cf1118b9a4347f3572 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 11 Sep 2022 21:24:25 +0300 Subject: [PATCH 1/7] New model.yaml `activation:` field Add optional model yaml activation field to define model-wide activations, i.e.: ```yaml activation: nn.LeakyReLU(0.1) # activation with arguments activation: nn.SiLU() # activation with no arguments ``` Signed-off-by: Glenn Jocher --- models/common.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/models/common.py b/models/common.py index 8b7dbbfa95fe..ecf441a8f465 100644 --- a/models/common.py +++ b/models/common.py @@ -39,11 +39,14 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation class Conv(nn.Module): # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation) + act = nn.SiLU() # default activation + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) self.bn = nn.BatchNorm2d(c2) - self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + if act is not True: + self.act = act if isinstance(act, nn.Module) else nn.Identity() def forward(self, x): return self.act(self.bn(self.conv(x))) @@ -54,8 +57,8 @@ def forward_fuse(self, x): class DWConv(Conv): # Depth-wise convolution - def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups - super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act) + def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation + super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act) class DWConvTranspose2d(nn.ConvTranspose2d): From d03dee22e12df94abb74a19d5022c7be84719fc2 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 11 Sep 2022 21:34:20 +0300 Subject: [PATCH 2/7] Update yolo.py Signed-off-by: Glenn Jocher --- models/yolo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/models/yolo.py b/models/yolo.py index fa05fcf9a8d9..ecc90fc651ec 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -271,8 +271,12 @@ def _from_yaml(self, cfg): def parse_model(d, ch): # model_dict, input_channels(3) + # Parse a YOLOv5 model.yaml dictionary LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}") - anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] + anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation') + if act: + Conv.act = eval(act) # redefine default activation, i.e. Conv.act = nn.SiLU() + LOGGER.info(f"{colorstr('activation:')} {act}") # print na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors no = na * (nc + 5) # number of outputs = anchors * (classes + 5) From a07173164d06f13d70a42f205c164eb9d7f4a4de Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 14 Sep 2022 00:50:58 +0200 Subject: [PATCH 3/7] Add example models --- models/yolov5l-LeakyReLU00.yaml | 49 +++++++++++++++++++++++++++++++++ models/yolov5l-LeakyReLU01.yaml | 49 +++++++++++++++++++++++++++++++++ models/yolov5l-Mish.yaml | 49 +++++++++++++++++++++++++++++++++ models/yolov5l-ReLU.yaml | 49 +++++++++++++++++++++++++++++++++ models/yolov5l-SiLU.yaml | 49 +++++++++++++++++++++++++++++++++ models/yolov5l-Sigmoid.yaml | 49 +++++++++++++++++++++++++++++++++ models/yolov5l-Tanh.yaml | 49 +++++++++++++++++++++++++++++++++ 7 files changed, 343 insertions(+) create mode 100644 models/yolov5l-LeakyReLU00.yaml create mode 100644 models/yolov5l-LeakyReLU01.yaml create mode 100644 models/yolov5l-Mish.yaml create mode 100644 models/yolov5l-ReLU.yaml create mode 100644 models/yolov5l-SiLU.yaml create mode 100644 models/yolov5l-Sigmoid.yaml create mode 100644 models/yolov5l-Tanh.yaml diff --git a/models/yolov5l-LeakyReLU00.yaml b/models/yolov5l-LeakyReLU00.yaml new file mode 100644 index 000000000000..bf6e0647ee37 --- /dev/null +++ b/models/yolov5l-LeakyReLU00.yaml @@ -0,0 +1,49 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +activation: nn.LeakyReLU(0.0) +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models/yolov5l-LeakyReLU01.yaml b/models/yolov5l-LeakyReLU01.yaml new file mode 100644 index 000000000000..3c09db04ee1e --- /dev/null +++ b/models/yolov5l-LeakyReLU01.yaml @@ -0,0 +1,49 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +activation: nn.LeakyReLU(0.1) +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models/yolov5l-Mish.yaml b/models/yolov5l-Mish.yaml new file mode 100644 index 000000000000..1d365712198c --- /dev/null +++ b/models/yolov5l-Mish.yaml @@ -0,0 +1,49 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +activation: nn.Mish() +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models/yolov5l-ReLU.yaml b/models/yolov5l-ReLU.yaml new file mode 100644 index 000000000000..00bab7f2b69f --- /dev/null +++ b/models/yolov5l-ReLU.yaml @@ -0,0 +1,49 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +activation: nn.ReLU() +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models/yolov5l-SiLU.yaml b/models/yolov5l-SiLU.yaml new file mode 100644 index 000000000000..66bb4d39a0c3 --- /dev/null +++ b/models/yolov5l-SiLU.yaml @@ -0,0 +1,49 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +activation: nn.SiLU() +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models/yolov5l-Sigmoid.yaml b/models/yolov5l-Sigmoid.yaml new file mode 100644 index 000000000000..babfd2a94556 --- /dev/null +++ b/models/yolov5l-Sigmoid.yaml @@ -0,0 +1,49 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +activation: nn.Sigmoid() +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models/yolov5l-Tanh.yaml b/models/yolov5l-Tanh.yaml new file mode 100644 index 000000000000..73dbc822c56c --- /dev/null +++ b/models/yolov5l-Tanh.yaml @@ -0,0 +1,49 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +activation: nn.Tanh() +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] From 58c1f5822e6f0e22a0294f7e9ae927178fc10012 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 14 Sep 2022 22:02:27 +0200 Subject: [PATCH 4/7] l to m models --- models/{yolov5l-LeakyReLU00.yaml => yolov5m-LeakyReLU00.yaml} | 4 ++-- models/{yolov5l-LeakyReLU01.yaml => yolov5m-LeakyReLU01.yaml} | 4 ++-- models/{yolov5l-Mish.yaml => yolov5m-Mish.yaml} | 4 ++-- models/{yolov5l-ReLU.yaml => yolov5m-ReLU.yaml} | 4 ++-- models/{yolov5l-SiLU.yaml => yolov5m-SiLU.yaml} | 4 ++-- models/{yolov5l-Sigmoid.yaml => yolov5m-Sigmoid.yaml} | 4 ++-- models/{yolov5l-Tanh.yaml => yolov5m-Tanh.yaml} | 4 ++-- 7 files changed, 14 insertions(+), 14 deletions(-) rename models/{yolov5l-LeakyReLU00.yaml => yolov5m-LeakyReLU00.yaml} (93%) rename models/{yolov5l-LeakyReLU01.yaml => yolov5m-LeakyReLU01.yaml} (93%) rename models/{yolov5l-Mish.yaml => yolov5m-Mish.yaml} (93%) rename models/{yolov5l-ReLU.yaml => yolov5m-ReLU.yaml} (93%) rename models/{yolov5l-SiLU.yaml => yolov5m-SiLU.yaml} (93%) rename models/{yolov5l-Sigmoid.yaml => yolov5m-Sigmoid.yaml} (93%) rename models/{yolov5l-Tanh.yaml => yolov5m-Tanh.yaml} (93%) diff --git a/models/yolov5l-LeakyReLU00.yaml b/models/yolov5m-LeakyReLU00.yaml similarity index 93% rename from models/yolov5l-LeakyReLU00.yaml rename to models/yolov5m-LeakyReLU00.yaml index bf6e0647ee37..05e1d201f54b 100644 --- a/models/yolov5l-LeakyReLU00.yaml +++ b/models/yolov5m-LeakyReLU00.yaml @@ -2,8 +2,8 @@ # Parameters nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 diff --git a/models/yolov5l-LeakyReLU01.yaml b/models/yolov5m-LeakyReLU01.yaml similarity index 93% rename from models/yolov5l-LeakyReLU01.yaml rename to models/yolov5m-LeakyReLU01.yaml index 3c09db04ee1e..4108f41d46c6 100644 --- a/models/yolov5l-LeakyReLU01.yaml +++ b/models/yolov5m-LeakyReLU01.yaml @@ -2,8 +2,8 @@ # Parameters nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 diff --git a/models/yolov5l-Mish.yaml b/models/yolov5m-Mish.yaml similarity index 93% rename from models/yolov5l-Mish.yaml rename to models/yolov5m-Mish.yaml index 1d365712198c..3208ef8c0e69 100644 --- a/models/yolov5l-Mish.yaml +++ b/models/yolov5m-Mish.yaml @@ -2,8 +2,8 @@ # Parameters nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 diff --git a/models/yolov5l-ReLU.yaml b/models/yolov5m-ReLU.yaml similarity index 93% rename from models/yolov5l-ReLU.yaml rename to models/yolov5m-ReLU.yaml index 00bab7f2b69f..12e11b18eb1f 100644 --- a/models/yolov5l-ReLU.yaml +++ b/models/yolov5m-ReLU.yaml @@ -2,8 +2,8 @@ # Parameters nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 diff --git a/models/yolov5l-SiLU.yaml b/models/yolov5m-SiLU.yaml similarity index 93% rename from models/yolov5l-SiLU.yaml rename to models/yolov5m-SiLU.yaml index 66bb4d39a0c3..770828cf962a 100644 --- a/models/yolov5l-SiLU.yaml +++ b/models/yolov5m-SiLU.yaml @@ -2,8 +2,8 @@ # Parameters nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 diff --git a/models/yolov5l-Sigmoid.yaml b/models/yolov5m-Sigmoid.yaml similarity index 93% rename from models/yolov5l-Sigmoid.yaml rename to models/yolov5m-Sigmoid.yaml index babfd2a94556..ab9effb0050d 100644 --- a/models/yolov5l-Sigmoid.yaml +++ b/models/yolov5m-Sigmoid.yaml @@ -2,8 +2,8 @@ # Parameters nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 diff --git a/models/yolov5l-Tanh.yaml b/models/yolov5m-Tanh.yaml similarity index 93% rename from models/yolov5l-Tanh.yaml rename to models/yolov5m-Tanh.yaml index 73dbc822c56c..743465bdbdfb 100644 --- a/models/yolov5l-Tanh.yaml +++ b/models/yolov5m-Tanh.yaml @@ -2,8 +2,8 @@ # Parameters nc: 80 # number of classes -depth_multiple: 1.0 # model depth multiple -width_multiple: 1.0 # layer channel multiple +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 From 14a8d1bf39c88a1b84b409820d7bf5897e92dede Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 15 Sep 2022 18:29:09 +0200 Subject: [PATCH 5/7] update --- models/common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/common.py b/models/common.py index ecf441a8f465..6130dbc2349a 100644 --- a/models/common.py +++ b/models/common.py @@ -45,8 +45,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) self.bn = nn.BatchNorm2d(c2) - if act is not True: - self.act = act if isinstance(act, nn.Module) else nn.Identity() + self.act = self.act if act is True else act if isinstance(act, nn.Module) else nn.Identity() def forward(self, x): return self.act(self.bn(self.conv(x))) From 56193f7325d04799ff6d1bd3a0ae1ff62a6c0ba8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 16 Sep 2022 00:43:52 +0200 Subject: [PATCH 6/7] Add yolov5s-LeakyReLU.yaml --- .../yolov5s-LeakyReLU.yaml} | 6 +-- models/yolov5m-LeakyReLU00.yaml | 49 ------------------- models/yolov5m-Mish.yaml | 49 ------------------- models/yolov5m-ReLU.yaml | 49 ------------------- models/yolov5m-SiLU.yaml | 49 ------------------- models/yolov5m-Sigmoid.yaml | 49 ------------------- models/yolov5m-Tanh.yaml | 49 ------------------- 7 files changed, 3 insertions(+), 297 deletions(-) rename models/{yolov5m-LeakyReLU01.yaml => hub/yolov5s-LeakyReLU.yaml} (93%) delete mode 100644 models/yolov5m-LeakyReLU00.yaml delete mode 100644 models/yolov5m-Mish.yaml delete mode 100644 models/yolov5m-ReLU.yaml delete mode 100644 models/yolov5m-SiLU.yaml delete mode 100644 models/yolov5m-Sigmoid.yaml delete mode 100644 models/yolov5m-Tanh.yaml diff --git a/models/yolov5m-LeakyReLU01.yaml b/models/hub/yolov5s-LeakyReLU.yaml similarity index 93% rename from models/yolov5m-LeakyReLU01.yaml rename to models/hub/yolov5s-LeakyReLU.yaml index 4108f41d46c6..2da1ea7c5463 100644 --- a/models/yolov5m-LeakyReLU01.yaml +++ b/models/hub/yolov5s-LeakyReLU.yaml @@ -2,15 +2,15 @@ # Parameters nc: 80 # number of classes -depth_multiple: 0.67 # model depth multiple -width_multiple: 0.75 # layer channel multiple +activation: nn.LeakyReLU(0.1) +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 - [116,90, 156,198, 373,326] # P5/32 # YOLOv5 v6.0 backbone -activation: nn.LeakyReLU(0.1) backbone: # [from, number, module, args] [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 diff --git a/models/yolov5m-LeakyReLU00.yaml b/models/yolov5m-LeakyReLU00.yaml deleted file mode 100644 index 05e1d201f54b..000000000000 --- a/models/yolov5m-LeakyReLU00.yaml +++ /dev/null @@ -1,49 +0,0 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license - -# Parameters -nc: 80 # number of classes -depth_multiple: 0.67 # model depth multiple -width_multiple: 0.75 # layer channel multiple -anchors: - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -activation: nn.LeakyReLU(0.0) -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) - - [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) - ] diff --git a/models/yolov5m-Mish.yaml b/models/yolov5m-Mish.yaml deleted file mode 100644 index 3208ef8c0e69..000000000000 --- a/models/yolov5m-Mish.yaml +++ /dev/null @@ -1,49 +0,0 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license - -# Parameters -nc: 80 # number of classes -depth_multiple: 0.67 # model depth multiple -width_multiple: 0.75 # layer channel multiple -anchors: - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -activation: nn.Mish() -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) - - [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) - ] diff --git a/models/yolov5m-ReLU.yaml b/models/yolov5m-ReLU.yaml deleted file mode 100644 index 12e11b18eb1f..000000000000 --- a/models/yolov5m-ReLU.yaml +++ /dev/null @@ -1,49 +0,0 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license - -# Parameters -nc: 80 # number of classes -depth_multiple: 0.67 # model depth multiple -width_multiple: 0.75 # layer channel multiple -anchors: - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -activation: nn.ReLU() -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) - - [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) - ] diff --git a/models/yolov5m-SiLU.yaml b/models/yolov5m-SiLU.yaml deleted file mode 100644 index 770828cf962a..000000000000 --- a/models/yolov5m-SiLU.yaml +++ /dev/null @@ -1,49 +0,0 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license - -# Parameters -nc: 80 # number of classes -depth_multiple: 0.67 # model depth multiple -width_multiple: 0.75 # layer channel multiple -anchors: - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -activation: nn.SiLU() -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) - - [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) - ] diff --git a/models/yolov5m-Sigmoid.yaml b/models/yolov5m-Sigmoid.yaml deleted file mode 100644 index ab9effb0050d..000000000000 --- a/models/yolov5m-Sigmoid.yaml +++ /dev/null @@ -1,49 +0,0 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license - -# Parameters -nc: 80 # number of classes -depth_multiple: 0.67 # model depth multiple -width_multiple: 0.75 # layer channel multiple -anchors: - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -activation: nn.Sigmoid() -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) - - [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) - ] diff --git a/models/yolov5m-Tanh.yaml b/models/yolov5m-Tanh.yaml deleted file mode 100644 index 743465bdbdfb..000000000000 --- a/models/yolov5m-Tanh.yaml +++ /dev/null @@ -1,49 +0,0 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license - -# Parameters -nc: 80 # number of classes -depth_multiple: 0.67 # model depth multiple -width_multiple: 0.75 # layer channel multiple -anchors: - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -activation: nn.Tanh() -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 1, SPPF, [1024, 5]], # 9 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 13 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [256, False]], # 17 (P3/8-small) - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 14], 1, Concat, [1]], # cat head P4 - [-1, 3, C3, [512, False]], # 20 (P4/16-medium) - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 10], 1, Concat, [1]], # cat head P5 - [-1, 3, C3, [1024, False]], # 23 (P5/32-large) - - [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) - ] From 316750e5d322afab0bb8718cf7540af99ff4d440 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 16 Sep 2022 00:47:28 +0200 Subject: [PATCH 7/7] Update yolov5s-LeakyReLU.yaml Signed-off-by: Glenn Jocher --- models/hub/yolov5s-LeakyReLU.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/hub/yolov5s-LeakyReLU.yaml b/models/hub/yolov5s-LeakyReLU.yaml index 2da1ea7c5463..3a179bf3311c 100644 --- a/models/hub/yolov5s-LeakyReLU.yaml +++ b/models/hub/yolov5s-LeakyReLU.yaml @@ -2,7 +2,7 @@ # Parameters nc: 80 # number of classes -activation: nn.LeakyReLU(0.1) +activation: nn.LeakyReLU(0.1) # <----- Conv() activation used throughout entire YOLOv5 model depth_multiple: 0.33 # model depth multiple width_multiple: 0.50 # layer channel multiple anchors: