From 94daea377322c8a9c811632c8e14d2dd0686346a Mon Sep 17 00:00:00 2001 From: Yoav Katz <68273864+yoavkatz@users.noreply.github.com> Date: Mon, 22 Jul 2024 19:36:53 +0300 Subject: [PATCH] Refactor Rouge and Meteor to InstanceMetric for faster score computation (#1011) * Remove confidence interval calculation for meteor metric by default added a new metric with interval calculations Signed-off-by: Yoav Katz * Added error mesage when metrics not a list Signed-off-by: Yoav Katz * Added error mesage when post processors are not a list Signed-off-by: Yoav Katz * Changed Rouge to be HuggingfaceBulkMetric to avoid recalculation of metric on every resample Signed-off-by: Yoav Katz * added meteor as an HuggingFaceInstanceMetric Signed-off-by: dafnapension * removed meteor_with_confidence_intervals.json Signed-off-by: dafnapension * fixed test_metric_utils.py by better concentrating on rougeL only Signed-off-by: dafnapension * comment about rounded floats in tested scores Signed-off-by: dafnapension * while generating metric meteor, compmare against HF implementation Signed-off-by: dafnapension * added a test comparing new Rouge with HF Rouge, nd per arielge's good advice, changed bootstrap method to percentile in case of 100 or more instances Signed-off-by: dafnapension * implemented Meteor and Rouge with inhouse code Signed-off-by: dafnapension * download quietly, and import in prepare Signed-off-by: dafnapension * trying to avoid .secrets.baseline Signed-off-by: dafnapension * secret.baseline how do I get rid of it? Signed-off-by: dafnapension --------- Signed-off-by: Yoav Katz Signed-off-by: dafnapension Co-authored-by: dafnapension Co-authored-by: Elron Bandel --- .secrets.baseline | 4 +- prepare/metrics/meteor.py | 61 ++++++- prepare/metrics/rouge.py | 34 ++-- src/unitxt/catalog/metrics/meteor.json | 5 +- src/unitxt/catalog/metrics/rouge.json | 3 +- .../rouge_with_confidence_intervals.json | 3 +- src/unitxt/metrics.py | 149 ++++++++++++++++-- src/unitxt/standard.py | 10 ++ tests/library/test_metric_utils.py | 10 +- tests/library/test_metrics.py | 57 +++++-- 10 files changed, 281 insertions(+), 55 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index 32b037230..6ddf4c07e 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -3,7 +3,7 @@ "files": "^.secrets.baseline$", "lines": null }, - "generated_at": "2024-07-09T07:07:12Z", + "generated_at": "2024-07-22T10:56:00Z", "plugins_used": [ { "name": "AWSKeyDetector" @@ -82,7 +82,7 @@ "hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889", "is_secret": false, "is_verified": false, - "line_number": 1531, + "line_number": 1607, "type": "Hex High Entropy String", "verified_result": null } diff --git a/prepare/metrics/meteor.py b/prepare/metrics/meteor.py index c90b3d7cb..626198210 100644 --- a/prepare/metrics/meteor.py +++ b/prepare/metrics/meteor.py @@ -1,8 +1,65 @@ from unitxt import add_to_catalog -from unitxt.metrics import HuggingfaceMetric +from unitxt.metrics import HuggingfaceMetric, Meteor +from unitxt.test_utils.metrics import test_metric -metric = HuggingfaceMetric( +metric = Meteor() + +predictions = [ + "It is a guide to action which ensures that the military always obeys the commands of the party", + "We strive for peace", + "On the rag sat the cat", + "I caught the ball", +] +references = [ + [ + "It is a guide to action that ensures that the military will forever heed Party commands" + ], + ["We hope for peace"], + ["The cat sat on the rag"], + ["He threw the ball"], +] + +# the floats shown here are rounded just for the test. the actually +# returned score are 15-16 digits to the right of the decimal point +instance_targets = [ + {"meteor": 0.69, "score": 0.69, "score_name": "meteor"}, + {"meteor": 0.64, "score": 0.64, "score_name": "meteor"}, + {"meteor": 0.5, "score": 0.5, "score_name": "meteor"}, + {"meteor": 0.47, "score": 0.47, "score_name": "meteor"}, +] + +global_target = { + "meteor": 0.58, + "meteor_ci_high": 0.59, + "meteor_ci_low": 0.58, + "score": 0.58, + "score_ci_high": 0.59, + "score_ci_low": 0.58, + "score_name": "meteor", +} + +metric.n_resamples = 3 +# to match the setting to occur by testing on the global version, metric2, below + +outputs = test_metric( + metric=metric, + predictions=predictions, + references=references, + instance_targets=instance_targets, + global_target=global_target, +) + +# compare results with the HF version of meteor +metric2 = HuggingfaceMetric( hf_metric_name="meteor", main_score="meteor", prediction_type="str" ) +outputs = test_metric( + metric=metric2, + predictions=predictions, + references=references, + instance_targets=instance_targets, + global_target=global_target, +) + add_to_catalog(metric, "metrics.meteor", overwrite=True) diff --git a/prepare/metrics/rouge.py b/prepare/metrics/rouge.py index 56517b6c6..357806c54 100644 --- a/prepare/metrics/rouge.py +++ b/prepare/metrics/rouge.py @@ -2,7 +2,7 @@ from unitxt.metrics import Rouge from unitxt.test_utils.metrics import test_metric -metric = Rouge(n_resamples=None) +metric = Rouge() predictions = ["hello there", "general kenobi"] references = [["hello", "there"], ["general kenobi", "general yoda"]] @@ -28,13 +28,22 @@ global_target = { "rouge1": 0.83, + "rouge1_ci_high": 1.0, + "rouge1_ci_low": 0.67, "rouge2": 0.5, + "rouge2_ci_high": 1.0, + "rouge2_ci_low": 0.0, "rougeL": 0.83, + "rougeL_ci_high": 1.0, + "rougeL_ci_low": 0.67, "rougeLsum": 0.83, + "rougeLsum_ci_high": 1.0, + "rougeLsum_ci_low": 0.67, "score": 0.83, + "score_ci_high": 1.0, + "score_ci_low": 0.67, "score_name": "rougeL", } - outputs = test_metric( metric=metric, predictions=predictions, @@ -43,27 +52,12 @@ global_target=global_target, ) add_to_catalog(metric, "metrics.rouge", overwrite=True) - -global_target_with_confidence_intervals = global_target.copy() -global_target_with_confidence_intervals.update( - { - "rougeL_ci_low": 0.83, - "rougeL_ci_high": 0.83, - "score_ci_low": 0.83, - "score_ci_high": 0.83, - } +metric = Rouge( + __description__="This is deprecated. Use 'metrics.rouge' which also generate confidence intervals" ) -metric_with_confidence_intervals = Rouge() -outputs = test_metric( - metric=metric_with_confidence_intervals, - predictions=predictions, - references=references, - instance_targets=instance_targets, - global_target=global_target_with_confidence_intervals, -) add_to_catalog( - metric_with_confidence_intervals, + metric, "metrics.rouge_with_confidence_intervals", overwrite=True, ) diff --git a/src/unitxt/catalog/metrics/meteor.json b/src/unitxt/catalog/metrics/meteor.json index 293c6eae8..1b36f4d7f 100644 --- a/src/unitxt/catalog/metrics/meteor.json +++ b/src/unitxt/catalog/metrics/meteor.json @@ -1,6 +1,3 @@ { - "__type__": "huggingface_metric", - "hf_metric_name": "meteor", - "main_score": "meteor", - "prediction_type": "str" + "__type__": "meteor" } diff --git a/src/unitxt/catalog/metrics/rouge.json b/src/unitxt/catalog/metrics/rouge.json index 448f21f09..82844033a 100644 --- a/src/unitxt/catalog/metrics/rouge.json +++ b/src/unitxt/catalog/metrics/rouge.json @@ -1,4 +1,3 @@ { - "__type__": "rouge", - "n_resamples": null + "__type__": "rouge" } diff --git a/src/unitxt/catalog/metrics/rouge_with_confidence_intervals.json b/src/unitxt/catalog/metrics/rouge_with_confidence_intervals.json index 82844033a..85da472ec 100644 --- a/src/unitxt/catalog/metrics/rouge_with_confidence_intervals.json +++ b/src/unitxt/catalog/metrics/rouge_with_confidence_intervals.json @@ -1,3 +1,4 @@ { - "__type__": "rouge" + "__type__": "rouge", + "__description__": "This is deprecated. Use 'metrics.rouge' which also generate confidence intervals" } diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 73fadeb2a..79e720699 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -327,6 +327,7 @@ def score_based_confidence_interval( # otherwise, the aggregation_func needs to be applied AFTER resampling the instances; # that is, re-form the groups, calculate the function, and take the mean of the group scores aggregation_func = self.average_item_scores + for score_name in score_names: # If all computed instance level scores are the same, there is no point in computing # confidence intervals. So skip to the next score. @@ -1300,6 +1301,81 @@ def compute( return results +class HuggingfaceInstanceMetric(InstanceMetric): + hf_metric_name: str + + hf_metric_fields: List[str] + hf_compute_args: dict = {} + + def prepare(self): + super().prepare() + self.metric = evaluate.load( + self.hf_metric_name, experiment_id=str(uuid.uuid4()) + ) + + def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict: + # invokes module.compute, which invokes, e.g., meteor's _compute + + try: + score = self.metric.compute( + predictions=[prediction], + references=[references], + **self.hf_compute_args, + ) + except: + score = {self.main_score: np.nan} + + if self.hf_metric_fields is not None and len(self.hf_metric_fields) > 0: + to_ret = {field: score[field] for field in self.hf_metric_fields} + score = to_ret + + return score + + +class Meteor(InstanceMetric): + main_score = "meteor" + ci_scores = ["meteor"] + reduction_map = {"mean": ["meteor"]} + prediction_type = "str" + + _requirements_list: List[str] = ["nltk"] + alpha: float = 0.9 + beta: int = 3 + gamma: float = 0.5 + # unitxt uses nltk version >= 3.8 + + def prepare(self): + import nltk + + nltk.download("wordnet", quiet=True) + nltk.download("omw-1.4", quiet=True) + from nltk import word_tokenize + from nltk.translate import meteor_score + + self.word_tokenize = word_tokenize + self.meteor_score = meteor_score + + def verify(self): + import importlib.metadata as importlib_metadata + + from datasets.config import version + + nltk_version = version.parse(importlib_metadata.version("nltk")) + assert nltk_version >= version.Version( + "3.6.6" + ), "nltk version must be at least 3.6.6" + + def compute(self, references, prediction, task_data): + score = self.meteor_score.meteor_score( + [self.word_tokenize(ref) for ref in references], + self.word_tokenize(prediction), + alpha=self.alpha, + beta=self.beta, + gamma=self.gamma, + ) + return {"meteor": score} + + class F1(GlobalMetric): _metric = None main_score = "f1_macro" @@ -1691,7 +1767,49 @@ class F1MacroMultiLabel(F1MultiLabel): average = None -class Rouge(HuggingfaceMetric): +class Rouge(InstanceMetric): + main_score = "rougeL" + prediction_type = "str" + single_reference_per_prediction = False # multiple references allowed + rouge_types: List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + reduction_map = {"mean": ["rouge1", "rouge2", "rougeL", "rougeLsum"]} + ci_scores = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + + sent_split_newline: bool = True + _requirements_list: List[str] = ["nltk", "rouge_score"] + + def prepare(self): + import nltk + from rouge_score import rouge_scorer + + self.rouge_scorer = rouge_scorer + + nltk.download("punkt", quiet=True) + self.sent_tokenize = nltk.sent_tokenize + + def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict: + # for a single instance, prediction is of type str, and references: list of str + if self.sent_split_newline: + prediction = "\n".join(self.sent_tokenize(prediction.strip())) + + references = [ + "\n".join(self.sent_tokenize(reference.strip())) + for reference in references + ] + + # the following is taken from HF rouge, using the defaults: + # use_aggregator=True, use_stemmer=False, tokenizer=None + scorer = self.rouge_scorer.RougeScorer( + rouge_types=self.rouge_types, use_stemmer=False, tokenizer=None + ) + # with Unitxt, references is a list + score = scorer.score_multi(references, prediction) + for key in score: + score[key] = score[key].fmeasure + return score + + +class RougeHF(HuggingfaceInstanceMetric): hf_metric_name = "rouge" main_score = "rougeL" scale = 1.0 @@ -1699,8 +1817,10 @@ class Rouge(HuggingfaceMetric): prediction_type = "str" single_reference_per_prediction = False # multiple references allowed - use_aggregator: bool = True rouge_types: List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + reduction_map = {"mean": ["rouge1", "rouge2", "rougeL", "rougeLsum"]} + hf_metric_fields = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + ci_scores = ["rouge1", "rouge2", "rougeL", "rougeLsum"] sent_split_newline: bool = True @@ -1709,26 +1829,33 @@ class Rouge(HuggingfaceMetric): def prepare(self): super().prepare() + # We don't use the aggregation, to avoid running bootstrapping by the + # internal library (which is costly) and done by Unitxt in any case. self.hf_compute_args.update( - {"use_aggregator": self.use_aggregator, "rouge_types": self.rouge_types} + {"use_aggregator": False, "rouge_types": self.rouge_types} ) import nltk - nltk.download("punkt") + nltk.download("punkt", quiet=True) self.sent_tokenize = nltk.sent_tokenize - def compute(self, references, predictions, task_data: List[Dict]): + def compute(self, references, prediction, task_data: List[Dict]): + # for a single instance, prediction is of type str, and references: list of str if self.sent_split_newline: - predictions = [ - "\n".join(self.sent_tokenize(prediction.strip())) - for prediction in predictions - ] + prediction = "\n".join(self.sent_tokenize(prediction.strip())) + references = [ - ["\n".join(self.sent_tokenize(r.strip())) for r in reference] + "\n".join(self.sent_tokenize(reference.strip())) for reference in references ] - return super().compute(references, predictions, task_data) + + hf_score = super().compute(references, prediction, task_data) + for metric_field in self.hf_metric_fields: + if isinstance(hf_score[metric_field], list): + assert len(hf_score[metric_field]) == 1 + hf_score[metric_field] = hf_score[metric_field][0] + return hf_score # Computes char edit distance, ignoring whitespace diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index 3b110644f..eed5cde04 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -96,6 +96,16 @@ def verify(self): raise ValueError( f"max_train_instances should not exceed loader_limit ({self.loader_limit}), Got max_train_instances={self.max_train_instances}" ) + if self.metrics is not None and not isinstance(self.metrics, List): + raise ValueError( + f"metrics must be a list of metrics. Got metrics = {self.metrics}" + ) + if self.postprocessors is not None and not isinstance( + self.postprocessors, List + ): + raise ValueError( + f"post processors must be a list of post processor. Got postprocessors = {self.postprocessors}" + ) def prepare_refiners(self): self.train_refiner.max_instances = self.max_train_instances diff --git a/tests/library/test_metric_utils.py b/tests/library/test_metric_utils.py index 1f13b43e3..c03ab306c 100644 --- a/tests/library/test_metric_utils.py +++ b/tests/library/test_metric_utils.py @@ -21,12 +21,16 @@ class AvgRougeNoBootstrap(Rouge): def prepare(self): self.n_resamples = None self.rouge_types = ["rougeL"] + self.ci_scores = ["rougeL"] + self.hf_metric_fields = ["rougeL"] + self.reduction_map = {"mean": ["rougeL"]} self.use_aggregator = False super().prepare() - def compute(self, references, predictions, task_data: List[Dict]): - res_list = super().compute(references, predictions, task_data)["rougeL"] - return {"rougeL": nanmean(res_list)} + def compute(self, references, prediction, task_data: List[Dict]): + # single score for a single instance + res = super().compute(references, prediction, task_data)["rougeL"] + return {"rougeL": res} metric = AvgRougeNoBootstrap() references = [ diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 9c5a1991e..f1a3f27d9 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -1,4 +1,5 @@ from math import isnan +from typing import Dict, List from unitxt.inference import MockInferenceEngine from unitxt.llm_as_judge import LLMAsJudge @@ -38,6 +39,7 @@ GroupMeanAccuracy, GroupMeanStringContainment, GroupMeanTokenOverlap, + HuggingfaceMetric, KendallTauMetric, LlamaIndexCorrectness, MaxAccuracy, @@ -799,19 +801,54 @@ def test_rouge(self): global_target = 5 / 6 self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"]) - def test_rouge_l(self): - metric = Rouge( - n_resamples=None, # disable confidence interval calculation which fails for this metric configuration - use_aggregator=False, - rouge_types=["rougeL"], - ) - references = [["hello", "there"], ["general kenobi", "general yoda"]] - predictions = ["hello there", "general kenobi"] + # compare with the HF implementation + class OldRouge(HuggingfaceMetric): + hf_metric_name = "rouge" + main_score = "rougeL" + scale = 1.0 + + prediction_type = "str" + single_reference_per_prediction = False # multiple references allowed + + use_aggregator: bool = True + rouge_types: List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + + sent_split_newline: bool = True + + _requirements_list: List[str] = ["nltk", "rouge_score"] + + def prepare(self): + super().prepare() + + self.hf_compute_args.update( + { + "use_aggregator": self.use_aggregator, + "rouge_types": self.rouge_types, + } + ) + + import nltk + + nltk.download("punkt") + self.sent_tokenize = nltk.sent_tokenize + + def compute(self, references, predictions, task_data: List[Dict]): + if self.sent_split_newline: + predictions = [ + "\n".join(self.sent_tokenize(prediction.strip())) + for prediction in predictions + ] + references = [ + ["\n".join(self.sent_tokenize(r.strip())) for r in reference] + for reference in references + ] + return super().compute(references, predictions, task_data) + + metric = OldRouge() outputs = apply_metric( metric=metric, predictions=predictions, references=references ) - global_target = [2 / 3, 1.0] - self.assertListEqual(global_target, outputs[0]["score"]["global"]["score"]) + self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"]) def test_token_overlap(self): metric = TokenOverlap()