Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

separating rag metrics and adding bge metrics - catalog non backward compatible change #1104

Merged
merged 14 commits into from
Aug 5, 2024
23 changes: 18 additions & 5 deletions examples/evaluate_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,27 @@
result, _ = evaluate(
df,
metric_names=[
"metrics.rag.mrr",
"metrics.rag.map",
"metrics.rag.answer_correctness",
# default implementations
"metrics.rag.context_correctness",
"metrics.rag.context_relevance",
"metrics.rag.context_perplexity",
"metrics.rag.faithfulness",
"metrics.rag.answer_reward",
"metrics.rag.context_correctness",
"metrics.rag.context_perplexity",
"metrics.rag.answer_correctness",
# specific implementations
"metrics.rag.context_correctness.mrr",
"metrics.rag.context_correctness.map",
"metrics.rag.context_relevance.perplexity_flan_t5_small",
"metrics.rag.context_relevance.sentence_bert_bge",
"metrics.rag.context_relevance.sentence_bert_mini_lm",
"metrics.rag.faithfulness.token_k_precision",
"metrics.rag.faithfulness.bert_score_k_precision",
"metrics.rag.faithfulness.sentence_bert_bge",
"metrics.rag.faithfulness.sentence_bert_mini_lm",
"metrics.rag.answer_correctness.token_recall",
"metrics.rag.answer_correctness.bert_score_recall",
"metrics.rag.answer_correctness.sentence_bert_bge",
"metrics.rag.answer_correctness.sentence_bert_mini_lm",
],
)
result.round(2).to_csv("dataset_out.csv")
68 changes: 6 additions & 62 deletions prepare/metrics/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
"metrics.sentence_bert.mpnet_base_v2": SentenceBert(
model_name="sentence-transformers/all-mpnet-base-v2"
),
"metrics.sentence_bert.minilm_l12_v2": SentenceBert(
model_name="sentence-transformers/all-MiniLM-L12-v2"
),
"metrics.sentence_bert.bge_large_en_1_5": SentenceBert(
model_name="BAAI/bge-large-en-v1.5"
),
"metrics.reward.deberta_v3_large_v2": Reward(
model_name="OpenAssistant/reward-model-deberta-v3-large-v2"
),
Expand Down Expand Up @@ -316,68 +322,6 @@
# metrics.rag.correctness
# metrics.rag.recall
# metrics.rag.bert_recall
context_relevance = MetricPipeline(
main_score="perplexity",
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="question", to_field="prediction"),
],
metric="metrics.perplexity_q.flan_t5_small",
)
add_to_catalog(context_relevance, "metrics.rag.context_relevance", overwrite=True)
context_perplexity = MetricPipeline(
main_score="score",
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="question", to_field="prediction"),
],
metric="metrics.perplexity_q.flan_t5_small",
postpreprocess_steps=[
Copy(field="score/instance/reference_scores", to_field="score/instance/score")
],
)
add_to_catalog(context_perplexity, "metrics.rag.context_perplexity", overwrite=True)
for new_catalog_name, base_catalog_name in [
("metrics.rag.faithfulness", "metrics.token_overlap"),
("metrics.rag.k_precision", "metrics.token_overlap"),
("metrics.rag.bert_k_precision", "metrics.bert_score.deberta_large_mnli"),
(
"metrics.rag.bert_k_precision_ml",
"metrics.bert_score.deberta_v3_base_mnli_xnli_ml",
),
]:
metric = MetricPipeline(
main_score="precision",
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="answer", to_field="prediction"),
],
metric=base_catalog_name,
)
add_to_catalog(metric, new_catalog_name, overwrite=True)

answer_reward = MetricPipeline(
main_score="score",
preprocess_steps=[
Copy(field="question", to_field="references"),
Copy(field="answer", to_field="prediction"),
# This metric compares the answer (as the prediction) to the question (as the reference).
# We have to wrap the question by a list (otherwise it will be a string),
# because references are expected to be lists
ListFieldValues(fields=["references"], to_field="references"),
],
metric="metrics.reward.deberta_v3_large_v2",
)
add_to_catalog(answer_reward, "metrics.rag.answer_reward", overwrite=True)
answer_inference = MetricPipeline(
main_score="perplexity",
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="answer", to_field="prediction"),
],
metric="metrics.perplexity_nli.t5_nli_mixture",
)
add_to_catalog(answer_inference, "metrics.rag.answer_inference", overwrite=True)

for axis, base_metric, main_score in [
("correctness", "token_overlap", "f1"),
Expand Down
25 changes: 18 additions & 7 deletions prepare/metrics/rag_answer_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,32 @@ def test_answer_correctness(task_data, catalog_name, global_target, instance_tar
)


for new_catalog_name, base_catalog_name in [
("metrics.rag.answer_correctness", "metrics.token_overlap"),
("metrics.rag.recall", "metrics.token_overlap"),
("metrics.rag.bert_recall", "metrics.bert_score.deberta_large_mnli"),
("metrics.rag.bert_recall_ml", "metrics.bert_score.deberta_v3_base_mnli_xnli_ml"),
base = "metrics.rag.answer_correctness"
default = "token_recall"

for new_catalog_name, base_catalog_name, main_score in [
("token_recall", "metrics.token_overlap", "recall"),
("bert_score_recall", "metrics.bert_score.deberta_large_mnli", "recall"),
(
"bert_score_recall_ml",
"metrics.bert_score.deberta_v3_base_mnli_xnli_ml",
"recall",
),
("sentence_bert_bge", "metrics.sentence_bert.bge_large_en_1_5", "score"),
("sentence_bert_mini_lm", "metrics.sentence_bert.bge_large_en_1_5", "score"),
]:
metric = MetricPipeline(
main_score="recall",
main_score=main_score,
preprocess_steps=[
Copy(field="ground_truths", to_field="references"),
Copy(field="answer", to_field="prediction"),
],
metric=base_catalog_name,
)
add_to_catalog(metric, new_catalog_name, overwrite=True)
add_to_catalog(metric, f"{base}.{new_catalog_name}", overwrite=True)

if new_catalog_name == default:
add_to_catalog(metric, base, overwrite=True)

if __name__ == "__main__":
# don't use "A" as a token because it is considered an article and removed by the token overlap
Expand Down
28 changes: 28 additions & 0 deletions prepare/metrics/rag_answer_relevance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from unitxt import add_to_catalog
from unitxt.metrics import (
MetricPipeline,
)
from unitxt.operators import Copy, ListFieldValues

answer_reward = MetricPipeline(
main_score="score",
preprocess_steps=[
Copy(field="question", to_field="references"),
Copy(field="answer", to_field="prediction"),
# This metric compares the answer (as the prediction) to the question (as the reference).
# We have to wrap the question by a list (otherwise it will be a string),
# because references are expected to be lists
ListFieldValues(fields=["references"], to_field="references"),
],
metric="metrics.reward.deberta_v3_large_v2",
)
add_to_catalog(answer_reward, "metrics.rag.answer_reward", overwrite=True)
answer_inference = MetricPipeline(
main_score="perplexity",
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="answer", to_field="prediction"),
],
metric="metrics.perplexity_nli.t5_nli_mixture",
)
add_to_catalog(answer_inference, "metrics.rag.answer_inference", overwrite=True)
21 changes: 13 additions & 8 deletions prepare/metrics/rag_context_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
from unitxt.metrics import MetricPipeline
from unitxt.operators import Copy

for metric_name, catalog_name in [
("map", "metrics.rag.map"),
("mrr", "metrics.rag.mrr"),
("mrr", "metrics.rag.context_correctness"),
("retrieval_at_k", "metrics.rag.retrieval_at_k"),
base = "metrics.rag.context_correctness"
default = "mrr"

for new_catalog_name, base_catalog_name, main_score in [
("mrr", "metrics.mrr", "score"),
("map", "metrics.map", "score"),
("retrieval_at_k", "metrics.retrieval_at_k", "score"),
]:
metric = MetricPipeline(
main_score="score",
main_score=main_score,
preprocess_steps=[
Copy(field="context_ids", to_field="prediction"),
Wrap(
field="ground_truths_context_ids", inside="list", to_field="references"
),
],
metric=f"metrics.{metric_name}",
metric=base_catalog_name,
)
add_to_catalog(metric, catalog_name, overwrite=True)
add_to_catalog(metric, f"{base}.{new_catalog_name}", overwrite=True)

if new_catalog_name == default:
add_to_catalog(metric, base, overwrite=True)
39 changes: 39 additions & 0 deletions prepare/metrics/rag_context_relevance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unitxt import add_to_catalog
from unitxt.metrics import (
MetricPipeline,
)
from unitxt.operators import Copy

base = "metrics.rag.context_relevance"
default = "perplexity_flan_t5_small"

for new_catalog_name, base_catalog_name, main_score in [
("perplexity_flan_t5_small", "metrics.perplexity_q.flan_t5_small", "perplexity"),
("sentence_bert_bge", "metrics.sentence_bert.bge_large_en_1_5", "score"),
("sentence_bert_mini_lm", "metrics.sentence_bert.bge_large_en_1_5", "score"),
]:
metric = MetricPipeline(
main_score=main_score,
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="question", to_field="prediction"),
],
metric=base_catalog_name,
)
add_to_catalog(metric, f"{base}.{new_catalog_name}", overwrite=True)

if new_catalog_name == default:
add_to_catalog(metric, base, overwrite=True)

context_perplexity = MetricPipeline(
main_score="score",
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="question", to_field="prediction"),
],
metric="metrics.perplexity_q.flan_t5_small",
postpreprocess_steps=[
Copy(field="score/instance/reference_scores", to_field="score/instance/score")
],
)
add_to_catalog(context_perplexity, "metrics.rag.context_perplexity", overwrite=True)
32 changes: 32 additions & 0 deletions prepare/metrics/rag_faithfulness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from unitxt import add_to_catalog
from unitxt.metrics import (
MetricPipeline,
)
from unitxt.operators import Copy

base = "metrics.rag.faithfulness"
default = "token_k_precision"

for new_catalog_name, base_catalog_name, main_score in [
("token_k_precision", "metrics.token_overlap", "precision"),
("bert_score_k_precision", "metrics.bert_score.deberta_large_mnli", "precision"),
(
"bert_score_k_precision_ml",
"metrics.bert_score.deberta_v3_base_mnli_xnli_ml",
"precision",
),
("sentence_bert_bge", "metrics.sentence_bert.bge_large_en_1_5", "score"),
("sentence_bert_mini_lm", "metrics.sentence_bert.bge_large_en_1_5", "score"),
]:
metric = MetricPipeline(
main_score=main_score,
preprocess_steps=[
Copy(field="contexts", to_field="references"),
Copy(field="answer", to_field="prediction"),
],
metric=base_catalog_name,
)
add_to_catalog(metric, f"{base}.{new_catalog_name}", overwrite=True)

if new_catalog_name == default:
add_to_catalog(metric, base, overwrite=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type__": "metric_pipeline",
"main_score": "score",
"preprocess_steps": [
{
"__type__": "copy",
"field": "ground_truths",
"to_field": "references"
},
{
"__type__": "copy",
"field": "answer",
"to_field": "prediction"
}
],
"metric": "metrics.sentence_bert.bge_large_en_1_5"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type__": "metric_pipeline",
"main_score": "score",
"preprocess_steps": [
{
"__type__": "copy",
"field": "ground_truths",
"to_field": "references"
},
{
"__type__": "copy",
"field": "answer",
"to_field": "prediction"
}
],
"metric": "metrics.sentence_bert.bge_large_en_1_5"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type__": "metric_pipeline",
"main_score": "perplexity",
"preprocess_steps": [
{
"__type__": "copy",
"field": "contexts",
"to_field": "references"
},
{
"__type__": "copy",
"field": "question",
"to_field": "prediction"
}
],
"metric": "metrics.perplexity_q.flan_t5_small"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type__": "metric_pipeline",
"main_score": "score",
"preprocess_steps": [
{
"__type__": "copy",
"field": "contexts",
"to_field": "references"
},
{
"__type__": "copy",
"field": "question",
"to_field": "prediction"
}
],
"metric": "metrics.sentence_bert.bge_large_en_1_5"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type__": "metric_pipeline",
"main_score": "score",
"preprocess_steps": [
{
"__type__": "copy",
"field": "contexts",
"to_field": "references"
},
{
"__type__": "copy",
"field": "question",
"to_field": "prediction"
}
],
"metric": "metrics.sentence_bert.bge_large_en_1_5"
}
17 changes: 17 additions & 0 deletions src/unitxt/catalog/metrics/rag/faithfulness/sentence_bert_bge.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type__": "metric_pipeline",
"main_score": "score",
"preprocess_steps": [
{
"__type__": "copy",
"field": "contexts",
"to_field": "references"
},
{
"__type__": "copy",
"field": "answer",
"to_field": "prediction"
}
],
"metric": "metrics.sentence_bert.bge_large_en_1_5"
}
Loading
Loading