diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py old mode 100644 new mode 100755 index 42f8bc3..706ebb9 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -3,7 +3,6 @@ from torch.nn import functional as F from .utils import ( - relu_fn, round_filters, round_repeats, drop_connect, @@ -11,6 +10,8 @@ get_model_params, efficientnet_params, load_pretrained_weights, + Swish, + MemoryEfficientSwish, ) class MBConvBlock(nn.Module): @@ -61,6 +62,7 @@ def __init__(self, block_args, global_params): final_oup = self._block_args.output_filters self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() def forward(self, inputs, drop_connect_rate=None): """ @@ -72,13 +74,13 @@ def forward(self, inputs, drop_connect_rate=None): # Expansion and Depthwise Convolution x = inputs if self._block_args.expand_ratio != 1: - x = relu_fn(self._bn0(self._expand_conv(inputs))) - x = relu_fn(self._bn1(self._depthwise_conv(x))) + x = self._swish(self._bn0(self._expand_conv(inputs))) + x = self._swish(self._bn1(self._depthwise_conv(x))) # Squeeze and Excitation if self.has_se: x_squeezed = F.adaptive_avg_pool2d(x, 1) - x_squeezed = self._se_expand(relu_fn(self._se_reduce(x_squeezed))) + x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) x = torch.sigmoid(x_squeezed) * x x = self._bn2(self._project_conv(x)) @@ -91,6 +93,12 @@ def forward(self, inputs, drop_connect_rate=None): x = x + inputs # skip connection return x + def set_swish(self, memory_efficient=True): + if memory_efficient: + self._swish = MemoryEfficientSwish() + else: + self._swish = Swish() + class EfficientNet(nn.Module): """ @@ -153,12 +161,23 @@ def __init__(self, blocks_args=None, global_params=None): self._avg_pooling = nn.AdaptiveAvgPool2d(1) self._dropout = nn.Dropout(self._global_params.dropout_rate) self._fc = nn.Linear(out_channels, self._global_params.num_classes) + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + if memory_efficient: + self._swish = MemoryEfficientSwish() + else: + self._swish = Swish() + + for block in self._blocks: + block.set_swish(memory_efficient) + def extract_features(self, inputs): """ Returns output of the final convolution layer """ # Stem - x = relu_fn(self._bn0(self._conv_stem(inputs))) + x = self._swish(self._bn0(self._conv_stem(inputs))) # Blocks for idx, block in enumerate(self._blocks): @@ -168,7 +187,7 @@ def extract_features(self, inputs): x = block(x, drop_connect_rate=drop_connect_rate) # Head - x = relu_fn(self._bn1(self._conv_head(x))) + x = self._swish(self._bn1(self._conv_head(x))) return x diff --git a/efficientnet_pytorch/utils.py b/efficientnet_pytorch/utils.py old mode 100644 new mode 100755 index 02abc7f..198b3b4 --- a/efficientnet_pytorch/utils.py +++ b/efficientnet_pytorch/utils.py @@ -47,13 +47,13 @@ def backward(ctx, grad_output): return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) -class Swish(nn.Module): - @staticmethod - def forward(x): +class MemoryEfficientSwish(nn.Module): + def forward(self, x): return SwishImplementation.apply(x) - -relu_fn = Swish() +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) def round_filters(filters, global_params):