Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MixConv2d() remove shortcut + apply depthwise #5410

Merged
merged 1 commit into from
Oct 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.act = nn.SiLU()
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

def forward(self, x):
Expand Down
21 changes: 11 additions & 10 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Experimental modules
"""

import math
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,26 +48,27 @@ def forward(self, x):

class MixConv2d(nn.Module):
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
super().__init__()
groups = len(k)
n = len(k) # number of convolutions
if equal_ch: # equal c_ per group
i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
c_ = [(i == g).sum() for g in range(n)] # intermediate channels
else: # equal weight.numel() per group
b = [c2] + [0] * groups
a = np.eye(groups + 1, groups, k=-1)
b = [c2] + [0] * n
a = np.eye(n + 1, n, k=-1)
a -= np.roll(a, 1, axis=1)
a *= np.array(k) ** 2
a[0] = 1
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b

self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
self.m = nn.ModuleList(
[nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.act = nn.SiLU()

def forward(self, x):
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))


class Ensemble(nn.ModuleList):
Expand Down
2 changes: 1 addition & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def initialize_weights(model):
elif t is nn.BatchNorm2d:
m.eps = 1e-3
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True


Expand Down