Skip to content

Commit

Permalink
Equalized EfficientNetV2 PT and TF implementations (#52)
Browse files Browse the repository at this point in the history
* quickfix

* fix mbconv and fusedmbconv for pytorch

* fix mbconv and fusedmbconv for pytorch

* fixed kernel size in mbconv pt

* mbconv patch pt

* equalized FusedMBConv param count with TF

* config fixes

* fixed kernel size in mbconv pt

* mbconv patch pt

* equalized FusedMBConv param count with TF

* config fixes
  • Loading branch information
DavidLandup0 authored Feb 8, 2023
1 parent 4683958 commit ec8a23f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 38 deletions.
43 changes: 20 additions & 23 deletions deepvision/layers/fused_mbconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def __init__(
momentum=self.bn_momentum,
)

self.bn2 = layers.BatchNormalization(momentum=self.bn_momentum)

self.se_conv1 = layers.Conv2D(
self.filters_se,
1,
Expand All @@ -69,7 +67,7 @@ def __init__(
use_bias=False,
)

self.bn3 = layers.BatchNormalization(momentum=self.bn_momentum)
self.bn_out = layers.BatchNormalization(momentum=self.bn_momentum)

def call(self, inputs):
# Expansion
Expand All @@ -92,7 +90,7 @@ def call(self, inputs):

# Output projection
x = self.output_conv(x)
x = self.bn3(x)
x = self.bn_out(x)
if self.expand_ratio == 1:
x = self.activation(x)

Expand All @@ -101,10 +99,8 @@ def call(self, inputs):
if self.dropout:
x = layers.Dropout(
self.dropout,
noise_shape=(None, 1, 1, 1),
name=self.name + "drop",
)(x)
x = layers.add([x, inputs], name=self.name + "add")
x = layers.add([x, inputs])
return x

def get_config(self):
Expand Down Expand Up @@ -153,20 +149,20 @@ def __init__(
self.filters = self.input_filters * self.expand_ratio
self.filters_se = max(1, int(input_filters * se_ratio))

self.conv1 = nn.Conv2d(
in_channels=self.input_filters,
out_channels=self.filters,
kernel_size=kernel_size,
stride=strides,
padding=same_padding(kernel_size, strides),
bias=False,
)
self.bn1 = nn.BatchNorm2d(self.filters, momentum=self.bn_momentum)
self.bn2 = nn.BatchNorm2d(self.filters, momentum=self.bn_momentum)

self.se_conv1 = nn.Conv2d(self.filters, self.filters_se, 1, padding="same")
if self.expand_ratio != 1:
self.conv1 = nn.Conv2d(
in_channels=self.input_filters,
out_channels=self.filters,
kernel_size=kernel_size,
stride=strides,
padding=same_padding(kernel_size, strides),
bias=False,
)
self.bn1 = nn.BatchNorm2d(self.filters, momentum=self.bn_momentum)

self.se_conv2 = nn.Conv2d(self.filters_se, self.filters, 1, padding="same")
if 0 < self.se_ratio <= 1:
self.se_conv1 = nn.Conv2d(self.filters, self.filters_se, 1, padding="same")
self.se_conv2 = nn.Conv2d(self.filters_se, self.filters, 1, padding="same")

self.output_conv = nn.Conv2d(
in_channels=self.filters,
Expand All @@ -177,7 +173,7 @@ def __init__(
bias=False,
)

self.bn3 = nn.BatchNorm2d(self.output_filters, momentum=self.bn_momentum)
self.bn_out = nn.BatchNorm2d(self.output_filters, momentum=self.bn_momentum)

def forward(self, inputs):
if self.expand_ratio != 1:
Expand All @@ -190,7 +186,8 @@ def forward(self, inputs):
# Squeeze-and-Excite
if 0 < self.se_ratio <= 1:
se = nn.AvgPool2d(x.shape[2])(x)
se = se.reshape(x.shape[0], self.filters, 1, 1)
# No need to reshape, output is already [B, C, 1, 1]
# se = se.reshape(x.shape[0], self.filters, 1, 1)

se = self.se_conv1(se)
se = self.activation()(se)
Expand All @@ -200,7 +197,7 @@ def forward(self, inputs):

# Output projection
x = self.output_conv(x)
x = self.bn3(x)
x = self.bn_out(x)
if self.expand_ratio == 1:
x = self.activation()(x)

Expand Down
29 changes: 16 additions & 13 deletions deepvision/layers/mbconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self.bn1 = layers.BatchNormalization(momentum=self.bn_momentum)

self.depthwise = layers.DepthwiseConv2D(
kernel_size=kernel_size,
kernel_size=3,
strides=strides,
padding="same",
use_bias=False,
Expand Down Expand Up @@ -160,15 +160,16 @@ def __init__(
self.filters = self.input_filters * self.expand_ratio
self.filters_se = max(1, int(input_filters * se_ratio))

self.conv1 = nn.Conv2d(
in_channels=self.input_filters,
out_channels=self.filters,
kernel_size=kernel_size,
stride=1,
padding=same_padding(kernel_size, strides),
bias=False,
)
self.bn1 = nn.BatchNorm2d(self.filters, momentum=self.bn_momentum)
if self.expand_ratio != 1:
self.conv1 = nn.Conv2d(
in_channels=self.input_filters,
out_channels=self.filters,
kernel_size=1,
stride=1,
padding="same",
bias=False,
)
self.bn1 = nn.BatchNorm2d(self.filters, momentum=self.bn_momentum)

# Depthwise = same in_channels as groups
self.depthwise = nn.Conv2d(
Expand All @@ -182,8 +183,9 @@ def __init__(
)
self.bn2 = nn.BatchNorm2d(self.filters, momentum=self.bn_momentum)

self.se_conv1 = nn.Conv2d(self.filters, self.filters_se, 1, padding="same")
self.se_conv2 = nn.Conv2d(self.filters_se, self.filters, 1, padding="same")
if 0 < self.se_ratio <= 1:
self.se_conv1 = nn.Conv2d(self.filters, self.filters_se, 1, padding="same")
self.se_conv2 = nn.Conv2d(self.filters_se, self.filters, 1, padding="same")

self.output_conv = nn.Conv2d(
in_channels=self.filters,
Expand Down Expand Up @@ -213,7 +215,8 @@ def forward(self, inputs):
# Squeeze-and-excite
if 0 < self.se_ratio <= 1:
se = nn.AvgPool2d(x.shape[2])(x)
se = se.reshape(x.shape[0], self.filters, 1, 1)
# No need to reshape, output is already [B, C, 1, 1]
# se = se.reshape(x.shape[0], self.filters, 1, 1)

se = self.se_conv1(se)
se = self.activation()(se)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
],
},
"EfficientNetV2M": {
"block_kernel_sizes": [3, 3, 3, 3, 3, 3],
"block_kernel_sizes": [3, 3, 3, 3, 3, 3, 3],
"block_num_repeat": [3, 5, 5, 7, 14, 18, 5],
"block_in_filters": [24, 24, 48, 80, 160, 176, 304],
"block_out_filters": [24, 48, 80, 160, 176, 304, 512],
Expand All @@ -61,10 +61,11 @@
"mbconv",
"mbconv",
"mbconv",
"mbconv",
],
},
"EfficientNetV2L": {
"block_kernel_sizes": [3, 3, 3, 3, 3, 3],
"block_kernel_sizes": [3, 3, 3, 3, 3, 3, 3],
"block_num_repeat": [4, 7, 7, 10, 19, 25, 7],
"block_in_filters": [32, 32, 64, 96, 192, 224, 384],
"block_out_filters": [
Expand Down

0 comments on commit ec8a23f

Please sign in to comment.