Skip to content

Commit

Permalink
Add different swish implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Oct 14, 2019
1 parent 4268864 commit 8a5da1d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
31 changes: 25 additions & 6 deletions efficientnet_pytorch/model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from torch.nn import functional as F

from .utils import (
relu_fn,
round_filters,
round_repeats,
drop_connect,
get_same_padding_conv2d,
get_model_params,
efficientnet_params,
load_pretrained_weights,
Swish,
MemoryEfficientSwish,
)

class MBConvBlock(nn.Module):
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, block_args, global_params):
final_oup = self._block_args.output_filters
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
self._swish = MemoryEfficientSwish()

def forward(self, inputs, drop_connect_rate=None):
"""
Expand All @@ -72,13 +74,13 @@ def forward(self, inputs, drop_connect_rate=None):
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = relu_fn(self._bn0(self._expand_conv(inputs)))
x = relu_fn(self._bn1(self._depthwise_conv(x)))
x = self._swish(self._bn0(self._expand_conv(inputs)))
x = self._swish(self._bn1(self._depthwise_conv(x)))

# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_expand(relu_fn(self._se_reduce(x_squeezed)))
x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
x = torch.sigmoid(x_squeezed) * x

x = self._bn2(self._project_conv(x))
Expand All @@ -91,6 +93,12 @@ def forward(self, inputs, drop_connect_rate=None):
x = x + inputs # skip connection
return x

def set_swish(self, memory_efficient=True):
if memory_efficient:
self._swish = MemoryEfficientSwish()
else:
self._swish = Swish()


class EfficientNet(nn.Module):
"""
Expand Down Expand Up @@ -153,12 +161,23 @@ def __init__(self, blocks_args=None, global_params=None):
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()

def set_swish(self, memory_efficient=True):
if memory_efficient:
self._swish = MemoryEfficientSwish()
else:
self._swish = Swish()

for block in self._blocks:
block.set_swish(memory_efficient)


def extract_features(self, inputs):
""" Returns output of the final convolution layer """

# Stem
x = relu_fn(self._bn0(self._conv_stem(inputs)))
x = self._swish(self._bn0(self._conv_stem(inputs)))

# Blocks
for idx, block in enumerate(self._blocks):
Expand All @@ -168,7 +187,7 @@ def extract_features(self, inputs):
x = block(x, drop_connect_rate=drop_connect_rate)

# Head
x = relu_fn(self._bn1(self._conv_head(x)))
x = self._swish(self._bn1(self._conv_head(x)))

return x

Expand Down
10 changes: 5 additions & 5 deletions efficientnet_pytorch/utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def backward(ctx, grad_output):
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class Swish(nn.Module):
@staticmethod
def forward(x):
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)


relu_fn = Swish()
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)


def round_filters(filters, global_params):
Expand Down

0 comments on commit 8a5da1d

Please sign in to comment.