diff --git a/efficientdet/keras/efficientdet_keras.py b/efficientdet/keras/efficientdet_keras.py index e22795726..118ff3ae2 100644 --- a/efficientdet/keras/efficientdet_keras.py +++ b/efficientdet/keras/efficientdet_keras.py @@ -87,18 +87,14 @@ def fuse_features(self, nodes): dtype = nodes[0].dtype if self.weight_method == 'attn': - edge_weights = [] - for var in self.vars: - var = tf.cast(var, dtype=dtype) - edge_weights.append(var) + edge_weights = [tf.cast(var, dtype=dtype) + for var in self.vars] normalized_weights = tf.nn.softmax(tf.stack(edge_weights)) nodes = tf.stack(nodes, axis=-1) new_node = tf.reduce_sum(nodes * normalized_weights, -1) elif self.weight_method == 'fastattn': - edge_weights = [] - for var in self.vars: - var = tf.cast(var, dtype=dtype) - edge_weights.append(var) + edge_weights = [tf.nn.relu(tf.cast(var, dtype=dtype)) + for var in self.vars] weights_sum = add_n(edge_weights) nodes = [ nodes[i] * edge_weights[i] / (weights_sum + 0.0001) @@ -106,19 +102,14 @@ def fuse_features(self, nodes): ] new_node = add_n(nodes) elif self.weight_method == 'channel_attn': - edge_weights = [] - for var in self.vars: - var = tf.cast(var, dtype=dtype) - edge_weights.append(var) + edge_weights = [tf.cast(var, dtype=dtype) + for var in self.vars] normalized_weights = tf.nn.softmax(tf.stack(edge_weights, -1), axis=-1) nodes = tf.stack(nodes, axis=-1) new_node = tf.reduce_sum(nodes * normalized_weights, -1) elif self.weight_method == 'channel_fastattn': - edge_weights = [] - for var in self.vars: - var = tf.cast(var, dtype=dtype) - edge_weights.append(var) - + edge_weights = [tf.nn.relu(tf.cast(var, dtype=dtype)) + for var in self.vars] weights_sum = add_n(edge_weights) nodes = [ nodes[i] * edge_weights[i] / (weights_sum + 0.0001)