Skip to content

Commit

Permalink
MetaAconC(c2)_bias
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Apr 23, 2021
1 parent 69cb876 commit 766812c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
1 change: 1 addition & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k
# self.act = Mish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = AconC() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = MetaAconC() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = SiLU_beta() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
self.act = MetaAconC(c2) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

def forward(self, x):
Expand Down
4 changes: 2 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
m = self.model[-1] # Detect()
if isinstance(m, Detect):
s = 256 # 2x min stride
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(2, ch, s, s))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
check_anchor_order(m)
self.stride = m.stride
Expand Down Expand Up @@ -264,7 +264,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
# Create model
model = Model(opt.cfg).to(device)
model.train()

# Profile
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device)
# y = model(img, profile=True)
Expand Down
19 changes: 13 additions & 6 deletions utils/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ def forward(x):
return x * torch.sigmoid(x)


class SiLU_beta(nn.Module): # SiLU() with B term: y = x * sigmoid(B * x)
def __init__(self, c1):
super().__init__()
self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))

def forward(self, x):
return x * torch.sigmoid(self.beta * x)


class Hardswish(nn.Module): # export-friendly version of nn.Hardswish()
@staticmethod
def forward(x):
Expand Down Expand Up @@ -84,16 +93,14 @@ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
c2 = max(r, c1 // r)
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=False)
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)
self.bn1 = nn.BatchNorm2d(c2)
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=False)
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)
self.bn2 = nn.BatchNorm2d(c1)

# batch-size 1 bug https://github.com/nmaac/acon/issues/4
def forward(self, x):
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
if x.shape[0] > 1: # batch-size 1 bug https://github.com/nmaac/acon/issues/4
beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))
else:
beta = torch.sigmoid(self.fc2(self.fc1(y)))
beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x

0 comments on commit 766812c

Please sign in to comment.