From 4b4db08dac1b7e93e6e50a11b1c7a926982a2abb Mon Sep 17 00:00:00 2001 From: Louis Combaldieu Date: Fri, 25 Feb 2022 10:56:37 +0100 Subject: [PATCH] Fix export for 1-channel images (#6780) Export failed for 1-channel input shape, 1-liner fix --- export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/export.py b/export.py index 15e92a784a50..286df623d252 100644 --- a/export.py +++ b/export.py @@ -260,9 +260,9 @@ def export_saved_model(model, im, file, dynamic, batch_size, ch, *imgsz = list(im.shape) # BCHW tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) - im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow + im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) - inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) + inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size) outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) keras_model.trainable = False