Skip to content

Commit

Permalink
additions
Browse files Browse the repository at this point in the history
  • Loading branch information
assaftibm committed Apr 4, 2024
1 parent e86f213 commit 902d8bc
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 0 deletions.
70 changes: 70 additions & 0 deletions prepare/metrics/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,76 @@
global_target=global_target,
)

metric = metrics["metrics.bert_score.deberta_large_mnli"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
["hello there general kenobi", "hello there!"],
["foo bar foobar", "foo bar"],
]
instance_targets = [
{"f1": 0.73, "precision": 0.83, "recall": 0.79, "score": 0.73, "score_name": "f1"},
{"f1": 1.0, "precision": 1.0, "recall": 1.0, "score": 1.0, "score_name": "f1"},
]

global_target = {
"f1": 0.87,
"f1_ci_high": 1.0,
"f1_ci_low": 0.73,
"precision": 0.92,
"precision_ci_high": 1.0,
"precision_ci_low": 0.83,
"recall": 0.9,
"recall_ci_high": 1.0,
"recall_ci_low": 0.79,
"score": 0.87,
"score_ci_high": 1.0,
"score_ci_low": 0.73,
"score_name": "f1",
}

test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

metric = metrics["metrics.bert_score.deberta_base_mnli"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
["hello there general kenobi", "hello there!"],
["foo bar foobar", "foo bar"],
]
instance_targets = [
{"f1": 0.81, "precision": 0.85, "recall": 0.81, "score": 0.81, "score_name": "f1"},
{"f1": 1.0, "precision": 1.0, "recall": 1.0, "score": 1.0, "score_name": "f1"},
]

global_target = {
"f1": 0.9,
"f1_ci_high": 1.0,
"f1_ci_low": 0.81,
"precision": 0.93,
"precision_ci_high": 1.0,
"precision_ci_low": 0.85,
"recall": 0.91,
"recall_ci_high": 1.0,
"recall_ci_low": 0.81,
"score": 0.9,
"score_ci_high": 1.0,
"score_ci_low": 0.81,
"score_name": "f1",
}

test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

metric = metrics["metrics.bert_score.distilbert_base_uncased"]
predictions = ["hello there general dude", "foo bar foobar"]
references = [
Expand Down
4 changes: 4 additions & 0 deletions src/unitxt/catalog/metrics/bert_score/deberta_base_mnli.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "bert_score",
"model_name": "microsoft/deberta-base-mnli"
}
4 changes: 4 additions & 0 deletions src/unitxt/catalog/metrics/bert_score/deberta_large_mnli.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "bert_score",
"model_name": "microsoft/deberta-large-mnli"
}

0 comments on commit 902d8bc

Please sign in to comment.