diff --git a/export.py b/export.py index e146dad42980..cc7a74db9af2 100644 --- a/export.py +++ b/export.py @@ -285,12 +285,12 @@ def export_saved_model(model, if keras: keras_model.save(f, save_format='tf') else: - m = tf.function(lambda x: keras_model(x)) # full model spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) + m = tf.function(lambda x: keras_model(x)) # full model m = m.get_concrete_function(spec) frozen_func = convert_variables_to_constants_v2(m) tfm = tf.Module() - tfm.__call__ = tf.function(lambda x: frozen_func(x)[0], [spec]) + tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec]) tfm.__call__(im) tf.saved_model.save(tfm, f,