Skip to content

Commit

Permalink
ACON activation function (ultralytics#2893)
Browse files Browse the repository at this point in the history
* ACON Activation Function

## 🚀 Feature

There is a new activation function [ACON (CVPR 2021)](https://arxiv.org/pdf/2009.04759.pdf) that unifies ReLU and Swish. 
ACON is simple but very effective, code is here: https://github.com/nmaac/acon/blob/main/acon.py#L19

![image](https://user-images.githubusercontent.com/5032208/115676962-a38dfe80-a382-11eb-9883-61fa3216e3e6.png)

The improvements are very significant:
![image](https://user-images.githubusercontent.com/5032208/115680180-eac9be80-a385-11eb-9c7a-8643db552c69.png)

## Alternatives

It also has an enhanced version meta-ACON that uses a small network to learn beta explicitly, which may influence the speed a bit.

## Additional context

[Code](https://github.com/nmaac/acon) and [paper](https://arxiv.org/pdf/2009.04759.pdf).

* Update activations.py
  • Loading branch information
glenn-jocher committed Apr 22, 2021
1 parent 62618ee commit 1cdb58d
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions utils/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,6 @@ def forward(x):
return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX


class MemoryEfficientSwish(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x * torch.sigmoid(x)

@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
sx = torch.sigmoid(x)
return grad_output * (sx * (1 + x * (1 - sx)))

def forward(self, x):
return self.F.apply(x)


# Mish https://github.com/digantamisra98/Mish --------------------------------------------------------------------------
class Mish(nn.Module):
@staticmethod
Expand Down Expand Up @@ -70,3 +53,44 @@ def __init__(self, c1, k=3): # ch_in, kernel

def forward(self, x):
return torch.max(x, self.bn(self.conv(x)))


# ACON https://arxiv.org/pdf/2009.04759.pdf ----------------------------------------------------------------------------
class AconC(nn.Module):
r""" ACON activation (activate or not).
AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
"""

def __init__(self, c1):
super().__init__()
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))

def forward(self, x):
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x


class MetaAconC(nn.Module):
r""" ACON activation (activate or not).
MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
"""

def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
super().__init__()
c2 = max(r, c1 // r)
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=False)
self.bn1 = nn.BatchNorm2d(c2)
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=False)
self.bn2 = nn.BatchNorm2d(c1)

def forward(self, x):
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x

0 comments on commit 1cdb58d

Please sign in to comment.