Skip to content

Commit

Permalink
Update activations.py
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Apr 22, 2021
1 parent 4353c8c commit a9ea248
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions utils/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,39 @@ def forward(self, x):
# ACON https://arxiv.org/pdf/2009.04759.pdf ----------------------------------------------------------------------------
class AconC(nn.Module):
r""" ACON activation (activate or not).
# AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
# according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
"""

def __init__(self, width):
def __init__(self, c1):
super().__init__()
self.p1 = nn.Parameter(torch.randn(1, width, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, width, 1, 1))
self.beta = nn.Parameter(torch.ones(1, width, 1, 1))
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))

def forward(self, x):
return (self.p1 * x - self.p2 * x) * torch.sigmoid(self.beta * (self.p1 * x - self.p2 * x)) + self.p2 * x
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x


class MetaAconC(nn.Module):
r""" ACON activation (activate or not).
# MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
# according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
"""

def __init__(self, width, r=16):
def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
super().__init__()
self.p1 = nn.Parameter(torch.randn(1, width, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, width, 1, 1))
self.fc1 = nn.Conv2d(width, max(r, width // r), kernel_size=1, stride=1, bias=True)
self.bn1 = nn.BatchNorm2d(max(r, width // r))
self.fc2 = nn.Conv2d(max(r, width // r), width, kernel_size=1, stride=1, bias=True)
self.bn2 = nn.BatchNorm2d(width)
c2 = max(r, c1 // r)
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=False)
self.bn1 = nn.BatchNorm2d(c2)
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=False)
self.bn2 = nn.BatchNorm2d(c1)

def forward(self, x):
beta = torch.sigmoid(
self.bn2(self.fc2(self.bn1(self.fc1(x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True))))))
return (self.p1 * x - self.p2 * x) * torch.sigmoid(beta * (self.p1 * x - self.p2 * x)) + self.p2 * x
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x

0 comments on commit a9ea248

Please sign in to comment.