Skip to content

Commit

Permalink
Unify preds, target input arguments for text metrics [2of2] cer, …
Browse files Browse the repository at this point in the history
…ter, wer, mer, rouge, squad (#727)

* Update naming conventions for cer, ter, wer, sacrebleu, rouge, squad
* Add warnings for BC

Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 13, 2022
1 parent 77babc9 commit 9094985
Show file tree
Hide file tree
Showing 19 changed files with 423 additions and 332 deletions.
40 changes: 40 additions & 0 deletions tests/text/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections import namedtuple

Input = namedtuple("Input", ["preds", "targets"])
SquadInput = namedtuple("SquadInput", ["preds", "targets", "exact_match", "f1"])

# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu and adjusted
Expand Down Expand Up @@ -64,6 +65,45 @@

_inputs_error_rate_batch_size_2 = Input(**ERROR_RATES_BATCHES_2)

SAMPLE_1 = {
"exact_match": 100.0,
"f1": 100.0,
"preds": {"prediction_text": "1976", "id": "id1"},
"targets": {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
}

SAMPLE_2 = {
"exact_match": 0.0,
"f1": 0.0,
"preds": {"prediction_text": "Hello", "id": "id2"},
"targets": {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
}

BATCH = {
"exact_match": [100.0, 0.0],
"f1": [100.0, 0.0],
"preds": [
{"prediction_text": "1976", "id": "id1"},
{"prediction_text": "Hello", "id": "id2"},
],
"targets": [
{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
{"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
],
}

_inputs_squad_exact_match = SquadInput(
preds=SAMPLE_1["preds"], targets=SAMPLE_1["targets"], exact_match=SAMPLE_1["exact_match"], f1=SAMPLE_1["f1"]
)

_inputs_squad_exact_mismatch = SquadInput(
preds=SAMPLE_2["preds"], targets=SAMPLE_2["targets"], exact_match=SAMPLE_2["exact_match"], f1=SAMPLE_2["f1"]
)

_inputs_squad_batch_match = SquadInput(
preds=BATCH["preds"], targets=BATCH["targets"], exact_match=BATCH["exact_match"], f1=BATCH["f1"]
)

# single reference
TUPLE_OF_SINGLE_REFERENCES = (((REFERENCE_1A), (REFERENCE_1B)), ((REFERENCE_1B), (REFERENCE_1C)))
_inputs_single_reference = Input(preds=TUPLE_OF_HYPOTHESES, targets=TUPLE_OF_SINGLE_REFERENCES)
4 changes: 2 additions & 2 deletions tests/text/test_cer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
compute_measures = Callable


def compare_fn(prediction: Union[str, List[str]], reference: Union[str, List[str]]):
return cer(reference, prediction)
def compare_fn(preds: Union[str, List[str]], target: Union[str, List[str]]):
return cer(target, preds)


@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
Expand Down
4 changes: 2 additions & 2 deletions tests/text/test_mer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torchmetrics.text.mer import MatchErrorRate


def _compute_mer_metric_jiwer(prediction: Union[str, List[str]], reference: Union[str, List[str]]):
return compute_measures(reference, prediction)["mer"]
def _compute_mer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["mer"]


@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
Expand Down
14 changes: 7 additions & 7 deletions tests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,27 @@

def _compute_rouge_score(
preds: Sequence[str],
targets: Sequence[Sequence[str]],
target: Sequence[Sequence[str]],
use_stemmer: bool,
rouge_level: str,
metric: str,
accumulate: str,
):
"""Evaluates rouge scores from rouge-score package for baseline evaluation."""
if isinstance(targets, list) and all(isinstance(target, str) for target in targets):
targets = [targets] if isinstance(preds, str) else [[target] for target in targets]
if isinstance(target, list) and all(isinstance(tgt, str) for tgt in target):
target = [target] if isinstance(preds, str) else [[tgt] for tgt in target]

if isinstance(preds, str):
preds = [preds]

if isinstance(targets, str):
targets = [[targets]]
if isinstance(target, str):
target = [[target]]

scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = BootstrapAggregator()

for target_raw, pred_raw in zip(targets, preds):
list_results = [scorer.score(target, pred_raw) for target in target_raw]
for target_raw, pred_raw in zip(target, preds):
list_results = [scorer.score(tgt, pred_raw) for tgt in target_raw]
aggregator_avg = BootstrapAggregator()

if accumulate == "best":
Expand Down
72 changes: 35 additions & 37 deletions tests/text/test_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,26 @@
import torch.multiprocessing as mp

from tests.helpers.testers import _assert_allclose, _assert_tensor
from tests.text.inputs import _inputs_squad_batch_match, _inputs_squad_exact_match, _inputs_squad_exact_mismatch
from torchmetrics.functional.text import squad
from torchmetrics.text.squad import SQuAD

SAMPLE_1 = {
"exact_match": 100.0,
"f1": 100.0,
"predictions": {"prediction_text": "1976", "id": "id1"},
"references": {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
}

SAMPLE_2 = {
"exact_match": 0.0,
"f1": 0.0,
"predictions": {"prediction_text": "Hello", "id": "id2"},
"references": {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
}

BATCH = {
"exact_match": [100.0, 0.0],
"f1": [100.0, 0.0],
"predictions": [
{"prediction_text": "1976", "id": "id1"},
{"prediction_text": "Hello", "id": "id2"},
],
"references": [
{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"},
{"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"},
],
}


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
"preds, targets, exact_match, f1",
[
(SAMPLE_1["predictions"], SAMPLE_1["references"], SAMPLE_1["exact_match"], SAMPLE_1["exact_match"]),
(SAMPLE_2["predictions"], SAMPLE_2["references"], SAMPLE_2["exact_match"], SAMPLE_2["exact_match"]),
(
_inputs_squad_exact_match.preds,
_inputs_squad_exact_match.targets,
_inputs_squad_exact_match.exact_match,
_inputs_squad_exact_match.f1,
),
(
_inputs_squad_exact_mismatch.preds,
_inputs_squad_exact_mismatch.targets,
_inputs_squad_exact_mismatch.exact_match,
_inputs_squad_exact_mismatch.f1,
),
],
)
def test_score_fn(preds, targets, exact_match, f1):
Expand All @@ -54,14 +38,21 @@ def test_score_fn(preds, targets, exact_match, f1):


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])],
"preds, targets, exact_match, f1",
[
(
_inputs_squad_batch_match.preds,
_inputs_squad_batch_match.targets,
_inputs_squad_batch_match.exact_match,
_inputs_squad_batch_match.f1,
)
],
)
def test_accumulation(preds, targets, exact_match, f1):
"""Tests for metric works with accumulation."""
squad_metric = SQuAD()
for pred, target in zip(preds, targets):
squad_metric.update(preds=[pred], targets=[target])
squad_metric.update(preds=[pred], target=[target])
metrics_score = squad_metric.compute()

_assert_tensor(metrics_score["exact_match"])
Expand All @@ -70,13 +61,13 @@ def test_accumulation(preds, targets, exact_match, f1):
_assert_allclose(metrics_score["f1"], torch.mean(torch.tensor(f1)))


def _squad_score_ddp(rank, world_size, pred, target, exact_match, f1):
def _squad_score_ddp(rank, world_size, pred, targets, exact_match, f1):
"""Define a DDP process for SQuAD metric."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
squad_metric = SQuAD()
squad_metric.update(pred, target)
squad_metric.update(pred, targets)
metrics_score = squad_metric.compute()
_assert_tensor(metrics_score["exact_match"])
_assert_tensor(metrics_score["f1"])
Expand All @@ -91,8 +82,15 @@ def _test_score_ddp_fn(rank, world_size, preds, targets, exact_match, f1):


@pytest.mark.parametrize(
"preds,targets,exact_match,f1",
[(BATCH["predictions"], BATCH["references"], BATCH["exact_match"], BATCH["f1"])],
"preds, targets, exact_match, f1",
[
(
_inputs_squad_batch_match.preds,
_inputs_squad_batch_match.targets,
_inputs_squad_batch_match.exact_match,
_inputs_squad_batch_match.f1,
)
],
)
@pytest.mark.skipif(not dist.is_available(), reason="test requires torch distributed")
def test_score_ddp(preds, targets, exact_match, f1):
Expand Down
42 changes: 21 additions & 21 deletions tests/text/test_ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def sacrebleu_ter_fn(
preds: Sequence[str],
targets: Sequence[Sequence[str]],
target: Sequence[Sequence[str]],
normalized: bool,
no_punct: bool,
asian_support: bool,
Expand All @@ -26,8 +26,8 @@ def sacrebleu_ter_fn(
normalized=normalized, no_punct=no_punct, asian_support=asian_support, case_sensitive=case_sensitive
)
# Sacrebleu CHRF expects different format of input
targets = [[target[i] for target in targets] for i in range(len(targets[0]))]
sacrebleu_ter = sacrebleu_ter.corpus_score(preds, targets).score / 100
target = [[tgt[i] for tgt in target] for i in range(len(target[0]))]
sacrebleu_ter = sacrebleu_ter.corpus_score(preds, target).score / 100
return tensor(sacrebleu_ter)


Expand Down Expand Up @@ -118,41 +118,41 @@ def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctu


def test_ter_empty_functional():
hyp = []
ref = [[]]
assert translation_edit_rate(hyp, ref) == tensor(0.0)
preds = []
targets = [[]]
assert translation_edit_rate(preds, targets) == tensor(0.0)


def test_ter_empty_class():
ter_metric = TranslationEditRate()
hyp = []
ref = [[]]
assert ter_metric(hyp, ref) == tensor(0.0)
preds = []
targets = [[]]
assert ter_metric(preds, targets) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_functional():
hyp = ["python"]
ref = [[]]
assert translation_edit_rate(hyp, ref) == tensor(0.0)
preds = ["python"]
targets = [[]]
assert translation_edit_rate(preds, targets) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_class():
ter_metric = TranslationEditRate()
hyp = ["python"]
ref = [[]]
assert ter_metric(hyp, ref) == tensor(0.0)
preds = ["python"]
targets = [[]]
assert ter_metric(preds, targets) == tensor(0.0)


def test_ter_return_sentence_level_score_functional():
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = translation_edit_rate(hyp, ref, return_sentence_level_score=True)
preds = _inputs_single_sentence_multiple_references.preds
targets = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = translation_edit_rate(preds, targets, return_sentence_level_score=True)
isinstance(sentence_ter, Tensor)


def test_ter_return_sentence_level_class():
ter_metric = TranslationEditRate(return_sentence_level_score=True)
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter_metric(hyp, ref)
preds = _inputs_single_sentence_multiple_references.preds
targets = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter_metric(preds, targets)
isinstance(sentence_ter, Tensor)
4 changes: 2 additions & 2 deletions tests/text/test_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torchmetrics.text.wer import WordErrorRate


def _compute_wer_metric_jiwer(prediction: Union[str, List[str]], reference: Union[str, List[str]]):
return compute_measures(reference, prediction)["wer"]
def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["wer"]


@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer")
Expand Down
Loading

0 comments on commit 9094985

Please sign in to comment.