From 54f49fa581aac1d9909636bfc13f94001b08b55b Mon Sep 17 00:00:00 2001 From: paradigm Date: Tue, 25 Oct 2022 17:53:22 +0200 Subject: [PATCH] Add TFLite Metadata to TFLite and Edge TPU models (#9903) * added embedded meta data to tflite models * added try block for inference * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactored tfite meta data into separate function * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Creat tmp file in /tmp * Update export.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update export.py * Update export.py * Update export.py * Update export.py * Update common.py * Update export.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update common.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- export.py | 39 +++++++++++++++++++++++++++++++++++++-- models/common.py | 9 +++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/export.py b/export.py index 93845a0c14fa..e43d9b730fc6 100644 --- a/export.py +++ b/export.py @@ -45,6 +45,7 @@ """ import argparse +import contextlib import json import os import platform @@ -453,6 +454,39 @@ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): return f, None +def add_tflite_metadata(file, metadata, num_outputs): + # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata + with contextlib.suppress(ImportError): + # check_requirements('tflite_support') + from tflite_support import flatbuffers + from tflite_support import metadata as _metadata + from tflite_support import metadata_schema_py_generated as _metadata_fb + + tmp_file = Path('/tmp/meta.txt') + with open(tmp_file, 'w') as meta_f: + meta_f.write(str(metadata)) + + model_meta = _metadata_fb.ModelMetadataT() + label_file = _metadata_fb.AssociatedFileT() + label_file.name = tmp_file.name + model_meta.associatedFiles = [label_file] + + subgraph = _metadata_fb.SubGraphMetadataT() + subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()] + subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs + model_meta.subgraphMetadata = [subgraph] + + b = flatbuffers.Builder(0) + b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + populator = _metadata.MetadataPopulator.with_model_file(file) + populator.load_metadata_buffer(metadata_buf) + populator.load_associated_files([str(tmp_file)]) + populator.populate() + tmp_file.unlink() + + @smart_inference_mode() def run( data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' @@ -550,8 +584,9 @@ def run( f[6], _ = export_pb(s_model, file) if tflite or edgetpu: f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) - if edgetpu: - f[8], _ = export_edgetpu(file) + if edgetpu: + f[8], _ = export_edgetpu(file) + add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs)) if tfjs: f[9], _ = export_tfjs(file) if paddle: # PaddlePaddle diff --git a/models/common.py b/models/common.py index af8132fffb7a..6347e51cdf0b 100644 --- a/models/common.py +++ b/models/common.py @@ -3,10 +3,13 @@ Common modules """ +import ast +import contextlib import json import math import platform import warnings +import zipfile from collections import OrderedDict, namedtuple from copy import copy from pathlib import Path @@ -462,6 +465,12 @@ def gd_outputs(gd): interpreter.allocate_tensors() # allocate input_details = interpreter.get_input_details() # inputs output_details = interpreter.get_output_details() # outputs + # load metadata + with contextlib.suppress(zipfile.BadZipFile): + with zipfile.ZipFile(w, "r") as model: + meta_file = model.namelist()[0] + meta = ast.literal_eval(model.read(meta_file).decode("utf-8")) + stride, names = int(meta['stride']), meta['names'] elif tfjs: # TF.js raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported') elif paddle: # PaddlePaddle