Skip to content

Commit

Permalink
1. Add tf slim resnet_v1_101 model into examples.
Browse files Browse the repository at this point in the history
2. Implement pytorch model saver.
  • Loading branch information
kitstar committed Dec 6, 2017
1 parent 52c0386 commit 4e05f1c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
10 changes: 6 additions & 4 deletions mmdnn/conversion/examples/imagenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class TestKit(object):
'tensorflow' : {
'vgg19' : [(21, 11.285443), (144, 10.240093), (23, 9.1792336), (22, 8.1113129), (128, 8.1065922)],
'resnet' : [(22, 11.756789), (147, 8.5718527), (24, 6.1751032), (88, 4.3121386), (141, 4.1778097)],
'resnet_v1_101' : [(21, 14.384739), (23, 14.262486), (144, 14.068737), (94, 12.17205), (134, 12.064575)],

This comment has been minimized.

Copy link
@GongXinyuu

GongXinyuu Apr 24, 2018

Hey! I'm wondering what are these magic numbers. Could you please help me? THX!

'inception_v3' : [(22, 9.4921198), (24, 4.0932288), (25, 3.700398), (23, 3.3715961), (147, 3.3620636)],
'mobilenet' : [(22, 16.223597), (24, 14.54775), (147, 13.173758), (145, 11.36431), (728, 11.083847)]
},
Expand Down Expand Up @@ -56,6 +57,7 @@ class TestKit(object):
'vgg19' : lambda path : TestKit.ZeroCenter(path, 224, False),
'inception_v3' : lambda path : TestKit.Standard(path, 299),
'resnet' : lambda path : TestKit.Standard(path, 299),
'resnet_v1_101' : lambda path : TestKit.ZeroCenter(path, 224, False),
'resnet152' : lambda path : TestKit.Standard(path, 299),
'mobilenet' : lambda path : TestKit.Standard(path, 224)
},
Expand Down Expand Up @@ -92,11 +94,11 @@ def __init__(self):
parser.add_argument('-n', type=_text_type, default='kit_imagenet',
help='Network structure file name.')

parser.add_argument('-s', type = _text_type, help = 'Source Framework Type',
choices = ["caffe", "tensorflow", "keras", "cntk", "mxnet"])
parser.add_argument('-s', type=_text_type, help='Source Framework Type',
choices=self.truth.keys())

parser.add_argument('-w',
type = _text_type, help = 'Network weights file name', required = True)
parser.add_argument('-w', type=_text_type, required=True,
help='Network weights file name')

parser.add_argument('--image', '-i',
type = _text_type,
Expand Down
12 changes: 11 additions & 1 deletion mmdnn/conversion/examples/pytorch/imagenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def inference(self, image_path):
self.test_truth()


def dump(self, path=None):
if path is None: path = self.args.dump
torch.save(self.model, path)
print('PyTorch model file is saved as [{}], generated by [{}.py] and [{}].'.format(
path, self.args.n, self.args.w))


if __name__=='__main__':
tester = TestTorch()
tester.inference(tester.args.image)
if tester.args.dump:
tester.dump()
else:
tester.inference(tester.args.image)
26 changes: 15 additions & 11 deletions mmdnn/conversion/examples/tensorflow/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets import vgg
from tensorflow.contrib.slim.python.slim.nets import inception
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
from tensorflow.contrib.slim.python.slim.nets import resnet_v2
from mmdnn.conversion.examples.imagenet_test import TestKit

slim = tf.contrib.slim

input_layer_map = {
'vgg16' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 224, 224, 3]),
'vgg19' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 224, 224, 3]),
'inception_v1' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 224, 224, 3]),
'inception_v2' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
'inception_v3' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
'resnet50' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
'resnet101' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
'resnet152' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
'resnet200' : lambda : tf.placeholder(name = 'input', dtype = tf.float32, shape=[None, 299, 299, 3]),
'vgg16' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
'vgg19' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
'inception_v1' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
'inception_v2' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
'inception_v3' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
'resnet50' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
'resnet_v1_101' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 224, 224, 3]),
'resnet101' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
'resnet152' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
'resnet200' : lambda : tf.placeholder(name='input', dtype=tf.float32, shape=[None, 299, 299, 3]),
}

arg_scopes_map = {
Expand All @@ -32,6 +34,7 @@
'inception_v2' : inception.inception_v3_arg_scope,
'inception_v3' : inception.inception_v3_arg_scope,
'resnet50' : resnet_v2.resnet_arg_scope,
'resnet_v1_101' : resnet_v2.resnet_arg_scope,

This comment has been minimized.

Copy link
@GongXinyuu

GongXinyuu Apr 24, 2018

This should be 'resnet_v1_101' : resnet_v1.resnet_arg_scope instead?

'resnet101' : resnet_v2.resnet_arg_scope,
'resnet152' : resnet_v2.resnet_arg_scope,
'resnet200' : resnet_v2.resnet_arg_scope,
Expand All @@ -44,6 +47,7 @@
'inception_v1' : lambda : inception.inception_v1,
'inception_v2' : lambda : inception.inception_v2,
'inception_v3' : lambda : inception.inception_v3,
'resnet_v1_101' : lambda : resnet_v1.resnet_v1_101,
'resnet50' : lambda : resnet_v2.resnet_v2_50,
'resnet101' : lambda : resnet_v2.resnet_v2_101,
'resnet152' : lambda : resnet_v2.resnet_v2_152,
Expand All @@ -65,11 +69,11 @@ def _main():

args = parser.parse_args()

num_classes = 1000 if args.network in ('vgg16', 'vgg19') else 1001
num_classes = 1000 if args.network in ('vgg16', 'vgg19', 'resnet_v1_101') else 1001

with slim.arg_scope(arg_scopes_map[args.network]()):
data_input = input_layer_map[args.network]()
logits, endpoints = networks_map[args.network]()(data_input, num_classes = num_classes, is_training = False)
logits, endpoints = networks_map[args.network]()(data_input, num_classes=num_classes, is_training=False)
labels = tf.squeeze(logits)

init = tf.global_variables_initializer()
Expand Down

0 comments on commit 4e05f1c

Please sign in to comment.