Skip to content

Commit

Permalink
Add PNASNetlarge with pretrained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee committed May 12, 2018
1 parent fb04b2c commit e2e0f0f
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 60 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ with tf.Session() as sess:
| [InceptionResNet2](tensornets/inceptions.py#L264) | 19.744 | 4.748 | 3.962 | 55.9M | 54.3M | 656.8 | [[paper]](https://arxiv.org/abs/1602.07261) [[tf-slim]](https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py) |
| [NASNetAlarge](tensornets/nasnets.py#L100) | 17.502 | 3.996 | 3.412 | 93.5M | 89.5M | 2081 | [[paper]](https://arxiv.org/abs/1707.07012) [[tf-slim]](https://github.com/tensorflow/models/tree/master/research/slim/nets/nasnet) |
| [NASNetAmobile](tensornets/nasnets.py#L108) | 25.634 | 8.146 | 6.758 | 7.7M | 6.7M | 165.8 | [[paper]](https://arxiv.org/abs/1707.07012) [[tf-slim]](https://github.com/tensorflow/models/tree/master/research/slim/nets/nasnet) |
| [PNASNetlarge](tensornets/nasnets.py#L100) | 17.366 | 3.950 | 3.358 | 86.2M | 81.9M | 2081 | [[paper]](https://arxiv.org/abs/1712.00559) [[tf-slim]](https://github.com/tensorflow/models/tree/master/research/slim/nets/nasnet) |
| [VGG16](tensornets/vggs.py#L71) | 28.732 | 9.950 | 8.834 | 138.4M | 14.7M | 348.4 | [[paper]](https://arxiv.org/abs/1409.1556) [[keras]](https://github.com/keras-team/keras/blob/master/keras/applications/vgg16.py) |
| [VGG19](tensornets/vggs.py#L78) | 28.744 | 10.012 | 8.774 | 143.7M | 20.0M | 399.8 | [[paper]](https://arxiv.org/abs/1409.1556) [[keras]](https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py) |
| [DenseNet121](tensornets/densenets.py#L63) | 25.480 | 8.022 | 6.842 | 8.1M | 7.0M | 202.9 | [[paper]](https://arxiv.org/abs/1608.06993) [[torch]](https://github.com/liuzhuang13/DenseNet/blob/master/models/densenet.lua) |
Expand Down Expand Up @@ -316,14 +317,15 @@ with tf.Session() as sess:

## News 📰

- PNASNetlarge is released, [12 May 2018]().
- The six variants of MobileNetv2 are released, [5 May 2018](https://github.com/taehoonlee/tensornets/commit/fb429b6637f943875249dff50f4bc6220d9d50bf).
- YOLOv3 for COCO and VOC are released, [4 April 2018](https://github.com/taehoonlee/tensornets/commit/d8b2d8a54dc4b775a174035da63561028deb6624).
- Generic object detection models for YOLOv2 and FasterRCNN are released, [26 March 2018](https://github.com/taehoonlee/tensornets/commit/67915e659d2097a96c82ba7740b9e43a8c69858d).

## Future work 🔥

- Add training codes.
- Add image classification models (PolyNet, PNASNet).
- Add image classification models (PolyNet).
- Add object detection models (MaskRCNN, SSD).
- Add image segmentation models (FCN, UNet).
- Add image datasets (COCO, OpenImages).
Expand Down
1 change: 1 addition & 0 deletions tensornets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .nasnets import NASNetAlarge
from .nasnets import NASNetAmobile
from .nasnets import PNASNetlarge

from .vggs import VGG16
from .vggs import VGG19
Expand Down
19 changes: 14 additions & 5 deletions tensornets/middles.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ def names_squeezenet():


def names_nasnets(k):
names = ["normal%d/concat:0" % (i + 1) for i in range(k)]
names.insert(k // 3, "reduction%d/concat:0" % (k // 3))
names.insert(k // 3 * 2 + 1, "reduction%d/concat:0" % (k // 3 * 2))
names = []
for i in range(3):
base = sum(k[:i])
names += ["normal%d/concat:0" % (base + j + 1) for j in range(k[i])]
if i < 2:
names += ["reduction%d/concat:0" % (base + k[i])]
return names


Expand Down Expand Up @@ -210,15 +213,21 @@ def direct(model_name):
'nasnetAlarge': (
list(range(145, 371, 45)) + [416] + list(range(466, 692, 45)) + [748] +
list(range(798, 1024, 45)),
names_nasnets(18),
names_nasnets([6, 6, 6]),
-8
),
'nasnetAmobile': (
list(range(145, 281, 45)) + [326] + list(range(376, 512, 45)) + [568] +
list(range(618, 754, 45)),
names_nasnets(12),
names_nasnets([4, 4, 4]),
-6
),
'pnasnetlarge': (
list(range(169, 323, 51)) + [376] + list(range(432, 535, 51)) + [588] +
list(range(644, 747, 51)),
names_nasnets([4, 3, 3]),
-5
),
'vgg16': (
list(range(11, 16, 2)) + list(range(18, 23, 2)) +
list(range(25, 30, 2)),
Expand Down
93 changes: 83 additions & 10 deletions tensornets/nasnets.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
"""Collection of NASNet variants
The reference paper:
The reference papers:
1. Original (a.k.a. NASNet)
- Learning Transferable Architectures for Scalable Image Recognition,
arXiv 2017
- Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le
- https://arxiv.org/abs/1707.07012
2. PNASNet
- Progressive Neural Architecture Search, arXiv 2017
- Chenxi Liu et al.
- https://arxiv.org/abs/1712.00559
The reference implementation:
1. TF Slim
- https://github.com/tensorflow/models/blob/master/research/slim/nets/
nasnet/nasnet.py
nasnet/{nasnet,pnasnet}.py
"""
from __future__ import absolute_import
from __future__ import division
Expand Down Expand Up @@ -66,16 +71,15 @@ def nasnet(x, stem_filters, normals, filters, skip_reduction, use_aux,
x, p = reductionA(x, None, filters * scaling ** (-2), scope='stem1')
x, p = reductionA(x, p, filters * scaling ** (-1), scope='stem2')

for i in range(normals):
x, p = normalA(x, p, filters, scope="normal%d" % (i + 1))
for i in range(1, normals + 1):
x, p = normalA(x, p, filters, scope="normal%d" % i)

x, p0 = reductionA(x, p, filters * scaling,
scope="reduction%d" % normals)
p = p0 if not skip_reduction else p

for i in range(normals):
x, p = normalA(x, p, filters * scaling,
scope="normal%d" % (i + normals + 1))
for i in range(normals + 1, normals * 2 + 1):
x, p = normalA(x, p, filters * scaling, scope="normal%d" % i)

if use_aux is True:
a = aux(x, classes, scope='aux')
Expand All @@ -84,9 +88,8 @@ def nasnet(x, stem_filters, normals, filters, skip_reduction, use_aux,
scope="reduction%d" % (normals * 2))
p = p0 if not skip_reduction else p

for i in range(normals):
x, p = normalA(x, p, filters * scaling ** 2,
scope="normal%d" % (i + normals * 2 + 1))
for i in range(normals * 2 + 1, normals * 3 + 1):
x, p = normalA(x, p, filters * scaling ** 2, scope="normal%d" % i)

x = relu(x, name='relu')
if stem: return x
Expand All @@ -113,6 +116,44 @@ def nasnetAmobile(x, is_training=False, classes=1000,
is_training, classes, stem, scope, reuse)


def pnasnet(x, stem_filters, blocks, filters,
scaling, is_training, classes, stem, scope=None, reuse=None):
x = conv(x, stem_filters, 3, stride=2, padding='VALID', scope='conv0')

x, p = normalP(x, None, filters * scaling ** (-2), stride=2, scope='stem1')
x, p = normalP(x, p, filters * scaling ** (-1), stride=2, scope='stem2')

for i in range(1, blocks + 1):
x, p = normalP(x, p, filters, scope="normal%d" % i)

x, p = normalP(x, p, filters * scaling, stride=2,
scope="reduction%d" % blocks)

for i in range(blocks + 1, blocks * 2):
x, p = normalP(x, p, filters * scaling, scope="normal%d" % i)

x, p = normalP(x, p, filters * scaling ** 2, stride=2,
scope="reduction%d" % (blocks * 2 - 1))

for i in range(blocks * 2, blocks * 3 - 1):
x, p = normalP(x, p, filters * scaling ** 2, scope="normal%d" % i)

x = relu(x, name='relu')
if stem: return x
x = reduce_mean(x, [1, 2], name='avgpool')
x = dropout(x, keep_prob=0.5, scope='dropout')
x = fc(x, classes, scope='logits')
x = softmax(x, name='probs')
return x


@var_scope('pnasnetlarge')
@set_args(__args__)
def pnasnetlarge(x, is_training=False, classes=1000,
stem=False, scope=None, reuse=None):
return pnasnet(x, 96, 4, 216, 2, is_training, classes, stem, scope, reuse)


@var_scope('adjust')
def adjust(p, x, filters, scope=None):
if p is None:
Expand Down Expand Up @@ -173,6 +214,37 @@ def reductionA(x, p, filters, scope=None):
return concat([x2, x3, x5, x4], axis=3, name='concat'), x


@var_scope('pool')
def pool(x, filters, stride, scope=None):
y = max_pool2d(x, 3, stride=stride)
if int(x.shape[-1]) != filters:
y = conv(y, filters, 1, scope='1x1')
return y


@var_scope('normalP')
def normalP(x, p, filters, stride=1, scope=None):
p = adjust(p, x, filters)

h = relu(x)
h = conv(h, filters, 1, scope='1x1')

x1 = add(sconv(p, filters, 5, stride, scope='left1'),
pool(p, filters, stride, scope='right1'), name='add1')
x2 = add(sconv(h, filters, 7, stride, scope='left2'),
pool(h, filters, stride, scope='right2'), name='add2')
x3 = add(sconv(h, filters, 5, stride, scope='left3'),
sconv(h, filters, 3, stride, scope='right3'), name='add3')
x4 = add(sconv(x3, filters, 3, scope='left4'),
pool(h, filters, stride, scope='right4'), name='add4')
x5 = add(
sconv(p, filters, 3, stride, scope='left5'),
conv(relu(h), filters, 1, stride, scope='right5') if stride > 1 else h,
name='add5')

return concat([x1, x2, x3, x4, x5], axis=3, name='concat'), x


@var_scope('aux')
def aux(x, classes, scope=None):
x = relu(x, name='relu1')
Expand All @@ -190,3 +262,4 @@ def aux(x, classes, scope=None):
# Simple alias.
NASNetAlarge = nasnetAlarge
NASNetAmobile = nasnetAmobile
PNASNetlarge = pnasnetlarge
1 change: 1 addition & 0 deletions tensornets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def faster_rcnn_preprocess(x):
'wideresnet50': wrn_preprocess,
'nasnetAlarge': tfslim_preprocess,
'nasnetAmobile': tfslim_preprocess,
'pnasnetlarge': tfslim_preprocess,
'vgg16': keras_resnet_preprocess,
'vgg19': keras_resnet_preprocess,
'densenet': fb_preprocess,
Expand Down
12 changes: 12 additions & 0 deletions tensornets/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,17 @@ def load_nasnetAmobile(scopes, return_fn=_assign):
return return_fn(scopes, values)


def load_pnasnetlarge(scopes, return_fn=_assign):
"""Converted from the [TF Slim][2]."""
filename = 'pnasnet_large.npz'
weights_path = get_file(
filename, __model_url__ + 'nasnet/' + filename,
cache_subdir='models',
file_hash='a1afc7679b950b643865aa7ea716c901')
values = parse_weights(weights_path)
return return_fn(scopes, values)


def load_vgg16(scopes, return_fn=_assign):
"""Copied from [Keras][3]."""
filename = 'vgg16_weights_tf_dim_ordering_tf_kernels.h5'
Expand Down Expand Up @@ -761,6 +772,7 @@ def load_ref_faster_rcnn_vgg16_voc(scopes, return_fn=_assign):
'wideresnet50': load_wideresnet50,
'nasnetAlarge': load_nasnetAlarge,
'nasnetAmobile': load_nasnetAmobile,
'pnasnetlarge': load_pnasnetlarge,
'vgg16': load_vgg16,
'vgg19': load_vgg19,
'densenet121': load_densenet121,
Expand Down
6 changes: 6 additions & 0 deletions tests/basics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@
marks=pytest.mark.xfail(
LooseVersion(tf.__version__) < LooseVersion('1.3.0'),
reason='NASNetAmobile requires TensorFlow >= 1.3.0')),
pytest.param(
nets.PNASNetlarge, (331, 331, 3),
marks=pytest.mark.xfail(
LooseVersion(tf.__version__) < LooseVersion('1.3.0'),
reason='PNASNetlarge requires TensorFlow >= 1.3.0')),
random.choice([
(nets.VGG16, (224, 224, 3)),
(nets.VGG19, (224, 224, 3)),
Expand Down Expand Up @@ -87,6 +92,7 @@
'InceptionResNet2',
'NASNetAlarge',
'NASNetAmobile',
'PNASNetlarge',
'VGG',
'DenseNet',
'MobileNet',
Expand Down
63 changes: 63 additions & 0 deletions translations/tfslim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import tensorflow as tf
import tensornets as nets
import tensorflow_hub as hub

models_list = [
(nets.MobileNet35v2, (224, 224, 3), 'mobilenet_v2_035_224'),
(nets.MobileNet50v2, (224, 224, 3), 'mobilenet_v2_050_224'),
(nets.MobileNet75v2, (224, 224, 3), 'mobilenet_v2_075_224'),
(nets.MobileNet100v2, (224, 224, 3), 'mobilenet_v2_100_224'),
(nets.MobileNet130v2, (224, 224, 3), 'mobilenet_v2_130_224'),
(nets.MobileNet140v2, (224, 224, 3), 'mobilenet_v2_140_224'),
(nets.PNASNetlarge, (331, 331, 3), 'pnasnet_large'),
]

url = 'https://tfhub.dev/google/imagenet'


for (net, shape, model_name) in models_list:

with tf.Graph().as_default():

inputs = tf.placeholder(tf.float32, [None] + list(shape))
model = net(inputs, scope='a')

tfhub = hub.Module("%s/%s/classification/1" % (url, model_name))
features = tfhub(inputs, signature="image_classification",
as_dict=True)
model_tfhub = tf.nn.softmax(features['default'])

img = nets.utils.load_img('cat.png',
target_size=int(shape[0] * 8 / 7),
crop_size=shape[0])

with tf.Session() as sess:

# Retrieve values
sess.run(tf.global_variables_initializer())
weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
scope='module')
values = sess.run(weights)
for i in range(-2, 0):
values[i] = np.delete(np.squeeze(values[i]), 0, axis=-1)

# Adjust the order of the values to cover TF < 1.4.0
names = [w.name[2:] for w in model.get_weights()]
for i in range(len(names) - 1):
if 'gamma:0' in names[i] and 'beta:0' in names[i + 1]:
names[i], names[i + 1] = names[i + 1], names[i]
values[i], values[i + 1] = values[i + 1], values[i]

# Save the values as the TensorNets format
np.savez(model_name, names=names, values=values)

# Load and set the values
weights = model.get_weights()
values = nets.utils.parse_weights(model_name + '.npz')
sess.run([w.assign(v) for (w, v) in zip(weights, values)])

# Run equivalence tests
preds = sess.run(model, {inputs: model.preprocess(img)})
preds_tfhub = sess.run(model_tfhub, {inputs: img / 255.})
np.testing.assert_allclose(preds, preds_tfhub[:, 1:], atol=1e-4)
44 changes: 0 additions & 44 deletions translations/tfslim_mobilenets.py

This file was deleted.

0 comments on commit e2e0f0f

Please sign in to comment.