Skip to content

Commit

Permalink
Added static padding
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemelas committed Jun 29, 2019
1 parent 125e823 commit bd3c392
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 24 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ example/test*
*.pth*
examples/imagenet/data/
!examples/imagenet/data/README.md
tmp
tf_to_pytorch/pretrained_tensorflow
!tf_to_pytorch/pretrained_tensorflow/download.sh
examples/imagenet/run.sh



36 changes: 22 additions & 14 deletions efficientnet_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
round_filters,
round_repeats,
drop_connect,
Conv2dSamePadding,
get_same_padding_conv2d,
get_model_params,
efficientnet_params,
load_pretrained_weights,
Expand All @@ -33,30 +33,33 @@ def __init__(self, block_args, global_params):
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
self.id_skip = block_args.id_skip # skip connection and drop connect

# Get static or dynamic convolution depending on image size
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)

# Expansion phase
inp = self._block_args.input_filters # number of input channels
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
if self._block_args.expand_ratio != 1:
self._expand_conv = Conv2dSamePadding(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)

# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
self._depthwise_conv = Conv2dSamePadding(
self._depthwise_conv = Conv2d(
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
kernel_size=k, stride=s, bias=False)
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)

# Squeeze and Excitation layer, if desired
if self.has_se:
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
self._se_reduce = Conv2dSamePadding(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
self._se_expand = Conv2dSamePadding(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)

# Output phase
final_oup = self._block_args.output_filters
self._project_conv = Conv2dSamePadding(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)

def forward(self, inputs, drop_connect_rate=None):
Expand Down Expand Up @@ -109,14 +112,17 @@ def __init__(self, blocks_args=None, global_params=None):
self._global_params = global_params
self._blocks_args = blocks_args

# Get static or dynamic convolution depending on image size
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)

# Batch norm parameters
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon

# Stem
in_channels = 3 # rgb
out_channels = round_filters(32, self._global_params) # number of output channels
self._conv_stem = Conv2dSamePadding(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)

# Build blocks
Expand All @@ -140,7 +146,7 @@ def __init__(self, blocks_args=None, global_params=None):
# Head
in_channels = block_args.output_filters # output of final block
out_channels = round_filters(1280, self._global_params)
self._conv_head = Conv2dSamePadding(in_channels, out_channels, kernel_size=1, bias=False)
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)

# Final linear layer
Expand All @@ -158,7 +164,10 @@ def extract_features(self, inputs):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate)
x = block(x, drop_connect_rate=drop_connect_rate)

# Head
x = relu_fn(self._bn1(self._conv_head(x)))

return x

Expand All @@ -168,8 +177,7 @@ def forward(self, inputs):
# Convolution layers
x = self.extract_features(inputs)

# Head
x = relu_fn(self._bn1(self._conv_head(x)))
# Pooling and final linear layer
x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1)
if self._dropout:
x = F.dropout(x, p=self._dropout, training=self.training)
Expand All @@ -183,9 +191,9 @@ def from_name(cls, model_name, override_params=None):
return EfficientNet(blocks_args, global_params)

@classmethod
def from_pretrained(cls, model_name):
model = EfficientNet.from_name(model_name)
load_pretrained_weights(model, model_name)
def from_pretrained(cls, model_name, num_classes=1000):
model = EfficientNet.from_name(model_name, override_params={'num_classes': 1000})
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
return model

@classmethod
Expand Down
56 changes: 46 additions & 10 deletions efficientnet_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import math
import collections
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
Expand All @@ -21,7 +22,7 @@
GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
'num_classes', 'width_coefficient', 'depth_coefficient',
'depth_divisor', 'min_depth', 'drop_connect_rate',])
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])


# Parameters for an individual model block
Expand Down Expand Up @@ -75,8 +76,16 @@ def drop_connect(inputs, p, training):
return output


class Conv2dSamePadding(nn.Conv2d):
""" 2D Convolutions like TensorFlow """
def get_same_padding_conv2d(image_size=None):
""" Chooses static padding if you have specified an image size, and dynamic padding otherwise.
Static padding is necessary for ONNX exporting of models. """
if image_size is None:
return Conv2dDynamicSamePadding
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)

class Conv2dDynamicSamePadding(nn.Conv2d):
""" 2D Convolutions like TensorFlow, for a dynamic image size """
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]]*2
Expand All @@ -93,6 +102,31 @@ def forward(self, x):
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


class Conv2dStaticSamePadding(nn.Conv2d):
""" 2D Convolutions like TensorFlow, for a fixed image size"""
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2

# Calculate padding based on image size and save it
assert image_size is not None
ih, iw = image_size if type(image_size) == list else [image_size, image_size]
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
else:
self.static_padding = nn.Identity()

def forward(self, x):
x = self.static_padding(x)
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x


########################################################################
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
########################################################################
Expand Down Expand Up @@ -189,8 +223,8 @@ def encode(blocks_args):
return block_strings


def efficientnet(width_coefficient=None, depth_coefficient=None,
dropout_rate=0.2, drop_connect_rate=0.2):
def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
drop_connect_rate=0.2, image_size=None, num_classes=1000):
""" Creates a efficientnet model. """

blocks_args = [
Expand All @@ -207,11 +241,12 @@ def efficientnet(width_coefficient=None, depth_coefficient=None,
dropout_rate=dropout_rate,
drop_connect_rate=drop_connect_rate,
# data_format='channels_last', # removed, this is always true in PyTorch
num_classes=1000,
num_classes=num_classes,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None
min_depth=None,
image_size=image_size,
)

return blocks_args, global_params
Expand All @@ -220,9 +255,10 @@ def efficientnet(width_coefficient=None, depth_coefficient=None,
def get_model_params(model_name, override_params):
""" Get the block args and global params for a given model """
if model_name.startswith('efficientnet'):
w, d, _, p = efficientnet_params(model_name)
w, d, s, p = efficientnet_params(model_name)
# note: all models have drop connect rate = 0.2
blocks_args, global_params = efficientnet(width_coefficient=w, depth_coefficient=d, dropout_rate=p)
blocks_args, global_params = efficientnet(
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
else:
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
if override_params:
Expand All @@ -240,7 +276,7 @@ def get_model_params(model_name, override_params):
'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet-b5-586e6cc6.pth',
}

def load_pretrained_weights(model, model_name):
def load_pretrained_weights(model, model_name, load_fc=True):
""" Loads pretrained weights, and downloads if loading for the first time. """
state_dict = model_zoo.load_url(url_map[model_name])
model.load_state_dict(state_dict)
Expand Down

0 comments on commit bd3c392

Please sign in to comment.