diff --git a/models/tf.py b/models/tf.py index ae58ca738e2e..0520c30a96df 100644 --- a/models/tf.py +++ b/models/tf.py @@ -310,7 +310,7 @@ def call(self, inputs): y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1) z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no])) - return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), x) + return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1),) @staticmethod def _make_grid(nx=20, ny=20):