From b1be6850050959a09c3e26813646c52b2a73b1a0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 19 Jul 2021 12:41:15 +0200 Subject: [PATCH] Module `super().__init__()` (#4065) * Module `super().__init__()` * remove NMS --- models/common.py | 42 ++++++++++++++++++------------------------ models/experimental.py | 12 ++++++------ models/yolo.py | 20 +++----------------- 3 files changed, 27 insertions(+), 47 deletions(-) diff --git a/models/common.py b/models/common.py index 4db90b54663e..901648b693a3 100644 --- a/models/common.py +++ b/models/common.py @@ -36,7 +36,7 @@ def DWConv(c1, c2, k=1, s=1, act=True): class Conv(nn.Module): # Standard convolution def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups - super(Conv, self).__init__() + super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) @@ -87,7 +87,7 @@ def forward(self, x): class Bottleneck(nn.Module): # Standard bottleneck def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion - super(Bottleneck, self).__init__() + super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_, c2, 3, 1, g=g) @@ -100,7 +100,7 @@ def forward(self, x): class BottleneckCSP(nn.Module): # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion - super(BottleneckCSP, self).__init__() + super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) @@ -119,7 +119,7 @@ def forward(self, x): class C3(nn.Module): # CSP Bottleneck with 3 convolutions def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion - super(C3, self).__init__() + super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1) @@ -139,10 +139,18 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): self.m = TransformerBlock(c_, c_, 4, n) +class C3SPP(C3): + # C3 module with SPP() + def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = SPP(c_, c_, k) + + class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13)): - super(SPP, self).__init__() + super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) @@ -156,7 +164,7 @@ def forward(self, x): class Focus(nn.Module): # Focus wh information into c-space def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups - super(Focus, self).__init__() + super().__init__() self.conv = Conv(c1 * 4, c2, k, s, p, g, act) # self.contract = Contract(gain=2) @@ -196,27 +204,13 @@ def forward(self, x): class Concat(nn.Module): # Concatenate a list of tensors along dimension def __init__(self, dimension=1): - super(Concat, self).__init__() + super().__init__() self.d = dimension def forward(self, x): return torch.cat(x, self.d) -class NMS(nn.Module): - # Non-Maximum Suppression (NMS) module - conf = 0.25 # confidence threshold - iou = 0.45 # IoU threshold - classes = None # (optional list) filter by class - max_det = 1000 # maximum number of detections per image - - def __init__(self): - super(NMS, self).__init__() - - def forward(self, x): - return non_max_suppression(x[0], self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) - - class AutoShape(nn.Module): # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS conf = 0.25 # NMS confidence threshold @@ -225,7 +219,7 @@ class AutoShape(nn.Module): max_det = 1000 # maximum number of detections per image def __init__(self, model): - super(AutoShape, self).__init__() + super().__init__() self.model = model.eval() def autoshape(self): @@ -292,7 +286,7 @@ def forward(self, imgs, size=640, augment=False, profile=False): class Detections: # YOLOv5 detections class for inference results def __init__(self, imgs, pred, files, times=None, names=None, shape=None): - super(Detections, self).__init__() + super().__init__() d = pred[0].device # device gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations self.imgs = imgs # list of images as numpy arrays @@ -383,7 +377,7 @@ def __len__(self): class Classify(nn.Module): # Classification head, i.e. x(b,c1,20,20) to x(b,c2) def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups - super(Classify, self).__init__() + super().__init__() self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1) self.flat = nn.Flatten() diff --git a/models/experimental.py b/models/experimental.py index 30dc36192bc0..0d996d913b0c 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -12,7 +12,7 @@ class CrossConv(nn.Module): # Cross Convolution Downsample def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): # ch_in, ch_out, kernel, stride, groups, expansion, shortcut - super(CrossConv, self).__init__() + super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, (1, k), (1, s)) self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) @@ -25,7 +25,7 @@ def forward(self, x): class Sum(nn.Module): # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, n, weight=False): # n: number of inputs - super(Sum, self).__init__() + super().__init__() self.weight = weight # apply weights boolean self.iter = range(n - 1) # iter object if weight: @@ -46,7 +46,7 @@ def forward(self, x): class GhostConv(nn.Module): # Ghost Convolution https://github.com/huawei-noah/ghostnet def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups - super(GhostConv, self).__init__() + super().__init__() c_ = c2 // 2 # hidden channels self.cv1 = Conv(c1, c_, k, s, None, g, act) self.cv2 = Conv(c_, c_, 5, 1, None, c_, act) @@ -59,7 +59,7 @@ def forward(self, x): class GhostBottleneck(nn.Module): # Ghost Bottleneck https://github.com/huawei-noah/ghostnet def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride - super(GhostBottleneck, self).__init__() + super().__init__() c_ = c2 // 2 self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw @@ -74,7 +74,7 @@ def forward(self, x): class MixConv2d(nn.Module): # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): - super(MixConv2d, self).__init__() + super().__init__() groups = len(k) if equal_ch: # equal c_ per group i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices @@ -98,7 +98,7 @@ def forward(self, x): class Ensemble(nn.ModuleList): # Ensemble of models def __init__(self): - super(Ensemble, self).__init__() + super().__init__() def forward(self, x, augment=False, profile=False, visualize=False): y = [] diff --git a/models/yolo.py b/models/yolo.py index 3a3af9b5fbde..2e7a20f813e2 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -33,7 +33,7 @@ class Detect(nn.Module): onnx_dynamic = False # ONNX export parameter def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer - super(Detect, self).__init__() + super().__init__() self.nc = nc # number of classes self.no = nc + 5 # number of outputs per anchor self.nl = len(anchors) # number of detection layers @@ -77,7 +77,7 @@ def _make_grid(nx=20, ny=20): class Model(nn.Module): def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes - super(Model, self).__init__() + super().__init__() if isinstance(cfg, dict): self.yaml = cfg # model dict else: # is *.yaml @@ -209,20 +209,6 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers self.info() return self - def nms(self, mode=True): # add or remove NMS module - present = type(self.model[-1]) is NMS # last layer is NMS - if mode and not present: - LOGGER.info('Adding NMS... ') - m = NMS() # module - m.f = -1 # from - m.i = self.model[-1].i + 1 # index - self.model.add_module(name='%s' % m.i, module=m) # add - self.eval() - elif not mode and present: - LOGGER.info('Removing NMS... ') - self.model = self.model[:-1] # remove - return self - def autoshape(self): # add AutoShape module LOGGER.info('Adding AutoShape... ') m = AutoShape(self) # wrap model @@ -250,7 +236,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = max(round(n * gd), 1) if n > 1 else n # depth gain if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, - C3, C3TR]: + C3, C3TR, C3SPP]: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8)