From dc2e5878ddc51fb542a1dddc6efbe2851cfd457d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 30 Oct 2021 13:38:51 +0200 Subject: [PATCH] Fix MixConv2d() remove shortcut + apply depthwise (#5410) --- models/common.py | 2 +- models/experimental.py | 21 +++++++++++---------- utils/torch_utils.py | 2 +- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/models/common.py b/models/common.py index d0fb0e8596ed..8b70a6fea595 100644 --- a/models/common.py +++ b/models/common.py @@ -113,7 +113,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) self.cv4 = Conv(2 * c_, c2, 1, 1) self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) - self.act = nn.LeakyReLU(0.1, inplace=True) + self.act = nn.SiLU() self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) def forward(self, x): diff --git a/models/experimental.py b/models/experimental.py index adb86c81fc06..2e92ccb36faf 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -2,7 +2,7 @@ """ Experimental modules """ - +import math import numpy as np import torch import torch.nn as nn @@ -48,26 +48,27 @@ def forward(self, x): class MixConv2d(nn.Module): # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595 - def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): + def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy super().__init__() - groups = len(k) + n = len(k) # number of convolutions if equal_ch: # equal c_ per group - i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices - c_ = [(i == g).sum() for g in range(groups)] # intermediate channels + i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices + c_ = [(i == g).sum() for g in range(n)] # intermediate channels else: # equal weight.numel() per group - b = [c2] + [0] * groups - a = np.eye(groups + 1, groups, k=-1) + b = [c2] + [0] * n + a = np.eye(n + 1, n, k=-1) a -= np.roll(a, 1, axis=1) a *= np.array(k) ** 2 a[0] = 1 c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b - self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) + self.m = nn.ModuleList( + [nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)]) self.bn = nn.BatchNorm2d(c2) - self.act = nn.LeakyReLU(0.1, inplace=True) + self.act = nn.SiLU() def forward(self, x): - return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) + return self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) class Ensemble(nn.ModuleList): diff --git a/utils/torch_utils.py b/utils/torch_utils.py index e6d8ebd743bf..fc214147da72 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -166,7 +166,7 @@ def initialize_weights(model): elif t is nn.BatchNorm2d: m.eps = 1e-3 m.momentum = 0.03 - elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]: + elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: m.inplace = True