From 6451f55d58c4580908c17086012819f059fbcae9 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 3 Jan 2022 19:45:23 -0800 Subject: [PATCH 1/2] Global export sort --- detect.py | 6 +- export.py | 148 +++++++++++++++++++++++------------------------ models/common.py | 68 +++++++++++----------- val.py | 6 +- 4 files changed, 114 insertions(+), 114 deletions(-) diff --git a/detect.py b/detect.py index 6aa5b825da48..a4a1fb69b42e 100644 --- a/detect.py +++ b/detect.py @@ -15,13 +15,13 @@ $ python path/to/detect.py --weights yolov5s.pt # PyTorch yolov5s.torchscript # TorchScript yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn - yolov5s.mlmodel # CoreML (under development) yolov5s.xml # OpenVINO + yolov5s.engine # TensorRT + yolov5s.mlmodel # CoreML (under development) yolov5s_saved_model # TensorFlow SavedModel - yolov5s.pb # TensorFlow protobuf + yolov5s.pb # TensorFlow GraphDef yolov5s.tflite # TensorFlow Lite yolov5s_edgetpu.tflite # TensorFlow Edge TPU - yolov5s.engine # TensorRT """ import argparse diff --git a/export.py b/export.py index fa40864ac378..3b677d2ca144 100644 --- a/export.py +++ b/export.py @@ -2,19 +2,19 @@ """ Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit -Format | Example | `--include ...` argument +Format | `export.py --include` | Model --- | --- | --- -PyTorch | yolov5s.pt | - -TorchScript | yolov5s.torchscript | `torchscript` -ONNX | yolov5s.onnx | `onnx` -CoreML | yolov5s.mlmodel | `coreml` -OpenVINO | yolov5s_openvino_model/ | `openvino` -TensorFlow SavedModel | yolov5s_saved_model/ | `saved_model` -TensorFlow GraphDef | yolov5s.pb | `pb` -TensorFlow Lite | yolov5s.tflite | `tflite` -TensorFlow Edge TPU | yolov5s_edgetpu.tflite | `edgetpu` -TensorFlow.js | yolov5s_web_model/ | `tfjs` -TensorRT | yolov5s.engine | `engine` +PyTorch | - | yolov5s.pt +TorchScript | `torchscript` | yolov5s.torchscript +ONNX | `onnx` | yolov5s.onnx +OpenVINO | `openvino` | yolov5s_openvino_model/ +TensorRT | `engine` | yolov5s.engine +CoreML | `coreml` | yolov5s.mlmodel +TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/ +TensorFlow GraphDef | `pb` | yolov5s.pb +TensorFlow Lite | `tflite` | yolov5s.tflite +TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite +TensorFlow.js | `tfjs` | yolov5s_web_model/ Usage: $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml openvino saved_model tflite tfjs @@ -23,13 +23,13 @@ $ python path/to/detect.py --weights yolov5s.pt # PyTorch yolov5s.torchscript # TorchScript yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn - yolov5s.mlmodel # CoreML (under development) yolov5s.xml # OpenVINO + yolov5s.engine # TensorRT + yolov5s.mlmodel # CoreML (under development) yolov5s_saved_model # TensorFlow SavedModel - yolov5s.pb # TensorFlow protobuf + yolov5s.pb # TensorFlow GraphDef yolov5s.tflite # TensorFlow Lite yolov5s_edgetpu.tflite # TensorFlow Edge TPU - yolov5s.engine # TensorRT TensorFlow.js: $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example @@ -126,6 +126,23 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst LOGGER.info(f'{prefix} export failure: {e}') +def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')): + # YOLOv5 OpenVINO export + try: + check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ + import openvino.inference_engine as ie + + LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') + f = str(file).replace('.pt', '_openvino_model' + os.sep) + + cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}" + subprocess.check_output(cmd, shell=True) + + LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + except Exception as e: + LOGGER.info(f'\n{prefix} export failure: {e}') + + def export_coreml(model, im, file, prefix=colorstr('CoreML:')): # YOLOv5 CoreML export ct_model = None @@ -148,27 +165,57 @@ def export_coreml(model, im, file, prefix=colorstr('CoreML:')): return ct_model -def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')): - # YOLOv5 OpenVINO export +def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): + # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt try: - check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ - import openvino.inference_engine as ie + check_requirements(('tensorrt',)) + import tensorrt as trt - LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') - f = str(file).replace('.pt', '_openvino_model' + os.sep) + opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x + export_onnx(model, im, file, opset, train, False, simplify) + onnx = file.with_suffix('.onnx') + assert onnx.exists(), f'failed to export ONNX file: {onnx}' - cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}" - subprocess.check_output(cmd, shell=True) + LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') + f = file.with_suffix('.engine') # TensorRT engine file + logger = trt.Logger(trt.Logger.INFO) + if verbose: + logger.min_severity = trt.Logger.Severity.VERBOSE + + builder = trt.Builder(logger) + config = builder.create_builder_config() + config.max_workspace_size = workspace * 1 << 30 + + flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + network = builder.create_network(flag) + parser = trt.OnnxParser(network, logger) + if not parser.parse_from_file(str(onnx)): + raise RuntimeError(f'failed to load ONNX file: {onnx}') + + inputs = [network.get_input(i) for i in range(network.num_inputs)] + outputs = [network.get_output(i) for i in range(network.num_outputs)] + LOGGER.info(f'{prefix} Network Description:') + for inp in inputs: + LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') + for out in outputs: + LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') + half &= builder.platform_has_fast_fp16 + LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}') + if half: + config.set_flag(trt.BuilderFlag.FP16) + with builder.build_engine(network, config) as engine, open(f, 'wb') as t: + t.write(engine.serialize()) LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + except Exception as e: LOGGER.info(f'\n{prefix} export failure: {e}') def export_saved_model(model, im, file, dynamic, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45, - conf_thres=0.25, prefix=colorstr('TensorFlow saved_model:')): - # YOLOv5 TensorFlow saved_model export + conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')): + # YOLOv5 TensorFlow SavedModel export keras_model = None try: import tensorflow as tf @@ -304,53 +351,6 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): LOGGER.info(f'\n{prefix} export failure: {e}') -def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): - # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt - try: - check_requirements(('tensorrt',)) - import tensorrt as trt - - opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x - export_onnx(model, im, file, opset, train, False, simplify) - onnx = file.with_suffix('.onnx') - assert onnx.exists(), f'failed to export ONNX file: {onnx}' - - LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') - f = file.with_suffix('.engine') # TensorRT engine file - logger = trt.Logger(trt.Logger.INFO) - if verbose: - logger.min_severity = trt.Logger.Severity.VERBOSE - - builder = trt.Builder(logger) - config = builder.create_builder_config() - config.max_workspace_size = workspace * 1 << 30 - - flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) - network = builder.create_network(flag) - parser = trt.OnnxParser(network, logger) - if not parser.parse_from_file(str(onnx)): - raise RuntimeError(f'failed to load ONNX file: {onnx}') - - inputs = [network.get_input(i) for i in range(network.num_inputs)] - outputs = [network.get_output(i) for i in range(network.num_outputs)] - LOGGER.info(f'{prefix} Network Description:') - for inp in inputs: - LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') - for out in outputs: - LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') - - half &= builder.platform_has_fast_fp16 - LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}') - if half: - config.set_flag(trt.BuilderFlag.FP16) - with builder.build_engine(network, config) as engine, open(f, 'wb') as t: - t.write(engine.serialize()) - LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') - - except Exception as e: - LOGGER.info(f'\n{prefix} export failure: {e}') - - @torch.no_grad() def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' weights=ROOT / 'yolov5s.pt', # weights path @@ -417,12 +417,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' export_torchscript(model, im, file, optimize) if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX export_onnx(model, im, file, opset, train, dynamic, simplify) + if 'openvino' in include: + export_openvino(model, im, file) if 'engine' in include: export_engine(model, im, file, train, half, simplify, workspace, verbose) if 'coreml' in include: export_coreml(model, im, file) - if 'openvino' in include: - export_openvino(model, im, file) # TensorFlow Exports if any(tf_exports): diff --git a/models/common.py b/models/common.py index 519ce611d7ef..37e84b527141 100644 --- a/models/common.py +++ b/models/common.py @@ -316,17 +316,6 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): if extra_files['config.txt']: d = json.loads(extra_files['config.txt']) # extra_files dict stride, names = int(d['stride']), d['names'] - elif coreml: # CoreML - LOGGER.info(f'Loading {w} for CoreML inference...') - import coremltools as ct - model = ct.models.MLModel(w) - elif xml: # OpenVINO - LOGGER.info(f'Loading {w} for OpenVINO inference...') - check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ - import openvino.inference_engine as ie - core = ie.IECore() - network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths - executable_network = core.load_network(network, device_name='CPU', num_requests=1) elif dnn: # ONNX OpenCV DNN LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...') check_requirements(('opencv-python>=4.5.4',)) @@ -338,6 +327,13 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): import onnxruntime providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] session = onnxruntime.InferenceSession(w, providers=providers) + elif xml: # OpenVINO + LOGGER.info(f'Loading {w} for OpenVINO inference...') + check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ + import openvino.inference_engine as ie + core = ie.IECore() + network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths + executable_network = core.load_network(network, device_name='CPU', num_requests=1) elif engine: # TensorRT LOGGER.info(f'Loading {w} for TensorRT inference...') import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download @@ -356,9 +352,17 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) context = model.create_execution_context() batch_size = bindings['images'].shape[0] - else: # TensorFlow (TFLite, pb, saved_model) - if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt - LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...') + elif coreml: # CoreML + LOGGER.info(f'Loading {w} for CoreML inference...') + import coremltools as ct + model = ct.models.MLModel(w) + else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + if saved_model: # SavedModel + LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') + import tensorflow as tf + model = tf.keras.models.load_model(w) + elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') import tensorflow as tf def wrap_frozen_graph(gd, inputs, outputs): @@ -369,19 +373,15 @@ def wrap_frozen_graph(gd, inputs, outputs): graph_def = tf.Graph().as_graph_def() graph_def.ParseFromString(open(w, 'rb').read()) frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0") - elif saved_model: - LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...') - import tensorflow as tf - model = tf.keras.models.load_model(w) elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python - if 'edgetpu' in w.lower(): + if 'edgetpu' in w.lower(): # Edge TPU LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') import tflite_runtime.interpreter as tfli delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime 'Darwin': 'libedgetpu.1.dylib', 'Windows': 'edgetpu.dll'}[platform.system()] interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)]) - else: + else: # Lite LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') import tensorflow as tf interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model @@ -396,21 +396,13 @@ def forward(self, im, augment=False, visualize=False, val=False): if self.pt or self.jit: # PyTorch y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize) return y if val else y[0] - elif self.coreml: # CoreML - im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) - im = Image.fromarray((im[0] * 255).astype('uint8')) - # im = im.resize((192, 320), Image.ANTIALIAS) - y = self.model.predict({'image': im}) # coordinates are xywh normalized - box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels - conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float) - y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) - elif self.onnx: # ONNX + elif self.dnn: # ONNX OpenCV DNN + im = im.cpu().numpy() # torch to numpy + self.net.setInput(im) + y = self.net.forward() + elif self.onnx: # ONNX Runtime im = im.cpu().numpy() # torch to numpy - if self.dnn: # ONNX OpenCV DNN - self.net.setInput(im) - y = self.net.forward() - else: # ONNX Runtime - y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0] + y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0] elif self.xml: # OpenVINO im = im.cpu().numpy() # FP32 desc = self.ie.TensorDesc(precision='FP32', dims=im.shape, layout='NCHW') # Tensor Description @@ -423,6 +415,14 @@ def forward(self, im, augment=False, visualize=False, val=False): self.binding_addrs['images'] = int(im.data_ptr()) self.context.execute_v2(list(self.binding_addrs.values())) y = self.bindings['output'].data + elif self.coreml: # CoreML + im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) + im = Image.fromarray((im[0] * 255).astype('uint8')) + # im = im.resize((192, 320), Image.ANTIALIAS) + y = self.model.predict({'image': im}) # coordinates are xywh normalized + box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels + conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float) + y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) else: # TensorFlow model (TFLite, pb, saved_model) im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) if self.pb: diff --git a/val.py b/val.py index 704a7a46eb38..4d707f62bffa 100644 --- a/val.py +++ b/val.py @@ -9,13 +9,13 @@ $ python path/to/val.py --weights yolov5s.pt # PyTorch yolov5s.torchscript # TorchScript yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn - yolov5s.mlmodel # CoreML (under development) yolov5s.xml # OpenVINO + yolov5s.engine # TensorRT + yolov5s.mlmodel # CoreML (under development) yolov5s_saved_model # TensorFlow SavedModel - yolov5s.pb # TensorFlow protobuf + yolov5s.pb # TensorFlow GraphDef yolov5s.tflite # TensorFlow Lite yolov5s_edgetpu.tflite # TensorFlow Edge TPU - yolov5s.engine # TensorRT """ import argparse From 72644026d3d06ff3da35ee16525d7741763af984 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 3 Jan 2022 20:00:00 -0800 Subject: [PATCH 2/2] Cleanup --- models/common.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/models/common.py b/models/common.py index 37e84b527141..284dd2bb3af0 100644 --- a/models/common.py +++ b/models/common.py @@ -423,13 +423,13 @@ def forward(self, im, augment=False, visualize=False, val=False): box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float) y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) - else: # TensorFlow model (TFLite, pb, saved_model) + else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) - if self.pb: - y = self.frozen_func(x=self.tf.constant(im)).numpy() - elif self.saved_model: + if self.saved_model: # SavedModel y = self.model(im, training=False).numpy() - elif self.tflite: + elif self.pb: # GraphDef + y = self.frozen_func(x=self.tf.constant(im)).numpy() + elif self.tflite: # Lite input, output = self.input_details[0], self.output_details[0] int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model if int8: @@ -451,7 +451,7 @@ def forward(self, im, augment=False, visualize=False, val=False): def warmup(self, imgsz=(1, 3, 640, 640), half=False): # Warmup model by running inference once - if self.pt or self.engine or self.onnx: # warmup types + if self.pt or self.jit or self.onnx or self.engine: # warmup types if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image self.forward(im) # warmup