From 9bd5b602c9d73ac28f73d011852e32be6a1dcbde Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 14 Aug 2021 12:55:22 +0200 Subject: [PATCH] Add `yolov5s-ghost.yaml` (#4412) * Add yolov5s-ghost.yaml * Finish C3Ghost * Add C3Ghost to list * Add C3Ghost to number of repeats if statement * Fixes * Cleanup --- models/common.py | 36 +++++++++++++++++++++++++++ models/experimental.py | 28 --------------------- models/hub/yolov5s-ghost.yaml | 46 +++++++++++++++++++++++++++++++++++ models/yolo.py | 4 +-- 4 files changed, 84 insertions(+), 30 deletions(-) create mode 100644 models/hub/yolov5s-ghost.yaml diff --git a/models/common.py b/models/common.py index 2d24672a6b44..5ef3996007a2 100644 --- a/models/common.py +++ b/models/common.py @@ -149,6 +149,14 @@ def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5): self.m = SPP(c_, c_, k) +class C3Ghost(C3): + # C3 module with GhostBottleneck() + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) # hidden channels + self.m = nn.Sequential(*[GhostBottleneck(c_, c_) for _ in range(n)]) + + class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13)): @@ -177,6 +185,34 @@ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) # return self.conv(self.contract(x)) +class GhostConv(nn.Module): + # Ghost Convolution https://github.com/huawei-noah/ghostnet + def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups + super().__init__() + c_ = c2 // 2 # hidden channels + self.cv1 = Conv(c1, c_, k, s, None, g, act) + self.cv2 = Conv(c_, c_, 5, 1, None, c_, act) + + def forward(self, x): + y = self.cv1(x) + return torch.cat([y, self.cv2(y)], 1) + + +class GhostBottleneck(nn.Module): + # Ghost Bottleneck https://github.com/huawei-noah/ghostnet + def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride + super().__init__() + c_ = c2 // 2 + self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw + DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw + GhostConv(c_, c2, 1, 1, act=False)) # pw-linear + self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), + Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() + + def forward(self, x): + return self.conv(x) + self.shortcut(x) + + class Contract(nn.Module): # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) def __init__(self, gain=2): diff --git a/models/experimental.py b/models/experimental.py index 581c7b14b61e..5c690cce3d99 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -43,34 +43,6 @@ def forward(self, x): return y -class GhostConv(nn.Module): - # Ghost Convolution https://github.com/huawei-noah/ghostnet - def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups - super().__init__() - c_ = c2 // 2 # hidden channels - self.cv1 = Conv(c1, c_, k, s, None, g, act) - self.cv2 = Conv(c_, c_, 5, 1, None, c_, act) - - def forward(self, x): - y = self.cv1(x) - return torch.cat([y, self.cv2(y)], 1) - - -class GhostBottleneck(nn.Module): - # Ghost Bottleneck https://github.com/huawei-noah/ghostnet - def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride - super().__init__() - c_ = c2 // 2 - self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw - DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw - GhostConv(c_, c2, 1, 1, act=False)) # pw-linear - self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), - Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() - - def forward(self, x): - return self.conv(x) + self.shortcut(x) - - class MixConv2d(nn.Module): # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595 def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): diff --git a/models/hub/yolov5s-ghost.yaml b/models/hub/yolov5s-ghost.yaml new file mode 100644 index 000000000000..d99d56d26e85 --- /dev/null +++ b/models/hub/yolov5s-ghost.yaml @@ -0,0 +1,46 @@ +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, GhostConv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3Ghost, [128]], + [-1, 1, GhostConv, [256, 3, 2]], # 3-P3/8 + [-1, 9, C3Ghost, [256]], + [-1, 1, GhostConv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3Ghost, [512]], + [-1, 1, GhostConv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, C3Ghost, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, GhostConv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3Ghost, [512, False]], # 13 + + [-1, 1, GhostConv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3Ghost, [256, False]], # 17 (P3/8-small) + + [-1, 1, GhostConv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3Ghost, [512, False]], # 20 (P4/16-medium) + + [-1, 1, GhostConv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3Ghost, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models/yolo.py b/models/yolo.py index 98e578d20384..88adb71f8fea 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -236,13 +236,13 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, - C3, C3TR, C3SPP]: + C3, C3TR, C3SPP, C3Ghost]: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) args = [c1, c2, *args[1:]] - if m in [BottleneckCSP, C3, C3TR]: + if m in [BottleneckCSP, C3, C3TR, C3Ghost]: args.insert(2, n) # number of repeats n = 1 elif m is nn.BatchNorm2d: