From 2c2ef25f8bb351b34aef89f8fce75742c698e847 Mon Sep 17 00:00:00 2001 From: Jiacong Fang Date: Sat, 25 Sep 2021 05:18:15 +0800 Subject: [PATCH] TensorFlow.js export enhancements (#4905) * Add arguments to TensorFlow NMS call * Add regex substitution to reorder Identity_* * Delete reorder in docstring * Cleanup * Cleanup2 * Removed `+ \` on string ends (not needed) Co-authored-by: Glenn Jocher --- export.py | 29 +++++++++++++++++++++++++++-- models/tf.py | 2 +- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/export.py b/export.py index e876af234592..d5b63c410af8 100644 --- a/export.py +++ b/export.py @@ -14,7 +14,6 @@ yolov5s.tflite TensorFlow.js: - $ # Edit yolov5s_web_model/model.json to sort Identity* in ascending order $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example $ npm install $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model @@ -213,16 +212,32 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): # YOLOv5 TensorFlow.js export try: check_requirements(('tensorflowjs',)) + import re import tensorflowjs as tfjs print(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') f = str(file).replace('.pt', '_web_model') # js dir f_pb = file.with_suffix('.pb') # *.pb path + f_json = f + '/model.json' # *.json path cmd = f"tensorflowjs_converter --input_format=tf_frozen_model " \ f"--output_node_names='Identity,Identity_1,Identity_2,Identity_3' {f_pb} {f}" subprocess.run(cmd, shell=True) + json = open(f_json).read() + with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order + subst = re.sub( + r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, ' + r'"Identity.?.?": {"name": "Identity.?.?"}, ' + r'"Identity.?.?": {"name": "Identity.?.?"}, ' + r'"Identity.?.?": {"name": "Identity.?.?"}}}', + r'{"outputs": {"Identity": {"name": "Identity"}, ' + r'"Identity_1": {"name": "Identity_1"}, ' + r'"Identity_2": {"name": "Identity_2"}, ' + r'"Identity_3": {"name": "Identity_3"}}}', + json) + j.write(subst) + print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') except Exception as e: print(f'\n{prefix} export failure: {e}') @@ -243,6 +258,10 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' dynamic=False, # ONNX/TF: dynamic axes simplify=False, # ONNX: simplify model opset=12, # ONNX: opset version + topk_per_class=100, # TF.js NMS: topk per class to keep + topk_all=100, # TF.js NMS: topk for all classes to keep + iou_thres=0.45, # TF.js NMS: IoU threshold + conf_thres=0.25 # TF.js NMS: confidence threshold ): t = time.time() include = [x.lower() for x in include] @@ -290,7 +309,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' if any(tf_exports): pb, tflite, tfjs = tf_exports[1:] assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' - model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs) # keras model + model = export_saved_model(model, im, file, dynamic, tf_nms=tfjs, agnostic_nms=tfjs, + topk_per_class=topk_per_class, topk_all=topk_all, conf_thres=conf_thres, + iou_thres=iou_thres) # keras model if pb or tfjs: # pb prerequisite to tfjs export_pb(model, im, file) if tflite: @@ -319,6 +340,10 @@ def parse_opt(): parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version') + parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') + parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') + parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') + parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') parser.add_argument('--include', nargs='+', default=['torchscript', 'onnx'], help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)') diff --git a/models/tf.py b/models/tf.py index 3265b7b75f55..b7d99359c863 100644 --- a/models/tf.py +++ b/models/tf.py @@ -367,7 +367,7 @@ class AgnosticNMS(keras.layers.Layer): # TF Agnostic NMS def call(self, input, topk_all, iou_thres, conf_thres): # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450 - return tf.map_fn(self._nms, input, + return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres), input, fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32), name='agnostic_nms')