Skip to content

Commit

Permalink
[tests] starting to write testing for LSL
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzobattistela committed Jul 15, 2023
1 parent 2e5d8ba commit e2ef8e6
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions tests/losses/test_lifted_structure_loss.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,42 @@
import tensorflow as tf
from absl.testing import parameterized
from tensorflow.keras.losses import Reduction
from tensorflow.python.framework import combinations
from tensorflow.keras.losses import Reduction
from tensorflow_similarity import losses
from . import utils

@combinations.generate(combinations.combine(mode=["graph", "eager"]))
class LiftedStructureLossTest(tf.test.TestCase, parameterized.TestCase):
class TestLiftedStructLoss(tf.test.TestCase, parameterized.TestCase):
def test_config(self):
lsl_obj = losses.LiftedStructLoss( name="lifted_struct_loss", distance="cosine"
lifted_obj = losses.LiftedStructLoss(
reduction=Reduction.SUM,
name="lifted_loss",
distance="cosine",
)
self.assertEqual(lifted_obj.distance.name, "cosine")
self.assertEqual(lifted_obj.name, "lifted_loss")
self.assertEqual(lifted_obj.reduction, Reduction.SUM)

@parameterized.named_parameters(
{
"testcase_name": "_soft_margin",
"margin": None,
"expected_loss": 0.31326169,
},
{
"testcase_name": "_fixed_margin",
"margin": 1.1,
"expected_loss": 0.1,
},
)
def test_all_correct_unweighted(self, margin, expected_loss):
lifted_obj = losses.LiftedStructLoss(
reduction=Reduction.SUM_OVER_BATCH_SIZE,
margin=margin,
)
self.assertEqual(lsl_obj.distance.name, "cosine")
self.assertEqual(lsl_obj.name, "lifted_struct_loss")
y_true, y_preds = utils.generate_perfect_test_batch(batch_size=4)
loss = lifted_obj(y_true, y_preds)
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)



# TODO calculate results by hand before

0 comments on commit e2ef8e6

Please sign in to comment.