Skip to content

Commit

Permalink
Merge pull request #208 from nwschurink/#192_include_top
Browse files Browse the repository at this point in the history
#192 include top
  • Loading branch information
lukemelas committed Aug 25, 2020
2 parents a746930 + a78e84e commit 761ac94
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
13 changes: 7 additions & 6 deletions efficientnet_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ class EfficientNet(nn.Module):
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
Example:
>>> import torch
import torch
>>> from efficientnet.model import EfficientNet
>>> inputs = torch.rand(1, 3, 224, 224)
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
Expand Down Expand Up @@ -307,13 +309,12 @@ def forward(self, inputs):
"""
# Convolution layers
x = self.extract_features(inputs)

# Pooling and final linear layer
x = self._avg_pooling(x)
x = x.flatten(start_dim=1)
x = self._dropout(x)
x = self._fc(x)

if self._global_params.include_top:
x = x.flatten(start_dim=1)
x = self._dropout(x)
x = self._fc(x)
return x

@classmethod
Expand Down
5 changes: 3 additions & 2 deletions efficientnet_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
GlobalParams = collections.namedtuple('GlobalParams', [
'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
'drop_connect_rate', 'depth_divisor', 'min_depth'])
'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])

# Parameters for an individual model block
BlockArgs = collections.namedtuple('BlockArgs', [
Expand Down Expand Up @@ -475,7 +475,7 @@ def efficientnet_params(model_name):


def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000):
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
"""Create BlockArgs and GlobalParams for efficientnet model.
Args:
Expand Down Expand Up @@ -517,6 +517,7 @@ def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None
drop_connect_rate=drop_connect_rate,
depth_divisor=8,
min_depth=None,
include_top=include_top,
)

return blocks_args, global_params
Expand Down

0 comments on commit 761ac94

Please sign in to comment.