From 4353c8c5269bf32ae0fb107c1ddda4fac536bc97 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 22 Apr 2021 12:58:25 +0200 Subject: [PATCH 1/2] ACON Activation Function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## 🚀 Feature There is a new activation function [ACON (CVPR 2021)](https://arxiv.org/pdf/2009.04759.pdf) that unifies ReLU and Swish. ACON is simple but very effective, code is here: https://github.com/nmaac/acon/blob/main/acon.py#L19 ![image](https://user-images.githubusercontent.com/5032208/115676962-a38dfe80-a382-11eb-9883-61fa3216e3e6.png) The improvements are very significant: ![image](https://user-images.githubusercontent.com/5032208/115680180-eac9be80-a385-11eb-9c7a-8643db552c69.png) ## Alternatives It also has an enhanced version meta-ACON that uses a small network to learn beta explicitly, which may influence the speed a bit. ## Additional context [Code](https://github.com/nmaac/acon) and [paper](https://arxiv.org/pdf/2009.04759.pdf). --- utils/activations.py | 55 ++++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/utils/activations.py b/utils/activations.py index aa3ddf071d28..05f69945996b 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -19,23 +19,6 @@ def forward(x): return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX -class MemoryEfficientSwish(nn.Module): - class F(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - ctx.save_for_backward(x) - return x * torch.sigmoid(x) - - @staticmethod - def backward(ctx, grad_output): - x = ctx.saved_tensors[0] - sx = torch.sigmoid(x) - return grad_output * (sx * (1 + x * (1 - sx))) - - def forward(self, x): - return self.F.apply(x) - - # Mish https://github.com/digantamisra98/Mish -------------------------------------------------------------------------- class Mish(nn.Module): @staticmethod @@ -70,3 +53,41 @@ def __init__(self, c1, k=3): # ch_in, kernel def forward(self, x): return torch.max(x, self.bn(self.conv(x))) + + +# ACON https://arxiv.org/pdf/2009.04759.pdf ---------------------------------------------------------------------------- +class AconC(nn.Module): + r""" ACON activation (activate or not). + # AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter + # according to "Activate or Not: Learning Customized Activation" . + """ + + def __init__(self, width): + super().__init__() + self.p1 = nn.Parameter(torch.randn(1, width, 1, 1)) + self.p2 = nn.Parameter(torch.randn(1, width, 1, 1)) + self.beta = nn.Parameter(torch.ones(1, width, 1, 1)) + + def forward(self, x): + return (self.p1 * x - self.p2 * x) * torch.sigmoid(self.beta * (self.p1 * x - self.p2 * x)) + self.p2 * x + + +class MetaAconC(nn.Module): + r""" ACON activation (activate or not). + # MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network + # according to "Activate or Not: Learning Customized Activation" . + """ + + def __init__(self, width, r=16): + super().__init__() + self.p1 = nn.Parameter(torch.randn(1, width, 1, 1)) + self.p2 = nn.Parameter(torch.randn(1, width, 1, 1)) + self.fc1 = nn.Conv2d(width, max(r, width // r), kernel_size=1, stride=1, bias=True) + self.bn1 = nn.BatchNorm2d(max(r, width // r)) + self.fc2 = nn.Conv2d(max(r, width // r), width, kernel_size=1, stride=1, bias=True) + self.bn2 = nn.BatchNorm2d(width) + + def forward(self, x): + beta = torch.sigmoid( + self.bn2(self.fc2(self.bn1(self.fc1(x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)))))) + return (self.p1 * x - self.p2 * x) * torch.sigmoid(beta * (self.p1 * x - self.p2 * x)) + self.p2 * x From a9ea24898eb46f36dd170b7720f09607e5eff3c0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 22 Apr 2021 17:08:19 +0200 Subject: [PATCH 2/2] Update activations.py --- utils/activations.py | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/utils/activations.py b/utils/activations.py index 05f69945996b..1d095c1cf0f1 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -58,36 +58,39 @@ def forward(self, x): # ACON https://arxiv.org/pdf/2009.04759.pdf ---------------------------------------------------------------------------- class AconC(nn.Module): r""" ACON activation (activate or not). - # AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter - # according to "Activate or Not: Learning Customized Activation" . + AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter + according to "Activate or Not: Learning Customized Activation" . """ - def __init__(self, width): + def __init__(self, c1): super().__init__() - self.p1 = nn.Parameter(torch.randn(1, width, 1, 1)) - self.p2 = nn.Parameter(torch.randn(1, width, 1, 1)) - self.beta = nn.Parameter(torch.ones(1, width, 1, 1)) + self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1)) + self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1)) + self.beta = nn.Parameter(torch.ones(1, c1, 1, 1)) def forward(self, x): - return (self.p1 * x - self.p2 * x) * torch.sigmoid(self.beta * (self.p1 * x - self.p2 * x)) + self.p2 * x + dpx = (self.p1 - self.p2) * x + return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x class MetaAconC(nn.Module): r""" ACON activation (activate or not). - # MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network - # according to "Activate or Not: Learning Customized Activation" . + MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network + according to "Activate or Not: Learning Customized Activation" . """ - def __init__(self, width, r=16): + def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r super().__init__() - self.p1 = nn.Parameter(torch.randn(1, width, 1, 1)) - self.p2 = nn.Parameter(torch.randn(1, width, 1, 1)) - self.fc1 = nn.Conv2d(width, max(r, width // r), kernel_size=1, stride=1, bias=True) - self.bn1 = nn.BatchNorm2d(max(r, width // r)) - self.fc2 = nn.Conv2d(max(r, width // r), width, kernel_size=1, stride=1, bias=True) - self.bn2 = nn.BatchNorm2d(width) + c2 = max(r, c1 // r) + self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1)) + self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1)) + self.fc1 = nn.Conv2d(c1, c2, k, s, bias=False) + self.bn1 = nn.BatchNorm2d(c2) + self.fc2 = nn.Conv2d(c2, c1, k, s, bias=False) + self.bn2 = nn.BatchNorm2d(c1) def forward(self, x): - beta = torch.sigmoid( - self.bn2(self.fc2(self.bn1(self.fc1(x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)))))) - return (self.p1 * x - self.p2 * x) * torch.sigmoid(beta * (self.p1 * x - self.p2 * x)) + self.p2 * x + y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True) + beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) + dpx = (self.p1 - self.p2) * x + return dpx * torch.sigmoid(beta * dpx) + self.p2 * x