diff --git a/export.py b/export.py index 8a483ae878b9..6a8c4f6f94a0 100644 --- a/export.py +++ b/export.py @@ -247,11 +247,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F def export_saved_model(model, im, file, dynamic, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45, - conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')): + conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')): # YOLOv5 TensorFlow SavedModel export try: import tensorflow as tf - from tensorflow import keras + from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 from models.tf import TFDetect, TFModel @@ -262,13 +262,26 @@ def export_saved_model(model, im, file, dynamic, tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) - inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) + inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) - keras_model = keras.Model(inputs=inputs, outputs=outputs) + keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) keras_model.trainable = False keras_model.summary() - keras_model.save(f, save_format='tf') - + if keras: + keras_model.save(f, save_format='tf') + else: + m = tf.function(lambda x: keras_model(x)) # full model + spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) + m = m.get_concrete_function(spec) + frozen_func = convert_variables_to_constants_v2(m) + tfm = tf.Module() + tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec]) + tfm.__call__(im) + tf.saved_model.save( + tfm, + f, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if + check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions()) LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return keras_model, f except Exception as e: diff --git a/models/common.py b/models/common.py index 38b94129e274..8831723ffa25 100644 --- a/models/common.py +++ b/models/common.py @@ -359,7 +359,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): if saved_model: # SavedModel LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') import tensorflow as tf - model = tf.keras.models.load_model(w) + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') import tensorflow as tf @@ -431,7 +432,7 @@ def forward(self, im, augment=False, visualize=False, val=False): else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) if self.saved_model: # SavedModel - y = self.model(im, training=False).numpy() + y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy() elif self.pb: # GraphDef y = self.frozen_func(x=self.tf.constant(im)).numpy() elif self.tflite: # Lite