diff --git a/models/common.py b/models/common.py index b13b9241f3da..1751ad09ff92 100644 --- a/models/common.py +++ b/models/common.py @@ -48,7 +48,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k # self.act = AconC() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) # self.act = MetaAconC() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) # self.act = SiLU_beta() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) - self.act = SiLU_beta(c2) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + self.act = MetaAconC(c2) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) def forward(self, x): return self.act(self.bn(self.conv(x))) diff --git a/utils/activations.py b/utils/activations.py index 4b233996bc91..5ef2bb37fd5d 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -94,13 +94,14 @@ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1)) self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1)) self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True) - self.bn1 = nn.BatchNorm2d(c2) + # self.bn1 = nn.BatchNorm2d(c2) self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True) - self.bn2 = nn.BatchNorm2d(c1) + # self.bn2 = nn.BatchNorm2d(c1) # batch-size 1 bug https://github.com/nmaac/acon/issues/4 def forward(self, x): y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True) - beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) + # beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) + beta = torch.sigmoid(self.fc2(self.fc1(y))) dpx = (self.p1 - self.p2) * x return dpx * torch.sigmoid(beta * dpx) + self.p2 * x