From bb8c17e7074b3b9fd715acc2064fb9197ffe459f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 9 Apr 2022 13:27:49 +0200 Subject: [PATCH] Add ONNX export metadata (#7353) --- export.py | 8 +++++++- models/common.py | 3 +++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/export.py b/export.py index ceb7862a49be..ecead3ef5a90 100644 --- a/export.py +++ b/export.py @@ -140,7 +140,13 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst # Checks model_onnx = onnx.load(f) # load onnx model onnx.checker.check_model(model_onnx) # check onnx model - # LOGGER.info(onnx.helper.printable_graph(model_onnx.graph)) # print + + # Metadata + d = {'stride': int(max(model.stride)), 'names': model.names} + for k, v in d.items(): + meta = model_onnx.metadata_props.add() + meta.key, meta.value = k, str(v) + onnx.save(model_onnx, f) # Simplify if simplify: diff --git a/models/common.py b/models/common.py index 5a83bce33fc8..49175f76a53a 100644 --- a/models/common.py +++ b/models/common.py @@ -328,6 +328,9 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, import onnxruntime providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] session = onnxruntime.InferenceSession(w, providers=providers) + meta = session.get_modelmeta().custom_metadata_map # metadata + if 'stride' in meta: + stride, names = int(meta['stride']), eval(meta['names']) elif xml: # OpenVINO LOGGER.info(f'Loading {w} for OpenVINO inference...') check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/