Skip to content

Commit

Permalink
Scope TF imports in DetectMultiBackend() (ultralytics#5792)
Browse files Browse the repository at this point in the history
* tensorflow or tflite exclusively as interpreter

As per bug report ultralytics#5709 I think there should be only one attempt to assign interpreter, and it appears tflite is only ever needed for the case of edgetpu model.

* Scope imports

* Nested definition line fix

* Update common.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
phodgers and glenn-jocher committed Nov 25, 2021
1 parent ac3daa6 commit b834ab4
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,19 +337,21 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model)
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
import tensorflow as tf

def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))

LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...')
import tensorflow as tf
model = tf.keras.models.load_model(w)
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if 'edgetpu' in w.lower():
Expand All @@ -361,6 +363,7 @@ def wrap_frozen_graph(gd, inputs, outputs):
interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)])
else:
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
Expand Down

0 comments on commit b834ab4

Please sign in to comment.