diff --git a/prepare/metrics/rag_answer_correctness.py b/prepare/metrics/rag_answer_correctness.py index b9f7cb712..c9c2ccea5 100644 --- a/prepare/metrics/rag_answer_correctness.py +++ b/prepare/metrics/rag_answer_correctness.py @@ -1,37 +1,6 @@ from unitxt import add_to_catalog from unitxt.metrics import MetricPipeline -from unitxt.operators import Copy, RenameFields -from unitxt.test_utils.metrics import test_evaluate, test_metric - - -def test_answer_correctness(task_data, catalog_name, global_target, instance_targets): - # test the evaluate call - test_evaluate( - global_target, - instance_targets=[ - {"score": instance["score"]} for instance in instance_targets - ], - task_data=task_data, - metric_name=catalog_name, - ) - # test using the usual metric pipeline - test_pipeline = MetricPipeline( - main_score="score", - preprocess_steps=[ - RenameFields(field_to_field={"task_data/ground_truths": "ground_truths"}), - RenameFields(field_to_field={"task_data/answer": "answer"}), - ], - metric=f"{catalog_name}", - ) - test_metric( - metric=test_pipeline, - predictions=[None] * len(instance_targets), - references=[[]] * len(instance_targets), - instance_targets=instance_targets, - global_target=global_target, - task_data=task_data, - ) - +from unitxt.operators import Copy base = "metrics.rag.answer_correctness" default = "token_recall" @@ -59,129 +28,3 @@ def test_answer_correctness(task_data, catalog_name, global_target, instance_tar if new_catalog_name == default: add_to_catalog(metric, base, overwrite=True) - -if __name__ == "__main__": - # don't use "A" as a token because it is considered an article and removed by the token overlap - # metric - task_data = [ - { # recall is 0.5 for the first ground_truth, 0 for the second ground_truth. - # so overall its max(0.5, 0) = 0.5 - "ground_truths": ["B C", "C"], - "answer": "B", - }, - { # recall is 1/3 - "ground_truths": ["D E F"], - "answer": "B C D", - }, - ] - - recall_instance_targets = [ - {"f1": 0.67, "precision": 1.0, "recall": 0.5, "score": 0.5, "score_name": "f1"}, - { - "f1": 0.33, - "precision": 0.33, - "recall": 0.33, - "score": 0.33, - "score_name": "f1", - }, - ] - - recall_global_target = { - "f1": 0.5, - "f1_ci_high": 0.67, - "f1_ci_low": 0.33, - "precision": 0.67, - "precision_ci_high": 1.0, - "precision_ci_low": 0.33, - "recall": 0.42, - "recall_ci_high": 0.5, - "recall_ci_low": 0.33, - "score": 0.42, - "score_ci_high": 0.67, - "score_ci_low": 0.33, - "score_name": "f1", - } - - for catalog_name, global_target, instance_targets in [ - ( - "metrics.rag.answer_correctness", - recall_global_target, - recall_instance_targets, - ), - ("metrics.rag.recall", recall_global_target, recall_instance_targets), - ]: - test_answer_correctness( - task_data, catalog_name, global_target, instance_targets - ) - - test_answer_correctness( - task_data, - catalog_name="metrics.rag.bert_recall", - global_target={ - "f1": 0.71, - "f1_ci_high": 0.71, - "f1_ci_low": 0.71, - "precision": 0.74, - "precision_ci_high": 0.77, - "precision_ci_low": 0.71, - "recall": 0.71, - "recall_ci_high": 0.71, - "recall_ci_low": 0.71, - "score": 0.71, - "score_ci_high": 0.71, - "score_ci_low": 0.71, - "score_name": "f1", - }, - instance_targets=[ - { - "f1": 0.71, - "precision": 0.77, - "recall": 0.71, - "score": 0.71, - "score_name": "f1", - }, - { - "f1": 0.71, - "precision": 0.71, - "recall": 0.71, - "score": 0.71, - "score_name": "f1", - }, - ], - ) - - test_answer_correctness( - task_data, - catalog_name="metrics.rag.bert_recall_ml", - global_target={ - "f1": 0.86, - "f1_ci_high": 0.97, - "f1_ci_low": 0.74, - "precision": 0.86, - "precision_ci_high": 0.97, - "precision_ci_low": 0.74, - "recall": 0.86, - "recall_ci_high": 0.97, - "recall_ci_low": 0.74, - "score": 0.86, - "score_ci_high": 0.97, - "score_ci_low": 0.74, - "score_name": "f1", - }, - instance_targets=[ - { - "f1": 0.97, - "precision": 0.97, - "recall": 0.97, - "score": 0.97, - "score_name": "f1", - }, - { - "f1": 0.74, - "precision": 0.74, - "recall": 0.74, - "score": 0.74, - "score_name": "f1", - }, - ], - ) diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 2b3eea812..bbcf411af 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -1942,3 +1942,168 @@ def text_context_correctness(self): global_target=global_target, task_data=task_data, ) + + @staticmethod + def test_answer_correctness( + task_data, catalog_name, global_target, instance_targets + ): + # test the evaluate call + test_evaluate( + global_target, + instance_targets=[ + {"score": instance["score"]} for instance in instance_targets + ], + task_data=task_data, + metric_name=catalog_name, + ) + # test using the usual metric pipeline + test_pipeline = MetricPipeline( + main_score="score", + preprocess_steps=[ + RenameFields( + field_to_field={"task_data/ground_truths": "ground_truths"} + ), + RenameFields(field_to_field={"task_data/answer": "answer"}), + ], + metric=f"{catalog_name}", + ) + test_metric( + metric=test_pipeline, + predictions=[None] * len(instance_targets), + references=[[]] * len(instance_targets), + instance_targets=instance_targets, + global_target=global_target, + task_data=task_data, + ) + + def test_answer_correctness_metrics(self): + # don't use "A" as a token because it is considered an article and removed by the token overlap + # metric + task_data = [ + { # recall is 0.5 for the first ground_truth, 0 for the second ground_truth. + # so overall its max(0.5, 0) = 0.5 + "ground_truths": ["B C", "C"], + "answer": "B", + }, + { # recall is 1/3 + "ground_truths": ["D E F"], + "answer": "B C D", + }, + ] + + recall_instance_targets = [ + { + "f1": 0.67, + "precision": 1.0, + "recall": 0.5, + "score": 0.5, + "score_name": "f1", + }, + { + "f1": 0.33, + "precision": 0.33, + "recall": 0.33, + "score": 0.33, + "score_name": "f1", + }, + ] + + recall_global_target = { + "f1": 0.5, + "f1_ci_high": 0.67, + "f1_ci_low": 0.33, + "precision": 0.67, + "precision_ci_high": 1.0, + "precision_ci_low": 0.33, + "recall": 0.42, + "recall_ci_high": 0.5, + "recall_ci_low": 0.33, + "score": 0.42, + "score_ci_high": 0.67, + "score_ci_low": 0.33, + "score_name": "f1", + } + + for catalog_name, global_target, instance_targets in [ + ( + "metrics.rag.answer_correctness", + recall_global_target, + recall_instance_targets, + ), + ("metrics.rag.recall", recall_global_target, recall_instance_targets), + ]: + self.test_answer_correctness( + task_data, catalog_name, global_target, instance_targets + ) + + self.test_answer_correctness( + task_data, + catalog_name="metrics.rag.bert_recall", + global_target={ + "f1": 0.71, + "f1_ci_high": 0.71, + "f1_ci_low": 0.71, + "precision": 0.74, + "precision_ci_high": 0.77, + "precision_ci_low": 0.71, + "recall": 0.71, + "recall_ci_high": 0.71, + "recall_ci_low": 0.71, + "score": 0.71, + "score_ci_high": 0.71, + "score_ci_low": 0.71, + "score_name": "f1", + }, + instance_targets=[ + { + "f1": 0.71, + "precision": 0.77, + "recall": 0.71, + "score": 0.71, + "score_name": "f1", + }, + { + "f1": 0.71, + "precision": 0.71, + "recall": 0.71, + "score": 0.71, + "score_name": "f1", + }, + ], + ) + + self.test_answer_correctness( + task_data, + catalog_name="metrics.rag.bert_recall_ml", + global_target={ + "f1": 0.86, + "f1_ci_high": 0.97, + "f1_ci_low": 0.74, + "precision": 0.86, + "precision_ci_high": 0.97, + "precision_ci_low": 0.74, + "recall": 0.86, + "recall_ci_high": 0.97, + "recall_ci_low": 0.74, + "score": 0.86, + "score_ci_high": 0.97, + "score_ci_low": 0.74, + "score_name": "f1", + }, + instance_targets=[ + { + "f1": 0.97, + "precision": 0.97, + "recall": 0.97, + "score": 0.97, + "score_name": "f1", + }, + { + "f1": 0.74, + "precision": 0.74, + "recall": 0.74, + "score": 0.74, + "score_name": "f1", + }, + ], + )