Skip to content

Commit

Permalink
add LlamaIndex faithfulness metric (#971)
Browse files Browse the repository at this point in the history
* add LlamaIndex faithfulness metric

Signed-off-by: Ariel Gera <ariel.gera1@ibm.com>

* share code between LlamaIndex metrics

Signed-off-by: Ariel Gera <ariel.gera1@ibm.com>

* use existing 'score_prefix' field

Signed-off-by: Ariel Gera <ariel.gera1@ibm.com>

* remove unused field

Signed-off-by: Ariel Gera <ariel.gera1@ibm.com>

---------

Signed-off-by: Ariel Gera <ariel.gera1@ibm.com>
  • Loading branch information
arielge authored and gitMichal committed Jul 15, 2024
1 parent 1483095 commit 643a438
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 94 deletions.
58 changes: 0 additions & 58 deletions prepare/metrics/llama_index_correctness.py

This file was deleted.

68 changes: 68 additions & 0 deletions prepare/metrics/llama_index_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from unitxt import add_to_catalog
from unitxt.metrics import LlamaIndexCorrectness, LlamaIndexFaithfulness
from unitxt.test_utils.metrics import test_metric

# Test with mock
model_name = "mock"
model_name_normalized = model_name.replace(".", "_").replace("-", "_")

predictions = ["The right answer"]
references = [["The right answer", "The wrong answer"]]
task_data = [
{
"question": "question number 1",
"contexts": ["context number 1"],
# "reference_answers": ["The right answer", "The wrong answer"],
},
]

metric_classes = {
"correctness": LlamaIndexCorrectness,
"faithfulness": LlamaIndexFaithfulness,
}

for metric_name, metric_class in metric_classes.items():
metric = metric_class(model_name=model_name)

score_name = f"{metric_name}_llama_index_by_{model_name_normalized}_judge"

instance_targets = [
{
"score": 1.0,
"score_name": score_name,
score_name: 1.0,
# "feedback": "The generated answer is fully correct and relevant to the user query, matching the reference answer exactly.",
}
] * len(predictions)

global_target = {
"score": 1.0,
"score_name": score_name,
score_name: 1.0,
}

outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
task_data=task_data,
instance_targets=instance_targets,
global_target=global_target,
)

# GPT model to catalog
model_names = ["gpt-3.5-turbo", "mock"]
for model_name in model_names:
model_name_normalized = model_name.replace(".", "_").replace("-", "_")

metric = (
metric_class(model_name=model_name, data_classification_policy=["public"])
if model_name != "mock"
else metric_class(model_name=model_name)
)

add_to_catalog(
metric,
f"metrics.rag.{metric_name}.llama_index_by_{model_name_normalized}",
overwrite=True,
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
{
"__type__": "llama_index_correctness",
"model_name": "gpt-3.5-turbo"
"model_name": "gpt-3.5-turbo",
"data_classification_policy": [
"public"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"__type__": "llama_index_faithfulness",
"model_name": "gpt-3.5-turbo",
"data_classification_policy": [
"public"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"__type__": "llama_index_faithfulness",
"model_name": "mock"
}
95 changes: 60 additions & 35 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,9 +2134,7 @@ def compute(
return self.pipe(predictions, batch_size=self.batch_size)


class LlamaIndexCorrectness(InstanceMetric):
"""LlamaIndex based metric class for evaluating correctness."""

class LlamaIndexLLMMetric(InstanceMetric):
model_name: str = ""
main_score: str = ""
prediction_type: str = "str"
Expand All @@ -2151,6 +2149,34 @@ class LlamaIndexCorrectness(InstanceMetric):

_requirements_list: List[str] = ["llama_index"]

def prepare(self):
self.model_name_normalized = self.model_name.replace(".", "_").replace("-", "_")
self.main_score: str = f"llama_index_by_{self.model_name_normalized}_judge"

self.reduction_map: Dict[str, List[str]] = {"mean": [self.main_score]}

if self.model_name in self.openai_models:
from llama_index.llms.openai import OpenAI

self.llm = OpenAI("gpt-3.5-turbo")
elif self.model_name in self.mock_models:
from llama_index.core.llms.mock import MockLLM

self.llm = MockLLM(system_prompt="5") # perfect score
else:
raise NotImplementedError(
f"LlamaIndexLLM metric does not support {self.model_name}, currently only gpt-3.5-turbo is supported"
)

def _model_using_extrnal_api(self):
return self.model_name in self.external_api_models


class LlamaIndexCorrectness(LlamaIndexLLMMetric):
"""LlamaIndex based metric class for evaluating correctness."""

score_prefix = "correctness_"

@staticmethod
def _custom_parser(eval_response: str):
"""Default parser function for evaluation response.
Expand All @@ -2174,37 +2200,14 @@ def _custom_parser(eval_response: str):
reasoning = reasoning_str.lstrip("\n")
return score, reasoning

def _model_using_extrnal_api(self):
return self.model_name in self.external_api_models

def prepare(self):
"""Initialization method for the metric. Initializes the CorrectnessEvaluator with the OpenAI model."""
super().prepare()

self.model_name_normalized = self.model_name.replace(".", "_").replace("-", "_")
self.main_score: str = (
f"correctness_llama_index_by_{self.model_name_normalized}_judge"
)

self.reduction_map: Dict[str, List[str]] = {"mean": [self.main_score]}

from llama_index.core.evaluation import CorrectnessEvaluator

if self.model_name in self.openai_models:
from llama_index.llms.openai import OpenAI

llm = OpenAI("gpt-3.5-turbo")
elif self.model_name in self.mock_models:
from llama_index.core.llms.mock import MockLLM

llm = MockLLM(system_prompt="5") # perfect score
else:
raise NotImplementedError(
f"LlamaIndexCorrectnessMetric does not support {self.model_name}, currently only gpt-3.5-turbo is supported"
)

self.evaluator = CorrectnessEvaluator(
llm=llm, parser_function=self._custom_parser
llm=self.llm, parser_function=self._custom_parser
)

def compute(
Expand All @@ -2226,9 +2229,6 @@ def compute(
Raises:
AssertionError: If the input does not meet the expected format.
"""
# treat the references as the questions and the predictions as answers
# assume a single reference

query = task_data["question"]

contexts = None
Expand All @@ -2247,11 +2247,36 @@ def compute(
)
result = max([results.score for results in per_reference_results])

return {
self.main_score: result / 5,
# "score_name": self.main_score,
# "feedback": result.feedback, # removed since this cannot be tested
}
return {self.main_score: result / 5}


class LlamaIndexFaithfulness(LlamaIndexLLMMetric):
"""LlamaIndex based metric class for evaluating faithfulness."""

score_prefix = "faithfulness_"

def prepare(self):
"""Initialization method for the metric. Initializes the FaithfulnessEvaluator with the OpenAI model."""
super().prepare()

from llama_index.core.evaluation import FaithfulnessEvaluator

self.evaluator = FaithfulnessEvaluator(llm=self.llm)

def compute(
self,
references: List[str],
prediction: str,
task_data: Dict,
) -> Dict[str, Any]:
result = self.evaluator.evaluate(
query=task_data["question"],
response=prediction,
contexts=task_data["contexts"],
)
score = result.score

return {self.main_score: score}


class Perplexity(BulkInstanceMetric):
Expand Down

0 comments on commit 643a438

Please sign in to comment.