Skip to content

Commit

Permalink
Add yolov5s-ghost.yaml (ultralytics#4412)
Browse files Browse the repository at this point in the history
* Add yolov5s-ghost.yaml

* Finish C3Ghost

* Add C3Ghost to list

* Add C3Ghost to number of repeats if statement

* Fixes

* Cleanup
  • Loading branch information
glenn-jocher authored and CesarBazanAV committed Sep 29, 2021
1 parent 9aaf7db commit 9bd5b60
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 30 deletions.
36 changes: 36 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
self.m = SPP(c_, c_, k)


class C3Ghost(C3):
# C3 module with GhostBottleneck()
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e) # hidden channels
self.m = nn.Sequential(*[GhostBottleneck(c_, c_) for _ in range(n)])


class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13)):
Expand Down Expand Up @@ -177,6 +185,34 @@ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
# return self.conv(self.contract(x))


class GhostConv(nn.Module):
# Ghost Convolution https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
super().__init__()
c_ = c2 // 2 # hidden channels
self.cv1 = Conv(c1, c_, k, s, None, g, act)
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)

def forward(self, x):
y = self.cv1(x)
return torch.cat([y, self.cv2(y)], 1)


class GhostBottleneck(nn.Module):
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
super().__init__()
c_ = c2 // 2
self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()

def forward(self, x):
return self.conv(x) + self.shortcut(x)


class Contract(nn.Module):
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
def __init__(self, gain=2):
Expand Down
28 changes: 0 additions & 28 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,6 @@ def forward(self, x):
return y


class GhostConv(nn.Module):
# Ghost Convolution https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
super().__init__()
c_ = c2 // 2 # hidden channels
self.cv1 = Conv(c1, c_, k, s, None, g, act)
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)

def forward(self, x):
y = self.cv1(x)
return torch.cat([y, self.cv2(y)], 1)


class GhostBottleneck(nn.Module):
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
super().__init__()
c_ = c2 // 2
self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()

def forward(self, x):
return self.conv(x) + self.shortcut(x)


class MixConv2d(nn.Module):
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
Expand Down
46 changes: 46 additions & 0 deletions models/hub/yolov5s-ghost.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32

# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, GhostConv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3Ghost, [128]],
[-1, 1, GhostConv, [256, 3, 2]], # 3-P3/8
[-1, 9, C3Ghost, [256]],
[-1, 1, GhostConv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3Ghost, [512]],
[-1, 1, GhostConv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, C3Ghost, [1024, False]], # 9
]

# YOLOv5 head
head:
[[-1, 1, GhostConv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3Ghost, [512, False]], # 13

[-1, 1, GhostConv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3Ghost, [256, False]], # 17 (P3/8-small)

[-1, 1, GhostConv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3Ghost, [512, False]], # 20 (P4/16-medium)

[-1, 1, GhostConv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3Ghost, [1024, False]], # 23 (P5/32-large)

[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
4 changes: 2 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,13 @@ def parse_model(d, ch): # model_dict, input_channels(3)

n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
C3, C3TR, C3SPP]:
C3, C3TR, C3SPP, C3Ghost]:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3, C3TR]:
if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
args.insert(2, n) # number of repeats
n = 1
elif m is nn.BatchNorm2d:
Expand Down

0 comments on commit 9bd5b60

Please sign in to comment.