From 91580c2a03d70494bb9a0f23cdad8aa8175b52a4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 15 Aug 2021 14:12:08 +0200 Subject: [PATCH 1/3] Add `SPPF()` layer --- models/common.py | 20 +++++++++++++++++++- models/yolo.py | 10 ++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/models/common.py b/models/common.py index fe4319b0f370..99a966dc303a 100644 --- a/models/common.py +++ b/models/common.py @@ -161,7 +161,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): class SPP(nn.Module): - # Spatial pyramid pooling layer used in YOLOv3-SPP + # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729 def __init__(self, c1, c2, k=(5, 9, 13)): super().__init__() c_ = c1 // 2 # hidden channels @@ -176,6 +176,24 @@ def forward(self, x): return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) +class SPPF(nn.Module): + # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 + def __init__(self, c1, c2, k=5, n=3): # equivalent to SPP(k=(5, 9, 13)) + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (n + 1), c2, 1, 1) + self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + + def forward(self, x): + x = self.cv1(x) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning + y1 = self.m(x) + y2 = self.m(y1) + return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1)) + + class Focus(nn.Module): # Focus wh information into c-space def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups diff --git a/models/yolo.py b/models/yolo.py index f3c1516f49f7..dee6032d069d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -237,8 +237,8 @@ def parse_model(d, ch): # model_dict, input_channels(3) pass n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain - if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, - C3, C3TR, C3SPP, C3Ghost]: + if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, + BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) @@ -279,6 +279,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) parser = argparse.ArgumentParser() parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--profile', action='store_true', help='profile model speed') opt = parser.parse_args() opt.cfg = check_file(opt.cfg) # check file set_logging() @@ -289,8 +290,9 @@ def parse_model(d, ch): # model_dict, input_channels(3) model.train() # Profile - # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device) - # y = model(img, profile=True) + if opt.profile: + img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device) + y = model(img, profile=True) # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898) # from torch.utils.tensorboard import SummaryWriter From 6494bb628085e9e61f2c8f5bfb0a74a553b9c4da Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 15 Aug 2021 14:15:41 +0200 Subject: [PATCH 2/3] Cleanup --- models/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/common.py b/models/common.py index 99a966dc303a..ea88390bbc6a 100644 --- a/models/common.py +++ b/models/common.py @@ -178,11 +178,11 @@ def forward(self, x): class SPPF(nn.Module): # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 - def __init__(self, c1, c2, k=5, n=3): # equivalent to SPP(k=(5, 9, 13)) + def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = Conv(c_ * (n + 1), c2, 1, 1) + self.cv2 = Conv(c_ * 4, c2, 1, 1) self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) def forward(self, x): From 8c89f4f332fde08eb72eff282b997c3df11af349 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 15 Aug 2021 14:23:41 +0200 Subject: [PATCH 3/3] Add credit --- models/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index ea88390bbc6a..e1f5aea3abed 100644 --- a/models/common.py +++ b/models/common.py @@ -177,7 +177,7 @@ def forward(self, x): class SPPF(nn.Module): - # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 + # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) super().__init__() c_ = c1 // 2 # hidden channels