From 4e2e251056463b7c0e18f480912c3e52c9244119 Mon Sep 17 00:00:00 2001 From: Lorenzobattistela Date: Sat, 15 Jul 2023 19:03:53 -0300 Subject: [PATCH] [feat] changes requested in LSL --- .../losses/lifted_structure_loss.py | 43 ++++--------------- 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/tensorflow_similarity/losses/lifted_structure_loss.py b/tensorflow_similarity/losses/lifted_structure_loss.py index 0c63b3c6..1ca6b4c2 100644 --- a/tensorflow_similarity/losses/lifted_structure_loss.py +++ b/tensorflow_similarity/losses/lifted_structure_loss.py @@ -24,7 +24,7 @@ from tensorflow_similarity.algebra import build_masks from tensorflow_similarity.distances import Distance, distance_canonicalizer from tensorflow_similarity.types import FloatTensor, IntTensor -from typing import Callable +from tensorflow_similarity import losses as tfsim_losses from .metric_loss import MetricLoss from .utils import compute_loss, negative_distances, positive_distances @@ -35,28 +35,9 @@ def lifted_struct_loss( distance: Distance, positive_mining_strategy: str = "hard", negative_mining_strategy: str = "easy", - soft_margin: bool = False, margin: float = 1.0, ) -> FloatTensor: - """Lifted Struct loss computations - - Args: - labels: labels associated with the embed - embeddings: Embedded examples. - distance: Which distance function to use to compute the pairwise - distances between embeddings. - positive_mining_strategy: What mining strategy to use to select - embedding from the same class. Defaults to 'hard'. - Available: {'easy', 'hard'} - negative_mining_strategy: What mining strategy to use for select the - embedding from the different class. Defaults to 'easy'. - Available: {'hard', 'semi-hard', 'easy'} - soft_margin: Use a soft margin instead of an explicit one. - margin: Use an explicit value for the margin term. - - Returns: - Loss: The loss value for the current batch. - """ + """Lifted Struct loss computations""" # Compute pairwise distances pairwise_distances = distance(embeddings) @@ -70,7 +51,7 @@ def lifted_struct_loss( ) # Get negative distances - negative_dists = negative_distances( + negative_dists, _ = negative_distances( negative_mining_strategy, pairwise_distances, negative_mask ) @@ -78,18 +59,13 @@ def lifted_struct_loss( reordered_pairwise_distances = tf.gather(pairwise_distances, positive_indices, axis=1) reordered_negative_mask = tf.gather(negative_mask, positive_indices, axis=1) - # Concatenate pairwise distances and negative masks along axis=1 - concatenated_distances = tf.concat([pairwise_distances, reordered_pairwise_distances], axis=1) - concatenated_negative_mask = tf.concat([negative_mask, reordered_negative_mask], axis=1) - - # Compute log sum exp with concatenated distances and negative mask - logsumexp_result = tf.reduce_logsumexp(concatenated_distances, axis=1, keepdims=True) - masked_logsumexp = tf.where(concatenated_negative_mask, logsumexp_result, tf.zeros_like(concatenated_distances)) + # Compute (margin - neg_dist) logsum_exp values for each row (equation 4 in the paper) + neg_logsumexp = tfsim_losses.utils.logsumexp(margin - reordered_pairwise_distances, reordered_negative_mask) # Calculate the loss - pairwise_diff = pairwise_distances - masked_logsumexp - j_values = tf.reduce_sum(tf.maximum(0.0, pairwise_diff), axis=1) - loss = tf.reduce_mean(j_values) / (2 * tf.reduce_sum(positive_mask)) + j_values = neg_logsumexp + positive_dists + + loss = j_values / 2.0 return loss @@ -117,7 +93,6 @@ def __init__( distance: Distance | str = "cosine", positive_mining_strategy: str = "hard", negative_mining_strategy: str = "easy", - soft_margin: bool = False, margin: float = 1.0, name: str = "LiftedStructLoss", **kwargs, @@ -133,7 +108,6 @@ def __init__( negative_mining_strategy: What mining strategy to use for select the embedding from the different class. Defaults to 'easy'. Available: {'hard', 'semi-hard', 'easy'} - soft_margin: Use a soft margin instead of an explicit one. margin: Use an explicit value for the margin term. name: Loss name. Defaults to "LiftedStructLoss". @@ -159,7 +133,6 @@ def __init__( distance=distance, positive_mining_strategy=positive_mining_strategy, negative_mining_strategy=negative_mining_strategy, - soft_margin=soft_margin, margin=margin, **kwargs, )