-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Rouge and Meteor to InstanceMetric for faster score computat…
…ion (#1011) * Remove confidence interval calculation for meteor metric by default added a new metric with interval calculations Signed-off-by: Yoav Katz <katz@il.ibm.com> * Added error mesage when metrics not a list Signed-off-by: Yoav Katz <katz@il.ibm.com> * Added error mesage when post processors are not a list Signed-off-by: Yoav Katz <katz@il.ibm.com> * Changed Rouge to be HuggingfaceBulkMetric to avoid recalculation of metric on every resample Signed-off-by: Yoav Katz <katz@il.ibm.com> * added meteor as an HuggingFaceInstanceMetric Signed-off-by: dafnapension <dafnashein@yahoo.com> * removed meteor_with_confidence_intervals.json Signed-off-by: dafnapension <dafnashein@yahoo.com> * fixed test_metric_utils.py by better concentrating on rougeL only Signed-off-by: dafnapension <dafnashein@yahoo.com> * comment about rounded floats in tested scores Signed-off-by: dafnapension <dafnashein@yahoo.com> * while generating metric meteor, compmare against HF implementation Signed-off-by: dafnapension <dafnashein@yahoo.com> * added a test comparing new Rouge with HF Rouge, nd per arielge's good advice, changed bootstrap method to percentile in case of 100 or more instances Signed-off-by: dafnapension <dafnashein@yahoo.com> * implemented Meteor and Rouge with inhouse code Signed-off-by: dafnapension <dafnashein@yahoo.com> * download quietly, and import in prepare Signed-off-by: dafnapension <dafnashein@yahoo.com> * trying to avoid .secrets.baseline Signed-off-by: dafnapension <dafnashein@yahoo.com> * secret.baseline how do I get rid of it? Signed-off-by: dafnapension <dafnashein@yahoo.com> --------- Signed-off-by: Yoav Katz <katz@il.ibm.com> Signed-off-by: dafnapension <dafnashein@yahoo.com> Co-authored-by: dafnapension <dafnashein@yahoo.com> Co-authored-by: Elron Bandel <elronbandel@gmail.com>
- Loading branch information
1 parent
db595cc
commit 94daea3
Showing
10 changed files
with
281 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,65 @@ | ||
from unitxt import add_to_catalog | ||
from unitxt.metrics import HuggingfaceMetric | ||
from unitxt.metrics import HuggingfaceMetric, Meteor | ||
from unitxt.test_utils.metrics import test_metric | ||
|
||
metric = HuggingfaceMetric( | ||
metric = Meteor() | ||
|
||
predictions = [ | ||
"It is a guide to action which ensures that the military always obeys the commands of the party", | ||
"We strive for peace", | ||
"On the rag sat the cat", | ||
"I caught the ball", | ||
] | ||
references = [ | ||
[ | ||
"It is a guide to action that ensures that the military will forever heed Party commands" | ||
], | ||
["We hope for peace"], | ||
["The cat sat on the rag"], | ||
["He threw the ball"], | ||
] | ||
|
||
# the floats shown here are rounded just for the test. the actually | ||
# returned score are 15-16 digits to the right of the decimal point | ||
instance_targets = [ | ||
{"meteor": 0.69, "score": 0.69, "score_name": "meteor"}, | ||
{"meteor": 0.64, "score": 0.64, "score_name": "meteor"}, | ||
{"meteor": 0.5, "score": 0.5, "score_name": "meteor"}, | ||
{"meteor": 0.47, "score": 0.47, "score_name": "meteor"}, | ||
] | ||
|
||
global_target = { | ||
"meteor": 0.58, | ||
"meteor_ci_high": 0.59, | ||
"meteor_ci_low": 0.58, | ||
"score": 0.58, | ||
"score_ci_high": 0.59, | ||
"score_ci_low": 0.58, | ||
"score_name": "meteor", | ||
} | ||
|
||
metric.n_resamples = 3 | ||
# to match the setting to occur by testing on the global version, metric2, below | ||
|
||
outputs = test_metric( | ||
metric=metric, | ||
predictions=predictions, | ||
references=references, | ||
instance_targets=instance_targets, | ||
global_target=global_target, | ||
) | ||
|
||
# compare results with the HF version of meteor | ||
metric2 = HuggingfaceMetric( | ||
hf_metric_name="meteor", main_score="meteor", prediction_type="str" | ||
) | ||
|
||
outputs = test_metric( | ||
metric=metric2, | ||
predictions=predictions, | ||
references=references, | ||
instance_targets=instance_targets, | ||
global_target=global_target, | ||
) | ||
|
||
add_to_catalog(metric, "metrics.meteor", overwrite=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,3 @@ | ||
{ | ||
"__type__": "huggingface_metric", | ||
"hf_metric_name": "meteor", | ||
"main_score": "meteor", | ||
"prediction_type": "str" | ||
"__type__": "meteor" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
{ | ||
"__type__": "rouge", | ||
"n_resamples": null | ||
"__type__": "rouge" | ||
} |
3 changes: 2 additions & 1 deletion
3
src/unitxt/catalog/metrics/rouge_with_confidence_intervals.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
{ | ||
"__type__": "rouge" | ||
"__type__": "rouge", | ||
"__description__": "This is deprecated. Use 'metrics.rouge' which also generate confidence intervals" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.