Skip to content

Commit

Permalink
[feat] changes requested in LSL
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzobattistela committed Jul 15, 2023
1 parent e2ef8e6 commit 4e2e251
Showing 1 changed file with 8 additions and 35 deletions.
43 changes: 8 additions & 35 deletions tensorflow_similarity/losses/lifted_structure_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tensorflow_similarity.algebra import build_masks
from tensorflow_similarity.distances import Distance, distance_canonicalizer
from tensorflow_similarity.types import FloatTensor, IntTensor
from typing import Callable
from tensorflow_similarity import losses as tfsim_losses
from .metric_loss import MetricLoss
from .utils import compute_loss, negative_distances, positive_distances

Expand All @@ -35,28 +35,9 @@ def lifted_struct_loss(
distance: Distance,
positive_mining_strategy: str = "hard",
negative_mining_strategy: str = "easy",
soft_margin: bool = False,
margin: float = 1.0,
) -> FloatTensor:
"""Lifted Struct loss computations
Args:
labels: labels associated with the embed
embeddings: Embedded examples.
distance: Which distance function to use to compute the pairwise
distances between embeddings.
positive_mining_strategy: What mining strategy to use to select
embedding from the same class. Defaults to 'hard'.
Available: {'easy', 'hard'}
negative_mining_strategy: What mining strategy to use for select the
embedding from the different class. Defaults to 'easy'.
Available: {'hard', 'semi-hard', 'easy'}
soft_margin: Use a soft margin instead of an explicit one.
margin: Use an explicit value for the margin term.
Returns:
Loss: The loss value for the current batch.
"""
"""Lifted Struct loss computations"""

# Compute pairwise distances
pairwise_distances = distance(embeddings)
Expand All @@ -70,26 +51,21 @@ def lifted_struct_loss(
)

# Get negative distances
negative_dists = negative_distances(
negative_dists, _ = negative_distances(
negative_mining_strategy, pairwise_distances, negative_mask
)

# Reorder pairwise distances and negative mask based on positive indices
reordered_pairwise_distances = tf.gather(pairwise_distances, positive_indices, axis=1)
reordered_negative_mask = tf.gather(negative_mask, positive_indices, axis=1)

# Concatenate pairwise distances and negative masks along axis=1
concatenated_distances = tf.concat([pairwise_distances, reordered_pairwise_distances], axis=1)
concatenated_negative_mask = tf.concat([negative_mask, reordered_negative_mask], axis=1)

# Compute log sum exp with concatenated distances and negative mask
logsumexp_result = tf.reduce_logsumexp(concatenated_distances, axis=1, keepdims=True)
masked_logsumexp = tf.where(concatenated_negative_mask, logsumexp_result, tf.zeros_like(concatenated_distances))
# Compute (margin - neg_dist) logsum_exp values for each row (equation 4 in the paper)
neg_logsumexp = tfsim_losses.utils.logsumexp(margin - reordered_pairwise_distances, reordered_negative_mask)

# Calculate the loss
pairwise_diff = pairwise_distances - masked_logsumexp
j_values = tf.reduce_sum(tf.maximum(0.0, pairwise_diff), axis=1)
loss = tf.reduce_mean(j_values) / (2 * tf.reduce_sum(positive_mask))
j_values = neg_logsumexp + positive_dists

loss = j_values / 2.0

return loss

Expand Down Expand Up @@ -117,7 +93,6 @@ def __init__(
distance: Distance | str = "cosine",
positive_mining_strategy: str = "hard",
negative_mining_strategy: str = "easy",
soft_margin: bool = False,
margin: float = 1.0,
name: str = "LiftedStructLoss",
**kwargs,
Expand All @@ -133,7 +108,6 @@ def __init__(
negative_mining_strategy: What mining strategy to use for select the
embedding from the different class. Defaults to 'easy'.
Available: {'hard', 'semi-hard', 'easy'}
soft_margin: Use a soft margin instead of an explicit one.
margin: Use an explicit value for the margin term.
name: Loss name. Defaults to "LiftedStructLoss".
Expand All @@ -159,7 +133,6 @@ def __init__(
distance=distance,
positive_mining_strategy=positive_mining_strategy,
negative_mining_strategy=negative_mining_strategy,
soft_margin=soft_margin,
margin=margin,
**kwargs,
)

0 comments on commit 4e2e251

Please sign in to comment.