diff --git a/export.py b/export.py index 93845a0c14fa..e43d9b730fc6 100644 --- a/export.py +++ b/export.py @@ -45,6 +45,7 @@ """ import argparse +import contextlib import json import os import platform @@ -453,6 +454,39 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): return f, None +def add_tflite_metadata(file, metadata, num_outputs): + # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata + with contextlib.suppress(ImportError): + # check_requirements('tflite_support') + from tflite_support import flatbuffers + from tflite_support import metadata as _metadata + from tflite_support import metadata_schema_py_generated as _metadata_fb + + tmp_file = Path('/tmp/meta.txt') + with open(tmp_file, 'w') as meta_f: + meta_f.write(str(metadata)) + + model_meta = _metadata_fb.ModelMetadataT() + label_file = _metadata_fb.AssociatedFileT() + label_file.name = tmp_file.name + model_meta.associatedFiles = [label_file] + + subgraph = _metadata_fb.SubGraphMetadataT() + subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()] + subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs + model_meta.subgraphMetadata = [subgraph] + + b = flatbuffers.Builder(0) + b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + populator = _metadata.MetadataPopulator.with_model_file(file) + populator.load_metadata_buffer(metadata_buf) + populator.load_associated_files([str(tmp_file)]) + populator.populate() + tmp_file.unlink() + + @smart_inference_mode() def run( data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' @@ -550,8 +584,9 @@ def run( f[6], _ = export_pb(s_model, file) if tflite or edgetpu: f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) - if edgetpu: - f[8], _ = export_edgetpu(file) + if edgetpu: + f[8], _ = export_edgetpu(file) + add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs)) if tfjs: f[9], _ = export_tfjs(file) if paddle: # PaddlePaddle diff --git a/models/common.py b/models/common.py index af8132fffb7a..6347e51cdf0b 100644 --- a/models/common.py +++ b/models/common.py @@ -3,10 +3,13 @@ Common modules """ +import ast +import contextlib import json import math import platform import warnings +import zipfile from collections import OrderedDict, namedtuple from copy import copy from pathlib import Path @@ -462,6 +465,12 @@ def gd_outputs(gd): interpreter.allocate_tensors() # allocate input_details = interpreter.get_input_details() # inputs output_details = interpreter.get_output_details() # outputs + # load metadata + with contextlib.suppress(zipfile.BadZipFile): + with zipfile.ZipFile(w, "r") as model: + meta_file = model.namelist()[0] + meta = ast.literal_eval(model.read(meta_file).decode("utf-8")) + stride, names = int(meta['stride']), meta['names'] elif tfjs: # TF.js raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported') elif paddle: # PaddlePaddle