Skip to content

Commit

Permalink
Merge pull request #215 from chrisyeh96/linting
Browse files Browse the repository at this point in the history
Cosmetic improvements to code
  • Loading branch information
lukemelas committed Aug 25, 2020
2 parents 2eb7a7d + 85e0a35 commit f543b74
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 76 deletions.
69 changes: 34 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# EfficientNet PyTorch

### Quickstart
### Quickstart

Install with `pip install efficientnet_pytorch` and load a pretrained EfficientNet with:
```python
Expand All @@ -12,46 +12,45 @@ model = EfficientNet.from_pretrained('efficientnet-b0')

#### Update (May 14, 2020)

This update adds comprehensive comments and documentation (thanks to @workingcoder).
This update adds comprehensive comments and documentation (thanks to @workingcoder).

#### Update (January 23, 2020)

This update adds a new category of pre-trained model based on adversarial training, called _advprop_. It is important to note that the preprocessing required for the advprop pretrained models is slightly different from normal ImageNet preprocessing. As a result, by default, advprop models are not used. To load a model with advprop, use:
```
```python
model = EfficientNet.from_pretrained("efficientnet-b0", advprop=True)
```
There is also a new, large `efficientnet-b8` pretrained model that is only available in advprop form. When using these models, replace ImageNet preprocessing code as follows:
```
```python
if advprop: # for models using advprop pretrained weights
normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
else:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
```
This update also addresses multiple other issues ([#115](https://github.com/lukemelas/EfficientNet-PyTorch/issues/115), [#128](https://github.com/lukemelas/EfficientNet-PyTorch/issues/128)).
This update also addresses multiple other issues ([#115](https://github.com/lukemelas/EfficientNet-PyTorch/issues/115), [#128](https://github.com/lukemelas/EfficientNet-PyTorch/issues/128)).

#### Update (October 15, 2019)

This update allows you to choose whether to use a memory-efficient Swish activation. The memory-efficient version is chosen by default, but it cannot be used when exporting using PyTorch JIT. For this purpose, we have also included a standard (export-friendly) swish activation function. To switch to the export-friendly version, simply call `model.set_swish(memory_efficient=False)` after loading your desired model. This update addresses issues [#88](https://github.com/lukemelas/EfficientNet-PyTorch/pull/88) and [#89](https://github.com/lukemelas/EfficientNet-PyTorch/pull/89).

#### Update (October 12, 2019)

This update makes the Swish activation function more memory-efficient. It also addresses pull requests [#72](https://github.com/lukemelas/EfficientNet-PyTorch/pull/72), [#73](https://github.com/lukemelas/EfficientNet-PyTorch/pull/73), [#85](https://github.com/lukemelas/EfficientNet-PyTorch/pull/85), and [#86](https://github.com/lukemelas/EfficientNet-PyTorch/pull/86). Thanks to the authors of all the pull requests!
This update makes the Swish activation function more memory-efficient. It also addresses pull requests [#72](https://github.com/lukemelas/EfficientNet-PyTorch/pull/72), [#73](https://github.com/lukemelas/EfficientNet-PyTorch/pull/73), [#85](https://github.com/lukemelas/EfficientNet-PyTorch/pull/85), and [#86](https://github.com/lukemelas/EfficientNet-PyTorch/pull/86). Thanks to the authors of all the pull requests!

#### Update (July 31, 2019)

_Upgrade the pip package with_ `pip install --upgrade efficientnet-pytorch`

The B6 and B7 models are now available. Additionally, _all_ pretrained models have been updated to use AutoAugment preprocessing, which translates to better performance across the board. Usage is the same as before:
The B6 and B7 models are now available. Additionally, _all_ pretrained models have been updated to use AutoAugment preprocessing, which translates to better performance across the board. Usage is the same as before:
```python
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b7')
model = EfficientNet.from_pretrained('efficientnet-b7')
```

#### Update (June 29, 2019)

This update adds easy model exporting ([#20](https://github.com/lukemelas/EfficientNet-PyTorch/issues/20)) and feature extraction ([#38](https://github.com/lukemelas/EfficientNet-PyTorch/issues/38)).
This update adds easy model exporting ([#20](https://github.com/lukemelas/EfficientNet-PyTorch/issues/20)) and feature extraction ([#38](https://github.com/lukemelas/EfficientNet-PyTorch/issues/38)).

* [Example: Export to ONNX](#example-export)
* [Example: Extract features](#example-feature-extraction)
Expand All @@ -60,29 +59,29 @@ This update adds easy model exporting ([#20](https://github.com/lukemelas/Effici
It is also now incredibly simple to load a pretrained model with a new number of classes for transfer learning:
```python
model = EfficientNet.from_pretrained('efficientnet-b1', num_classes=23)
```
```


#### Update (June 23, 2019)

The B4 and B5 models are now available. Their usage is identical to the other models:
The B4 and B5 models are now available. Their usage is identical to the other models:
```python
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b4')
model = EfficientNet.from_pretrained('efficientnet-b4')
```

### Overview
This repository contains an op-for-op PyTorch reimplementation of [EfficientNet](https://arxiv.org/abs/1905.11946), along with pre-trained models and examples.
This repository contains an op-for-op PyTorch reimplementation of [EfficientNet](https://arxiv.org/abs/1905.11946), along with pre-trained models and examples.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.
The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.

At the moment, you can easily:
* Load pretrained EfficientNet models
* Use EfficientNet models for classification or feature extraction
At the moment, you can easily:
* Load pretrained EfficientNet models
* Use EfficientNet models for classification or feature extraction
* Evaluate EfficientNet models on ImageNet or your own images

_Upcoming features_: In the next few days, you will be able to:
* Train new models from scratch on ImageNet with a simple command
* Train new models from scratch on ImageNet with a simple command
* Quickly finetune an EfficientNet on your own dataset
* Export EfficientNet models for production

Expand All @@ -95,11 +94,11 @@ _Upcoming features_: In the next few days, you will be able to:
* [Example: Classify](#example-classification)
* [Example: Extract features](#example-feature-extraction)
* [Example: Export to ONNX](#example-export)
6. [Contributing](#contributing)
6. [Contributing](#contributing)

### About EfficientNet

If you're new to EfficientNets, here is an explanation straight from the official TensorFlow implementation:
If you're new to EfficientNets, here is an explanation straight from the official TensorFlow implementation:

EfficientNets are a family of image classification models, which achieve state-of-the-art accuracy, yet being an order-of-magnitude smaller and faster than previous models. We develop EfficientNets based on AutoML and Compound Scaling. In particular, we first use [AutoML Mobile framework](https://ai.googleblog.com/2018/08/mnasnet-towards-automating-design-of.html) to develop a mobile-size baseline network, named as EfficientNet-B0; Then, we use the compound scaling method to scale up this baseline to obtain EfficientNet-B1 to B7.

Expand Down Expand Up @@ -141,25 +140,25 @@ Or install from source:
git clone https://github.com/lukemelas/EfficientNet-PyTorch
cd EfficientNet-Pytorch
pip install -e .
```
```

### Usage

#### Loading pretrained models

Load an EfficientNet:
Load an EfficientNet:
```python
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_name('efficientnet-b0')
```

Load a pretrained EfficientNet:
Load a pretrained EfficientNet:
```python
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')
```

Details about the models are below:
Details about the models are below:

| *Name* |*# Params*|*Top-1 Acc.*|*Pretrained?*|
|:-----------------:|:--------:|:----------:|:-----------:|
Expand All @@ -177,7 +176,7 @@ Details about the models are below:

Below is a simple, complete example. It may also be found as a jupyter notebook in `examples/simple` or as a [Colab Notebook](https://colab.research.google.com/drive/1Jw28xZ1NJq4Cja4jLe6tJ6_F5lCzElb4).

We assume that in your current directory, there is a `img.jpg` file and a `labels_map.txt` file (ImageNet class names). These are both included in `examples/simple`.
We assume that in your current directory, there is a `img.jpg` file and a `labels_map.txt` file (ImageNet class names). These are both included in `examples/simple`.

```python
import json
Expand Down Expand Up @@ -210,7 +209,7 @@ for idx in torch.topk(outputs, k=5).indices.squeeze(0).tolist():
print('{label:<75} ({p:.2f}%)'.format(label=labels_map[idx], p=prob*100))
```

#### Example: Feature Extraction
#### Example: Feature Extraction

You can easily extract features with `model.extract_features`:
```python
Expand All @@ -224,20 +223,20 @@ features = model.extract_features(img)
print(features.shape) # torch.Size([1, 1280, 7, 7])
```

#### Example: Export to ONNX
#### Example: Export to ONNX

Exporting to ONNX for deploying to production is now simple:
Exporting to ONNX for deploying to production is now simple:
```python
import torch
import torch
from efficientnet_pytorch import EfficientNet

model = EfficientNet.from_pretrained('efficientnet-b1')
dummy_input = torch.randn(10, 3, 240, 240)

torch.onnx.export(model, dummy_input, "test-b1.onnx", verbose=True)
```
```

[Here](https://colab.research.google.com/drive/1rOAEXeXHaA8uo3aG2YcFDHItlRJMV0VP) is a Colab example.
[Here](https://colab.research.google.com/drive/1rOAEXeXHaA8uo3aG2YcFDHItlRJMV0VP) is a Colab example.


#### ImageNet
Expand All @@ -246,6 +245,6 @@ See `examples/imagenet` for details about evaluating on ImageNet.

### Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.
If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!
I look forward to seeing what the community does with these models!
1 change: 0 additions & 1 deletion efficientnet_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@
efficientnet,
get_model_params,
)

32 changes: 16 additions & 16 deletions efficientnet_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, block_args, global_params, image_size=None):

# Squeeze and Excitation layer, if desired
if self.has_se:
Conv2d = get_same_padding_conv2d(image_size=(1,1))
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
Expand Down Expand Up @@ -147,7 +147,7 @@ class EfficientNet(nn.Module):
Args:
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
global_params (namedtuple): A set of GlobalParams shared between blocks.
References:
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
Expand Down Expand Up @@ -261,12 +261,12 @@ def extract_endpoints(self, inputs):
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)
if prev_x.size(2) > x.size(2):
endpoints[f'reduction_{len(endpoints)+1}'] = prev_x
endpoints['reduction_{}'.format(len(endpoints)+1)] = prev_x
prev_x = x

# Head
x = self._swish(self._bn1(self._conv_head(x)))
endpoints[f'reduction_{len(endpoints)+1}'] = x
endpoints['reduction_{}'.format(len(endpoints)+1)] = x

return endpoints

Expand All @@ -277,7 +277,7 @@ def extract_features(self, inputs):
inputs (tensor): Input tensor.
Returns:
Output of the final convolution
Output of the final convolution
layer in the efficientnet model.
"""
# Stem
Expand All @@ -289,7 +289,7 @@ def extract_features(self, inputs):
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)

# Head
x = self._swish(self._bn1(self._conv_head(x)))

Expand Down Expand Up @@ -323,7 +323,7 @@ def from_name(cls, model_name, in_channels=3, **override_params):
Args:
model_name (str): Name for efficientnet.
in_channels (int): Input data's channel number.
override_params (other key word params):
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
Expand All @@ -342,35 +342,35 @@ def from_name(cls, model_name, in_channels=3, **override_params):
return model

@classmethod
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
in_channels=3, num_classes=1000, **override_params):
"""create an efficientnet model according to name.
Args:
model_name (str): Name for efficientnet.
weights_path (None or str):
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
advprop (bool):
advprop (bool):
Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
in_channels (int): Input data's channel number.
num_classes (int):
num_classes (int):
Number of categories for classification.
It controls the output size for final linear layer.
override_params (other key word params):
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
'image_size', 'dropout_rate',
'num_classes', 'batch_norm_momentum',
'batch_norm_momentum',
'batch_norm_epsilon', 'drop_connect_rate',
'depth_divisor', 'min_depth'
Returns:
A pretrained efficientnet model.
"""
model = cls.from_name(model_name, num_classes = num_classes, **override_params)
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop)
model._change_in_channels(in_channels)
return model
Expand All @@ -391,7 +391,7 @@ def get_image_size(cls, model_name):

@classmethod
def _check_model_name_is_valid(cls, model_name):
"""Validates model name.
"""Validates model name.
Args:
model_name (str): Name for efficientnet.
Expand All @@ -409,6 +409,6 @@ def _change_in_channels(self, in_channels):
in_channels (int): Input data's channel number.
"""
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size = self._global_params.image_size)
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
out_channels = round_filters(32, self._global_params)
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
Loading

0 comments on commit f543b74

Please sign in to comment.