Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PaddlePaddle export and inference #9240

Merged
merged 38 commits into from
Sep 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b740d2f
Add PaddlePaddle Model Export
kisaragychihaya Sep 1, 2022
8b9bbf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2022
a82cf77
Cleanup Paddle Export
glenn-jocher Sep 3, 2022
0b02870
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2022
5be74f3
Update common.py
glenn-jocher Sep 3, 2022
ceb3fa0
Merge branch 'master' into master
glenn-jocher Sep 3, 2022
dce0b6e
Update export.py
glenn-jocher Sep 3, 2022
bad1b3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2022
dd72df3
Update export.py
glenn-jocher Sep 3, 2022
4b93588
Update export.py
glenn-jocher Sep 3, 2022
83c76cd
Update export.py
glenn-jocher Sep 3, 2022
9246755
Use PyTorch2Paddle
glenn-jocher Sep 3, 2022
8b994cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2022
953d587
Paddle no longer requires ONNX
glenn-jocher Sep 3, 2022
a776476
Update export.py
glenn-jocher Sep 3, 2022
59973ed
Update export.py
glenn-jocher Sep 3, 2022
f462079
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2022
19a39e0
Update benchmarks.py
glenn-jocher Sep 3, 2022
102463e
Merge branch 'master' into master
glenn-jocher Sep 3, 2022
6a2056e
Merge branch 'master' into master
glenn-jocher Sep 4, 2022
03f1fd5
Merge branch 'master' into master
glenn-jocher Sep 4, 2022
d6d5571
Merge branch 'master' into master
glenn-jocher Sep 4, 2022
8aae460
Add inference code of PaddlePaddle
kisaragychihaya Sep 6, 2022
cc094f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2022
190291d
Update common.py
kisaragychihaya Sep 6, 2022
b8f6694
Merge branch 'master' into master
glenn-jocher Sep 6, 2022
005cf01
Update common.py
glenn-jocher Sep 6, 2022
9861170
Add paddlepaddle-gpu install if cuda
glenn-jocher Sep 7, 2022
5c4f26f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
7b424e6
Update common.py
glenn-jocher Sep 7, 2022
452f5c8
Update common.py
glenn-jocher Sep 7, 2022
cb4b04d
Merge branch 'master' into master
kisaragychihaya Sep 7, 2022
209f645
Merge branch 'master' into master
glenn-jocher Sep 7, 2022
fa0999d
Merge branch 'master' into master
glenn-jocher Sep 9, 2022
afa0d93
Merge branch 'master' into master
glenn-jocher Sep 10, 2022
2ce58c3
Update common.py
glenn-jocher Sep 10, 2022
ea5bd98
Merge branch 'master' into master
glenn-jocher Sep 10, 2022
7004d95
Merge branch 'master' into master
glenn-jocher Sep 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TensorFlow Lite | `tflite` | yolov5s.tflite
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov5s_web_model/
PaddlePaddle | `paddle` | yolov5s_paddle_model/

Requirements:
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
Expand Down Expand Up @@ -54,7 +55,6 @@

import pandas as pd
import torch
import yaml
from torch.utils.mobile_optimizer import optimize_for_mobile

FILE = Path(__file__).resolve()
Expand All @@ -68,7 +68,7 @@
from models.yolo import ClassificationModel, Detect
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
check_yaml, colorstr, file_size, get_default_args, print_args, url2file)
check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
from utils.torch_utils import select_device, smart_inference_mode


Expand All @@ -85,7 +85,8 @@ def export_formats():
['TensorFlow GraphDef', 'pb', '.pb', True, True],
['TensorFlow Lite', 'tflite', '.tflite', True, False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
['TensorFlow.js', 'tfjs', '_web_model', False, False],]
['TensorFlow.js', 'tfjs', '_web_model', False, False],
['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])


Expand Down Expand Up @@ -180,7 +181,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst


@try_export
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
# YOLOv5 OpenVINO export
check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.inference_engine as ie
Expand All @@ -189,9 +190,23 @@ def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
f = str(file).replace('.pt', f'_openvino_model{os.sep}')

cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
subprocess.check_output(cmd.split()) # export
with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
subprocess.run(cmd.split(), check=True, env=os.environ) # export
yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
return f, None


@try_export
def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):
# YOLOv5 Paddle export
check_requirements(('paddlepaddle', 'x2paddle'))
import x2paddle
from x2paddle.convert import pytorch2paddle

LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
f = str(file).replace('.pt', f'_paddle_model{os.sep}')

pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export
yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
return f, None


Expand Down Expand Up @@ -464,7 +479,7 @@ def run(
fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
flags = [x in include for x in fmts]
assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights

# Load PyTorch model
Expand Down Expand Up @@ -497,47 +512,48 @@ def run(
if half and not coreml:
im, model = im.half(), model.half() # to FP16
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")

# Exports
f = [''] * 10 # exported filenames
f = [''] * len(fmts) # exported filenames
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
if jit:
if jit: # TorchScript
f[0], _ = export_torchscript(model, im, file, optimize)
if engine: # TensorRT required before ONNX
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
if onnx or xml: # OpenVINO requires ONNX
f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
if xml: # OpenVINO
f[3], _ = export_openvino(model, file, half)
if coreml:
f[3], _ = export_openvino(file, metadata, half)
if coreml: # CoreML
f[4], _ = export_coreml(model, im, file, int8, half)

# TensorFlow Exports
if any((saved_model, pb, tflite, edgetpu, tfjs)):
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
check_requirements('flatbuffers==1.12') # required before `import tensorflow`
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
f[5], model = export_saved_model(model.cpu(),
im,
file,
dynamic,
tf_nms=nms or agnostic_nms or tfjs,
agnostic_nms=agnostic_nms or tfjs,
topk_per_class=topk_per_class,
topk_all=topk_all,
iou_thres=iou_thres,
conf_thres=conf_thres,
keras=keras)
f[5], s_model = export_saved_model(model.cpu(),
im,
file,
dynamic,
tf_nms=nms or agnostic_nms or tfjs,
agnostic_nms=agnostic_nms or tfjs,
topk_per_class=topk_per_class,
topk_all=topk_all,
iou_thres=iou_thres,
conf_thres=conf_thres,
keras=keras)
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = export_pb(model, file)
f[6], _ = export_pb(s_model, file)
if tflite or edgetpu:
f[7], _ = export_tflite(model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
if edgetpu:
f[8], _ = export_edgetpu(file)
if tfjs:
f[9], _ = export_tfjs(file)
if paddle: # PaddlePaddle
f[10], _ = export_paddle(model, im, file, metadata)

# Finish
f = [str(x) for x in f if x] # filter out '' and None
Expand Down
114 changes: 67 additions & 47 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,16 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
# TensorFlow GraphDef: *.pb
# TensorFlow Lite: *.tflite
# TensorFlow Edge TPU: *_edgetpu.tflite
# PaddlePaddle: *_paddle_model
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import

super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self._model_type(w) # get backend
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = self._model_type(w) # type
w = attempt_download(w) # download if not local
fp16 &= pt or jit or onnx or engine # FP16
stride = 32 # default stride
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA

if pt: # PyTorch
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
Expand All @@ -351,7 +353,6 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
cuda = torch.cuda.is_available() and device.type != 'cpu'
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
import onnxruntime
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
Expand Down Expand Up @@ -408,48 +409,60 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
LOGGER.info(f'Loading {w} for CoreML inference...')
import coremltools as ct
model = ct.models.MLModel(w)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
if saved_model: # SavedModel
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf
keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
elif saved_model: # TF SavedModel
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf
keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
import tensorflow as tf

def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))

gd = tf.Graph().as_graph_def() # TF GraphDef
with open(w, 'rb') as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf

def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))

gd = tf.Graph().as_graph_def() # graph_def
with open(w, 'rb') as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
if edgetpu: # Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
delegate = {
'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
else: # Lite
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
interpreter = Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
elif tfjs:
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
else:
raise NotImplementedError(f'ERROR: {w} is not a supported format')
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
delegate = {
'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
else: # TFLite
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
interpreter = Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
elif tfjs: # TF.js
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
elif paddle: # PaddlePaddle
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
import paddle.inference as pdi
if not Path(w).is_file(): # if not *.pdmodel
w = next(Path(w).rglob('*.pdmodel')) # get *.xml file from *_openvino_model dir
weights = Path(w).with_suffix('.pdiparams')
config = pdi.Config(str(w), str(weights))
if cuda:
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
predictor = pdi.create_predictor(config)
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
else:
raise NotImplementedError(f'ERROR: {w} is not a supported format')

# class names
if 'names' not in locals():
Expand Down Expand Up @@ -502,6 +515,13 @@ def forward(self, im, augment=False, visualize=False):
else:
k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1]) # output key
y = y[k] # output
elif self.paddle: # PaddlePaddle
im = im.cpu().numpy().astype("float32")
self.input_handle.copy_from_cpu(im)
self.predictor.run()
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0])
y = output_handle.copy_to_cpu()
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel
Expand Down Expand Up @@ -542,13 +562,13 @@ def warmup(self, imgsz=(1, 3, 640, 640)):
def _model_type(p='path/to/model.pt'):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
from export import export_formats
suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
check_suffix(p, suffixes) # checks
sf = list(export_formats().Suffix) + ['.xml'] # export suffixes
check_suffix(p, sf) # checks
p = Path(p).name # eliminate trailing separators
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, xml2 = (s in p for s in sf)
xml |= xml2 # *_openvino_model or *.xml
tflite &= not edgetpu # *.tflite
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle

@staticmethod
def _load_metadata(f=Path('path/to/meta.yaml')):
Expand Down
2 changes: 1 addition & 1 deletion utils/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def run(
device = select_device(device)
for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU)
try:
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
assert i not in (9, 10, 11), 'inference not supported' # Edge TPU, TF.js and Paddle are unsupported
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
if 'cpu' in device.type:
assert cpu, 'inference not supported on CPU'
Expand Down