Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorFlow and TFLite export #1127

Merged
merged 71 commits into from
Aug 17, 2021

Conversation

zldrobit
Copy link
Contributor

@zldrobit zldrobit commented Oct 13, 2020

Since this PR has been merged into the master branch (and some code changes), TensorFlow/TFLite models can be exported using

python export.py --weights yolov5s.pt --include saved_model pb tflite [--int8] --img 640
python export.py --weights yolov5s.pt --include tfjs --img 640

and validated using

python detect.py --weights yolov5s_saved_model --img 640
                           yolov5s.pb
                           yolov5s-fp16.tflite
                           yolov5s-int8.tflite

After exporting TFLite models, https://github.com/zldrobit/yolov5/tree/tf-android can be used as an Android demo.

For Edge TPU model export, plz refer to #3630.

Original export method (obsoleted) This PR is a simplified version of (https://github.com//pull/959), which only adds TensorFlow and TFLite export functionality.

Export TensorFlow models (GraphDef and saved model) and fp16 TFLite models using:

python models/tf.py --weights weights/yolov5s.pt --cfg models/yolov5s.yaml --img 320

Export int8 quantized TFLite models using:

python3 models/tf.py --weights weights/yolov5s.pt --cfg models/yolov5s.yaml --tfl-int8 --source /data/dataset/coco/coco2017/train2017 --ncalib 100

Run TensorFlow/TFLite model inference using:

python3 detect.py --weights weights/yolov5s.pb          --img 320
                            weights/yolov5s_saved_model
                            weights/yolov5s-fp16.tflite
                            weights/yolov5s-int8.tflite           --tfl-int8
For TensorFlow.js

Export *.pb* model with class-agnostic NMS (tf.image.non_max_suppression)

python3 models/tf.py --weights weights/yolov5s.pt --cfg models/yolov5s.yaml --img 320 --tf-nms --agnostic-nms

Convert *.pb to a tfjs model with:

# install tfjs converter: pip install tensorflowjs
tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='Identity,Identity_1,Identity_2,Identity_3' \
    weights/yolov5s.pb \
    weights/web_model

Edit weights/web_model/model.json to shuffle output node order (see tensorflow/tfjs#3942):
"signature": {"outputs": {"Identity": {"name": "Identity"}, "Identity_3": {"name": "Identity_3"}, "Identity_1": {"name": "Identity_1"}, "Identity_2": {"name": "Identity_2"}}}->
"signature": {"outputs": {"Identity": {"name": "Identity"}, "Identity_1": {"name": "Identity_1"}, "Identity_2": {"name": "Identity_2"}, "Identity_3": {"name": "Identity_3"}}}.
using:

sed -i 's/{"outputs": {"Identity.\?.\?": {"name": "Identity.\?.\?"}, "Identity.\?.\?": {"name": "Identity.\?.\?"}, "Identity.\?.\?": {"name": "Identity.\?.\?"}, "Identity.\?.\?": {"name": "Identity.\?.\?"}}}/'\
'{"outputs": {"Identity": {"name": "Identity"}, "Identity_1": {"name": "Identity_1"}, "Identity_2": {"name": "Identity_2"}, "Identity_3": {"name": "Identity_3"}}}/' \
./weights/web_model/model.json

Deploy the model with tfjs-yolov5-example.

This PR is tested successfully with PyTorch 1.8 and TensorFlow 2.4.0/2.4.1.

EDIT:

  • Change --img from 640 to 320.
  • Remove --no-tfl-detect. It is now only used in tf-android-tfl-detect tf-edgetpu branch for Edge TPU model export.
  • Deprecate TensorFlow 1.x support. If you still want to use TF 1.x, plz refer to the archived tf-export branch. Note this branch does not support TFLite model export.
  • For EdgeTPU support, tf-android-tfl-detect branch is obsoleted by tf-edgetpu branch.
  • Add --tf-raw-resize option to map resize ops to EdgeTPU, which accelerates inference (not necessary with Edge TPU compiler v16).
  • Add instruction for model inference.
  • Add TensorFlow.js model export and detection example tfjs-yolov5-example. Demo page on https://zldrobit.github.io/tfjs-yolov5-example/
  • Remove --cfg since the config yaml is stored in the model weights now.

FAQ:

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

🌟 Summary

Added TensorFlow (TF) and TensorFlow Lite (TFLite) model support to the YOLOv5 detect.py script.

📊 Key Changes

  • Introduced the ability to run inference with TensorFlow models, including SavedModel format, TFLite, and quantized TFLite (INT8).
  • Importing numpy as np and TensorFlow (import tensorflow as tf) in detect.py.
  • TensorFlow models can now be loaded from their respective formats (.pb, TFLite, SavedModel).
  • Adjusted the image loading, preprocessing, and model inference code to support TensorFlow model formats in addition to existing PyTorch and ONNX support.
  • Added a command-line argument (--tfl-int8) to enable INT8 quantized TFLite model inference.
  • Added a new file models/tf.py to convert YOLOv5 models to TensorFlow formats.
  • Updated requirements.txt to include TensorFlow as an optional dependency.

🎯 Purpose & Impact

  • 🚀 Enables users to utilize the versatility and optimizations available in the TensorFlow ecosystem.
  • 🛠 Provides additional deployment options, such as running on mobile devices with TFLite or on TensorFlow-serving.
  • ⚡ Improved performance with INT8 quantization support for TFLite, potentially reducing latency and resource utilization during inference.
  • 🧠 Opens the door for AI applications that benefit from TensorFlow's advanced features, like model serving and support for a wide range of devices.

@glenn-jocher
Copy link
Member

@zldrobit thanks for the PR! Can you explain the role of auto in the dataloaders?

@glenn-jocher
Copy link
Member

@zldrobit oh also, is there a tensorflow version requirement here? We should probably add it to requirements.txt under export:

yolov5/requirements.txt

Lines 19 to 24 in 00917a6

# export --------------------------------------
# packaging # for coremltools
# coremltools==4.0
# onnx>=1.7.0
# scikit-learn==0.19.2 # for coreml quantization

@zldrobit
Copy link
Contributor Author

zldrobit commented Oct 14, 2020

@glenn-jocher Of course. auto controls whether to pad a resized input image to square.
auto is set to True in PyTorch by default, because rectangle image (width != height) inference is supported.
In TensorFlow and TFLite, the images need to be padded to squares for inference.
Though this PR does not include any inference code, int8 quantization requires images after preprocess (resize/normalization) to calibrate.
Therefore, auto is set to False for TFLite int8 quantization.

TensorFlow==2.3.0 is tested and can be used for export and inference.

@glenn-jocher
Copy link
Member

@zldrobit ah, I see. That's interesting. Yes we use rectangular inference (height != width) in PyTorch and in CoreML our iDetection app. This helps speed up inference significantly, proportional to the smaller area. For example 640x320 inference typically take half the time of 640x640 inference.

We resize to --img-size first and pad the shorter dimension as required to reach a 32 multiple, which is the max stride of the current YOLOv5 models.

Do you know where the square inference requirement is coming from in tensorflow?

@zldrobit
Copy link
Contributor Author

zldrobit commented Oct 14, 2020

@glenn-jocher Thanks for your explanation. I didn't express my idea correctly.
Actually, TensorFlow does not require square inference.
I should have said fixed input size instead of square inference requirement.
I set auto=False to assure the padded image size equals new_shape in letterbox:

def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):

Thus, the input image sizes after preprocess are the same.

The input size is fixed while exporting TensorFlow and TFLite:

yolov5/models/tf.py

Lines 394 to 395 in 23fe35e

inputs = keras.Input(shape=(*opt.img_size, 3))
keras_model = keras.Model(inputs=inputs, outputs=tf_model.predict(inputs))

Take COCO dataset for example, in int8 calibration, if one use auto=True, different size images will be fed to TFLite model
while model's input size is fixed.
For rectangle inference of TFLite int8 calibration, one could set --img-size to a rectangle (e.g. 640x320) while setting auto=False.

@thhart
Copy link

thhart commented Oct 15, 2020

@zldrobit Thanks for this great work, I tested shortly but unfortunately it failed with following error. I run in an Android emulator with your libraries specified in gradle. The App is started successful however on model load following exception is thrown:

E/tensorflow: CameraActivity: Exception!
    java.lang.RuntimeException: java.lang.IllegalStateException: Internal error: Unexpected failure when preparing tensor allocations: tensorflow/lite/kernels/add.cc:86 NumInputs(node) != 2 (1 != 2)
    Node number 5 (ADD) failed to prepare.
        at org.tensorflow.lite.examples.detection.tflite.YoloV5ClassifierDetect.create(YoloV5ClassifierDetect.java:116)
        at org.tensorflow.lite.examples.detection.tflite.DetectorFactory.getDetector(DetectorFactory.java:49)
        at org.tensorflow.lite.examples.detection.DetectorActivity.onPreviewSizeChosen(DetectorActivity.java:83)
        at org.tensorflow.lite.examples.detection.CameraActivity.onPreviewFrame(CameraActivity.java:253)
        at android.hardware.Camera$EventHandler.handleMessage(Camera.java:1209)
        at android.os.Handler.dispatchMessage(Handler.java:106)
        at android.os.Looper.loop(Looper.java:193)
        at android.app.ActivityThread.main(ActivityThread.java:6669)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:493)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)
     Caused by: java.lang.IllegalStateException: Internal error: Unexpected failure when preparing tensor allocations: tensorflow/lite/kernels/add.cc:86 NumInputs(node) != 2 (1 != 2)
    Node number 5 (ADD) failed to prepare.
    
        at org.tensorflow.lite.NativeInterpreterWrapper.allocateTensors(Native Method)
        at org.tensorflow.lite.NativeInterpreterWrapper.init(NativeInterpreterWrapper.java:87)
        at org.tensorflow.lite.NativeInterpreterWrapper.(NativeInterpreterWrapper.java:63)
        at org.tensorflow.lite.Interpreter.(Interpreter.java:266)
        at org.tensorflow.lite.examples.detection.tflite.YoloV5ClassifierDetect.create(YoloV5ClassifierDetect.java:114)
        	... 10 more

Tested on your tf-android branch.
Downloaded official yolov5s.pt.
Converted: python models/tf.py --no-tfl-detect
Copied yolov5s-fp16.tflite to asset folder.

@zldrobit
Copy link
Contributor Author

@thhart Thanks for using the code.
What is your TensorFlow version?
The tf-android branch is tested with TensorFlow 2.3.0

@thhart
Copy link

thhart commented Oct 16, 2020

@zldrobit I checked with 2.4.0-dev20201011(tf-nightly), now switch backed to 2.3.0 and it works better. Extremely slow in emulator but will check on real device soon...

@thhart
Copy link

thhart commented Oct 16, 2020

@zldrobit Checked fp16 and int8. 450 ms inference time on a Samsung Note 10 (GPU). This is with full 640 size which might a bit too much. But this is out of scope for now of course.
Good work again, looks very promising and worth to be integrated. Maybe you move the Android stuff into a different project since I understand from Ultralytics they don't want to bloat too much.

Further side notes and question maybe for later enhancement:
Is different resolution handling really supported, I can see some hard coded 640 in tf.py?
Where does the yolov5s.tflite comes from, I can see it is referenced in Java but commented out in tf.py maybe a leftover from first tests?
In the code there is the possibility of webcam calibration mentioned (ncalib) but it looks not integrated yet, maybe remove for first release to avoid confusion about?
Testing and check on TF 2.4.0 why it fails there maybe there is a change in the API?

@zldrobit
Copy link
Contributor Author

zldrobit commented Oct 19, 2020

@thhart Thanks for your suggestion. I am now keeping all the TensorFlow and TFLite related code in https://github.com/zldrobit/yolov5/tree/tf-android

Is different resolution handling really supported, I can see some hard coded 640 in tf.py?

640 in tf.py is just a default parameter value. You could change it to multiples of 32 using --img.
For example, use

PYTHONPATH=. python3  models/tf.py --weight weights/yolov5s.pt --cfg models/yolov5s.yaml --img 320 --no-tfl-detect --tfl-int8 --source /data/dataset/coco/coco2017/train2017 --ncalib 100

to generate TF and TFLite models.
Then, use one of

python3 detect.py --weight weights/yolov5s.pb --img 320
python3 detect.py --weight weights/yolov5s_saved_model/ --img 320
python3 detect.py --weight weights/yolov5s-fp16.tflite --img 320 --tfl-detect
python3 detect.py --weight weights/yolov5s-int8.tflite --img 320 --tfl-int8 --tfl-detect

to detect objects.

Or put the TFLite models to asset folder of Android project, and replace

inputSize = 640;
output_width = new int[]{80, 40, 20};

and
inputSize = 640;
output_width = new int[]{80, 40, 20};

with

            inputSize = 320;
            output_width = new int[]{40, 20, 10};

to build and run the Android project.
This reduces around

75% time of fp16 model inference on Snapdragon 820 CPU (4 threads) from 1.9s to 0.5s, 
70%                              on Snapdragon 820 GPU from 1.3s to 0.4s, 
70%      of int8                 on Snapdargon 820 CPU (4 threads) from 1.7s to lesser than 0.5s. 

@zldrobit
Copy link
Contributor Author

@thhart

Where does the yolov5s.tflite comes from, I can see it is referenced in Java but commented out in tf.py maybe a leftover from first tests?

You can uncomment

yolov5/models/tf.py

Lines 434 to 440 in eb626a6

# converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
# converter.allow_custom_ops = False
# converter.experimental_new_converter = True
# tflite_model = converter.convert()
# f = opt.weights.replace('.pt', '.tflite') # filename
# open(f, "wb").write(tflite_model)

to generate yolov5s.tflite.
Since the default inference precision in TFLite Android is fp16, yolov5s-fp16.tflite is enough for inference.
I commented out yolov5s.tflite generating code and leave it as a note.

In the code there is the possibility of webcam calibration mentioned (ncalib) but it looks not integrated yet, maybe remove for first release to avoid confusion about?

--ncalib is irrelevant to webcam calibation, and it's used for int8 TFLite quantization.

Testing and check on TF 2.4.0 why it fails there maybe there is a change in the API?

I think it's an TensorFlow issue about breaking its backward compatibility.

@glenn-jocher
Copy link
Member

@zldrobit can you add tensorflow and any other dependencies required to requirements.txt export section?

Thanks for the explanation, so that's great news square inference is not required. The actual inference size for iDetection iOS app is 320 vertical by 192 horizontal for accommodating vertical video in any of the 16:9 aspect ratio formats like 4k, 1080p etc. Is it possible to export to tflite in a similar shape?

Yes I see about the auto resizing in the dataloader. I'll think about that for a bit.

@zjutlzt
Copy link

zjutlzt commented Oct 20, 2020

Hello, I have some trouble using your branch. I put the code directly below,
Thanks,

import argparse
import os
import torch
import tensorflow as tf
from tensorflow import keras
import cv2
import numpy as np
def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
cpu_request = device.lower() == 'cpu'
if device and not cpu_request: # if device requested other than 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity

cuda = False if cpu_request else torch.cuda.is_available()
if cuda:
    c = 1024 ** 2  # bytes to MB
    ng = torch.cuda.device_count()
    if ng > 1 and batch_size:  # check that batch_size is compatible with device_count
        assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
    x = [torch.cuda.get_device_properties(i) for i in range(ng)]
    s = 'Using CUDA '
    for i in range(0, ng):
        if i == 1:
            s = ' ' * len(s)
return torch.device('cuda:0' if cuda else 'cpu')

def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle ultralytics/yolov3#232
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)

# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup:  # only scale down, do not scale up (for better test mAP)
    r = min(r, 1.0)

# Compute padding
ratio = r, r  # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
if auto:  # minimum rectangle
    dw, dh = np.mod(dw, 64), np.mod(dh, 64)  # wh padding
elif scaleFill:  # stretch
    dw, dh = 0.0, 0.0
    new_unpad = (new_shape[1], new_shape[0])
    ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

dw /= 2  # divide padding into 2 sides
dh /= 2

if shape[::-1] != new_unpad:  # resize
    img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
return img, ratio, (dw, dh)

coco_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']

def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")

wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
if print_graph == True:
    print("-" * 50)
    print("Frozen model layers: ")
    layers = [op.name for op in import_graph.get_operations()]
    for layer in layers:
        print(layer)
    print("-" * 50)
return wrapped_import.prune(tf.nest.map_structure(import_graph.as_graph_element, inputs), tf.nest.map_structure(import_graph.as_graph_element, outputs))

device = select_device("0")

graph = tf.Graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(open("yolov5s.pb", 'rb').read())
frozen_func = wrap_frozen_graph(graph_def=graph_def, inputs="x:0", outputs="Identity:0", print_graph=False)

img = torch.zeros((1, 3, 640, 640), device=device) # init img
_ = frozen_func(x=tf.constant(img.permute(0, 2, 3, 1).cpu().numpy()))
dete_image = cv2.imread("1.jpg")
detect_image = letterbox(dete_image, new_shape=(640,640))[0]
detect_image = detect_image[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
detect_image = np.ascontiguousarray(detect_image)
img = torch.from_numpy(detect_image).to(device)
img = img.float()
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)

pred = frozen_func(x=tf.constant(img.permute(0, 2, 3, 1).cpu().numpy()))
pred = torch.tensor(pred.numpy())
print(pred)

pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes,agnostic=opt.agnostic_nms)

tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [1,256,40,40] vs. shape[1] = [1,256,24,40]

idont why,can help?

@zldrobit
Copy link
Contributor Author

@zjutlzt It seems that you changed the input size of model.
I cannot figure out which code is relevant.
Could you provide a minimal reproducible code or show me the changed code?

A suggestion: you could surround your pasted code in github issue with
```
code
code
code
```

for better illustration:

code
code
code

@zjutlzt
Copy link

zjutlzt commented Oct 20, 2020

Sorry,my fault

i d not change any export code :
PYTHONPATH=. python models/tf.py --weights weights/yolov5s.pt --cfg models/yolov5s.yaml --img 640

and iwant make a minimum code to use yolov5s.pd to detect, so i write this code

import argparse
import os
import torch
import tensorflow as tf
from tensorflow import keras
import cv2
import  numpy as np
def select_device(device='', batch_size=None):
    # device = 'cpu' or '0' or '0,1,2,3'
    cpu_request = device.lower() == 'cpu'
    if device and not cpu_request:  # if device requested other than 'cpu'
        os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable
        assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device  # check availablity

    cuda = False if cpu_request else torch.cuda.is_available()
    if cuda:
        c = 1024 ** 2  # bytes to MB
        ng = torch.cuda.device_count()
        if ng > 1 and batch_size:  # check that batch_size is compatible with device_count
            assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
        x = [torch.cuda.get_device_properties(i) for i in range(ng)]
        s = 'Using CUDA '
        for i in range(0, ng):
            if i == 1:
                s = ' ' * len(s)
    return torch.device('cuda:0' if cuda else 'cpu')


def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
    shape = img.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better test mAP)
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, 64), np.mod(dh, 64)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2
    if shape[::-1] != new_unpad:  # resize
        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return img, ratio, (dw, dh)
coco_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
         'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
         'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
         'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
         'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
         'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
         'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
         'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
         'hair drier', 'toothbrush']

def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph
    if print_graph == True:
        print("-" * 50)
        print("Frozen model layers: ")
        layers = [op.name for op in import_graph.get_operations()]
        for layer in layers:
            print(layer)
        print("-" * 50)
    return wrapped_import.prune(tf.nest.map_structure(import_graph.as_graph_element, inputs), tf.nest.map_structure(import_graph.as_graph_element, outputs))

device = select_device("0")

graph = tf.Graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(open("yolov5s.pb", 'rb').read())
frozen_func = wrap_frozen_graph(graph_def=graph_def, inputs="x:0", outputs="Identity:0", print_graph=False)

img = torch.zeros((1, 3, 640, 640), device=device)  # init img
_ = frozen_func(x=tf.constant(img.permute(0, 2, 3, 1).cpu().numpy()))
dete_image = cv2.imread("1.jpg")
detect_image = letterbox(dete_image, new_shape=(640,640))[0]
detect_image = detect_image[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
detect_image = np.ascontiguousarray(detect_image)
img = torch.from_numpy(detect_image).to(device)
img = img.float()
img /= 255.0
if img.ndimension() == 3:
    img = img.unsqueeze(0)

pred = frozen_func(x=tf.constant(img.permute(0, 2, 3, 1).cpu().numpy()))
pred = torch.tensor(pred.numpy())
print(pred)
# pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes,agnostic=opt.agnostic_nms)

But something went wrong

tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [1,256,40,40] vs. shape[1] = [1,256,24,40]
[[node functional_1/tf__concat/concat (defined at D:/A_WorkSpace/TF_work_space/yolov5/demo_do/detect.py:93) ]] [Op:__inference_pruned_3280]

Function call stack:
pruned

@zldrobit
Copy link
Contributor Author

@zldrobit can you add tensorflow and any other dependencies required to requirements.txt export section?

Thanks for the explanation, so that's great news square inference is not required. The actual inference size for iDetection iOS app is 320 vertical by 192 horizontal for accommodating vertical video in any of the 16:9 aspect ratio formats like 4k, 1080p etc. Is it possible to export to tflite in a similar shape?

Yes I see about the auto resizing in the dataloader. I'll think about that for a bit.

@glenn-jocher Sure. I have updated requirements.txt.

Yes. Run

PYTHONPATH=. python3  models/tf.py --weight weights/yolov5s.pt --cfg models/yolov5s.yaml --img 320 192 --no-tfl-detect --tfl-int8 --source /data/dataset/coco/coco2017/train2017 --ncalib 100

to export a TFLite model of 320 vertical by 192 horizontal input , and run one of

python3 detect.py --weight weights/yolov5s-int8.tflite --img 320 192 --tfl-detect --tfl-int8
python3 detect.py --weight weights/yolov5s-fp16.tflite --img 320 192 --tfl-detect

with tf-android branch to detect.

@zldrobit
Copy link
Contributor Author

@zjutlzt you could change

detect_image = letterbox(dete_image, new_shape=(640,640))[0]

to

detect_image = letterbox(dete_image, new_shape=(640,640), auto=False)[0]

The auto argument controls whether to force the resized input to new_shape,
otherwise it may use rectangle input.
If you are using TensorFlow and TFLite model, you have to set auto=False.

@zjutlzt
Copy link

zjutlzt commented Oct 21, 2020

@zldrobit
it works, thanks 😊

@idenc
Copy link
Contributor

idenc commented Oct 22, 2020

When converting my custom model to TFLite and running inference using TFLite, the detections are significantly different than the source model. Using the COCO pre-trained model results in the same detections, but running the same code with my model results in different detections. I am using Tensorflow 2.3.0 and my custom model has 38 classes. Here is the original vs TFLite model.
Original
test
TFLite
test

@zldrobit
Copy link
Contributor Author

@idenc From the two images, my first guess is the anchors have been changed.
Could you check whether the anchors have been changed by auto anchor?
Just compare anchors from *.yaml files and
the trained model anchors with

print(torch.load('your_custom_model.pt')['model'].model[-1].anchors)

as in #447 (comment).

If auto anchor have changed the anchors, that won't be reflected in *.yaml config file from which TFLite export is generated.

The follwing code changes anchors by auto anchor:

yolov5/train.py

Lines 192 to 193 in 83deec1

if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)

yolov5/utils/general.py

Lines 109 to 110 in 83deec1

m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss

That's the reason why auto anchor affects detection boxes.
If it is the case, you could substitute the anchors within your *.yaml file with auto generated anchors and try TFLite export and detection again.

PS:
Could you check your custom model using PyTorch inference?
Which --img did you use to export the TFLite model?

Reference for auto anchor generation:
#503

@glenn-jocher
Copy link
Member

glenn-jocher commented Oct 22, 2020

@zldrobit @idenc yes the model anchors may no longer be equal to the yaml definition for a variety of reasons. To access current model anchors:

import torch

model = torch.load('yolov5s.pt')['model']

m = model.model[-1]  # Detect()
m.anchors  # in stride units
m.anchor_grid  # in pixel units

print(m.anchor_grid.view(-1,2))
tensor([[ 10.,  13.],
        [ 16.,  30.],
        [ 33.,  23.],
        [ 30.,  61.],
        [ 62.,  45.],
        [ 59., 119.],
        [116.,  90.],
        [156., 198.],
        [373., 326.]])

/rebase

@idenc
Copy link
Contributor

idenc commented Oct 22, 2020

Thanks guys, it was due to the anchors changing from auto anchors. Substituting the models anchors into the config.yaml fixed the detection.

@glenn-jocher
Copy link
Member

@zldrobit could you read in anchors from the model.pt rather than the yaml? In YOLOv5 we only use the yaml for training, but not afterwards (so we don't run into this situation).

@zldrobit
Copy link
Contributor Author

@glenn-jocher Thanks for your explanation of anchors and yaml files.
According to your suggestion, I have updated the code to read anchors directly from PyTorch model.pt.

@4yougames
Copy link

!PYTHONPATH=. python3 models/tf.py --weight runs/results/weights/best.pt --cfg models/yolov5s.yaml --no-tfl-detect --tfl-int8 --source /content/dataset/images/train --ncalib 100

best.pt size = 15.7 mb
best-fp16.tflite size = 15.6 mb
best-int8.tflite size = 8.3 mb

Hello! after converting, I get int8 file size of about 8 mb. It's normal size?
This is more than I expected. Could it be the number of classes? I have 326 classes.

i have attached config files cfg.zip
https://colab.research.google.com/drive/14Pk5dFl5ap6qVg_Hu5yjsjzSw-udFoHN?usp=sharing

@zldrobit
Copy link
Contributor Author

@Sergey-sib Hello! It is normal.
The model size is basically not related to the number of classes.
What's the model size you expected?

@4yougames
Copy link

@Sergey-sib Hello! It is normal.
The model size is basically not related to the number of classes.
What's the model size you expected?

I thought if the full model Yolo5s is 32 float = 15 mb, the result of conversion to 16 float is expected to be ~ 7.5 mb, then conversion to int 8 ~ 3.7.. mb.

@zldrobit
Copy link
Contributor Author

@Sergey-sib I guess you mean the best.pt is the full model.
Actually, it is stored in fp16 precision.
Despite the difference of PyTorch and TFLite, the size of best.pt is approximately equal to the size of best-fp16.tflite, .

@4yougames
Copy link

Is it possible to use the latest repository https://github.com/ultralytics/yolov5 for training and then convert to tflight using this repository https://github.com/zldrobit/yolov5?

@zldrobit
Copy link
Contributor Author

zldrobit commented Oct 27, 2020

@Sergey-sib Yes. As long as you are using YOLOv5 of version v2/v3.
Because the latest repo https://github.com/ultralytics/yolov5 is of version v3, you could convert *.pt weights to TFLite model with
https://github.com/zldrobit/yolov5/tree/tf-android

@wiky231
Copy link

wiky231 commented May 5, 2022

@axlecky If you're interested in detection objects in a still picture, plz refer to #1127 (comment).

Hi @zldrobit, could you elaborate more regarding this? Which part of the MainActivity.java and AndroidManifest.xml do I have to modify? Thanks!

@zldrobit
Copy link
Contributor Author

zldrobit commented May 5, 2022

And is there any way to get fps when detecting object with model yolov5s on smartphone. I am running this source code https://github.com/zldrobit/yolov5/tree/tf-android

@thaingoc2604 Running https://github.com/zldrobit/yolov5/tree/tf-android on a smartphone would give you the inference delay on the screen.

I train the model to get the mAP@.5 rate table as shown in the picture, but when I proceed to detect it only detects the objects with the mAP rate>=65% let me ask, the MINIMUM_CONFIDENCE_TF_OD_API index has an effect on the results. no detect result (I set index MINIMUM_CONFIDENCE_TF_OD_API = 0.3f). What value do I need to change to get better detection results (i.e. it can detect objects with mAP@.5 ratio lower than 65%.Thanks

Just set a lower value to MINIMUM_CONFIDENCE_TF_OD_API (e.g. 0.01~0.1). In such way, you could get more bounding boxes, but there will be more false positive detection.

@zldrobit
Copy link
Contributor Author

zldrobit commented May 5, 2022

@axlecky TFLite backend is optimized for ARM processors. I guess you have run a YOLOv5 TFLite model on a X86 CPU, and the inference speed has been very slow as expected. Running inference with python detect.py is only for validation purposes. If you're interested in running YOLOv5 TFLite models on Android, plz refer to https://github.com/zldrobit/yolov5.

@zldrobit Thanks for the response. Is the TFLite model supposed to run equally fast on an android device? Because I have tried deploying my TFLite model on your android app example, and the inference speed is relatively slower as well.

@axlecky Maybe you could try switching to GPU inference with the TFLite Android demo. GPU mode is much faster than CPU mode. Plz be aware of the input size. I chose 320x320 input resolution because the 640x640 model runs very slow.

@axlecky If you're interested in detection objects in a still picture, plz refer to #1127 (comment).
Hi @zldrobit, could you elaborate more regarding this? Which part of the MainActivity.java and AndroidManifest.xml do I have to modify? Thanks!

Maybe you have to figure out the details yourself, and I didn't run or test the inference from a single image. Or you could ask @chainyo, since he seems to succeed in detecting with a single image.

@wiky231
Copy link

wiky231 commented May 5, 2022

@zldrobit Thanks for the response!

@wiky231
Copy link

wiky231 commented May 5, 2022

It worked perfectly, thanks for the help !!!

@chainyo Hi, can you show me what amendments did u make to the code to perform inference from imported phone images/videos? Thanks!

@thaingoc2604
Copy link

@zldrobit
I use android souce at this link: https://github.com/zldrobit/yolov5/tree/tf-android
Instead of detecting objects with the rear camera, I instead want to do the function of uploading videos from the phone memory to the application, then I detect the objects in that video. But I still can't find the way, please suggest me how to do it. Thanks

@zldrobit
Copy link
Contributor Author

zldrobit commented May 8, 2022

@thaingoc2604 I made that repo just to demonstrate the YOLOv5 TFLite deployment. Maybe you have to consult an Android export about the video input/output part, and I am not that familiar with Android.

@thaingoc2604
Copy link

@thaingoc2604 I made that repo just to demonstrate the YOLOv5 TFLite deployment. Maybe you have to consult an Android export about the video input/output part, and I am not that familiar with Android.

Thank you

@Kimsoohyeun
Copy link

KakaoTalk_20220514_163509728
KakaoTalk_20220514_163509728_01

I tried like you said finally i got best-fp16.tflite

I even implemented it on mobile, but it works like the picture above.

What do I need to fix?

@zldrobit
Copy link
Contributor Author

@Kimsoohyeun Plz share some information for us to investigate this problem:

  • What version of YOLOv5 are you using?
  • Did you check the tflite model with python detect.py --weights best-fp16.tflite?
  • Did you edit the label text file according to your custom dataset?

@pcanas
Copy link

pcanas commented May 16, 2022

@zldrobit Hi. I was using your solution for our app, thanks so much for your job!
I was wondering if you could explain what is the difference between the YoloV5ClassifierDetect and the YoloV5Classifier classes. I saw that the former is not used anywhere in the example app, and it uses masks and anchors. Would you be able to develop on their differences and how the masks and anchors are used? Thank you so much!

@zldrobit
Copy link
Contributor Author

@pcanas The YoloV5ClassifierDetect class loads a YOLOv5 tflite model and reconstructs the Detect layer by java code. The anchors are used in the reconstruction. The YoloV5Classifier class loads an entire YOLOv5 model including the Detect layer, and is used in the current branch. The YoloV5ClassifierDetect class is used for older versions of TensorFlow and should be deprecated. I should have deleted the YoloV5ClassifierDetect class in the repo, so there won't be any confusion.

@wiky231
Copy link

wiky231 commented Jun 29, 2022

Hi @zldrobit. Is the code for building the Android app based on the official Tensorflow Android object detection example in this link?
[https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android]

@zldrobit
Copy link
Contributor Author

@axlecky The yolov5 android repo is partly based on the yolov4 tflite repo.

@Zengyf-CVer
Copy link
Contributor

@zldrobit
I would like to ask, if --weights is not yolov5s.pt, but a weight file I customized, such as widerface.pt, then the latter --data parameter should also be changed to a custom one yaml file?

python export.py --weights yolov5s.pt --include tflite --int8 --img 320 --data data/coco128.yaml

@zldrobit
Copy link
Contributor Author

zldrobit commented Jul 17, 2022

@Zengyf-CVer

@zldrobit

I would like to ask, if --weights is not yolov5s.pt, but a weight file I customized, such as widerface.pt, then the latter --data parameter should also be changed to a custom one yaml file?

python export.py --weights yolov5s.pt --include tflite --int8 --img 320 --data data/coco128.yaml

The '-data' option is not used in TensorFlow/TFLite export any more. So you could just ignore it.

@olliestanley
Copy link

Hi @zldrobit,

Have you tried training (finetuning) a YOLOv5 model in Keras after this conversion? I noticed there is specific code in the TFDetect layer for training. However, when comparing a PyTorch YOLOv5 model to a Keras one in training mode, I found the output shapes were slightly different (Keras version was flatter). I modified TFDetect to reshape the output to the same shape as PyTorch training outputs. However, I then found that the outputs themselves were different. With three detectors, I found that the first detector produced roughly identical outputs to the PyTorch model, but the other two produced very different outputs. Is this something you have experience with?

Thanks

BjarneKuehl pushed a commit to fhkiel-mlaip/yolov5 that referenced this pull request Aug 26, 2022
* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* Put representative dataset in tfl_int8 block

* detect.py TF inference

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* detect.py TF inference

* Put representative dataset in tfl_int8 block

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* implement C3() and SiLU()

* Fix reshape dim to support dynamic batching

* Add epsilon argument in tf_BN, which is different between TF and PT

* Set stride to None if not using PyTorch, and do not warmup without PyTorch

* Add list support in check_img_size()

* Add list input support in detect.py

* sys.path.append('./') to run from yolov5/

* Add int8 quantization support for TensorFlow 2.5

* Add get_coco128.sh

* Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU)

* Update requirements.txt

* Replace torch.load() with attempt_load()

* Update requirements.txt

* Add --tf-raw-resize to set half_pixel_centers=False

* Add --agnostic-nms for TF class-agnostic NMS

* Cleanup after merge

* Cleanup2 after merge

* Cleanup3 after merge

* Add tf.py docstring with credit and usage

* pb saved_model and tflite use only one model in detect.py

* Add use cases in docstring of tf.py

* Remove redundant `stride` definition

* Remove keras direct import

* Fix `check_requirements(('tensorflow>=2.4.1',))`

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
BjarneKuehl pushed a commit to fhkiel-mlaip/yolov5 that referenced this pull request Aug 26, 2022
* Auto TFLite uint8 detection

This PR automatically determines if TFLite models are uint8 quantized rather than accepting a manual argument.

The quantization determination is based on @zldrobit comment ultralytics#1127 (comment)

* Cleanup
@zldrobit
Copy link
Contributor Author

@olliestanley

Have you tried training (finetuning) a YOLOv5 model in Keras after this conversion?

I haven't trained YOLOv5 models in Keras, but I considered doing so before. The code in the repo after v6.1 (including) exports TF YOLOv5 mode's with untrainable parameters.

I modified TFDetect to reshape the output to the same shape as PyTorch training outputs. However, I then found that the outputs themselves were different. With three detectors, I found that the first detector produced roughly identical outputs to the PyTorch model, but the other two produced very different outputs. Is this something you have experience with?

Other than TFDetect, exported TF models includes additional normalization for bboxes' xywh. These xywh are denormalized in inference. Is this what you're looking for?

@Deep55033
Copy link

@zldrobit I want to use the new instance segmentation model introduced by YOLOv5, how should I modify the code?

@zldrobit
Copy link
Contributor Author

@Deep55033
You could export YOLOv5 segmentation models with

python export.py --weights yolov5m-seg.pt --include tflite --img 640

and then verify it with

python segment/predict.py --weights yolov5m-seg-fp16.tflite --source data/images/bus.jpg

If you want to use the YOLOv5 demo repo, I just have some preliminary ideas, and the below material is a good start point:

The yolov5 demo repo does not include any function related to mask, so you have to add code for drawing the segmentation mask.

SecretStar112 added a commit to SecretStar112/yolov5 that referenced this pull request May 24, 2023
* Auto TFLite uint8 detection

This PR automatically determines if TFLite models are uint8 quantized rather than accepting a manual argument.

The quantization determination is based on @zldrobit comment ultralytics/yolov5#1127 (comment)

* Cleanup
@Aandre99
Copy link

@Sergey-sib I guess you mean the best.pt is the full model. Actually, it is stored in fp16 precision. Despite the difference of PyTorch and TFLite, the size of best.pt is approximately equal to the size of best-fp16.tflite, .

@zldrobit, Shouldn't model-fp16.tflite be half the size of the original model in .pt? By How to know if model is half or full precision, PyTorch default float point precision is fp32.

@zldrobit
Copy link
Contributor Author

zldrobit commented Nov 1, 2023

@Aandre99 Despite the default precision (fp32) of PyTorch, the YOLOv5 model is saved as half precision if the training is completed. In strip_optimizer(), the model is converted to half precision and saved to disk as the following code,

yolov5/utils/general.py

Lines 1002 to 1015 in 915bbf2

def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

@Aandre99
Copy link

Aandre99 commented Nov 1, 2023

@Aandre99 Despite the default precision (fp32) of PyTorch, the YOLOv5 model is saved as half precision if the training is completed. In strip_optimizer(), the model is converted to half precision and saved to disk as the following code,

yolov5/utils/general.py

Lines 1002 to 1015 in 915bbf2

def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

@zldrobit thanks,I finally understand now!

@JonSnow4243
Copy link

JonSnow4243 commented Feb 13, 2024

Issue with loading TFLite model on Jetson Nano 2GB developer kit ( without GPU )

We are currently facing an issue while attempting to load a TFLite model on our Jetson Nano device (running Jetpack 4.6.1 and TensorFlow 2.4.1) using Torch. Previously, we successfully loaded a .pt model using the following code:

model = torch.hub.load('ultralytics/yolov5:v6.0', 'yolov5s')

However, when we tried to load the TFLite version of the model with the following code:

model = torch.hub.load('ultralytics/yolov5:v6.0', 'yolov5s.tflite', force_reload=True)

We encountered the following error:

RuntimeError: Cannot find callable yolov5s.tflite in hubconf

We understand that this error might be related to compatibility issues, and we opted for version v6.0 of the YOLOv5 model for Jetson compatibility.

Could you please provide guidance or insights on how to resolve this issue or any alternative approaches to load the TFLite model successfully on the Jetson Nano?

Thank you for your assistance.

@zldrobit
Copy link
Contributor Author

zldrobit commented Feb 26, 2024

@JonSnow4243 TFLite models cannot be loaded by torch.load. It is suggested to use tflite python package to load TFLite models. You could refer to the code from this repo,

yolov5/models/common.py

Lines 546 to 573 in 41603da

elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
Interpreter, load_delegate = (
tf.lite.Interpreter,
tf.lite.experimental.load_delegate,
)
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...")
delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[
platform.system()
]
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
else: # TFLite
LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
interpreter = Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
# load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
stride, names = int(meta["stride"]), meta["names"]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment