From 9c7bb5a52cc716166c2145ce1a878a0ad2cf93be Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 25 Apr 2021 22:54:07 +0200 Subject: [PATCH] ACON Activation batch-size 1 bug patch (#2901) * ACON Activation batch-size 1 bug path This is not a great solution to https://github.com/nmaac/acon/issues/4 but it's all I could think of at the moment. WARNING: YOLOv5 models with MetaAconC() activations are incapable of running inference at batch-size 1 properly due to a known bug in https://github.com/nmaac/acon/issues/4 with no known solution. * Update activations.py * Update activations.py * Update activations.py * Update activations.py --- utils/activations.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/utils/activations.py b/utils/activations.py index 1d095c1cf0f1..92a3b5eaa54b 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -84,13 +84,15 @@ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r c2 = max(r, c1 // r) self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1)) self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1)) - self.fc1 = nn.Conv2d(c1, c2, k, s, bias=False) - self.bn1 = nn.BatchNorm2d(c2) - self.fc2 = nn.Conv2d(c2, c1, k, s, bias=False) - self.bn2 = nn.BatchNorm2d(c1) + self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True) + self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True) + # self.bn1 = nn.BatchNorm2d(c2) + # self.bn2 = nn.BatchNorm2d(c1) def forward(self, x): y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True) - beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) + # batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891 + # beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable + beta = torch.sigmoid(self.fc2(self.fc1(y))) # bug patch BN layers removed dpx = (self.p1 - self.p2) * x return dpx * torch.sigmoid(beta * dpx) + self.p2 * x