Skip to content

Commit

Permalink
Merge pull request #87 from lukemelas/relu_update
Browse files Browse the repository at this point in the history
Updated ReLU, dropout, and more
  • Loading branch information
lukemelas authored Oct 12, 2019
2 parents de40cbf + 3be143e commit 532a6d7
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 17 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# EfficientNet PyTorch

### Update (October 12, 2019)

This update changes activation function implementation to more memory-efficient. For more details please refer to: https://github.com/lukemelas/EfficientNet-PyTorch/issues/18. Thanks to [Dmytro Panchenko](https://www.kaggle.com/hokmund) for the pull request.

### Update (July 31, 2019)

_Upgrade the pip package with_ `pip install --upgrade efficientnet-pytorch`
Expand Down
2 changes: 1 addition & 1 deletion efficientnet_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.0"
__version__ = "0.5.0"
from .model import EfficientNet
from .utils import (
GlobalParams,
Expand Down
22 changes: 17 additions & 5 deletions efficientnet_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __init__(self, blocks_args=None, global_params=None):
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)

# Final linear layer
self._dropout = self._global_params.dropout_rate
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)

def extract_features(self, inputs):
Expand All @@ -173,14 +174,14 @@ def extract_features(self, inputs):

def forward(self, inputs):
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """

bs = inputs.size(0)
# Convolution layers
x = self.extract_features(inputs)

# Pooling and final linear layer
x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1)
if self._dropout:
x = F.dropout(x, p=self._dropout, training=self.training)
x = self._avg_pooling(x)
x = x.view(bs, -1)
x = self._dropout(x)
x = self._fc(x)
return x

Expand All @@ -190,10 +191,21 @@ def from_name(cls, model_name, override_params=None):
blocks_args, global_params = get_model_params(model_name, override_params)
return cls(blocks_args, global_params)

@classmethod
def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3):
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
out_channels = round_filters(32, model._global_params)
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
return model

@classmethod
def from_pretrained(cls, model_name, num_classes=1000):
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))

return model

@classmethod
Expand Down
39 changes: 29 additions & 10 deletions efficientnet_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch.nn import functional as F
from torch.utils import model_zoo


########################################################################
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
########################################################################
Expand All @@ -24,21 +23,37 @@
'num_classes', 'width_coefficient', 'depth_coefficient',
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])


# Parameters for an individual model block
BlockArgs = collections.namedtuple('BlockArgs', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'stride', 'se_ratio'])


# Change namedtuple defaults
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)


def relu_fn(x):
""" Swish activation function """
return x * torch.sigmoid(x)
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result

@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class Swish(nn.Module):
@staticmethod
def forward(x):
return SwishImplementation.apply(x)


relu_fn = Swish()


def round_filters(filters, global_params):
Expand Down Expand Up @@ -84,11 +99,13 @@ def get_same_padding_conv2d(image_size=None):
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)


class Conv2dDynamicSamePadding(nn.Conv2d):
""" 2D Convolutions like TensorFlow, for a dynamic image size """

def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]]*2
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2

def forward(self, x):
ih, iw = x.size()[-2:]
Expand All @@ -98,12 +115,13 @@ def forward(self, x):
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


class Conv2dStaticSamePadding(nn.Conv2d):
""" 2D Convolutions like TensorFlow, for a fixed image size"""

def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
Expand All @@ -128,7 +146,7 @@ def forward(self, x):


class Identity(nn.Module):
def __init__(self,):
def __init__(self, ):
super(Identity, self).__init__()

def forward(self, input):
Expand Down Expand Up @@ -286,6 +304,7 @@ def get_model_params(model_name, override_params):
'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
}


def load_pretrained_weights(model, model_name, load_fc=True):
""" Loads pretrained weights, and downloads if loading for the first time. """
state_dict = model_zoo.load_url(url_map[model_name])
Expand All @@ -295,5 +314,5 @@ def load_pretrained_weights(model, model_name, load_fc=True):
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
res = model.load_state_dict(state_dict, strict=False)
assert str(res.missing_keys) == str(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
print('Loaded pretrained weights for {}'.format(model_name))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
EMAIL = 'lmelaskyriazi@college.harvard.edu'
AUTHOR = 'Luke'
REQUIRES_PYTHON = '>=3.5.0'
VERSION = '0.4.0'
VERSION = '0.5.0'

# What packages are required for this module to be executed?
REQUIRED = [
Expand Down
107 changes: 107 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from collections import OrderedDict

import pytest
import torch
import torch.nn as nn

from efficientnet_pytorch import EfficientNet


# -- fixtures -------------------------------------------------------------------------------------

@pytest.fixture(scope='module', params=[x for x in range(4)])
def model(request):
return 'efficientnet-b{}'.format(request.param)


@pytest.fixture(scope='module', params=[True, False])
def pretrained(request):
return request.param


@pytest.fixture(scope='function')
def net(model, pretrained):
return EfficientNet.from_pretrained(model) if pretrained else EfficientNet.from_name(model)


# -- tests ----------------------------------------------------------------------------------------

@pytest.mark.parametrize('img_size', [224, 256, 512])
def test_forward(net, img_size):
"""Test `.forward()` doesn't throw an error"""
data = torch.zeros((1, 3, img_size, img_size))
output = net(data)
assert not torch.isnan(output).any()


def test_dropout_training(net):
"""Test dropout `.training` is set by `.train()` on parent `nn.module`"""
net.train()
assert net._dropout.training == True


def test_dropout_eval(net):
"""Test dropout `.training` is set by `.eval()` on parent `nn.module`"""
net.eval()
assert net._dropout.training == False


def test_dropout_update(net):
"""Test dropout `.training` is updated by `.train()` and `.eval()` on parent `nn.module`"""
net.train()
assert net._dropout.training == True
net.eval()
assert net._dropout.training == False
net.train()
assert net._dropout.training == True
net.eval()
assert net._dropout.training == False


@pytest.mark.parametrize('img_size', [224, 256, 512])
def test_modify_dropout(net, img_size):
"""Test ability to modify dropout and fc modules of network"""
dropout = nn.Sequential(OrderedDict([
('_bn2', nn.BatchNorm1d(net._bn1.num_features)),
('_drop1', nn.Dropout(p=net._global_params.dropout_rate)),
('_linear1', nn.Linear(net._bn1.num_features, 512)),
('_relu', nn.ReLU()),
('_bn3', nn.BatchNorm1d(512)),
('_drop2', nn.Dropout(p=net._global_params.dropout_rate / 2))
]))
fc = nn.Linear(512, net._global_params.num_classes)

net._dropout = dropout
net._fc = fc

data = torch.zeros((2, 3, img_size, img_size))
output = net(data)
assert not torch.isnan(output).any()


@pytest.mark.parametrize('img_size', [224, 256, 512])
def test_modify_pool(net, img_size):
"""Test ability to modify pooling module of network"""

class AdaptiveMaxAvgPool(nn.Module):

def __init__(self):
super().__init__()
self.ada_avgpool = nn.AdaptiveAvgPool2d(1)
self.ada_maxpool = nn.AdaptiveMaxPool2d(1)

def forward(self, x):
avg_x = self.ada_avgpool(x)
max_x = self.ada_maxpool(x)
x = torch.cat((avg_x, max_x), dim=1)
return x

avg_pooling = AdaptiveMaxAvgPool()
fc = nn.Linear(net._fc.in_features * 2, net._global_params.num_classes)

net._avg_pooling = avg_pooling
net._fc = fc

data = torch.zeros((2, 3, img_size, img_size))
output = net(data)
assert not torch.isnan(output).any()

0 comments on commit 532a6d7

Please sign in to comment.