-
Notifications
You must be signed in to change notification settings - Fork 1
/
model_inspect.py
418 lines (362 loc) · 16.1 KB
/
model_inspect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
# Copyright 2020 Google Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Tool to inspect a model."""
from __future__ import absolute_import
from __future__ import division
# gtype import
from __future__ import print_function
import os
import time
from absl import flags
from absl import logging
import numpy as np
from PIL import Image
import tensorflow.compat.v1 as tf
from typing import Text, Tuple, List
import det_model_fn
import hparams_config
import inference
import utils
flags.DEFINE_string('model_name', 'efficientdet-d1', 'Model.')
flags.DEFINE_string('logdir', '/tmp/deff/', 'log directory.')
flags.DEFINE_string('runmode', 'dry', 'Run mode: {freeze, bm, dry}')
flags.DEFINE_string('trace_filename', None, 'Trace file name.')
flags.DEFINE_integer('num_classes', 90, 'Number of classes.')
flags.DEFINE_string('input_image_size', None, 'Size of input image. Enter a'
'single integer if the image height is equal to the width;'
'Otherwise, enter two integers seprated by a "x".'
'e.g. "1280x640" if width=1280 and height=640.')
flags.DEFINE_integer('threads', 0, 'Number of threads.')
flags.DEFINE_integer('bm_runs', 20, 'Number of benchmark runs.')
flags.DEFINE_string('tensorrt', None, 'TensorRT mode: {None, FP32, FP16, INT8}')
flags.DEFINE_bool('delete_logdir', True, 'Whether to delete logdir.')
flags.DEFINE_bool('freeze', False, 'Freeze graph.')
flags.DEFINE_bool('xla', False, 'Run with xla optimization.')
flags.DEFINE_string('ckpt_path', None, 'checkpoint dir used for eval.')
flags.DEFINE_string('export_ckpt', None, 'Path for exporting new models.')
flags.DEFINE_bool('enable_ema', True, 'Use ema variables for eval.')
flags.DEFINE_string('data_format', None, 'data format, e.g., channel_last.')
flags.DEFINE_string('input_image', None, 'Input image path for inference.')
flags.DEFINE_string('output_image_dir', None, 'Output dir for inference.')
# For visualization.
flags.DEFINE_integer('line_thickness', None, 'Line thickness for box.')
flags.DEFINE_integer('max_boxes_to_draw', None, 'Max number of boxes to draw.')
flags.DEFINE_float('min_score_thresh', None, 'Score threshold to show box.')
# For saved model.
flags.DEFINE_string('saved_model_dir', '/tmp/saved_model',
'Folder path for saved model.')
FLAGS = flags.FLAGS
class ModelInspector(object):
"""A simple helper class for inspecting a model."""
def __init__(self,
model_name: Text,
image_size: Text,
num_classes: int,
logdir: Text,
tensorrt: Text = False,
use_xla: bool = False,
ckpt_path: Text = None,
enable_ema: bool = True,
export_ckpt: Text = None,
saved_model_dir: Text = None,
data_format: Text = None):
self.model_name = model_name
self.model_params = hparams_config.get_detection_config(model_name)
self.logdir = logdir
self.tensorrt = tensorrt
self.use_xla = use_xla
self.ckpt_path = ckpt_path
self.enable_ema = enable_ema
self.export_ckpt = export_ckpt
self.saved_model_dir = saved_model_dir
if image_size is None:
image_size = hparams_config.get_detection_config(model_name).image_size
image_size = (image_size, image_size)
elif 'x' in image_size:
# image_size is in format of WIDTHxHEIGHT
width, height = image_size.split('x')
image_size = (int(height), int(width))
else:
# image_size is integer, witht the same width and height.
image_size = (int(image_size), int(image_size))
self.model_overrides = {
'image_size': image_size,
'num_classes': num_classes
}
if data_format:
self.model_overrides.update(dict(data_format=data_format))
# A few fixed parameters.
self.batch_size = 1
self.num_classes = num_classes
self.data_format = data_format
self.inputs_shape = [self.batch_size, image_size[0], image_size[1], 3]
self.labels_shape = [self.batch_size, self.num_classes]
self.image_size = image_size
def build_model(self, inputs: tf.Tensor,
is_training: bool = False) -> List[tf.Tensor]:
"""Build model with inputs and labels and print out model stats."""
logging.info('start building model')
model_arch = det_model_fn.get_model_arch(self.model_name)
cls_outputs, box_outputs = model_arch(
inputs,
model_name=self.model_name,
is_training_bn=is_training,
use_bfloat16=False,
**self.model_overrides)
print('backbone+fpn+box params/flops = {:.6f}M, {:.9f}B'.format(
*utils.num_params_flops()))
# Write to tfevent for tensorboard.
train_writer = tf.summary.FileWriter(self.logdir)
train_writer.add_graph(tf.get_default_graph())
train_writer.flush()
all_outputs = list(cls_outputs.values()) + list(box_outputs.values())
return all_outputs
def export_saved_model(self, **kwargs):
"""Export a saved model for inference."""
tf.enable_resource_variables()
driver = inference.ServingDriver(
self.model_name,
self.ckpt_path,
enable_ema=self.enable_ema,
use_xla=self.use_xla,
data_format=self.data_format,
**kwargs)
driver.build(params_override=self.model_overrides)
driver.export(self.saved_model_dir)
def saved_model_inference(self, image_path_pattern, output_dir, **kwargs):
"""Perform inference for the given saved model."""
driver = inference.ServingDriver(
self.model_name,
self.ckpt_path,
enable_ema=self.enable_ema,
use_xla=self.use_xla,
data_format=self.data_format,
**kwargs)
driver.load(self.saved_model_dir)
raw_images = []
image = Image.open(image_path_pattern)
raw_images.append(np.array(image))
detections_bs = driver.serve_images(raw_images)
for i, detections in enumerate(detections_bs):
img = driver.visualize(raw_images[i], detections, **kwargs)
output_image_path = os.path.join(output_dir, str(i) + '.jpg')
Image.fromarray(img).save(output_image_path)
logging.info('writing file to %s', output_image_path)
def saved_model_benchmark(self, image_path_pattern, **kwargs):
"""Perform inference for the given saved model."""
driver = inference.ServingDriver(
self.model_name,
self.ckpt_path,
enable_ema=self.enable_ema,
use_xla=self.use_xla,
data_format=self.data_format,
**kwargs)
driver.load(self.saved_model_dir)
raw_images = []
image = Image.open(image_path_pattern)
raw_images.append(np.array(image))
driver.benchmark(raw_images, FLAGS.trace_filename)
def inference_single_image(self, image_image_path, output_dir, **kwargs):
driver = inference.InferenceDriver(self.model_name, self.ckpt_path,
self.image_size, self.num_classes,
self.enable_ema, self.data_format)
driver.inference(image_image_path, output_dir, **kwargs)
def build_and_save_model(self):
"""build and save the model into self.logdir."""
with tf.Graph().as_default(), tf.Session() as sess:
# Build model with inputs and labels.
inputs = tf.placeholder(tf.float32, name='input', shape=self.inputs_shape)
outputs = self.build_model(inputs, is_training=False)
# Run the model
inputs_val = np.random.rand(*self.inputs_shape).astype(float)
labels_val = np.zeros(self.labels_shape).astype(np.int64)
labels_val[:, 0] = 1
sess.run(tf.global_variables_initializer())
# Run a single train step.
sess.run(outputs, feed_dict={inputs: inputs_val})
all_saver = tf.train.Saver(save_relative_paths=True)
all_saver.save(sess, os.path.join(self.logdir, self.model_name))
tf_graph = os.path.join(self.logdir, self.model_name + '_train.pb')
with tf.io.gfile.GFile(tf_graph, 'wb') as f:
f.write(sess.graph_def.SerializeToString())
def restore_model(self, sess, ckpt_path, enable_ema=True, export_ckpt=None):
"""Restore variables from a given checkpoint."""
sess.run(tf.global_variables_initializer())
checkpoint = tf.train.latest_checkpoint(ckpt_path)
if enable_ema:
ema = tf.train.ExponentialMovingAverage(decay=0.0)
ema_vars = utils.get_ema_vars()
var_dict = ema.variables_to_restore(ema_vars)
ema_assign_op = ema.apply(ema_vars)
else:
var_dict = utils.get_ema_vars()
ema_assign_op = None
tf.train.get_or_create_global_step()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_dict, max_to_keep=1)
saver.restore(sess, checkpoint)
if export_ckpt:
print('export model to {}'.format(export_ckpt))
if ema_assign_op is not None:
sess.run(ema_assign_op)
saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
saver.save(sess, export_ckpt)
def eval_ckpt(self):
"""build and save the model into self.logdir."""
with tf.Graph().as_default(), tf.Session() as sess:
# Build model with inputs and labels.
inputs = tf.placeholder(tf.float32, name='input', shape=self.inputs_shape)
self.build_model(inputs, is_training=False)
self.restore_model(
sess, self.ckpt_path, self.enable_ema, self.export_ckpt)
def freeze_model(self) -> Tuple[Text, Text]:
"""Freeze model and convert them into tflite and tf graph."""
with tf.Graph().as_default(), tf.Session() as sess:
inputs = tf.placeholder(tf.float32, name='input', shape=self.inputs_shape)
outputs = self.build_model(inputs, is_training=False)
checkpoint = tf.train.latest_checkpoint(self.logdir)
logging.info('Loading checkpoint: %s', checkpoint)
saver = tf.train.Saver()
# Restore the Variables from the checkpoint and freeze the Graph.
saver.restore(sess, checkpoint)
output_node_names = [node.name.split(':')[0] for node in outputs]
graphdef = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_node_names)
return graphdef
def benchmark_model(self, warmup_runs, bm_runs, num_threads,
trace_filename=None):
"""Benchmark model."""
if self.tensorrt:
print('Using tensorrt ', self.tensorrt)
self.build_and_save_model()
graphdef = self.freeze_model()
if num_threads > 0:
print('num_threads for benchmarking: {}'.format(num_threads))
sess_config = tf.ConfigProto(
intra_op_parallelism_threads=num_threads,
inter_op_parallelism_threads=1)
else:
sess_config = tf.ConfigProto()
# rewriter_config_pb2.RewriterConfig.OFF
sess_config.graph_options.rewrite_options.dependency_optimization = 2
if self.use_xla:
sess_config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2)
with tf.Graph().as_default(), tf.Session(config=sess_config) as sess:
inputs = tf.placeholder(tf.float32, name='input', shape=self.inputs_shape)
output = self.build_model(inputs, is_training=False)
img = np.random.uniform(size=self.inputs_shape)
sess.run(tf.global_variables_initializer())
if self.tensorrt:
fetches = [inputs.name] + [i.name for i in output]
goutput = self.convert_tr(graphdef, fetches)
inputs, output = goutput[0], goutput[1:]
if not self.use_xla:
# Don't use tf.group because XLA removes the whole graph for tf.group.
output = tf.group(*output)
for i in range(warmup_runs):
start_time = time.time()
sess.run(output, feed_dict={inputs: img})
print('Warm up: {} {:.4f}s'.format(i, time.time() - start_time))
print('Start benchmark runs total={}'.format(bm_runs))
timev = []
for i in range(bm_runs):
if trace_filename and i == (bm_runs // 2):
run_options = tf.RunOptions()
run_options.trace_level = tf.RunOptions.FULL_TRACE
run_metadata = tf.RunMetadata()
sess.run(output, feed_dict={inputs: img},
options=run_options, run_metadata=run_metadata)
logging.info('Dumping trace to %s', trace_filename)
trace_dir = os.path.dirname(trace_filename)
if not tf.io.gfile.exists(trace_dir):
tf.io.gfile.makedirs(trace_dir)
with tf.io.gfile.GFile(trace_filename, 'w') as trace_file:
from tensorflow.python.client import timeline # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file.write(
trace.generate_chrome_trace_format(show_memory=True))
start_time = time.time()
sess.run(output, feed_dict={inputs: img})
timev.append(time.time() - start_time)
timev.sort()
timev = timev[2:bm_runs-2]
print('{} {}runs {}threads: mean {:.4f} std {:.4f} min {:.4f} max {:.4f}'
.format(self.model_name, len(timev), num_threads, np.mean(timev),
np.std(timev), np.min(timev), np.max(timev)))
print('Images per second FPS = {:.1f}'.format(
self.batch_size / float(np.mean(timev))))
def convert_tr(self, graph_def, fetches):
"""Convert to TensorRT."""
from tensorflow.python.compiler.tensorrt import trt # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
converter = trt.TrtGraphConverter(
nodes_blacklist=[t.split(':')[0] for t in fetches],
input_graph_def=graph_def,
precision_mode=self.tensorrt)
infer_graph = converter.convert()
goutput = tf.import_graph_def(infer_graph, return_elements=fetches)
return goutput
def run_model(self, runmode, threads=0):
"""Run the model on devices."""
if runmode == 'dry':
self.build_and_save_model()
elif runmode == 'freeze':
self.build_and_save_model()
self.freeze_model()
elif runmode == 'ckpt':
self.eval_ckpt()
elif runmode == 'saved_model_benchmark':
self.saved_model_benchmark(FLAGS.input_image)
elif runmode in ('infer', 'saved_model', 'saved_model_infer'):
config_dict = {}
if FLAGS.line_thickness:
config_dict['line_thickness'] = FLAGS.line_thickness
if FLAGS.max_boxes_to_draw:
config_dict['max_boxes_to_draw'] = FLAGS.max_boxes_to_draw
if FLAGS.min_score_thresh:
config_dict['min_score_thresh'] = FLAGS.min_score_thresh
if runmode == 'infer':
self.inference_single_image(
FLAGS.input_image, FLAGS.output_image_dir, **config_dict)
elif runmode == 'saved_model':
self.export_saved_model(**config_dict)
elif runmode == 'saved_model_infer':
self.saved_model_inference(
FLAGS.input_image, FLAGS.output_image_dir, **config_dict)
elif runmode == 'bm':
self.benchmark_model(warmup_runs=5, bm_runs=FLAGS.bm_runs,
num_threads=threads,
trace_filename=FLAGS.trace_filename)
def main(_):
if tf.io.gfile.exists(FLAGS.logdir) and FLAGS.delete_logdir:
logging.info('Deleting log dir ...')
tf.io.gfile.rmtree(FLAGS.logdir)
inspector = ModelInspector(
model_name=FLAGS.model_name,
image_size=FLAGS.input_image_size,
num_classes=FLAGS.num_classes,
logdir=FLAGS.logdir,
tensorrt=FLAGS.tensorrt,
use_xla=FLAGS.xla,
ckpt_path=FLAGS.ckpt_path,
enable_ema=FLAGS.enable_ema,
export_ckpt=FLAGS.export_ckpt,
saved_model_dir=FLAGS.saved_model_dir,
data_format=FLAGS.data_format)
inspector.run_model(FLAGS.runmode, FLAGS.threads)
if __name__ == '__main__':
logging.set_verbosity(logging.WARNING)
tf.disable_v2_behavior()
tf.app.run(main)