diff --git a/export.py b/export.py index cdf5dcddd07a..262b11a1a268 100644 --- a/export.py +++ b/export.py @@ -15,6 +15,7 @@ TensorFlow Lite | `tflite` | yolov5s.tflite TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite TensorFlow.js | `tfjs` | yolov5s_web_model/ +PaddlePaddle | `paddle` | yolov5s_paddle_model/ Requirements: $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU @@ -54,7 +55,6 @@ import pandas as pd import torch -import yaml from torch.utils.mobile_optimizer import optimize_for_mobile FILE = Path(__file__).resolve() @@ -68,7 +68,7 @@ from models.yolo import ClassificationModel, Detect from utils.dataloaders import LoadImages from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version, - check_yaml, colorstr, file_size, get_default_args, print_args, url2file) + check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save) from utils.torch_utils import select_device, smart_inference_mode @@ -85,7 +85,8 @@ def export_formats(): ['TensorFlow GraphDef', 'pb', '.pb', True, True], ['TensorFlow Lite', 'tflite', '.tflite', True, False], ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False], - ['TensorFlow.js', 'tfjs', '_web_model', False, False],] + ['TensorFlow.js', 'tfjs', '_web_model', False, False], + ['PaddlePaddle', 'paddle', '_paddle_model', True, True],] return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) @@ -180,7 +181,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst @try_export -def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): +def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')): # YOLOv5 OpenVINO export check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/ import openvino.inference_engine as ie @@ -189,9 +190,23 @@ def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): f = str(file).replace('.pt', f'_openvino_model{os.sep}') cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}" - subprocess.check_output(cmd.split()) # export - with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g: - yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml + subprocess.run(cmd.split(), check=True, env=os.environ) # export + yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml + return f, None + + +@try_export +def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')): + # YOLOv5 Paddle export + check_requirements(('paddlepaddle', 'x2paddle')) + import x2paddle + from x2paddle.convert import pytorch2paddle + + LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...') + f = str(file).replace('.pt', f'_paddle_model{os.sep}') + + pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export + yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml return f, None @@ -464,7 +479,7 @@ def run( fmts = tuple(export_formats()['Argument'][1:]) # --include arguments flags = [x in include for x in fmts] assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}' - jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans + jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights # Load PyTorch model @@ -497,47 +512,48 @@ def run( if half and not coreml: im, model = im.half(), model.half() # to FP16 shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape + metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") # Exports - f = [''] * 10 # exported filenames + f = [''] * len(fmts) # exported filenames warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning - if jit: + if jit: # TorchScript f[0], _ = export_torchscript(model, im, file, optimize) if engine: # TensorRT required before ONNX f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose) if onnx or xml: # OpenVINO requires ONNX f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify) if xml: # OpenVINO - f[3], _ = export_openvino(model, file, half) - if coreml: + f[3], _ = export_openvino(file, metadata, half) + if coreml: # CoreML f[4], _ = export_coreml(model, im, file, int8, half) - - # TensorFlow Exports - if any((saved_model, pb, tflite, edgetpu, tfjs)): + if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 check_requirements('flatbuffers==1.12') # required before `import tensorflow` assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.' assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.' - f[5], model = export_saved_model(model.cpu(), - im, - file, - dynamic, - tf_nms=nms or agnostic_nms or tfjs, - agnostic_nms=agnostic_nms or tfjs, - topk_per_class=topk_per_class, - topk_all=topk_all, - iou_thres=iou_thres, - conf_thres=conf_thres, - keras=keras) + f[5], s_model = export_saved_model(model.cpu(), + im, + file, + dynamic, + tf_nms=nms or agnostic_nms or tfjs, + agnostic_nms=agnostic_nms or tfjs, + topk_per_class=topk_per_class, + topk_all=topk_all, + iou_thres=iou_thres, + conf_thres=conf_thres, + keras=keras) if pb or tfjs: # pb prerequisite to tfjs - f[6], _ = export_pb(model, file) + f[6], _ = export_pb(s_model, file) if tflite or edgetpu: - f[7], _ = export_tflite(model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) + f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) if edgetpu: f[8], _ = export_edgetpu(file) if tfjs: f[9], _ = export_tfjs(file) + if paddle: # PaddlePaddle + f[10], _ = export_paddle(model, im, file, metadata) # Finish f = [str(x) for x in f if x] # filter out '' and None diff --git a/models/common.py b/models/common.py index 0e01b60e81e5..396b5de0b505 100644 --- a/models/common.py +++ b/models/common.py @@ -320,14 +320,16 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, # TensorFlow GraphDef: *.pb # TensorFlow Lite: *.tflite # TensorFlow Edge TPU: *_edgetpu.tflite + # PaddlePaddle: *_paddle_model from models.experimental import attempt_download, attempt_load # scoped to avoid circular import super().__init__() w = str(weights[0] if isinstance(weights, list) else weights) - pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self._model_type(w) # get backend + pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = self._model_type(w) # type w = attempt_download(w) # download if not local fp16 &= pt or jit or onnx or engine # FP16 stride = 32 # default stride + cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA if pt: # PyTorch model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse) @@ -351,7 +353,6 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, net = cv2.dnn.readNetFromONNX(w) elif onnx: # ONNX Runtime LOGGER.info(f'Loading {w} for ONNX Runtime inference...') - cuda = torch.cuda.is_available() and device.type != 'cpu' check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime')) import onnxruntime providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] @@ -408,48 +409,60 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, LOGGER.info(f'Loading {w} for CoreML inference...') import coremltools as ct model = ct.models.MLModel(w) - else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) - if saved_model: # SavedModel - LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') - import tensorflow as tf - keras = False # assume TF1 saved_model - model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) - elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt - LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') + elif saved_model: # TF SavedModel + LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') + import tensorflow as tf + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) + elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') + import tensorflow as tf + + def wrap_frozen_graph(gd, inputs, outputs): + x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped + ge = x.graph.as_graph_element + return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(w, 'rb') as f: + gd.ParseFromString(f.read()) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0") + elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python + try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: import tensorflow as tf - - def wrap_frozen_graph(gd, inputs, outputs): - x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped - ge = x.graph.as_graph_element - return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) - - gd = tf.Graph().as_graph_def() # graph_def - with open(w, 'rb') as f: - gd.ParseFromString(f.read()) - frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0") - elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python - try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu - from tflite_runtime.interpreter import Interpreter, load_delegate - except ImportError: - import tensorflow as tf - Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate, - if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime - LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') - delegate = { - 'Linux': 'libedgetpu.so.1', - 'Darwin': 'libedgetpu.1.dylib', - 'Windows': 'edgetpu.dll'}[platform.system()] - interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)]) - else: # Lite - LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') - interpreter = Interpreter(model_path=w) # load TFLite model - interpreter.allocate_tensors() # allocate - input_details = interpreter.get_input_details() # inputs - output_details = interpreter.get_output_details() # outputs - elif tfjs: - raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported') - else: - raise NotImplementedError(f'ERROR: {w} is not a supported format') + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate, + if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime + LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...') + delegate = { + 'Linux': 'libedgetpu.so.1', + 'Darwin': 'libedgetpu.1.dylib', + 'Windows': 'edgetpu.dll'}[platform.system()] + interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)]) + else: # TFLite + LOGGER.info(f'Loading {w} for TensorFlow Lite inference...') + interpreter = Interpreter(model_path=w) # load TFLite model + interpreter.allocate_tensors() # allocate + input_details = interpreter.get_input_details() # inputs + output_details = interpreter.get_output_details() # outputs + elif tfjs: # TF.js + raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported') + elif paddle: # PaddlePaddle + LOGGER.info(f'Loading {w} for PaddlePaddle inference...') + check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle') + import paddle.inference as pdi + if not Path(w).is_file(): # if not *.pdmodel + w = next(Path(w).rglob('*.pdmodel')) # get *.xml file from *_openvino_model dir + weights = Path(w).with_suffix('.pdiparams') + config = pdi.Config(str(w), str(weights)) + if cuda: + config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) + predictor = pdi.create_predictor(config) + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + else: + raise NotImplementedError(f'ERROR: {w} is not a supported format') # class names if 'names' not in locals(): @@ -502,6 +515,13 @@ def forward(self, im, augment=False, visualize=False): else: k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key y = y[k] # output + elif self.paddle: # PaddlePaddle + im = im.cpu().numpy().astype("float32") + self.input_handle.copy_from_cpu(im) + self.predictor.run() + output_names = self.predictor.get_output_names() + output_handle = self.predictor.get_output_handle(output_names[0]) + y = output_handle.copy_to_cpu() else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) if self.saved_model: # SavedModel @@ -542,13 +562,13 @@ def warmup(self, imgsz=(1, 3, 640, 640)): def _model_type(p='path/to/model.pt'): # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx from export import export_formats - suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes - check_suffix(p, suffixes) # checks + sf = list(export_formats().Suffix) + ['.xml'] # export suffixes + check_suffix(p, sf) # checks p = Path(p).name # eliminate trailing separators - pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes) + pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, xml2 = (s in p for s in sf) xml |= xml2 # *_openvino_model or *.xml tflite &= not edgetpu # *.tflite - return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs + return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle @staticmethod def _load_metadata(f=Path('path/to/meta.yaml')): diff --git a/utils/benchmarks.py b/utils/benchmarks.py index d5f4c1d61fbe..9d5c7f2965d5 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -61,7 +61,7 @@ def run( device = select_device(device) for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU) try: - assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported + assert i not in (9, 10, 11), 'inference not supported' # Edge TPU, TF.js and Paddle are unsupported assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML if 'cpu' in device.type: assert cpu, 'inference not supported on CPU'