Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metrics for binary tasks with float predictions #654

Merged
merged 9 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion prepare/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.unitxt import add_to_catalog
from src.unitxt.metrics import Accuracy
from src.unitxt.metrics import Accuracy, BinaryAccuracy, BinaryMaxAccuracy
from src.unitxt.test_utils.metrics import test_metric

metric = Accuracy()
Expand Down Expand Up @@ -32,3 +32,9 @@
)

add_to_catalog(metric, "metrics.accuracy", overwrite=True)

metric = BinaryAccuracy()
add_to_catalog(metric, "metrics.accuracy_binary", overwrite=True)

metric = BinaryMaxAccuracy()
add_to_catalog(metric, "metrics.max_accuracy_binary", overwrite=True)
11 changes: 11 additions & 0 deletions prepare/processors/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TakeFirstWord,
ToYesOrNone,
YesNoToInt,
YesToOneElseZero,
)

logger = get_logger()
Expand Down Expand Up @@ -160,6 +161,16 @@
overwrite=True,
)

add_to_catalog(
SequentialOperator(
steps=[
YesToOneElseZero(field="prediction", process_every_value=False),
]
),
"processors.predictions_yes_1_else_0",
overwrite=True,
)

add_to_catalog(
SequentialOperator(
steps=[
Expand Down
3 changes: 3 additions & 0 deletions src/unitxt/catalog/metrics/accuracy_binary.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"type": "binary_accuracy"
}
3 changes: 3 additions & 0 deletions src/unitxt/catalog/metrics/max_accuracy_binary.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"type": "binary_max_accuracy"
}
10 changes: 10 additions & 0 deletions src/unitxt/catalog/processors/predictions_yes_1_else_0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"type": "sequential_operator",
"steps": [
{
"type": "yes_to_one_else_zero",
"field": "prediction",
"process_every_value": false
}
]
}
129 changes: 88 additions & 41 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import re
import string
import uuid
Expand Down Expand Up @@ -1098,11 +1097,6 @@ def get_str_id(self, str):
self.id_to_str[id] = str
return self.str_to_id[str]

def _labels_match_average_format(
self, references: List[List[str]], predictions: List[str]
):
return True

def compute(
self,
references: List[List[str]],
Expand All @@ -1112,8 +1106,6 @@ def compute(
assert all(
len(reference) == 1 for reference in references
), "Only a single reference per prediction is allowed in F1 metric"
if not self._labels_match_average_format(references, predictions):
return {self.main_score: np.nan}

self.str_to_id = {}
self.id_to_str = {}
Expand Down Expand Up @@ -1149,27 +1141,29 @@ class F1Micro(F1):


class F1Binary(F1):
"""Calculate f1 for a binary task, using 0.5 as the threshold in the case of float predictions."""

process_single_instances = False
main_score = "f1_binary"
average = "binary"
pos_classes = {"1", "1.0", "yes", "true"}
threshold = 0.5

def get_str_id(self, str):
if str.lower() in self.pos_classes:
return 1
return 0
return int(str)

# References and predictions must include up to 2 unique values, one of them in pos_classes
def _labels_match_average_format(
self, references: List[List[str]], predictions: List[str]
):
classes = set(predictions + list(itertools.chain(*references)))
n_classes = len(classes)
if n_classes > 2:
return False
if n_classes == 2 and len(set(classes).difference(self.pos_classes)) == 0:
return False
return True
def compute(
self,
references: List[List[str]],
predictions: List[str],
task_data: List[Dict],
) -> dict:
predictions_floats = [to_float_or_default(p) for p in predictions]
predictions = [str(int(p > self.threshold)) for p in predictions_floats]
references = [
["1"] if r[0].lower() in self.pos_classes else ["0"] for r in references
]
return super().compute(references, predictions, task_data)


class RecallBinary(F1Binary):
Expand Down Expand Up @@ -3088,6 +3082,8 @@ class FixedGroupAbsvalNormHedgesGParaphraseStringContainment(StringContainment):


class BinaryMaxF1(F1Binary):
"""Calculate the maximal F1 and the decision threshold that achieves it for a binary task with float predictions."""

main_score = "max_f1_binary"

def compute(
Expand All @@ -3099,31 +3095,14 @@ def compute(
assert all(
len(reference) == 1 for reference in references
), "Only a single reference per prediction is allowed in F1 metric"
classes = set(itertools.chain(*references))
n_clases = len(classes)
assert len(classes) <= 2, "References of BinaryMaxF1 must be binary"
pos_classes = classes.intersection(self.pos_classes)
neg_classes = classes.difference(self.pos_classes)
n_pos_classes = len(pos_classes)
if n_clases == 2:
assert (
n_pos_classes == 1
), "Only one positive class is allowed in BinaryMaxF1"
pos_class = next(iter(pos_classes)) if n_pos_classes > 0 else "1.0"
neg_class = next(iter(neg_classes)) if len(neg_classes) > 0 else "0.0"

float_predictions = []
for prediction in predictions:
try:
float_predictions.append(float(prediction))
except Exception:
float_predictions.append(0)
float_predictions = [to_float_or_default(p) for p in predictions]

best_thr = -1
best_f1 = -1
for thr in set(float_predictions):
new_predictions = [
pos_class if float_prediction >= thr else neg_class
"1" if float_prediction >= thr else "0"
for float_prediction in float_predictions
]
f1 = super().compute(references, new_predictions, task_data)[
Expand All @@ -3134,3 +3113,71 @@ def compute(
best_thr = thr

return {self.main_score: best_f1, "best_thr_maxf1": best_thr}


class BinaryAccuracy(InstanceMetric):
"""Calculate accuracy for a binary task, using 0.5 as the threshold in the case of float predictions."""

reduction_map = {"mean": ["accuracy_binary"]}
main_score = "accuracy_binary"
ci_scores = ["accuracy_binary"]
pos_classes = {"1", "1.0", "yes", "true"}
threshold = 0.5

def compute(
self, references: List[Any], prediction: Any, task_data: List[Dict]
) -> dict:
assert (
len(references) == 1
), "Only a single reference per prediction is allowed in Binary Accuracy metric"

float_prediction = to_float_or_default(prediction)
prediction = str(int(float_prediction > self.threshold))
references = ["1"] if references[0].lower() in self.pos_classes else ["0"]

result = {self.main_score: float([prediction] == references)}
result["score"] = result[self.main_score]
result["score_name"] = self.main_score
return result


class BinaryMaxAccuracy(GlobalMetric):
"""Calculate the maximal accuracy and the decision threshold that achieves it for a binary task with float predictions."""

process_single_instances = False
main_score = "max_accuracy_binary"
pos_classes = {"1", "1.0", "yes", "true"}

def compute(
self,
references: List[List[str]],
predictions: List[List[str]],
task_data: List[Dict],
) -> dict:
assert all(
len(reference) == 1 for reference in references
), "Only a single reference per prediction is allowed in BinaryMaxAccuracy metric"

float_predictions = [to_float_or_default(p) for p in predictions]
references = [
["1"] if r[0].lower() in self.pos_classes else ["0"] for r in references
]

best_thr = -1
best_acc = -1
for thr in set(float_predictions):
new_predictions = [
"1" if float_prediction >= thr else "0"
for float_prediction in float_predictions
]
acc = np.mean(
[
[prediction] == reference
for prediction, reference in zip(new_predictions, references)
]
)
if acc > best_acc:
best_acc = acc
best_thr = thr

return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
7 changes: 7 additions & 0 deletions src/unitxt/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ def process_value(self, text: Any) -> Any:
return text


class YesToOneElseZero(FieldOperator):
def process_value(self, text: Any) -> Any:
if text == "yes":
return "1"
return "0"


class StrToFloatFormat(FieldOperator):
def process_value(self, text: Any) -> Any:
try:
Expand Down
58 changes: 40 additions & 18 deletions tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from src.unitxt.logging_utils import get_logger
from src.unitxt.metrics import (
Accuracy,
BinaryAccuracy,
BinaryMaxAccuracy,
BinaryMaxF1,
F1Binary,
F1Macro,
Expand Down Expand Up @@ -169,8 +171,8 @@ def test_f1_micro(self):

def test_f1_binary(self):
metric = F1Binary()
references = [["1"], ["0"], ["0"], ["0"], ["1"], ["1"]]
predictions = ["1", "1", "0", "0", "1", "1"]
references = [["1"], ["0"], ["0"], ["0"], ["Yes"], ["1"]]
predictions = ["0.8", "1", "0.2", "0", "0.6", "1"]

global_target = 0.8571428571428
outputs = apply_metric(
Expand Down Expand Up @@ -213,38 +215,58 @@ def test_recall_binary(self):
self.assertEqual("recall_binary", outputs[0]["score"]["global"]["score_name"])
self.assertEqual("recall_binary", outputs[0]["score"]["instance"]["score_name"])

def test_f1_binary_non_binary(self):
metric = F1Binary()
references = [["1"], ["0"], ["yes"], ["0"], ["1"], ["1"]]
predictions = ["1", "1", "0", "0", "1", "1"]
def test_max_f1(self):
metric = BinaryMaxF1()
references = [["1"], ["0"], ["0"]]
predictions = ["0.3", "0", "0.7"]

global_target = 0.666666666666
outputs = apply_metric(
metric=metric, predictions=predictions, references=references
)
self.assertTrue(isnan(outputs[0]["score"]["global"]["score"]))

metric = F1Binary()
references = [["1"], ["yes"], ["1"], ["1"]]
predictions = ["1", "1", "1", "1"]
self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"])
self.assertEqual("max_f1_binary", outputs[0]["score"]["global"]["score_name"])
self.assertEqual("max_f1_binary", outputs[0]["score"]["instance"]["score_name"])

def test_accuracy_binary(self):
metric = BinaryAccuracy()
references = [["1"], ["0"], ["0"], ["1"], ["0"]]
predictions = ["0.3", "0", "0.7", "1.0", "0.2"]

expected_global_result = {
"accuracy_binary": 3 / 5,
"score": 3 / 5,
"score_name": "accuracy_binary",
}

outputs = apply_metric(
metric=metric, predictions=predictions, references=references
)
self.assertTrue(isnan(outputs[0]["score"]["global"]["score"]))
global_result = {
k: v
for k, v in outputs[0]["score"]["global"].items()
if k in expected_global_result
}
self.assertDictEqual(expected_global_result, global_result)

def test_max_f1(self):
metric = BinaryMaxF1()
references = [["1"], ["0"], ["0"]]
predictions = ["0.3", "0", "0.7"]
def test_binary_max_accuracy(self):
metric = BinaryMaxAccuracy()
references = [["1"], ["0"], ["0"], ["1"], ["0"]]
predictions = ["0.3", "0", "0.7", "1.0", "0.2"]

global_target = 0.666666666666
global_target = 0.8
outputs = apply_metric(
metric=metric, predictions=predictions, references=references
)

self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"])
self.assertEqual("max_f1_binary", outputs[0]["score"]["global"]["score_name"])
self.assertEqual("max_f1_binary", outputs[0]["score"]["instance"]["score_name"])
self.assertEqual(
"max_accuracy_binary", outputs[0]["score"]["global"]["score_name"]
)
self.assertEqual(
"max_accuracy_binary", outputs[0]["score"]["instance"]["score_name"]
)

def test_f1_macro(self):
metric = F1Macro()
Expand Down
16 changes: 16 additions & 0 deletions tests/library/test_postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ def test_to_yes_or_none(self):
tester=self,
)

def test_predictions_yes_1_else_0(self):
parser, _ = fetch_artifact("processors.predictions_yes_1_else_0")
inputs = ["yes", "no", "yaa"]
targets = [
{"references": ["yes"], "prediction": "1"},
{"references": ["no"], "prediction": "0"},
{"references": ["yaa"], "prediction": "0"},
]

check_operator(
operator=parser,
inputs=list_to_stream_with_prediction_and_references(inputs),
targets=targets,
tester=self,
)

def test_str_to_float_format(self):
parser, _ = fetch_artifact("processors.str_to_float_format")
inputs = ["-2.4", "5", "5a"]
Expand Down
Loading