diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index d03f2287b..1818a4684 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -8,6 +8,8 @@ F1Micro, F1MicroMultiLabel, F1Weighted, + KendallTauMetric, + RocAuc, Rouge, Squad, TokenOverlap, @@ -388,6 +390,26 @@ def test_token_overlap(self): for target, value in global_targets.items(): self.assertAlmostEqual(value, outputs[0]["score"]["global"][target]) + def test_roc_auc(self): + metric = RocAuc() + predictions = ["0.2", "0.8", "1.0"] + references = [["1.0"], ["0.0"], ["1.0"]] + outputs = apply_metric( + metric=metric, predictions=predictions, references=references + ) + global_target = 0.5 + self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"]) + + def test_kendalltau(self): + metric = KendallTauMetric() + predictions = ["1.0", "2.0", "1.0"] + references = [["-1.0"], ["1.0"], ["0.0"]] + outputs = apply_metric( + metric=metric, predictions=predictions, references=references + ) + global_target = 0.81649658092772 + self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"]) + class TestConfidenceIntervals(UnitxtTestCase): def test_confidence_interval_off(self):