Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update relu_fn for better memory consumption #86

Merged
merged 2 commits into from
Oct 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# EfficientNet PyTorch

### Update (October 08, 2019)

This update changes activation function implementation to more memory-efficient. For more details please refer to: https://github.com/lukemelas/EfficientNet-PyTorch/issues/18

### Update (July 31, 2019)

_Upgrade the pip package with_ `pip install --upgrade efficientnet-pytorch`
Expand Down
37 changes: 28 additions & 9 deletions efficientnet_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch.nn import functional as F
from torch.utils import model_zoo


########################################################################
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
########################################################################
Expand All @@ -24,21 +23,37 @@
'num_classes', 'width_coefficient', 'depth_coefficient',
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])


# Parameters for an individual model block
BlockArgs = collections.namedtuple('BlockArgs', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'stride', 'se_ratio'])


# Change namedtuple defaults
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)


def relu_fn(x):
""" Swish activation function """
return x * torch.sigmoid(x)
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result

@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class Swish(nn.Module):
@staticmethod
def forward(x):
return SwishImplementation.apply(x)


relu_fn = Swish()


def round_filters(filters, global_params):
Expand Down Expand Up @@ -84,11 +99,13 @@ def get_same_padding_conv2d(image_size=None):
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)


class Conv2dDynamicSamePadding(nn.Conv2d):
""" 2D Convolutions like TensorFlow, for a dynamic image size """

def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]]*2
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2

def forward(self, x):
ih, iw = x.size()[-2:]
Expand All @@ -98,12 +115,13 @@ def forward(self, x):
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


class Conv2dStaticSamePadding(nn.Conv2d):
""" 2D Convolutions like TensorFlow, for a fixed image size"""

def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
Expand All @@ -128,7 +146,7 @@ def forward(self, x):


class Identity(nn.Module):
def __init__(self,):
def __init__(self, ):
super(Identity, self).__init__()

def forward(self, input):
Expand Down Expand Up @@ -286,6 +304,7 @@ def get_model_params(model_name, override_params):
'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
}


def load_pretrained_weights(model, model_name, load_fc=True):
""" Loads pretrained weights, and downloads if loading for the first time. """
state_dict = model_zoo.load_url(url_map[model_name])
Expand Down