Skip to content

Commit

Permalink
ACON Activation batch-size 1 bug path
Browse files Browse the repository at this point in the history
This is not a great solution to nmaac/acon#4 but it's all I could think of at the moment.

WARNING: YOLOv5 models with MetaAconC() activations are incapable of running inference at batch-size 1 properly due to a known bug in nmaac/acon#4 with no known solution.
  • Loading branch information
glenn-jocher committed Apr 25, 2021
1 parent c0d3f80 commit 804773c
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion utils/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r

def forward(self, x):
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))
if x.shape[0] > 1: # batch-size 1 bug https://github.com/nmaac/acon/issues/4
beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))
else:
beta = torch.sigmoid(self.fc2(self.fc1(y)))
dpx = (self.p1 - self.p2) * x
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x

0 comments on commit 804773c

Please sign in to comment.