Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmetic improvements to code #215

Merged
merged 3 commits into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 34 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# EfficientNet PyTorch

### Quickstart
### Quickstart

Install with `pip install efficientnet_pytorch` and load a pretrained EfficientNet with:
```python
Expand All @@ -12,46 +12,45 @@ model = EfficientNet.from_pretrained('efficientnet-b0')

#### Update (May 14, 2020)

This update adds comprehensive comments and documentation (thanks to @workingcoder).
This update adds comprehensive comments and documentation (thanks to @workingcoder).

#### Update (January 23, 2020)

This update adds a new category of pre-trained model based on adversarial training, called _advprop_. It is important to note that the preprocessing required for the advprop pretrained models is slightly different from normal ImageNet preprocessing. As a result, by default, advprop models are not used. To load a model with advprop, use:
```
```python
model = EfficientNet.from_pretrained("efficientnet-b0", advprop=True)
```
There is also a new, large `efficientnet-b8` pretrained model that is only available in advprop form. When using these models, replace ImageNet preprocessing code as follows:
```
```python
if advprop: # for models using advprop pretrained weights
normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
else:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

```
This update also addresses multiple other issues ([#115](https://github.com/lukemelas/EfficientNet-PyTorch/issues/115), [#128](https://github.com/lukemelas/EfficientNet-PyTorch/issues/128)).
This update also addresses multiple other issues ([#115](https://github.com/lukemelas/EfficientNet-PyTorch/issues/115), [#128](https://github.com/lukemelas/EfficientNet-PyTorch/issues/128)).

#### Update (October 15, 2019)

This update allows you to choose whether to use a memory-efficient Swish activation. The memory-efficient version is chosen by default, but it cannot be used when exporting using PyTorch JIT. For this purpose, we have also included a standard (export-friendly) swish activation function. To switch to the export-friendly version, simply call `model.set_swish(memory_efficient=False)` after loading your desired model. This update addresses issues [#88](https://github.com/lukemelas/EfficientNet-PyTorch/pull/88) and [#89](https://github.com/lukemelas/EfficientNet-PyTorch/pull/89).

#### Update (October 12, 2019)

This update makes the Swish activation function more memory-efficient. It also addresses pull requests [#72](https://github.com/lukemelas/EfficientNet-PyTorch/pull/72), [#73](https://github.com/lukemelas/EfficientNet-PyTorch/pull/73), [#85](https://github.com/lukemelas/EfficientNet-PyTorch/pull/85), and [#86](https://github.com/lukemelas/EfficientNet-PyTorch/pull/86). Thanks to the authors of all the pull requests!
This update makes the Swish activation function more memory-efficient. It also addresses pull requests [#72](https://github.com/lukemelas/EfficientNet-PyTorch/pull/72), [#73](https://github.com/lukemelas/EfficientNet-PyTorch/pull/73), [#85](https://github.com/lukemelas/EfficientNet-PyTorch/pull/85), and [#86](https://github.com/lukemelas/EfficientNet-PyTorch/pull/86). Thanks to the authors of all the pull requests!

#### Update (July 31, 2019)

_Upgrade the pip package with_ `pip install --upgrade efficientnet-pytorch`

The B6 and B7 models are now available. Additionally, _all_ pretrained models have been updated to use AutoAugment preprocessing, which translates to better performance across the board. Usage is the same as before:
The B6 and B7 models are now available. Additionally, _all_ pretrained models have been updated to use AutoAugment preprocessing, which translates to better performance across the board. Usage is the same as before:
```python
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b7')
model = EfficientNet.from_pretrained('efficientnet-b7')
```

#### Update (June 29, 2019)

This update adds easy model exporting ([#20](https://github.com/lukemelas/EfficientNet-PyTorch/issues/20)) and feature extraction ([#38](https://github.com/lukemelas/EfficientNet-PyTorch/issues/38)).
This update adds easy model exporting ([#20](https://github.com/lukemelas/EfficientNet-PyTorch/issues/20)) and feature extraction ([#38](https://github.com/lukemelas/EfficientNet-PyTorch/issues/38)).

* [Example: Export to ONNX](#example-export)
* [Example: Extract features](#example-feature-extraction)
Expand All @@ -60,29 +59,29 @@ This update adds easy model exporting ([#20](https://github.com/lukemelas/Effici
It is also now incredibly simple to load a pretrained model with a new number of classes for transfer learning:
```python
model = EfficientNet.from_pretrained('efficientnet-b1', num_classes=23)
```
```


#### Update (June 23, 2019)

The B4 and B5 models are now available. Their usage is identical to the other models:
The B4 and B5 models are now available. Their usage is identical to the other models:
```python
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b4')
model = EfficientNet.from_pretrained('efficientnet-b4')
```

### Overview
This repository contains an op-for-op PyTorch reimplementation of [EfficientNet](https://arxiv.org/abs/1905.11946), along with pre-trained models and examples.
This repository contains an op-for-op PyTorch reimplementation of [EfficientNet](https://arxiv.org/abs/1905.11946), along with pre-trained models and examples.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.
The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.

At the moment, you can easily:
* Load pretrained EfficientNet models
* Use EfficientNet models for classification or feature extraction
At the moment, you can easily:
* Load pretrained EfficientNet models
* Use EfficientNet models for classification or feature extraction
* Evaluate EfficientNet models on ImageNet or your own images

_Upcoming features_: In the next few days, you will be able to:
* Train new models from scratch on ImageNet with a simple command
* Train new models from scratch on ImageNet with a simple command
* Quickly finetune an EfficientNet on your own dataset
* Export EfficientNet models for production

Expand All @@ -95,11 +94,11 @@ _Upcoming features_: In the next few days, you will be able to:
* [Example: Classify](#example-classification)
* [Example: Extract features](#example-feature-extraction)
* [Example: Export to ONNX](#example-export)
6. [Contributing](#contributing)
6. [Contributing](#contributing)

### About EfficientNet

If you're new to EfficientNets, here is an explanation straight from the official TensorFlow implementation:
If you're new to EfficientNets, here is an explanation straight from the official TensorFlow implementation:

EfficientNets are a family of image classification models, which achieve state-of-the-art accuracy, yet being an order-of-magnitude smaller and faster than previous models. We develop EfficientNets based on AutoML and Compound Scaling. In particular, we first use [AutoML Mobile framework](https://ai.googleblog.com/2018/08/mnasnet-towards-automating-design-of.html) to develop a mobile-size baseline network, named as EfficientNet-B0; Then, we use the compound scaling method to scale up this baseline to obtain EfficientNet-B1 to B7.

Expand Down Expand Up @@ -141,25 +140,25 @@ Or install from source:
git clone https://github.com/lukemelas/EfficientNet-PyTorch
cd EfficientNet-Pytorch
pip install -e .
```
```

### Usage

#### Loading pretrained models

Load an EfficientNet:
Load an EfficientNet:
```python
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_name('efficientnet-b0')
```

Load a pretrained EfficientNet:
Load a pretrained EfficientNet:
```python
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')
```

Details about the models are below:
Details about the models are below:

| *Name* |*# Params*|*Top-1 Acc.*|*Pretrained?*|
|:-----------------:|:--------:|:----------:|:-----------:|
Expand All @@ -177,7 +176,7 @@ Details about the models are below:

Below is a simple, complete example. It may also be found as a jupyter notebook in `examples/simple` or as a [Colab Notebook](https://colab.research.google.com/drive/1Jw28xZ1NJq4Cja4jLe6tJ6_F5lCzElb4).

We assume that in your current directory, there is a `img.jpg` file and a `labels_map.txt` file (ImageNet class names). These are both included in `examples/simple`.
We assume that in your current directory, there is a `img.jpg` file and a `labels_map.txt` file (ImageNet class names). These are both included in `examples/simple`.

```python
import json
Expand Down Expand Up @@ -210,7 +209,7 @@ for idx in torch.topk(outputs, k=5).indices.squeeze(0).tolist():
print('{label:<75} ({p:.2f}%)'.format(label=labels_map[idx], p=prob*100))
```

#### Example: Feature Extraction
#### Example: Feature Extraction

You can easily extract features with `model.extract_features`:
```python
Expand All @@ -224,20 +223,20 @@ features = model.extract_features(img)
print(features.shape) # torch.Size([1, 1280, 7, 7])
```

#### Example: Export to ONNX
#### Example: Export to ONNX

Exporting to ONNX for deploying to production is now simple:
Exporting to ONNX for deploying to production is now simple:
```python
import torch
import torch
from efficientnet_pytorch import EfficientNet

model = EfficientNet.from_pretrained('efficientnet-b1')
dummy_input = torch.randn(10, 3, 240, 240)

torch.onnx.export(model, dummy_input, "test-b1.onnx", verbose=True)
```
```

[Here](https://colab.research.google.com/drive/1rOAEXeXHaA8uo3aG2YcFDHItlRJMV0VP) is a Colab example.
[Here](https://colab.research.google.com/drive/1rOAEXeXHaA8uo3aG2YcFDHItlRJMV0VP) is a Colab example.


#### ImageNet
Expand All @@ -246,6 +245,6 @@ See `examples/imagenet` for details about evaluating on ImageNet.

### Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.
If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!
I look forward to seeing what the community does with these models!
1 change: 0 additions & 1 deletion efficientnet_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@
efficientnet,
get_model_params,
)

32 changes: 16 additions & 16 deletions efficientnet_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, block_args, global_params, image_size=None):

# Squeeze and Excitation layer, if desired
if self.has_se:
Conv2d = get_same_padding_conv2d(image_size=(1,1))
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
Expand Down Expand Up @@ -147,7 +147,7 @@ class EfficientNet(nn.Module):
Args:
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
global_params (namedtuple): A set of GlobalParams shared between blocks.

References:
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)

Expand Down Expand Up @@ -261,12 +261,12 @@ def extract_endpoints(self, inputs):
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)
if prev_x.size(2) > x.size(2):
endpoints[f'reduction_{len(endpoints)+1}'] = prev_x
endpoints['reduction_{}'.format(len(endpoints)+1)] = prev_x
prev_x = x

# Head
x = self._swish(self._bn1(self._conv_head(x)))
endpoints[f'reduction_{len(endpoints)+1}'] = x
endpoints['reduction_{}'.format(len(endpoints)+1)] = x

return endpoints

Expand All @@ -277,7 +277,7 @@ def extract_features(self, inputs):
inputs (tensor): Input tensor.

Returns:
Output of the final convolution
Output of the final convolution
layer in the efficientnet model.
"""
# Stem
Expand All @@ -289,7 +289,7 @@ def extract_features(self, inputs):
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)

# Head
x = self._swish(self._bn1(self._conv_head(x)))

Expand Down Expand Up @@ -323,7 +323,7 @@ def from_name(cls, model_name, in_channels=3, **override_params):
Args:
model_name (str): Name for efficientnet.
in_channels (int): Input data's channel number.
override_params (other key word params):
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
Expand All @@ -342,35 +342,35 @@ def from_name(cls, model_name, in_channels=3, **override_params):
return model

@classmethod
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
in_channels=3, num_classes=1000, **override_params):
"""create an efficientnet model according to name.

Args:
model_name (str): Name for efficientnet.
weights_path (None or str):
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
advprop (bool):
advprop (bool):
Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
in_channels (int): Input data's channel number.
num_classes (int):
num_classes (int):
Number of categories for classification.
It controls the output size for final linear layer.
override_params (other key word params):
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
'image_size', 'dropout_rate',
'num_classes', 'batch_norm_momentum',
'batch_norm_momentum',
'batch_norm_epsilon', 'drop_connect_rate',
'depth_divisor', 'min_depth'

Returns:
A pretrained efficientnet model.
"""
model = cls.from_name(model_name, num_classes = num_classes, **override_params)
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop)
model._change_in_channels(in_channels)
return model
Expand All @@ -391,7 +391,7 @@ def get_image_size(cls, model_name):

@classmethod
def _check_model_name_is_valid(cls, model_name):
"""Validates model name.
"""Validates model name.

Args:
model_name (str): Name for efficientnet.
Expand All @@ -409,6 +409,6 @@ def _change_in_channels(self, in_channels):
in_channels (int): Input data's channel number.
"""
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size = self._global_params.image_size)
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
out_channels = round_filters(32, self._global_params)
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
Loading