diff --git a/docs/source/base_evaluator.mdx b/docs/source/base_evaluator.mdx index 82b34340d..1a0ecd834 100644 --- a/docs/source/base_evaluator.mdx +++ b/docs/source/base_evaluator.mdx @@ -7,6 +7,7 @@ Currently supported tasks are: - `"token-classification"`: will use the [`TokenClassificationEvaluator`]. - `"question-answering"`: will use the [`QuestionAnsweringEvaluator`]. - `"image-classification"`: will use the [`ImageClassificationEvaluator`]. +- `"text-generation"`: will use the [`TextGenerationEvaluator`]. - `"text2text-generation"`: will use the [`Text2TextGenerationEvaluator`]. - `"summarization"`: will use the [`SummarizationEvaluator`]. - `"translation"`: will use the [`TranslationEvaluator`]. diff --git a/docs/source/package_reference/evaluator_classes.mdx b/docs/source/package_reference/evaluator_classes.mdx index 788b18970..76d6ce02e 100644 --- a/docs/source/package_reference/evaluator_classes.mdx +++ b/docs/source/package_reference/evaluator_classes.mdx @@ -32,6 +32,11 @@ The base class for all evaluator classes: [[autodoc]] evaluate.TokenClassificationEvaluator - compute +### TextGenerationEvaluator + +[[autodoc]] evaluate.TextGenerationEvaluator + - compute + ### Text2TextGenerationEvaluator [[autodoc]] evaluate.Text2TextGenerationEvaluator diff --git a/src/evaluate/__init__.py b/src/evaluate/__init__.py index 888fe676d..ca7ac67d5 100644 --- a/src/evaluate/__init__.py +++ b/src/evaluate/__init__.py @@ -33,6 +33,7 @@ SummarizationEvaluator, Text2TextGenerationEvaluator, TextClassificationEvaluator, + TextGenerationEvaluator, TokenClassificationEvaluator, TranslationEvaluator, evaluator, diff --git a/src/evaluate/evaluator/__init__.py b/src/evaluate/evaluator/__init__.py index 5aa5fa9c7..a70a79453 100644 --- a/src/evaluate/evaluator/__init__.py +++ b/src/evaluate/evaluator/__init__.py @@ -29,6 +29,7 @@ from .question_answering import QuestionAnsweringEvaluator from .text2text_generation import SummarizationEvaluator, Text2TextGenerationEvaluator, TranslationEvaluator from .text_classification import TextClassificationEvaluator +from .text_generation import TextGenerationEvaluator from .token_classification import TokenClassificationEvaluator @@ -49,6 +50,10 @@ "implementation": TokenClassificationEvaluator, "default_metric_name": "seqeval", }, + "text-generation": { + "implementation": TextGenerationEvaluator, + "default_metric_name": "word_count", + }, "text2text-generation": { "implementation": Text2TextGenerationEvaluator, "default_metric_name": "bleu", diff --git a/src/evaluate/evaluator/base.py b/src/evaluate/evaluator/base.py index 9601e6388..b5d43f57d 100644 --- a/src/evaluate/evaluator/base.py +++ b/src/evaluate/evaluator/base.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from numbers import Number -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union # Lint as: python3 from datasets import Dataset, load_dataset @@ -234,7 +234,7 @@ def compute( input_column: str = "text", label_column: str = "label", label_mapping: Optional[Dict[str, Number]] = None, - ) -> Tuple[Dict[str, float], Any]: + ) -> Dict[str, float]: result = {} @@ -347,7 +347,7 @@ def load_data(self, data: Union[str, Dataset], subset: str = None, split: str = "Please specify a valid `data` object - either a `str` with a name or a `Dataset` object." ) - def prepare_data(self, data: Dataset, input_column: str, label_column: str): + def prepare_data(self, data: Dataset, input_column: str, label_column: str, *args, **kwargs): """ Prepare data. diff --git a/src/evaluate/evaluator/text_generation.py b/src/evaluate/evaluator/text_generation.py new file mode 100644 index 000000000..870b5492d --- /dev/null +++ b/src/evaluate/evaluator/text_generation.py @@ -0,0 +1,68 @@ +# Copyright 2022 The HuggingFace Evaluate Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple + +from datasets import Dataset + +from .base import Evaluator +from .utils import DatasetColumn + + +TASK_DOCUMENTATION_KWARGS = r""" + input_column (`str`, defaults to `"text"`): + the name of the column containing the input text in the dataset specified by `data`. + generation_kwargs (`Dict`, *optional*, defaults to `None`): + The generation kwargs are passed to the pipeline and set the text generation strategy. +""" + + +class TextGenerationEvaluator(Evaluator): + """ + Text generation evaluator. + This Text generation evaluator can currently be loaded from [`evaluator`] using the default task name + `text-generation`. + Methods in this class assume a data format compatible with the [`TextGenerationPipeline`]. + """ + + def predictions_processor(self, predictions, *args, **kwargs): + """ + Args: + predictions: A list of lists of dicts + + Returns: + `dict`: All the generated texts are flattened and stored under the "data" key. + """ + return {"data": [pred[f"{self.predictions_prefix}_text"] for pred_list in predictions for pred in pred_list]} + + def __init__(self, task="text-generation", default_metric_name=None, predictions_prefix: str = "generated"): + super().__init__(task=task, default_metric_name=default_metric_name) + self.predictions_prefix = predictions_prefix + + def prepare_data(self, data: Dataset, input_column: str, *args, **kwargs) -> Tuple[Dict, DatasetColumn]: + """ + Prepare data. + + Args: + data (`Dataset`): Specifies the dataset we will run evaluation on. + input_column (`str`, defaults to `"text"`): + the name of the column containing the text feature in the dataset specified by `data`. + Returns: + `dict`: metric inputs. + `list`: pipeline inputs. + """ + + self.check_required_columns(data, {"input_column": input_column}) + + return {}, DatasetColumn(data, input_column) diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index b51610b1f..8cd56333f 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -36,12 +36,23 @@ QuestionAnsweringEvaluator, Text2TextGenerationEvaluator, TextClassificationEvaluator, + TextGenerationEvaluator, TokenClassificationEvaluator, evaluator, load, ) +class DummyTextGenerationPipeline: + def __init__(self, prefix="generated", task="text-generation", num_return_sequences=1): + self.task = task + self.prefix = prefix + self.num_return_sequences = num_return_sequences + + def __call__(self, inputs, **kwargs): + return [[{f"{self.prefix}_text": "Lorem ipsum"} for _ in range(self.num_return_sequences)] for _ in inputs] + + class DummyText2TextGenerationPipeline: def __init__(self, prefix="generated", task="text2text-generation"): self.task = task @@ -781,6 +792,53 @@ def test_predictions_processor(self): self.assertListEqual(predictions["predictions"][0], ["B-LOC", "O", "O", "O", "B-LOC", "O"]) +class TestTextGenerationEvaluator(TestCase): + def setUp(self): + self.data = Dataset.from_dict({"text": ["Lorem ipsum"]}) + self.pipe = DummyTextGenerationPipeline(num_return_sequences=4) + self.evaluator = evaluator("text-generation") + + def test_class_init(self): + evaluator = TextGenerationEvaluator() + self.assertEqual(evaluator.task, "text-generation") + self.assertIsNone(evaluator.default_metric_name) + + results = evaluator.compute( + model_or_pipeline=self.pipe, + data=self.data, + metric="word_count", + ) + self.assertIsInstance(results["unique_words"], int) + + def test_default_pipe_init(self): + results = self.evaluator.compute(data=self.data) + self.assertIsInstance(results["unique_words"], int) + + def test_overwrite_default_metric(self): + word_length = load("word_length") + results = self.evaluator.compute( + model_or_pipeline=self.pipe, + data=self.data, + metric=word_length, + ) + self.assertIsInstance(results["average_word_length"], int) + results = self.evaluator.compute( + model_or_pipeline=self.pipe, + data=self.data, + metric="word_length", + ) + self.assertIsInstance(results["average_word_length"], int) + + def test_process_predictions_multiple_return_sequences(self): + processed_predictions = self.evaluator.predictions_processor( + [ + [{"generated_text": "A"}, {"generated_text": "B"}], + [{"generated_text": "C"}, {"generated_text": "D"}], + ] + ) + self.assertEqual(processed_predictions, {"data": ["A", "B", "C", "D"]}) + + class TestText2TextGenerationEvaluator(TestCase): def setUp(self): self.data = Dataset.from_dict(