From 40d0a961157d634233198b3873aa48c9c63a0b5d Mon Sep 17 00:00:00 2001 From: Maria Luisa <32341580+luisaadanttas@users.noreply.github.com> Date: Wed, 17 Jul 2024 03:47:15 -0300 Subject: [PATCH] Chore/rename task fields (#994) * chore: Rename Task inputs and outputs fields Signed-off-by: luisaadanttas * docs: Rename Task inputs and outputs fields Signed-off-by: luisaadanttas * chore: update remaining input_fields and reference_fields in Tasks Signed-off-by: luisaadanttas * refactor: handle deprecated input/output fields and add prepare method for compatibility Signed-off-by: luisaadanttas * test: add tests for deprecated inputs/outputs and conflicting fields in Task Signed-off-by: luisaadanttas * test: update tests for task initialization with detailed field checks Signed-off-by: luisaadanttas * refactor: separate checks for input_fields and reference_fields Signed-off-by: luisaadanttas * fix:update field names in atta_q, attaq_500, and bold cards Signed-off-by: luisaadanttas --------- Signed-off-by: luisaadanttas Co-authored-by: Yoav Katz <68273864+yoavkatz@users.noreply.github.com> --- docs/docs/adding_dataset.rst | 8 +- docs/docs/adding_metric.rst | 4 +- docs/docs/adding_task.rst | 8 +- .../standalone_evaluation_llm_as_judge.py | 4 +- examples/standalone_qa_evaluation.py | 4 +- prepare/cards/atta_q.py | 4 +- prepare/cards/attaq_500.py | 4 +- prepare/cards/bold.py | 4 +- prepare/cards/human_eval.py | 4 +- prepare/cards/mbpp.py | 4 +- prepare/cards/mrpc.py | 4 +- prepare/cards/pop_qa.py | 4 +- prepare/cards/qqp.py | 4 +- prepare/cards/wsc.py | 4 +- prepare/operators/balancers/per_task.py | 10 +-- prepare/tasks/classification.py | 28 +++---- prepare/tasks/completion/multiple_choice.py | 20 +++-- prepare/tasks/evaluation.py | 4 +- prepare/tasks/generation.py | 4 +- prepare/tasks/grammatical_error_correction.py | 4 +- prepare/tasks/language_identification.py | 4 +- prepare/tasks/ner.py | 8 +- prepare/tasks/qa/multiple_choice/tasks.py | 16 ++-- prepare/tasks/qa/tasks.py | 12 +-- prepare/tasks/rag/response_generation.py | 4 +- prepare/tasks/regression/tasks.py | 12 +-- .../pairwise_comparison/multi_turn.py | 4 +- .../multi_turn_with_reference.py | 4 +- .../pairwise_comparison/single_turn.py | 4 +- .../single_turn_with_reference.py | 4 +- .../response_assessment/rating/multi_turn.py | 4 +- .../rating/multi_turn_with_reference.py | 4 +- .../response_assessment/rating/single_turn.py | 4 +- .../rating/single_turn_with_reference.py | 4 +- prepare/tasks/rewriting.py | 8 +- prepare/tasks/selection.py | 4 +- prepare/tasks/span_labeling.py | 4 +- prepare/tasks/summarization/abstractive.py | 4 +- .../tasks/targeted_sentiment_extraction.py | 8 +- prepare/tasks/translation/directed.py | 8 +- src/unitxt/catalog/cards/atta_q.json | 4 +- src/unitxt/catalog/cards/attaq_500.json | 4 +- src/unitxt/catalog/cards/bold.json | 4 +- src/unitxt/catalog/cards/human_eval.json | 4 +- src/unitxt/catalog/cards/mbpp.json | 4 +- src/unitxt/catalog/cards/mrpc.json | 4 +- src/unitxt/catalog/cards/pop_qa.json | 4 +- src/unitxt/catalog/cards/qqp.json | 4 +- src/unitxt/catalog/cards/wsc.json | 4 +- .../balancers/classification/by_label.json | 2 +- .../minimum_one_example_per_class.json | 2 +- .../multi_label/zero_vs_many_labels.json | 2 +- .../balancers/ner/zero_vs_many_entities.json | 2 +- .../operators/balancers/qa/by_answer.json | 2 +- .../catalog/tasks/classification/binary.json | 4 +- .../classification/binary/zero_or_one.json | 4 +- .../tasks/classification/multi_class.json | 4 +- .../classification/multi_class/relation.json | 4 +- .../multi_class/topic_classification.json | 4 +- .../with_classes_descriptions.json | 4 +- .../tasks/classification/multi_label.json | 4 +- .../catalog/tasks/completion/abstractive.json | 4 +- .../catalog/tasks/completion/extractive.json | 4 +- .../tasks/completion/multiple_choice.json | 4 +- .../catalog/tasks/evaluation/preference.json | 4 +- src/unitxt/catalog/tasks/generation.json | 4 +- .../tasks/grammatical_error_correction.json | 4 +- .../tasks/language_identification.json | 4 +- .../catalog/tasks/ner/all_entity_types.json | 4 +- .../catalog/tasks/ner/single_entity_type.json | 4 +- .../tasks/qa/multiple_choice/open.json | 4 +- .../qa/multiple_choice/with_context.json | 4 +- .../with_context/with_topic.json | 4 +- .../tasks/qa/multiple_choice/with_topic.json | 4 +- src/unitxt/catalog/tasks/qa/open.json | 4 +- .../tasks/qa/with_context/abstractive.json | 4 +- .../tasks/qa/with_context/extractive.json | 4 +- .../tasks/rag/response_generation.json | 4 +- .../catalog/tasks/regression/single_text.json | 4 +- .../catalog/tasks/regression/two_texts.json | 4 +- .../regression/two_texts/similarity.json | 4 +- .../pairwise_comparison/multi_turn.json | 4 +- .../multi_turn_with_reference.json | 4 +- .../pairwise_comparison/single_turn.json | 4 +- .../single_turn_with_reference.json | 4 +- .../rating/multi_turn.json | 4 +- .../rating/multi_turn_with_reference.json | 4 +- .../rating/single_turn.json | 4 +- .../rating/single_turn_with_reference.json | 4 +- .../catalog/tasks/rewriting/by_attribute.json | 4 +- .../catalog/tasks/rewriting/paraphrase.json | 4 +- .../catalog/tasks/selection/by_attribute.json | 4 +- .../tasks/span_labeling/extraction.json | 4 +- .../tasks/summarization/abstractive.json | 4 +- .../all_sentiment_classes.json | 4 +- .../single_sentiment_class.json | 4 +- .../catalog/tasks/translation/directed.json | 4 +- src/unitxt/operators.py | 2 +- src/unitxt/schema.py | 4 +- src/unitxt/splitters.py | 22 ++--- src/unitxt/task.py | 84 ++++++++++++++----- src/unitxt/templates.py | 29 ++++--- tests/library/test_api.py | 4 +- tests/library/test_card.py | 4 +- tests/library/test_operators.py | 34 ++++---- tests/library/test_splitters.py | 18 ++-- tests/library/test_tasks.py | 81 ++++++++++++++++-- tests/library/test_templates.py | 54 ++++++------ 108 files changed, 478 insertions(+), 334 deletions(-) diff --git a/docs/docs/adding_dataset.rst b/docs/docs/adding_dataset.rst index 4d5924efa..d82b255d1 100644 --- a/docs/docs/adding_dataset.rst +++ b/docs/docs/adding_dataset.rst @@ -29,8 +29,8 @@ an Engish to French translation task or for a French to English translation task The Task schema is a formal definition of the NLP task , including its inputs, outputs, and default evaluation metrics. -The `inputs` of the task are a set of fields that are used to format the textual input to the model. -The `output` of the task are a set of fields that are used to format the textual expected output from the model (gold references). +The `input_fields` of the task are a set of fields that are used to format the textual input to the model. +The `reference_fields` of the task are a set of fields that are used to format the textual expected output from the model (gold references). The `metrics` of the task are a set of default metrics to be used to evaluate the outputs of the model. While language models generate textual predictions, the metrics often evaluate on a different datatypes. For example, @@ -46,8 +46,8 @@ We will use the `bleu` metric for a reference based evaluation. .. code-block:: python task=Task( - inputs= { "text" : "str", "source_language" : "str", "target_language" : "str"}, - outputs= {"translation" : "str"}, + input_fields= { "text" : "str", "source_language" : "str", "target_language" : "str"}, + reference_fields= {"translation" : "str"}, prediction_type="str", metrics=["metrics.bleu"], ), diff --git a/docs/docs/adding_metric.rst b/docs/docs/adding_metric.rst index 9a749fcda..5ee74514c 100644 --- a/docs/docs/adding_metric.rst +++ b/docs/docs/adding_metric.rst @@ -19,8 +19,8 @@ For example: .. code-block:: python task = Task( - inputs={ "question" : "str" }, - outputs={ "answer" : str }, + input_fields={ "question" : "str" }, + reference_fields={ "answer" : str }, prediction_type="str", metrics=[ "metrics.rouge", diff --git a/docs/docs/adding_task.rst b/docs/docs/adding_task.rst index 9de631b6d..b09a52c83 100644 --- a/docs/docs/adding_task.rst +++ b/docs/docs/adding_task.rst @@ -13,8 +13,8 @@ Tasks are fundamental to Unitxt, acting as standardized interface for integratin The Task schema is a formal definition of the NLP task, including its inputs, outputs, and default evaluation metrics. -The `inputs` of the task are a set of fields that are used to format the textual input to the model. -The `output` of the task are a set of fields that are used to format the expected textual output from the model (gold references). +The `input_fields` of the task are a set of fields that are used to format the textual input to the model. +The `reference_fields` of the task are a set of fields that are used to format the expected textual output from the model (gold references). The `metrics` of the task are a set of default metrics to be used to evaluate the outputs of the model. As an example, consider an evaluation task for LLMs to evaluate how well they are able to calculate the sum of two integer numbers. @@ -25,8 +25,8 @@ The task is formally defined as: from unitxt.blocks import Task task = Task( - inputs={"num1" : "int", "num2" : "int"}, - outputs={"sum" : "int"}, + input_fields={"num1" : "int", "num2" : "int"}, + reference_fields={"sum" : "int"}, prediction_type="int", metrics=[ "metrics.sum_accuracy", diff --git a/examples/standalone_evaluation_llm_as_judge.py b/examples/standalone_evaluation_llm_as_judge.py index a7d15ffe9..20ae7ad31 100644 --- a/examples/standalone_evaluation_llm_as_judge.py +++ b/examples/standalone_evaluation_llm_as_judge.py @@ -56,8 +56,8 @@ card = TaskCard( loader=LoadFromDictionary(data=data), task=Task( - inputs={"question": "str"}, - outputs={"answer": "str"}, + input_fields={"question": "str"}, + reference_fields={"answer": "str"}, prediction_type="str", metrics=[llm_judge_metric], ), diff --git a/examples/standalone_qa_evaluation.py b/examples/standalone_qa_evaluation.py index 8470f347c..44e2c50d4 100644 --- a/examples/standalone_qa_evaluation.py +++ b/examples/standalone_qa_evaluation.py @@ -24,8 +24,8 @@ loader=LoadFromDictionary(data=data), # Define the QA task input and output and metrics. task=Task( - inputs={"question": "str"}, - outputs={"answer": "str"}, + input_fields={"question": "str"}, + reference_fields={"answer": "str"}, prediction_type="str", metrics=["metrics.accuracy"], ), diff --git a/prepare/cards/atta_q.py b/prepare/cards/atta_q.py index 2c5d3d36a..fcbb1b70f 100644 --- a/prepare/cards/atta_q.py +++ b/prepare/cards/atta_q.py @@ -23,7 +23,9 @@ DumpJson(field="input_label"), ], task=Task( - inputs=["input"], outputs=["input_label"], metrics=["metrics.safety_metric"] + input_fields=["input"], + reference_fields=["input_label"], + metrics=["metrics.safety_metric"], ), templates=TemplatesList( [ diff --git a/prepare/cards/attaq_500.py b/prepare/cards/attaq_500.py index 8bae75b61..46c86b585 100644 --- a/prepare/cards/attaq_500.py +++ b/prepare/cards/attaq_500.py @@ -527,7 +527,9 @@ DumpJson(field="input_label"), ], task=Task( - inputs=["input"], outputs=["input_label"], metrics=["metrics.safety_metric"] + input_fields=["input"], + reference_fields=["input_label"], + metrics=["metrics.safety_metric"], ), templates=TemplatesList( [ diff --git a/prepare/cards/bold.py b/prepare/cards/bold.py index b29fa0334..15a8048a6 100644 --- a/prepare/cards/bold.py +++ b/prepare/cards/bold.py @@ -35,8 +35,8 @@ DumpJson(field="input_label"), ], task=Task( - inputs=["first_prompt"], - outputs=["input_label"], + input_fields=["first_prompt"], + reference_fields=["input_label"], metrics=["metrics.regard_metric"], ), templates=TemplatesList( diff --git a/prepare/cards/human_eval.py b/prepare/cards/human_eval.py index 8628feac7..2681a1da3 100644 --- a/prepare/cards/human_eval.py +++ b/prepare/cards/human_eval.py @@ -26,8 +26,8 @@ ) ], task=Task( - inputs=["prompt"], - outputs=["prompt", "canonical_solution", "test_list"], + input_fields=["prompt"], + reference_fields=["prompt", "canonical_solution", "test_list"], metrics=["metrics.bleu"], ), templates=TemplatesList( diff --git a/prepare/cards/mbpp.py b/prepare/cards/mbpp.py index 0b633b57b..0c04ba5af 100644 --- a/prepare/cards/mbpp.py +++ b/prepare/cards/mbpp.py @@ -17,8 +17,8 @@ JoinStr(field_to_field={"test_list": "test_list_str"}, separator=os.linesep), ], task=Task( - inputs=["text", "test_list_str"], - outputs=["test_list", "code"], + input_fields=["text", "test_list_str"], + reference_fields=["test_list", "code"], metrics=["metrics.bleu"], ), templates=TemplatesList( diff --git a/prepare/cards/mrpc.py b/prepare/cards/mrpc.py index 40e7de1cf..d0f942d80 100644 --- a/prepare/cards/mrpc.py +++ b/prepare/cards/mrpc.py @@ -31,8 +31,8 @@ ), ], task=Task( - inputs=["choices", "sentence1", "sentence2"], - outputs=["label"], + input_fields=["choices", "sentence1", "sentence2"], + reference_fields=["label"], metrics=["metrics.accuracy"], ), templates=TemplatesList( diff --git a/prepare/cards/pop_qa.py b/prepare/cards/pop_qa.py index 2ad3c9347..38a8c22df 100644 --- a/prepare/cards/pop_qa.py +++ b/prepare/cards/pop_qa.py @@ -17,8 +17,8 @@ LoadJson(field="possible_answers"), ], task=Task( - inputs=["question", "prop", "subj"], - outputs=["possible_answers"], + input_fields=["question", "prop", "subj"], + reference_fields=["possible_answers"], metrics=["metrics.accuracy"], ), templates=TemplatesList( diff --git a/prepare/cards/qqp.py b/prepare/cards/qqp.py index 1c16ebd3a..841c0d10a 100644 --- a/prepare/cards/qqp.py +++ b/prepare/cards/qqp.py @@ -24,8 +24,8 @@ ), ], task=Task( - inputs=["choices", "question1", "question2"], - outputs=["label"], + input_fields=["choices", "question1", "question2"], + reference_fields=["label"], metrics=["metrics.accuracy"], ), templates=TemplatesList( diff --git a/prepare/cards/wsc.py b/prepare/cards/wsc.py index b95c36ca8..82e28c6b6 100644 --- a/prepare/cards/wsc.py +++ b/prepare/cards/wsc.py @@ -22,8 +22,8 @@ ), ], task=Task( - inputs=["choices", "text", "span1_text", "span2_text"], - outputs=["label"], + input_fields=["choices", "text", "span1_text", "span2_text"], + reference_fields=["label"], metrics=["metrics.accuracy"], ), templates=TemplatesList( diff --git a/prepare/operators/balancers/per_task.py b/prepare/operators/balancers/per_task.py index bd1510845..bf9999433 100644 --- a/prepare/operators/balancers/per_task.py +++ b/prepare/operators/balancers/per_task.py @@ -5,27 +5,27 @@ MinimumOneExamplePerLabelRefiner, ) -balancer = DeterministicBalancer(fields=["outputs/label"]) +balancer = DeterministicBalancer(fields=["reference_fields/label"]) add_to_catalog(balancer, "operators.balancers.classification.by_label", overwrite=True) -balancer = DeterministicBalancer(fields=["outputs/answer"]) +balancer = DeterministicBalancer(fields=["reference_fields/answer"]) add_to_catalog(balancer, "operators.balancers.qa.by_answer", overwrite=True) -balancer = LengthBalancer(fields=["outputs/labels"], segments_boundaries=[1]) +balancer = LengthBalancer(fields=["reference_fields/labels"], segments_boundaries=[1]) add_to_catalog( balancer, "operators.balancers.multi_label.zero_vs_many_labels", overwrite=True ) -balancer = LengthBalancer(fields=["outputs/labels"], segments_boundaries=[1]) +balancer = LengthBalancer(fields=["reference_fields/labels"], segments_boundaries=[1]) add_to_catalog( balancer, "operators.balancers.ner.zero_vs_many_entities", overwrite=True ) -balancer = MinimumOneExamplePerLabelRefiner(fields=["outputs/label"]) +balancer = MinimumOneExamplePerLabelRefiner(fields=["reference_fields/label"]) add_to_catalog( balancer, diff --git a/prepare/tasks/classification.py b/prepare/tasks/classification.py index cb1af9e7c..3bb243507 100644 --- a/prepare/tasks/classification.py +++ b/prepare/tasks/classification.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"text": "str", "text_type": "str", "class": "str"}, - outputs={"class": "str", "label": "List[str]"}, + input_fields={"text": "str", "text_type": "str", "class": "str"}, + reference_fields={"class": "str", "label": "List[str]"}, prediction_type="List[str]", metrics=[ "metrics.f1_micro_multi_label", @@ -20,8 +20,8 @@ add_to_catalog( Task( - inputs={"text": "str", "text_type": "str", "class": "str"}, - outputs={"class": "str", "label": "int"}, + input_fields={"text": "str", "text_type": "str", "class": "str"}, + reference_fields={"class": "str", "label": "int"}, prediction_type="float", metrics=[ "metrics.accuracy", @@ -36,13 +36,13 @@ add_to_catalog( Task( - inputs={ + input_fields={ "text": "str", "text_type": "str", "classes": "List[str]", "type_of_classes": "str", }, - outputs={"labels": "List[str]"}, + reference_fields={"labels": "List[str]"}, prediction_type="List[str]", metrics=[ "metrics.f1_micro_multi_label", @@ -58,13 +58,13 @@ add_to_catalog( Task( - inputs={ + input_fields={ "text": "str", "text_type": "str", "classes": "List[str]", "type_of_class": "str", }, - outputs={"label": "str"}, + reference_fields={"label": "str"}, prediction_type="str", metrics=["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"], augmentable_inputs=["text"], @@ -76,7 +76,7 @@ add_to_catalog( Task( - inputs={ + input_fields={ "text_a": "str", "text_a_type": "str", "text_b": "str", @@ -84,7 +84,7 @@ "classes": "List[str]", "type_of_relation": "str", }, - outputs={"label": "str"}, + reference_fields={"label": "str"}, prediction_type="str", metrics=["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"], augmentable_inputs=["text_a", "text_b"], @@ -97,14 +97,14 @@ add_to_catalog( Task( - inputs={ + input_fields={ "text": "str", "text_type": "str", "classes": "List[str]", "type_of_class": "str", "classes_descriptions": "str", }, - outputs={"label": "str"}, + reference_fields={"label": "str"}, prediction_type="str", metrics=["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"], augmentable_inputs=["text"], @@ -116,13 +116,13 @@ add_to_catalog( Task( - inputs={ + input_fields={ "text": "str", "text_type": "str", "classes": "List[str]", "type_of_class": "str", }, - outputs={"label": "str"}, + reference_fields={"label": "str"}, prediction_type="str", metrics=["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"], augmentable_inputs=["text"], diff --git a/prepare/tasks/completion/multiple_choice.py b/prepare/tasks/completion/multiple_choice.py index 103ec8278..a057e1e3e 100644 --- a/prepare/tasks/completion/multiple_choice.py +++ b/prepare/tasks/completion/multiple_choice.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"context": "str", "context_type": "str", "choices": "List[str]"}, - outputs={"answer": "int", "choices": "List[str]"}, + input_fields={"context": "str", "context_type": "str", "choices": "List[str]"}, + reference_fields={"answer": "int", "choices": "List[str]"}, prediction_type="Any", metrics=["metrics.accuracy"], ), @@ -14,8 +14,12 @@ add_to_catalog( Task( - inputs={"context": "str", "context_type": "str", "completion_type": "str"}, - outputs={"completion": "str"}, + input_fields={ + "context": "str", + "context_type": "str", + "completion_type": "str", + }, + reference_fields={"completion": "str"}, prediction_type="str", metrics=["metrics.rouge"], ), @@ -25,8 +29,12 @@ add_to_catalog( Task( - inputs={"context": "str", "context_type": "str", "completion_type": "str"}, - outputs={"completion": "str"}, + input_fields={ + "context": "str", + "context_type": "str", + "completion_type": "str", + }, + reference_fields={"completion": "str"}, prediction_type="Dict[str,Any]", metrics=["metrics.squad"], ), diff --git a/prepare/tasks/evaluation.py b/prepare/tasks/evaluation.py index 44db7acdf..b942da41b 100644 --- a/prepare/tasks/evaluation.py +++ b/prepare/tasks/evaluation.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs=["input", "input_type", "output_type", "choices", "instruction"], - outputs=["choices", "output_choice"], + input_fields=["input", "input_type", "output_type", "choices", "instruction"], + reference_fields=["choices", "output_choice"], metrics=[ "metrics.accuracy", ], diff --git a/prepare/tasks/generation.py b/prepare/tasks/generation.py index 9f48b0819..82519ec68 100644 --- a/prepare/tasks/generation.py +++ b/prepare/tasks/generation.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"input": "str", "type_of_input": "str", "type_of_output": "str"}, - outputs={"output": "str"}, + input_fields={"input": "str", "type_of_input": "str", "type_of_output": "str"}, + reference_fields={"output": "str"}, prediction_type="str", metrics=["metrics.normalized_sacrebleu"], augmentable_inputs=["input"], diff --git a/prepare/tasks/grammatical_error_correction.py b/prepare/tasks/grammatical_error_correction.py index c13f868a4..48b1a8022 100644 --- a/prepare/tasks/grammatical_error_correction.py +++ b/prepare/tasks/grammatical_error_correction.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs=["original_text"], - outputs=["corrected_texts"], + input_fields=["original_text"], + reference_fields=["corrected_texts"], metrics=[ "metrics.char_edit_dist_accuracy", "metrics.rouge", diff --git a/prepare/tasks/language_identification.py b/prepare/tasks/language_identification.py index 892708a3d..0fca85998 100644 --- a/prepare/tasks/language_identification.py +++ b/prepare/tasks/language_identification.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"text": "str"}, - outputs={"label": "str"}, + input_fields={"text": "str"}, + reference_fields={"label": "str"}, prediction_type="str", metrics=["metrics.accuracy"], ), diff --git a/prepare/tasks/ner.py b/prepare/tasks/ner.py index 79c1ec3a3..36ce265b5 100644 --- a/prepare/tasks/ner.py +++ b/prepare/tasks/ner.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"text": "str", "entity_type": "str"}, - outputs={ + input_fields={"text": "str", "entity_type": "str"}, + reference_fields={ "spans_starts": "List[int]", "spans_ends": "List[int]", "text": "str", @@ -20,8 +20,8 @@ add_to_catalog( Task( - inputs={"text": "str", "entity_types": "List[str]"}, - outputs={ + input_fields={"text": "str", "entity_types": "List[str]"}, + reference_fields={ "spans_starts": "List[int]", "spans_ends": "List[int]", "text": "str", diff --git a/prepare/tasks/qa/multiple_choice/tasks.py b/prepare/tasks/qa/multiple_choice/tasks.py index a55a7eaae..c269199ca 100644 --- a/prepare/tasks/qa/multiple_choice/tasks.py +++ b/prepare/tasks/qa/multiple_choice/tasks.py @@ -3,13 +3,13 @@ add_to_catalog( Task( - inputs={ + input_fields={ "context": "str", "context_type": "str", "question": "str", "choices": "List[str]", }, - outputs={"answer": "Union[int,str]", "choices": "List[str]"}, + reference_fields={"answer": "Union[int,str]", "choices": "List[str]"}, prediction_type="str", metrics=["metrics.accuracy"], ), @@ -20,8 +20,8 @@ add_to_catalog( Task( - inputs={"topic": "str", "question": "str", "choices": "List[str]"}, - outputs={"answer": "Union[int,str]", "choices": "List[str]"}, + input_fields={"topic": "str", "question": "str", "choices": "List[str]"}, + reference_fields={"answer": "Union[int,str]", "choices": "List[str]"}, prediction_type="str", metrics=["metrics.accuracy"], ), @@ -31,8 +31,8 @@ add_to_catalog( Task( - inputs={"question": "str", "choices": "List[str]"}, - outputs={"answer": "Union[int,str]", "choices": "List[str]"}, + input_fields={"question": "str", "choices": "List[str]"}, + reference_fields={"answer": "Union[int,str]", "choices": "List[str]"}, prediction_type="str", metrics=["metrics.accuracy"], ), @@ -42,14 +42,14 @@ add_to_catalog( Task( - inputs={ + input_fields={ "topic": "str", "context": "str", "context_type": "str", "question": "str", "choices": "List[str]", }, - outputs={"answer": "Union[int,str]", "choices": "List[str]"}, + reference_fields={"answer": "Union[int,str]", "choices": "List[str]"}, prediction_type="str", metrics=["metrics.accuracy"], ), diff --git a/prepare/tasks/qa/tasks.py b/prepare/tasks/qa/tasks.py index 69d43a900..e3137ee87 100644 --- a/prepare/tasks/qa/tasks.py +++ b/prepare/tasks/qa/tasks.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"context": "str", "context_type": "str", "question": "str"}, - outputs={"answers": "List[str]"}, + input_fields={"context": "str", "context_type": "str", "question": "str"}, + reference_fields={"answers": "List[str]"}, prediction_type="str", metrics=["metrics.squad"], ), @@ -14,8 +14,8 @@ add_to_catalog( Task( - inputs={"context": "str", "context_type": "str", "question": "str"}, - outputs={"answers": "List[str]"}, + input_fields={"context": "str", "context_type": "str", "question": "str"}, + reference_fields={"answers": "List[str]"}, prediction_type="str", metrics=["metrics.rouge"], augmentable_inputs=["context", "question"], @@ -26,8 +26,8 @@ add_to_catalog( Task( - inputs={"question": "str"}, - outputs={"answers": "List[str]"}, + input_fields={"question": "str"}, + reference_fields={"answers": "List[str]"}, prediction_type="str", metrics=["metrics.rouge"], ), diff --git a/prepare/tasks/rag/response_generation.py b/prepare/tasks/rag/response_generation.py index 0a59afdd0..43d43b158 100644 --- a/prepare/tasks/rag/response_generation.py +++ b/prepare/tasks/rag/response_generation.py @@ -5,12 +5,12 @@ add_to_catalog( Task( - inputs={ + input_fields={ "contexts": "List[str]", "contexts_ids": "List[int]", "question": "str", }, - outputs={"reference_answers": "List[str]"}, + reference_fields={"reference_answers": "List[str]"}, metrics=[ "metrics.rag.response_generation.correctness.token_overlap", "metrics.rag.response_generation.faithfullness.token_overlap", diff --git a/prepare/tasks/regression/tasks.py b/prepare/tasks/regression/tasks.py index a73fd3488..4aa23d762 100644 --- a/prepare/tasks/regression/tasks.py +++ b/prepare/tasks/regression/tasks.py @@ -3,13 +3,13 @@ add_to_catalog( Task( - inputs={ + input_fields={ "text": "str", "attribute_name": "str", "min_value": "Optional[float]", "max_value": "Optional[float]", }, - outputs={"attribute_value": "float"}, + reference_fields={"attribute_value": "float"}, prediction_type="Any", metrics=["metrics.spearman"], augmentable_inputs=["text"], @@ -20,14 +20,14 @@ add_to_catalog( Task( - inputs={ + input_fields={ "text1": "str", "text2": "str", "attribute_name": "str", "min_value": "Optional[float]", "max_value": "Optional[float]", }, - outputs={"attribute_value": "float"}, + reference_fields={"attribute_value": "float"}, prediction_type="Any", metrics=["metrics.spearman"], augmentable_inputs=["text1", "text2"], @@ -38,14 +38,14 @@ add_to_catalog( Task( - inputs={ + input_fields={ "text1": "str", "text2": "str", "attribute_name": "str", "min_value": "Optional[float]", "max_value": "Optional[float]", }, - outputs={"attribute_value": "float"}, + reference_fields={"attribute_value": "float"}, prediction_type="Any", metrics=["metrics.spearman"], augmentable_inputs=["text1", "text2"], diff --git a/prepare/tasks/response_assessment/pairwise_comparison/multi_turn.py b/prepare/tasks/response_assessment/pairwise_comparison/multi_turn.py index b800ae87f..02da1eac9 100644 --- a/prepare/tasks/response_assessment/pairwise_comparison/multi_turn.py +++ b/prepare/tasks/response_assessment/pairwise_comparison/multi_turn.py @@ -3,11 +3,11 @@ add_to_catalog( Task( - inputs={ + input_fields={ "dialog_a": "List[Tuple[str, str]]", "dialog_b": "List[Tuple[str, str]]", }, - outputs={ + reference_fields={ "winner": "str" }, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"}, metrics=["metrics.accuracy"], diff --git a/prepare/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.py b/prepare/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.py index c513406f8..b46418bb3 100644 --- a/prepare/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.py +++ b/prepare/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.py @@ -3,12 +3,12 @@ add_to_catalog( Task( - inputs={ + input_fields={ "dialog_a": "List[Tuple[str, str]]", "dialog_b": "List[Tuple[str, str]]", "reference_dialog": "List[Tuple[str, str]]", }, - outputs={ + reference_fields={ "winner": "str" }, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"}, metrics=["metrics.accuracy"], diff --git a/prepare/tasks/response_assessment/pairwise_comparison/single_turn.py b/prepare/tasks/response_assessment/pairwise_comparison/single_turn.py index 4ad66b8ad..30e440de7 100644 --- a/prepare/tasks/response_assessment/pairwise_comparison/single_turn.py +++ b/prepare/tasks/response_assessment/pairwise_comparison/single_turn.py @@ -3,12 +3,12 @@ add_to_catalog( Task( - inputs={ + input_fields={ "question": "str", "answer_a": "str", "answer_b": "str", }, - outputs={ + reference_fields={ "winner": "str" }, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']" metrics=["metrics.accuracy"], diff --git a/prepare/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.py b/prepare/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.py index e187c0b47..2e0948df8 100644 --- a/prepare/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.py +++ b/prepare/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.py @@ -3,13 +3,13 @@ add_to_catalog( Task( - inputs={ + input_fields={ "question": "str", "answer_a": "str", "answer_b": "str", "reference_answer": "str", }, - outputs={ + reference_fields={ "winner": "str" }, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"}, metrics=["metrics.accuracy"], diff --git a/prepare/tasks/response_assessment/rating/multi_turn.py b/prepare/tasks/response_assessment/rating/multi_turn.py index 0b902d6f3..4c98a89b9 100644 --- a/prepare/tasks/response_assessment/rating/multi_turn.py +++ b/prepare/tasks/response_assessment/rating/multi_turn.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"dialog": "List[Tuple[str, str]]"}, - outputs={"rating": "float"}, + input_fields={"dialog": "List[Tuple[str, str]]"}, + reference_fields={"rating": "float"}, metrics=["metrics.spearman"], ), "tasks.response_assessment.rating.multi_turn", diff --git a/prepare/tasks/response_assessment/rating/multi_turn_with_reference.py b/prepare/tasks/response_assessment/rating/multi_turn_with_reference.py index 5af1651b4..08c2ef2d5 100644 --- a/prepare/tasks/response_assessment/rating/multi_turn_with_reference.py +++ b/prepare/tasks/response_assessment/rating/multi_turn_with_reference.py @@ -3,11 +3,11 @@ add_to_catalog( Task( - inputs={ + input_fields={ "dialog": "List[Tuple[str, str]]", "reference_dialog": "List[Tuple[str, str]]", }, - outputs={"rating": "float"}, + reference_fields={"rating": "float"}, metrics=["metrics.spearman"], ), "tasks.response_assessment.rating.multi_turn_with_reference", diff --git a/prepare/tasks/response_assessment/rating/single_turn.py b/prepare/tasks/response_assessment/rating/single_turn.py index f0cbbfc2e..405262aa6 100644 --- a/prepare/tasks/response_assessment/rating/single_turn.py +++ b/prepare/tasks/response_assessment/rating/single_turn.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"question": "str", "answer": "str"}, - outputs={"rating": "float"}, + input_fields={"question": "str", "answer": "str"}, + reference_fields={"rating": "float"}, metrics=["metrics.spearman"], ), "tasks.response_assessment.rating.single_turn", diff --git a/prepare/tasks/response_assessment/rating/single_turn_with_reference.py b/prepare/tasks/response_assessment/rating/single_turn_with_reference.py index 6282b4bfb..c93a4114d 100644 --- a/prepare/tasks/response_assessment/rating/single_turn_with_reference.py +++ b/prepare/tasks/response_assessment/rating/single_turn_with_reference.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"question": "str", "answer": "str", "reference_answer": "str"}, - outputs={"rating": "float"}, + input_fields={"question": "str", "answer": "str", "reference_answer": "str"}, + reference_fields={"rating": "float"}, metrics=["metrics.spearman"], ), "tasks.response_assessment.rating.single_turn_with_reference", diff --git a/prepare/tasks/rewriting.py b/prepare/tasks/rewriting.py index e1779fd34..3fa3bdff3 100644 --- a/prepare/tasks/rewriting.py +++ b/prepare/tasks/rewriting.py @@ -3,13 +3,13 @@ add_to_catalog( Task( - inputs=[ + input_fields=[ "input_text", "input_text_type", "required_attribute", "output_text_type", ], - outputs=["output_text"], + reference_fields=["output_text"], metrics=[ "metrics.rouge", ], @@ -21,8 +21,8 @@ add_to_catalog( Task( - inputs=["input_text", "text_type"], - outputs=["output_text"], + input_fields=["input_text", "text_type"], + reference_fields=["output_text"], metrics=[ "metrics.rouge", ], diff --git a/prepare/tasks/selection.py b/prepare/tasks/selection.py index 848faa7f6..8d2e471a3 100644 --- a/prepare/tasks/selection.py +++ b/prepare/tasks/selection.py @@ -3,13 +3,13 @@ add_to_catalog( Task( - inputs=[ + input_fields=[ "required_attribute", "attribute_type", "choices_texts", "choices_text_type", ], - outputs=["choices_texts", "choice"], + reference_fields=["choices_texts", "choice"], metrics=[ "metrics.accuracy", ], diff --git a/prepare/tasks/span_labeling.py b/prepare/tasks/span_labeling.py index 93248bf09..9acaa1d35 100644 --- a/prepare/tasks/span_labeling.py +++ b/prepare/tasks/span_labeling.py @@ -3,13 +3,13 @@ add_to_catalog( Task( - inputs={ + input_fields={ "text": "str", "text_type": "str", "class_type": "str", "classes": "List[str]", }, - outputs={ + reference_fields={ "text": "str", "spans_starts": "List[int]", "spans_ends": "List[int]", diff --git a/prepare/tasks/summarization/abstractive.py b/prepare/tasks/summarization/abstractive.py index e7b722d19..b9581a2a1 100644 --- a/prepare/tasks/summarization/abstractive.py +++ b/prepare/tasks/summarization/abstractive.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"document": "str", "document_type": "str"}, - outputs={"summary": "str"}, + input_fields={"document": "str", "document_type": "str"}, + reference_fields={"summary": "str"}, prediction_type="str", metrics=["metrics.rouge"], defaults={"document_type": "document"}, diff --git a/prepare/tasks/targeted_sentiment_extraction.py b/prepare/tasks/targeted_sentiment_extraction.py index e77b66792..785f8a2c8 100644 --- a/prepare/tasks/targeted_sentiment_extraction.py +++ b/prepare/tasks/targeted_sentiment_extraction.py @@ -3,8 +3,8 @@ add_to_catalog( Task( - inputs={"text": "str", "text_type": "str", "sentiment_class": "str"}, - outputs={ + input_fields={"text": "str", "text_type": "str", "sentiment_class": "str"}, + reference_fields={ "spans_starts": "List[int]", "spans_ends": "List[int]", "text": "List[str]", @@ -21,8 +21,8 @@ add_to_catalog( Task( - inputs={"text": "str", "text_type": "str"}, - outputs={ + input_fields={"text": "str", "text_type": "str"}, + reference_fields={ "spans_starts": "List[int]", "spans_ends": "List[int]", "text": "List[str]", diff --git a/prepare/tasks/translation/directed.py b/prepare/tasks/translation/directed.py index 411a1ded3..f9620cd17 100644 --- a/prepare/tasks/translation/directed.py +++ b/prepare/tasks/translation/directed.py @@ -3,8 +3,12 @@ add_to_catalog( Task( - inputs={"text": "str", "source_language": "str", "target_language": "str"}, - outputs={"translation": "str"}, + input_fields={ + "text": "str", + "source_language": "str", + "target_language": "str", + }, + reference_fields={"translation": "str"}, prediction_type="str", metrics=["metrics.normalized_sacrebleu"], ), diff --git a/src/unitxt/catalog/cards/atta_q.json b/src/unitxt/catalog/cards/atta_q.json index ebf655624..2776b9f9d 100644 --- a/src/unitxt/catalog/cards/atta_q.json +++ b/src/unitxt/catalog/cards/atta_q.json @@ -35,10 +35,10 @@ ], "task": { "__type__": "task", - "inputs": [ + "input_fields": [ "input" ], - "outputs": [ + "reference_fields": [ "input_label" ], "metrics": [ diff --git a/src/unitxt/catalog/cards/attaq_500.json b/src/unitxt/catalog/cards/attaq_500.json index 54cae8293..d001673a5 100644 --- a/src/unitxt/catalog/cards/attaq_500.json +++ b/src/unitxt/catalog/cards/attaq_500.json @@ -543,10 +543,10 @@ ], "task": { "__type__": "task", - "inputs": [ + "input_fields": [ "input" ], - "outputs": [ + "reference_fields": [ "input_label" ], "metrics": [ diff --git a/src/unitxt/catalog/cards/bold.json b/src/unitxt/catalog/cards/bold.json index a88be5161..4257c7113 100644 --- a/src/unitxt/catalog/cards/bold.json +++ b/src/unitxt/catalog/cards/bold.json @@ -56,10 +56,10 @@ ], "task": { "__type__": "task", - "inputs": [ + "input_fields": [ "first_prompt" ], - "outputs": [ + "reference_fields": [ "input_label" ], "metrics": [ diff --git a/src/unitxt/catalog/cards/human_eval.json b/src/unitxt/catalog/cards/human_eval.json index 9d9a433b0..2f108de03 100644 --- a/src/unitxt/catalog/cards/human_eval.json +++ b/src/unitxt/catalog/cards/human_eval.json @@ -17,10 +17,10 @@ ], "task": { "__type__": "task", - "inputs": [ + "input_fields": [ "prompt" ], - "outputs": [ + "reference_fields": [ "prompt", "canonical_solution", "test_list" diff --git a/src/unitxt/catalog/cards/mbpp.json b/src/unitxt/catalog/cards/mbpp.json index e56d3f7e7..b5b58bc8e 100644 --- a/src/unitxt/catalog/cards/mbpp.json +++ b/src/unitxt/catalog/cards/mbpp.json @@ -17,11 +17,11 @@ ], "task": { "__type__": "task", - "inputs": [ + "input_fields": [ "text", "test_list_str" ], - "outputs": [ + "reference_fields": [ "test_list", "code" ], diff --git a/src/unitxt/catalog/cards/mrpc.json b/src/unitxt/catalog/cards/mrpc.json index cb389f7a4..acf383043 100644 --- a/src/unitxt/catalog/cards/mrpc.json +++ b/src/unitxt/catalog/cards/mrpc.json @@ -29,12 +29,12 @@ ], "task": { "__type__": "task", - "inputs": [ + "input_fields": [ "choices", "sentence1", "sentence2" ], - "outputs": [ + "reference_fields": [ "label" ], "metrics": [ diff --git a/src/unitxt/catalog/cards/pop_qa.json b/src/unitxt/catalog/cards/pop_qa.json index 43c3b5a92..d1d77e8af 100644 --- a/src/unitxt/catalog/cards/pop_qa.json +++ b/src/unitxt/catalog/cards/pop_qa.json @@ -16,12 +16,12 @@ ], "task": { "__type__": "task", - "inputs": [ + "input_fields": [ "question", "prop", "subj" ], - "outputs": [ + "reference_fields": [ "possible_answers" ], "metrics": [ diff --git a/src/unitxt/catalog/cards/qqp.json b/src/unitxt/catalog/cards/qqp.json index 46d6355fb..a044ea7c8 100644 --- a/src/unitxt/catalog/cards/qqp.json +++ b/src/unitxt/catalog/cards/qqp.json @@ -28,12 +28,12 @@ ], "task": { "__type__": "task", - "inputs": [ + "input_fields": [ "choices", "question1", "question2" ], - "outputs": [ + "reference_fields": [ "label" ], "metrics": [ diff --git a/src/unitxt/catalog/cards/wsc.json b/src/unitxt/catalog/cards/wsc.json index 01497bf79..97f09f91f 100644 --- a/src/unitxt/catalog/cards/wsc.json +++ b/src/unitxt/catalog/cards/wsc.json @@ -28,13 +28,13 @@ ], "task": { "__type__": "task", - "inputs": [ + "input_fields": [ "choices", "text", "span1_text", "span2_text" ], - "outputs": [ + "reference_fields": [ "label" ], "metrics": [ diff --git a/src/unitxt/catalog/operators/balancers/classification/by_label.json b/src/unitxt/catalog/operators/balancers/classification/by_label.json index 1a5693911..faa6c3f2f 100644 --- a/src/unitxt/catalog/operators/balancers/classification/by_label.json +++ b/src/unitxt/catalog/operators/balancers/classification/by_label.json @@ -1,6 +1,6 @@ { "__type__": "deterministic_balancer", "fields": [ - "outputs/label" + "reference_fields/label" ] } diff --git a/src/unitxt/catalog/operators/balancers/classification/minimum_one_example_per_class.json b/src/unitxt/catalog/operators/balancers/classification/minimum_one_example_per_class.json index 7ee1270fe..2e0832c7e 100644 --- a/src/unitxt/catalog/operators/balancers/classification/minimum_one_example_per_class.json +++ b/src/unitxt/catalog/operators/balancers/classification/minimum_one_example_per_class.json @@ -1,6 +1,6 @@ { "__type__": "minimum_one_example_per_label_refiner", "fields": [ - "outputs/label" + "reference_fields/label" ] } diff --git a/src/unitxt/catalog/operators/balancers/multi_label/zero_vs_many_labels.json b/src/unitxt/catalog/operators/balancers/multi_label/zero_vs_many_labels.json index 444e22495..fb247546d 100644 --- a/src/unitxt/catalog/operators/balancers/multi_label/zero_vs_many_labels.json +++ b/src/unitxt/catalog/operators/balancers/multi_label/zero_vs_many_labels.json @@ -1,7 +1,7 @@ { "__type__": "length_balancer", "fields": [ - "outputs/labels" + "reference_fields/labels" ], "segments_boundaries": [ 1 diff --git a/src/unitxt/catalog/operators/balancers/ner/zero_vs_many_entities.json b/src/unitxt/catalog/operators/balancers/ner/zero_vs_many_entities.json index 444e22495..fb247546d 100644 --- a/src/unitxt/catalog/operators/balancers/ner/zero_vs_many_entities.json +++ b/src/unitxt/catalog/operators/balancers/ner/zero_vs_many_entities.json @@ -1,7 +1,7 @@ { "__type__": "length_balancer", "fields": [ - "outputs/labels" + "reference_fields/labels" ], "segments_boundaries": [ 1 diff --git a/src/unitxt/catalog/operators/balancers/qa/by_answer.json b/src/unitxt/catalog/operators/balancers/qa/by_answer.json index 543693336..e06bba30d 100644 --- a/src/unitxt/catalog/operators/balancers/qa/by_answer.json +++ b/src/unitxt/catalog/operators/balancers/qa/by_answer.json @@ -1,6 +1,6 @@ { "__type__": "deterministic_balancer", "fields": [ - "outputs/answer" + "reference_fields/answer" ] } diff --git a/src/unitxt/catalog/tasks/classification/binary.json b/src/unitxt/catalog/tasks/classification/binary.json index c69bd87a5..51dd36cf7 100644 --- a/src/unitxt/catalog/tasks/classification/binary.json +++ b/src/unitxt/catalog/tasks/classification/binary.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "text_type": "str", "class": "str" }, - "outputs": { + "reference_fields": { "class": "str", "label": "List[str]" }, diff --git a/src/unitxt/catalog/tasks/classification/binary/zero_or_one.json b/src/unitxt/catalog/tasks/classification/binary/zero_or_one.json index 21cde64db..010022aff 100644 --- a/src/unitxt/catalog/tasks/classification/binary/zero_or_one.json +++ b/src/unitxt/catalog/tasks/classification/binary/zero_or_one.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "text_type": "str", "class": "str" }, - "outputs": { + "reference_fields": { "class": "str", "label": "int" }, diff --git a/src/unitxt/catalog/tasks/classification/multi_class.json b/src/unitxt/catalog/tasks/classification/multi_class.json index 02c5f82e4..d8651948d 100644 --- a/src/unitxt/catalog/tasks/classification/multi_class.json +++ b/src/unitxt/catalog/tasks/classification/multi_class.json @@ -1,12 +1,12 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "text_type": "str", "classes": "List[str]", "type_of_class": "str" }, - "outputs": { + "reference_fields": { "label": "str" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/classification/multi_class/relation.json b/src/unitxt/catalog/tasks/classification/multi_class/relation.json index 115b22bee..24e9ffe3c 100644 --- a/src/unitxt/catalog/tasks/classification/multi_class/relation.json +++ b/src/unitxt/catalog/tasks/classification/multi_class/relation.json @@ -1,6 +1,6 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text_a": "str", "text_a_type": "str", "text_b": "str", @@ -8,7 +8,7 @@ "classes": "List[str]", "type_of_relation": "str" }, - "outputs": { + "reference_fields": { "label": "str" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/classification/multi_class/topic_classification.json b/src/unitxt/catalog/tasks/classification/multi_class/topic_classification.json index abf09574b..abe7c1d1c 100644 --- a/src/unitxt/catalog/tasks/classification/multi_class/topic_classification.json +++ b/src/unitxt/catalog/tasks/classification/multi_class/topic_classification.json @@ -1,12 +1,12 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "text_type": "str", "classes": "List[str]", "type_of_class": "str" }, - "outputs": { + "reference_fields": { "label": "str" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/classification/multi_class/with_classes_descriptions.json b/src/unitxt/catalog/tasks/classification/multi_class/with_classes_descriptions.json index 714a09286..ec6566884 100644 --- a/src/unitxt/catalog/tasks/classification/multi_class/with_classes_descriptions.json +++ b/src/unitxt/catalog/tasks/classification/multi_class/with_classes_descriptions.json @@ -1,13 +1,13 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "text_type": "str", "classes": "List[str]", "type_of_class": "str", "classes_descriptions": "str" }, - "outputs": { + "reference_fields": { "label": "str" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/classification/multi_label.json b/src/unitxt/catalog/tasks/classification/multi_label.json index cc238c637..3fd11b4e8 100644 --- a/src/unitxt/catalog/tasks/classification/multi_label.json +++ b/src/unitxt/catalog/tasks/classification/multi_label.json @@ -1,12 +1,12 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "text_type": "str", "classes": "List[str]", "type_of_classes": "str" }, - "outputs": { + "reference_fields": { "labels": "List[str]" }, "prediction_type": "List[str]", diff --git a/src/unitxt/catalog/tasks/completion/abstractive.json b/src/unitxt/catalog/tasks/completion/abstractive.json index db670c45a..0d4c7ea46 100644 --- a/src/unitxt/catalog/tasks/completion/abstractive.json +++ b/src/unitxt/catalog/tasks/completion/abstractive.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "context": "str", "context_type": "str", "completion_type": "str" }, - "outputs": { + "reference_fields": { "completion": "str" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/completion/extractive.json b/src/unitxt/catalog/tasks/completion/extractive.json index c0022dc26..69ba70e17 100644 --- a/src/unitxt/catalog/tasks/completion/extractive.json +++ b/src/unitxt/catalog/tasks/completion/extractive.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "context": "str", "context_type": "str", "completion_type": "str" }, - "outputs": { + "reference_fields": { "completion": "str" }, "prediction_type": "Dict[str,Any]", diff --git a/src/unitxt/catalog/tasks/completion/multiple_choice.json b/src/unitxt/catalog/tasks/completion/multiple_choice.json index e08c075fd..c11fc34c0 100644 --- a/src/unitxt/catalog/tasks/completion/multiple_choice.json +++ b/src/unitxt/catalog/tasks/completion/multiple_choice.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "context": "str", "context_type": "str", "choices": "List[str]" }, - "outputs": { + "reference_fields": { "answer": "int", "choices": "List[str]" }, diff --git a/src/unitxt/catalog/tasks/evaluation/preference.json b/src/unitxt/catalog/tasks/evaluation/preference.json index 6f8a6200b..d6488a2fa 100644 --- a/src/unitxt/catalog/tasks/evaluation/preference.json +++ b/src/unitxt/catalog/tasks/evaluation/preference.json @@ -1,13 +1,13 @@ { "__type__": "task", - "inputs": [ + "input_fields": [ "input", "input_type", "output_type", "choices", "instruction" ], - "outputs": [ + "reference_fields": [ "choices", "output_choice" ], diff --git a/src/unitxt/catalog/tasks/generation.json b/src/unitxt/catalog/tasks/generation.json index 94c6247a0..149df7c37 100644 --- a/src/unitxt/catalog/tasks/generation.json +++ b/src/unitxt/catalog/tasks/generation.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "input": "str", "type_of_input": "str", "type_of_output": "str" }, - "outputs": { + "reference_fields": { "output": "str" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/grammatical_error_correction.json b/src/unitxt/catalog/tasks/grammatical_error_correction.json index 80935cb37..c4e3126d5 100644 --- a/src/unitxt/catalog/tasks/grammatical_error_correction.json +++ b/src/unitxt/catalog/tasks/grammatical_error_correction.json @@ -1,9 +1,9 @@ { "__type__": "task", - "inputs": [ + "input_fields": [ "original_text" ], - "outputs": [ + "reference_fields": [ "corrected_texts" ], "metrics": [ diff --git a/src/unitxt/catalog/tasks/language_identification.json b/src/unitxt/catalog/tasks/language_identification.json index db875f5d1..9d8f277aa 100644 --- a/src/unitxt/catalog/tasks/language_identification.json +++ b/src/unitxt/catalog/tasks/language_identification.json @@ -1,9 +1,9 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str" }, - "outputs": { + "reference_fields": { "label": "str" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/ner/all_entity_types.json b/src/unitxt/catalog/tasks/ner/all_entity_types.json index 23029bca4..942bbd9ce 100644 --- a/src/unitxt/catalog/tasks/ner/all_entity_types.json +++ b/src/unitxt/catalog/tasks/ner/all_entity_types.json @@ -1,10 +1,10 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "entity_types": "List[str]" }, - "outputs": { + "reference_fields": { "spans_starts": "List[int]", "spans_ends": "List[int]", "text": "str", diff --git a/src/unitxt/catalog/tasks/ner/single_entity_type.json b/src/unitxt/catalog/tasks/ner/single_entity_type.json index a8bb62c53..72a509ff6 100644 --- a/src/unitxt/catalog/tasks/ner/single_entity_type.json +++ b/src/unitxt/catalog/tasks/ner/single_entity_type.json @@ -1,10 +1,10 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "entity_type": "str" }, - "outputs": { + "reference_fields": { "spans_starts": "List[int]", "spans_ends": "List[int]", "text": "str", diff --git a/src/unitxt/catalog/tasks/qa/multiple_choice/open.json b/src/unitxt/catalog/tasks/qa/multiple_choice/open.json index f011d1648..53c15f40f 100644 --- a/src/unitxt/catalog/tasks/qa/multiple_choice/open.json +++ b/src/unitxt/catalog/tasks/qa/multiple_choice/open.json @@ -1,10 +1,10 @@ { "__type__": "task", - "inputs": { + "input_fields": { "question": "str", "choices": "List[str]" }, - "outputs": { + "reference_fields": { "answer": "Union[int,str]", "choices": "List[str]" }, diff --git a/src/unitxt/catalog/tasks/qa/multiple_choice/with_context.json b/src/unitxt/catalog/tasks/qa/multiple_choice/with_context.json index ccc1abde4..6bfc2541d 100644 --- a/src/unitxt/catalog/tasks/qa/multiple_choice/with_context.json +++ b/src/unitxt/catalog/tasks/qa/multiple_choice/with_context.json @@ -1,12 +1,12 @@ { "__type__": "task", - "inputs": { + "input_fields": { "context": "str", "context_type": "str", "question": "str", "choices": "List[str]" }, - "outputs": { + "reference_fields": { "answer": "Union[int,str]", "choices": "List[str]" }, diff --git a/src/unitxt/catalog/tasks/qa/multiple_choice/with_context/with_topic.json b/src/unitxt/catalog/tasks/qa/multiple_choice/with_context/with_topic.json index 7e2acc751..bba0daef3 100644 --- a/src/unitxt/catalog/tasks/qa/multiple_choice/with_context/with_topic.json +++ b/src/unitxt/catalog/tasks/qa/multiple_choice/with_context/with_topic.json @@ -1,13 +1,13 @@ { "__type__": "task", - "inputs": { + "input_fields": { "topic": "str", "context": "str", "context_type": "str", "question": "str", "choices": "List[str]" }, - "outputs": { + "reference_fields": { "answer": "Union[int,str]", "choices": "List[str]" }, diff --git a/src/unitxt/catalog/tasks/qa/multiple_choice/with_topic.json b/src/unitxt/catalog/tasks/qa/multiple_choice/with_topic.json index e260fa63d..6a7d9b104 100644 --- a/src/unitxt/catalog/tasks/qa/multiple_choice/with_topic.json +++ b/src/unitxt/catalog/tasks/qa/multiple_choice/with_topic.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "topic": "str", "question": "str", "choices": "List[str]" }, - "outputs": { + "reference_fields": { "answer": "Union[int,str]", "choices": "List[str]" }, diff --git a/src/unitxt/catalog/tasks/qa/open.json b/src/unitxt/catalog/tasks/qa/open.json index cc1f21586..bd84344f6 100644 --- a/src/unitxt/catalog/tasks/qa/open.json +++ b/src/unitxt/catalog/tasks/qa/open.json @@ -1,9 +1,9 @@ { "__type__": "task", - "inputs": { + "input_fields": { "question": "str" }, - "outputs": { + "reference_fields": { "answers": "List[str]" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/qa/with_context/abstractive.json b/src/unitxt/catalog/tasks/qa/with_context/abstractive.json index fbb159037..487525d97 100644 --- a/src/unitxt/catalog/tasks/qa/with_context/abstractive.json +++ b/src/unitxt/catalog/tasks/qa/with_context/abstractive.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "context": "str", "context_type": "str", "question": "str" }, - "outputs": { + "reference_fields": { "answers": "List[str]" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/qa/with_context/extractive.json b/src/unitxt/catalog/tasks/qa/with_context/extractive.json index c84771061..bb42c969b 100644 --- a/src/unitxt/catalog/tasks/qa/with_context/extractive.json +++ b/src/unitxt/catalog/tasks/qa/with_context/extractive.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "context": "str", "context_type": "str", "question": "str" }, - "outputs": { + "reference_fields": { "answers": "List[str]" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/rag/response_generation.json b/src/unitxt/catalog/tasks/rag/response_generation.json index 1bcd98e01..2a2fefee4 100644 --- a/src/unitxt/catalog/tasks/rag/response_generation.json +++ b/src/unitxt/catalog/tasks/rag/response_generation.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "contexts": "List[str]", "contexts_ids": "List[int]", "question": "str" }, - "outputs": { + "reference_fields": { "reference_answers": "List[str]" }, "metrics": [ diff --git a/src/unitxt/catalog/tasks/regression/single_text.json b/src/unitxt/catalog/tasks/regression/single_text.json index 70531788a..126e634e3 100644 --- a/src/unitxt/catalog/tasks/regression/single_text.json +++ b/src/unitxt/catalog/tasks/regression/single_text.json @@ -1,12 +1,12 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "attribute_name": "str", "min_value": "Optional[float]", "max_value": "Optional[float]" }, - "outputs": { + "reference_fields": { "attribute_value": "float" }, "prediction_type": "Any", diff --git a/src/unitxt/catalog/tasks/regression/two_texts.json b/src/unitxt/catalog/tasks/regression/two_texts.json index edeb53c41..1defeb102 100644 --- a/src/unitxt/catalog/tasks/regression/two_texts.json +++ b/src/unitxt/catalog/tasks/regression/two_texts.json @@ -1,13 +1,13 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text1": "str", "text2": "str", "attribute_name": "str", "min_value": "Optional[float]", "max_value": "Optional[float]" }, - "outputs": { + "reference_fields": { "attribute_value": "float" }, "prediction_type": "Any", diff --git a/src/unitxt/catalog/tasks/regression/two_texts/similarity.json b/src/unitxt/catalog/tasks/regression/two_texts/similarity.json index ba17bf6ea..5a384b15b 100644 --- a/src/unitxt/catalog/tasks/regression/two_texts/similarity.json +++ b/src/unitxt/catalog/tasks/regression/two_texts/similarity.json @@ -1,13 +1,13 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text1": "str", "text2": "str", "attribute_name": "str", "min_value": "Optional[float]", "max_value": "Optional[float]" }, - "outputs": { + "reference_fields": { "attribute_value": "float" }, "prediction_type": "Any", diff --git a/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/multi_turn.json b/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/multi_turn.json index 72cab8d42..a5d20dc10 100644 --- a/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/multi_turn.json +++ b/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/multi_turn.json @@ -1,10 +1,10 @@ { "__type__": "task", - "inputs": { + "input_fields": { "dialog_a": "List[Tuple[str, str]]", "dialog_b": "List[Tuple[str, str]]" }, - "outputs": { + "reference_fields": { "winner": "str" }, "metrics": [ diff --git a/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.json b/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.json index 3e3d9a0b8..6f59bdeea 100644 --- a/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.json +++ b/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/multi_turn_with_reference.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "dialog_a": "List[Tuple[str, str]]", "dialog_b": "List[Tuple[str, str]]", "reference_dialog": "List[Tuple[str, str]]" }, - "outputs": { + "reference_fields": { "winner": "str" }, "metrics": [ diff --git a/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/single_turn.json b/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/single_turn.json index 1b1b6e536..ea2573d16 100644 --- a/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/single_turn.json +++ b/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/single_turn.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "question": "str", "answer_a": "str", "answer_b": "str" }, - "outputs": { + "reference_fields": { "winner": "str" }, "metrics": [ diff --git a/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.json b/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.json index 9cea8d25f..ca8f04df9 100644 --- a/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.json +++ b/src/unitxt/catalog/tasks/response_assessment/pairwise_comparison/single_turn_with_reference.json @@ -1,12 +1,12 @@ { "__type__": "task", - "inputs": { + "input_fields": { "question": "str", "answer_a": "str", "answer_b": "str", "reference_answer": "str" }, - "outputs": { + "reference_fields": { "winner": "str" }, "metrics": [ diff --git a/src/unitxt/catalog/tasks/response_assessment/rating/multi_turn.json b/src/unitxt/catalog/tasks/response_assessment/rating/multi_turn.json index aa7ac5200..4da763cb2 100644 --- a/src/unitxt/catalog/tasks/response_assessment/rating/multi_turn.json +++ b/src/unitxt/catalog/tasks/response_assessment/rating/multi_turn.json @@ -1,9 +1,9 @@ { "__type__": "task", - "inputs": { + "input_fields": { "dialog": "List[Tuple[str, str]]" }, - "outputs": { + "reference_fields": { "rating": "float" }, "metrics": [ diff --git a/src/unitxt/catalog/tasks/response_assessment/rating/multi_turn_with_reference.json b/src/unitxt/catalog/tasks/response_assessment/rating/multi_turn_with_reference.json index a7c8bfdff..082cb4414 100644 --- a/src/unitxt/catalog/tasks/response_assessment/rating/multi_turn_with_reference.json +++ b/src/unitxt/catalog/tasks/response_assessment/rating/multi_turn_with_reference.json @@ -1,10 +1,10 @@ { "__type__": "task", - "inputs": { + "input_fields": { "dialog": "List[Tuple[str, str]]", "reference_dialog": "List[Tuple[str, str]]" }, - "outputs": { + "reference_fields": { "rating": "float" }, "metrics": [ diff --git a/src/unitxt/catalog/tasks/response_assessment/rating/single_turn.json b/src/unitxt/catalog/tasks/response_assessment/rating/single_turn.json index 465a7d87c..4c496eeb5 100644 --- a/src/unitxt/catalog/tasks/response_assessment/rating/single_turn.json +++ b/src/unitxt/catalog/tasks/response_assessment/rating/single_turn.json @@ -1,10 +1,10 @@ { "__type__": "task", - "inputs": { + "input_fields": { "question": "str", "answer": "str" }, - "outputs": { + "reference_fields": { "rating": "float" }, "metrics": [ diff --git a/src/unitxt/catalog/tasks/response_assessment/rating/single_turn_with_reference.json b/src/unitxt/catalog/tasks/response_assessment/rating/single_turn_with_reference.json index 57f5d9a59..85d12c4be 100644 --- a/src/unitxt/catalog/tasks/response_assessment/rating/single_turn_with_reference.json +++ b/src/unitxt/catalog/tasks/response_assessment/rating/single_turn_with_reference.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "question": "str", "answer": "str", "reference_answer": "str" }, - "outputs": { + "reference_fields": { "rating": "float" }, "metrics": [ diff --git a/src/unitxt/catalog/tasks/rewriting/by_attribute.json b/src/unitxt/catalog/tasks/rewriting/by_attribute.json index 9bed596ba..f0b568da6 100644 --- a/src/unitxt/catalog/tasks/rewriting/by_attribute.json +++ b/src/unitxt/catalog/tasks/rewriting/by_attribute.json @@ -1,12 +1,12 @@ { "__type__": "task", - "inputs": [ + "input_fields": [ "input_text", "input_text_type", "required_attribute", "output_text_type" ], - "outputs": [ + "reference_fields": [ "output_text" ], "metrics": [ diff --git a/src/unitxt/catalog/tasks/rewriting/paraphrase.json b/src/unitxt/catalog/tasks/rewriting/paraphrase.json index 13a319954..94fb99c8f 100644 --- a/src/unitxt/catalog/tasks/rewriting/paraphrase.json +++ b/src/unitxt/catalog/tasks/rewriting/paraphrase.json @@ -1,10 +1,10 @@ { "__type__": "task", - "inputs": [ + "input_fields": [ "input_text", "text_type" ], - "outputs": [ + "reference_fields": [ "output_text" ], "metrics": [ diff --git a/src/unitxt/catalog/tasks/selection/by_attribute.json b/src/unitxt/catalog/tasks/selection/by_attribute.json index 0034a8362..5e155cf95 100644 --- a/src/unitxt/catalog/tasks/selection/by_attribute.json +++ b/src/unitxt/catalog/tasks/selection/by_attribute.json @@ -1,12 +1,12 @@ { "__type__": "task", - "inputs": [ + "input_fields": [ "required_attribute", "attribute_type", "choices_texts", "choices_text_type" ], - "outputs": [ + "reference_fields": [ "choices_texts", "choice" ], diff --git a/src/unitxt/catalog/tasks/span_labeling/extraction.json b/src/unitxt/catalog/tasks/span_labeling/extraction.json index 44ad1bb21..e98cfc5ee 100644 --- a/src/unitxt/catalog/tasks/span_labeling/extraction.json +++ b/src/unitxt/catalog/tasks/span_labeling/extraction.json @@ -1,12 +1,12 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "text_type": "str", "class_type": "str", "classes": "List[str]" }, - "outputs": { + "reference_fields": { "text": "str", "spans_starts": "List[int]", "spans_ends": "List[int]", diff --git a/src/unitxt/catalog/tasks/summarization/abstractive.json b/src/unitxt/catalog/tasks/summarization/abstractive.json index e14cd0e06..832591735 100644 --- a/src/unitxt/catalog/tasks/summarization/abstractive.json +++ b/src/unitxt/catalog/tasks/summarization/abstractive.json @@ -1,10 +1,10 @@ { "__type__": "task", - "inputs": { + "input_fields": { "document": "str", "document_type": "str" }, - "outputs": { + "reference_fields": { "summary": "str" }, "prediction_type": "str", diff --git a/src/unitxt/catalog/tasks/targeted_sentiment_extraction/all_sentiment_classes.json b/src/unitxt/catalog/tasks/targeted_sentiment_extraction/all_sentiment_classes.json index 964286609..49556d6c5 100644 --- a/src/unitxt/catalog/tasks/targeted_sentiment_extraction/all_sentiment_classes.json +++ b/src/unitxt/catalog/tasks/targeted_sentiment_extraction/all_sentiment_classes.json @@ -1,10 +1,10 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "text_type": "str" }, - "outputs": { + "reference_fields": { "spans_starts": "List[int]", "spans_ends": "List[int]", "text": "List[str]", diff --git a/src/unitxt/catalog/tasks/targeted_sentiment_extraction/single_sentiment_class.json b/src/unitxt/catalog/tasks/targeted_sentiment_extraction/single_sentiment_class.json index b117f3ba6..58af81082 100644 --- a/src/unitxt/catalog/tasks/targeted_sentiment_extraction/single_sentiment_class.json +++ b/src/unitxt/catalog/tasks/targeted_sentiment_extraction/single_sentiment_class.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "text_type": "str", "sentiment_class": "str" }, - "outputs": { + "reference_fields": { "spans_starts": "List[int]", "spans_ends": "List[int]", "text": "List[str]", diff --git a/src/unitxt/catalog/tasks/translation/directed.json b/src/unitxt/catalog/tasks/translation/directed.json index 8f4f967c1..11c803692 100644 --- a/src/unitxt/catalog/tasks/translation/directed.json +++ b/src/unitxt/catalog/tasks/translation/directed.json @@ -1,11 +1,11 @@ { "__type__": "task", - "inputs": { + "input_fields": { "text": "str", "source_language": "str", "target_language": "str" }, - "outputs": { + "reference_fields": { "translation": "str" }, "prediction_type": "str", diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index bb147f028..7f996091f 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -552,7 +552,7 @@ def prepare(self): def set_task_input_fields(self, task_input_fields: List[str]): self._task_input_fields = [ - "inputs/" + task_input_field for task_input_field in task_input_fields + "input_fields/" + task_input_field for task_input_field in task_input_fields ] def process( diff --git a/src/unitxt/schema.py b/src/unitxt/schema.py index 25aca85b8..cf4058fe3 100644 --- a/src/unitxt/schema.py +++ b/src/unitxt/schema.py @@ -36,8 +36,8 @@ def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: task_data = { - **instance["inputs"], - **instance["outputs"], + **instance["input_fields"], + **instance["reference_fields"], "metadata": { "template": self.artifact_to_jsonable( instance["recipe_metadata"]["template"] diff --git a/src/unitxt/splitters.py b/src/unitxt/splitters.py index f07b5ea62..f181d147c 100644 --- a/src/unitxt/splitters.py +++ b/src/unitxt/splitters.py @@ -137,12 +137,14 @@ def sample( def filter_source_by_instance( self, instances_pool: List[Dict[str, object]], instance: Dict[str, object] ) -> List[Dict[str, object]]: - if "inputs" not in instance: - raise ValueError(f"'inputs' field is missing from '{instance}'.") + if "input_fields" not in instance: + raise ValueError(f"'input_fields' field is missing from '{instance}'.") # l = list(filter(lambda x: x["inputs"] != instance["inputs"], instances_pool)) try: return [ - item for item in instances_pool if item["inputs"] != instance["inputs"] + item + for item in instances_pool + if item["input_fields"] != instance["input_fields"] ] except Exception as e: raise e @@ -195,9 +197,9 @@ def prepare(self): self.labels_cache = None def exemplar_repr(self, exemplar): - if "inputs" not in exemplar: - raise ValueError(f"'inputs' field is missing from '{exemplar}'.") - inputs = exemplar["inputs"] + if "input_fields" not in exemplar: + raise ValueError(f"'input_fields' field is missing from '{exemplar}'.") + inputs = exemplar["input_fields"] if self.choices not in inputs: raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.") choices = inputs[self.choices] @@ -209,13 +211,13 @@ def exemplar_repr(self, exemplar): f"Unexpected input choices value '{choices}'. Expected a list or a string." ) - if "outputs" not in exemplar: - raise ValueError(f"'outputs' field is missing from '{exemplar}'.") - outputs = exemplar["outputs"] + if "reference_fields" not in exemplar: + raise ValueError(f"'reference_fields' field is missing from '{exemplar}'.") + outputs = exemplar["reference_fields"] if self.labels not in outputs: raise ValueError(f"'{self.labels}' field is missing from '{outputs}'.") - exemplar_outputs = exemplar["outputs"][self.labels] + exemplar_outputs = exemplar["reference_fields"][self.labels] if not isinstance(exemplar_outputs, list): raise ValueError( f"Unexpected exemplar_outputs value '{exemplar_outputs}'. Expected a list." diff --git a/src/unitxt/task.py b/src/unitxt/task.py index da3d6289b..bbe26620d 100644 --- a/src/unitxt/task.py +++ b/src/unitxt/task.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Union from .artifact import fetch_artifact +from .dataclass import DeprecatedField from .logging_utils import get_logger from .operator import InstanceOperator from .type_utils import ( @@ -17,10 +18,10 @@ class Task(InstanceOperator): """Task packs the different instance fields into dictionaries by their roles in the task. Attributes: - inputs (Union[Dict[str, str], List[str]]): + input_fields (Union[Dict[str, str], List[str]]): Dictionary with string names of instance input fields and types of respective values. In case a list is passed, each type will be assumed to be Any. - outputs (Union[Dict[str, str], List[str]]): + reference_fields (Union[Dict[str, str], List[str]]): Dictionary with string names of instance output fields and types of respective values. In case a list is passed, each type will be assumed to be Any. metrics (List[str]): List of names of metrics to be used in the task. @@ -29,25 +30,64 @@ class Task(InstanceOperator): be set to Any. defaults (Optional[Dict[str, Any]]): An optional dictionary with default values for chosen input/output keys. Needs to be - consistent with names and types provided in 'inputs' and/or 'outputs' arguments. + consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments. Will not overwrite values if already provided in a given instance. The output instance contains three fields: - "inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'. + "inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'. "outputs" -- for the fields listed in Arg "outputs". "metrics" -- to contain the value of Arg 'metrics' """ - inputs: Union[Dict[str, str], List[str]] - outputs: Union[Dict[str, str], List[str]] + input_fields: Optional[Union[Dict[str, str], List[str]]] = None + reference_fields: Optional[Union[Dict[str, str], List[str]]] = None + inputs: Union[Dict[str, str], List[str]] = DeprecatedField( + default=None, + metadata={ + "deprecation_msg": "The 'inputs' field is deprecated. Please use 'input_fields' instead." + }, + ) + outputs: Union[Dict[str, str], List[str]] = DeprecatedField( + default=None, + metadata={ + "deprecation_msg": "The 'outputs' field is deprecated. Please use 'reference_fields' instead." + }, + ) metrics: List[str] prediction_type: Optional[str] = None augmentable_inputs: List[str] = [] defaults: Optional[Dict[str, Any]] = None + def prepare(self): + super().prepare() + if self.input_fields is not None and self.inputs is not None: + raise ValueError( + "Conflicting attributes: 'input_fields' cannot be set simultaneously with 'inputs'. Use only 'input_fields'" + ) + if self.reference_fields is not None and self.outputs is not None: + raise ValueError( + "Conflicting attributes: 'reference_fields' cannot be set simultaneously with 'output'. Use only 'reference_fields'" + ) + + self.input_fields = ( + self.input_fields if self.input_fields is not None else self.inputs + ) + self.reference_fields = ( + self.reference_fields if self.reference_fields is not None else self.outputs + ) + def verify(self): - for io_type in ["inputs", "outputs"]: - data = self.inputs if io_type == "inputs" else self.outputs + if self.input_fields is None: + raise ValueError("Missing attribute in task: 'input_fields' not set.") + if self.reference_fields is None: + raise ValueError("Missing attribute in task: 'reference_fields' not set.") + for io_type in ["input_fields", "reference_fields"]: + data = ( + self.input_fields + if io_type == "input_fields" + else self.reference_fields + ) + if not isoftype(data, Dict[str, str]): get_logger().warning( f"'{io_type}' field of Task should be a dictionary of field names and their types. " @@ -56,10 +96,10 @@ def verify(self): f"will raise an exception." ) data = {key: "Any" for key in data} - if io_type == "inputs": - self.inputs = data + if io_type == "input_fields": + self.input_fields = data else: - self.outputs = data + self.reference_fields = data if not self.prediction_type: get_logger().warning( @@ -74,8 +114,8 @@ def verify(self): for augmentable_input in self.augmentable_inputs: assert ( - augmentable_input in self.inputs - ), f"augmentable_input {augmentable_input} is not part of {self.inputs}" + augmentable_input in self.input_fields + ), f"augmentable_input {augmentable_input} is not part of {self.input_fields}" self.verify_defaults() @@ -121,13 +161,13 @@ def verify_defaults(self): f"however, the key '{default_name}' is of type '{type(default_name)}'." ) - val_type = self.inputs.get(default_name) or self.outputs.get( + val_type = self.input_fields.get( default_name - ) + ) or self.reference_fields.get(default_name) assert val_type, ( f"If specified, all keys of the 'defaults' must refer to a chosen " - f"key in either 'inputs' or 'outputs'. However, the name '{default_name}' " + f"key in either 'input_fields' or 'reference_fields'. However, the name '{default_name}' " f"was provided which does not match any of the keys." ) @@ -146,16 +186,16 @@ def process( ) -> Dict[str, Any]: instance = self.set_default_values(instance) - verify_required_schema(self.inputs, instance) - verify_required_schema(self.outputs, instance) + verify_required_schema(self.input_fields, instance) + verify_required_schema(self.reference_fields, instance) - inputs = {key: instance[key] for key in self.inputs.keys()} - outputs = {key: instance[key] for key in self.outputs.keys()} + input_fields = {key: instance[key] for key in self.input_fields.keys()} + reference_fields = {key: instance[key] for key in self.reference_fields.keys()} data_classification_policy = instance.get("data_classification_policy", []) return { - "inputs": inputs, - "outputs": outputs, + "input_fields": input_fields, + "reference_fields": reference_fields, "metrics": self.metrics, "data_classification_policy": data_classification_policy, } diff --git a/src/unitxt/templates.py b/src/unitxt/templates.py index 449537991..7ef322b55 100644 --- a/src/unitxt/templates.py +++ b/src/unitxt/templates.py @@ -67,7 +67,11 @@ def process( return instance inputs = instance.get("inputs") + if inputs is None: + inputs = instance.get("input_fields") outputs = instance.get("outputs") + if outputs is None: + outputs = instance.get("reference_fields") inputs, outputs = self.preprocess_inputs_and_outputs(inputs, outputs) self.set_titles(inputs) @@ -401,16 +405,20 @@ def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str: return target, [target] def _shuffle_choices(self, instance): - target_index = self.outputs_to_target_index(instance["outputs"]) - original_label_choice = instance["outputs"][self.choices_field][target_index] - choices = instance["inputs"][self.choices_field] + target_index = self.outputs_to_target_index(instance["reference_fields"]) + original_label_choice = instance["reference_fields"][self.choices_field][ + target_index + ] + choices = instance["input_fields"][self.choices_field] random_generator = new_random_generator( - {**instance["inputs"], **instance["outputs"]} + {**instance["input_fields"], **instance["reference_fields"]} ) random_generator.shuffle(choices) - instance["inputs"][self.choices_field] = choices - instance["outputs"][self.choices_field] = choices - instance["outputs"][self.target_field] = choices.index(original_label_choice) + instance["input_fields"][self.choices_field] = choices + instance["reference_fields"][self.choices_field] = choices + instance["reference_fields"][self.target_field] = choices.index( + original_label_choice + ) return instance def process( @@ -419,9 +427,10 @@ def process( if self.shuffle_choices: instance = self._shuffle_choices(instance) result = super().process(instance, stream_name) - if "options" not in result["outputs"]: - result["outputs"]["options"] = self.inputs_to_choices( - instance["outputs"], self.target_choice_format + + if "options" not in result["reference_fields"]: + result["reference_fields"]["options"] = self.inputs_to_choices( + instance["reference_fields"], self.target_choice_format ) return result diff --git a/tests/library/test_api.py b/tests/library/test_api.py index 5904d5cda..aa2421eee 100644 --- a/tests/library/test_api.py +++ b/tests/library/test_api.py @@ -189,8 +189,8 @@ def test_load_dataset_from_dict(self): card = TaskCard( loader=LoadHF(path="glue", name="wnli"), task=Task( - inputs=["sentence1", "sentence2"], - outputs=["label"], + input_fields=["sentence1", "sentence2"], + reference_fields=["label"], metrics=["metrics.accuracy"], ), templates=TemplatesList( diff --git a/tests/library/test_card.py b/tests/library/test_card.py index 6537119d2..9dcc5f134 100644 --- a/tests/library/test_card.py +++ b/tests/library/test_card.py @@ -27,8 +27,8 @@ ), ], task=Task( - inputs=["choices", "sentence1", "sentence2"], - outputs=["label"], + input_fields=["choices", "sentence1", "sentence2"], + reference_fields=["label"], metrics=["metrics.accuracy"], ), templates=TemplatesList( diff --git a/tests/library/test_operators.py b/tests/library/test_operators.py index 41d16b1e1..6651cfa18 100644 --- a/tests/library/test_operators.py +++ b/tests/library/test_operators.py @@ -2505,7 +2505,7 @@ def test_augment_whitespace_model_input(self): def test_augment_whitespace_task_input_with_error(self): text = "The dog ate my cat" - inputs = [{"inputs": {"text": text}}] + inputs = [{"input_fields": {"text": text}}] operator = AugmentWhitespace(augment_task_input=True) operator.set_task_input_fields(["sentence"]) with self.assertRaises(ValueError): @@ -2513,11 +2513,11 @@ def test_augment_whitespace_task_input_with_error(self): def test_augment_whitespace_task_input(self): text = "The dog ate my cat" - inputs = [{"inputs": {"text": text}}] + inputs = [{"input_fields": {"text": text}}] operator = AugmentWhitespace(augment_task_input=True) operator.set_task_input_fields(["text"]) outputs = apply_operator(operator, inputs) - normalized_output_source = outputs[0]["inputs"]["text"].split() + normalized_output_source = outputs[0]["input_fields"]["text"].split() normalized_input_source = text.split() assert ( normalized_output_source == normalized_input_source @@ -2525,10 +2525,10 @@ def test_augment_whitespace_task_input(self): def test_augment_whitespace_with_none_text_error(self): text = None - inputs = [{"inputs": {"text": text}}] + inputs = [{"input_fields": {"text": text}}] operator = AugmentWhitespace(augment_task_input=True) operator.set_task_input_fields(["text"]) - exception_text = "Error processing instance '0' from stream 'test' in AugmentWhitespace due to: Error augmenting value 'None' from 'inputs/text' in instance: {'inputs': {'text': None}}" + exception_text = "Error processing instance '0' from stream 'test' in AugmentWhitespace due to: Error augmenting value 'None' from 'input_fields/text' in instance: {'input_fields': {'text': None}}" check_operator_exception( operator, inputs, @@ -2614,7 +2614,7 @@ def verify(self): def test_augment_prefix_suffix_task_input_with_error(self): text = "She is riding a black horse\t\t " - inputs = [{"inputs": {"text": text}}] + inputs = [{"input_fields": {"text": text}}] suffixes = ["Q", "R", "S", "T"] operator = AugmentPrefixSuffix( augment_task_input=True, suffixes=suffixes, prefixes=None @@ -2624,12 +2624,12 @@ def test_augment_prefix_suffix_task_input_with_error(self): apply_operator(operator, inputs) self.assertEqual( str(ve.exception), - "Error processing instance '0' from stream 'test' in AugmentPrefixSuffix due to: Failed to get inputs/sentence from {'inputs': {'text': 'She is riding a black horse\\t\\t '}}", + "Error processing instance '0' from stream 'test' in AugmentPrefixSuffix due to: Failed to get input_fields/sentence from {'input_fields': {'text': 'She is riding a black horse\\t\\t '}}", ) def test_augment_prefix_suffix_task_input(self): text = "\n She is riding a black horse \t\t " - inputs = [{"inputs": {"text": text}}] + inputs = [{"input_fields": {"text": text}}] suffixes = ["Q", "R", "S", "T"] operator = AugmentPrefixSuffix( augment_task_input=True, @@ -2639,13 +2639,13 @@ def test_augment_prefix_suffix_task_input(self): ) operator.set_task_input_fields(["text"]) outputs = apply_operator(operator, inputs) - output0 = str(outputs[0]["inputs"]["text"]).rstrip("".join(suffixes)) + output0 = str(outputs[0]["input_fields"]["text"]).rstrip("".join(suffixes)) assert ( " \t\t " not in output0 and "\n" not in output0 ), f"Leading and trailing whitespaces should have been removed, but still found in the output: {output0}" assert ( output0 == text.strip()[: len(output0)] - ), f"The prefix of {outputs[0]['inputs']['text']!s} is not equal to the prefix of the stripped input: {text.strip()}" + ), f"The prefix of {outputs[0]['input_fields']['text']!s} is not equal to the prefix of the stripped input: {text.strip()}" def test_augment_prefix_suffix_with_non_string_suffixes_error(self): prefixes = [10, 20, "O", "P"] @@ -2660,13 +2660,13 @@ def test_augment_prefix_suffix_with_non_string_suffixes_error(self): def test_augment_prefix_suffix_with_none_input_error(self): text = None - inputs = [{"inputs": {"text": text}}] + inputs = [{"input_fields": {"text": text}}] suffixes = ["Q", "R", "S", "T"] operator = AugmentPrefixSuffix( augment_task_input=True, suffixes=suffixes, prefixes=None ) operator.set_task_input_fields(["text"]) - exception_text = "Error processing instance '0' from stream 'test' in AugmentPrefixSuffix due to: Error augmenting value 'None' from 'inputs/text' in instance: {'inputs': {'text': None}}" + exception_text = "Error processing instance '0' from stream 'test' in AugmentPrefixSuffix due to: Error augmenting value 'None' from 'input_fields/text' in instance: {'input_fields': {'text': None}}" check_operator_exception( operator, inputs, @@ -2676,10 +2676,10 @@ def test_augment_prefix_suffix_with_none_input_error(self): def test_test_operator_without_tester_param(self): text = None - inputs = [{"inputs": {"text": text}}] + inputs = [{"input_fields": {"text": text}}] operator = AugmentWhitespace(augment_task_input=True) operator.set_task_input_fields(["text"]) - exception_text = "Error processing instance '0' from stream 'test' in AugmentWhitespace due to: Error augmenting value 'None' from 'inputs/text' in instance: {'inputs': {'text': None}}" + exception_text = "Error processing instance '0' from stream 'test' in AugmentWhitespace due to: Error augmenting value 'None' from 'input_fields/text' in instance: {'input_fields': {'text': None}}" check_operator_exception( operator, @@ -2689,10 +2689,10 @@ def test_test_operator_without_tester_param(self): def test_test_operator_unexpected_pass(self): text = "Should be ok" - inputs = [{"inputs": {"text": text}}] + inputs = [{"input_fields": {"text": text}}] operator = AugmentWhitespace(augment_task_input=True) operator.set_task_input_fields(["text"]) - exception_text = "Error processing instance '0' from stream 'test' in AugmentWhitespace due to: Error augmenting value 'None' from 'inputs/text' in instance: {'inputs': {'text': None}}" + exception_text = "Error processing instance '0' from stream 'test' in AugmentWhitespace due to: Error augmenting value 'None' from 'input_fields/text' in instance: {'input_fields': {'text': None}}" try: check_operator_exception( @@ -2703,7 +2703,7 @@ def test_test_operator_unexpected_pass(self): except Exception as e: self.assertEqual( str(e), - "Did not receive expected exception Error processing instance '0' from stream 'test' in AugmentWhitespace due to: Error augmenting value 'None' from 'inputs/text' in instance: {'inputs': {'text': None}}", + "Did not receive expected exception Error processing instance '0' from stream 'test' in AugmentWhitespace due to: Error augmenting value 'None' from 'input_fields/text' in instance: {'input_fields': {'text': None}}", ) def test_duplicate_instance(self): diff --git a/tests/library/test_splitters.py b/tests/library/test_splitters.py index dfe9a01d0..bac1943f9 100644 --- a/tests/library/test_splitters.py +++ b/tests/library/test_splitters.py @@ -16,8 +16,8 @@ def new_exemplar(choices=None, labels=None, text=""): if choices is None: choices = ["class_a", "class_b"] return { - "inputs": {"choices": choices, "text": text}, - "outputs": { + "input_fields": {"choices": choices, "text": text}, + "reference_fields": { "labels": labels, }, } @@ -41,7 +41,7 @@ def test_sample(self): counts = Counter() for i in range(0, num_samples): - counts[result[i]["outputs"]["labels"][0]] += 1 + counts[result[i]["reference_fields"]["labels"][0]] += 1 self.assertEqual(counts["dog"], 1) self.assertEqual(counts["cat"], 1) self.assertEqual(len(counts.keys()), 3) @@ -65,7 +65,7 @@ def test_sample_no_empty_labels(self): counts = Counter() for i in range(0, num_samples): - counts[result[i]["outputs"]["labels"][0]] += 1 + counts[result[i]["reference_fields"]["labels"][0]] += 1 self.assertEqual(set(counts.keys()), {"dog", "cat"}) def test_sample_list(self): @@ -84,7 +84,7 @@ def test_sample_list(self): counts = Counter() for j in range(0, num_samples): - counts[str(result[j]["outputs"]["labels"])] += 1 + counts[str(result[j]["reference_fields"]["labels"])] += 1 self.assertTrue( counts["['dog', 'cat']"] == 1 or counts["['cat']"] == 1, f"unexpected counts: {counts}", @@ -123,8 +123,8 @@ def _test_exemplar_repr_missing_field(self, missing_field): ) def test_exemplar_repr_missing_fields(self): - self._test_exemplar_repr_missing_field(missing_field="inputs") - self._test_exemplar_repr_missing_field(missing_field="outputs") + self._test_exemplar_repr_missing_field(missing_field="input_fields") + self._test_exemplar_repr_missing_field(missing_field="reference_fields") def test_filter_with_bad_input(self): sampler = DiverseLabelsSampler(3) @@ -139,10 +139,10 @@ def test_filter_with_bad_input(self): filtered_instances = sampler.filter_source_by_instance(instances, instance) self.assertEqual(len(filtered_instances), 2) - del instance["inputs"] + del instance["input_fields"] with self.assertRaises(ValueError) as cm: sampler.filter_source_by_instance(instances, instance) self.assertEqual( - f"'inputs' field is missing from '{instance}'.", + f"'input_fields' field is missing from '{instance}'.", str(cm.exception), ) diff --git a/tests/library/test_tasks.py b/tests/library/test_tasks.py index 799754d37..c0dc477b4 100644 --- a/tests/library/test_tasks.py +++ b/tests/library/test_tasks.py @@ -5,6 +5,25 @@ class TestTasks(UnitxtTestCase): def test_task_metrics_type_checking(self): + operator = Task( + input_fields={"input": "str"}, + reference_fields={"label": "str"}, + prediction_type="str", + metrics=["metrics.wer", "metrics.rouge"], + ) + + operator.check_metrics_type() + + operator.prediction_type = "Dict" + with self.assertRaises(ValueError) as e: + operator.check_metrics_type() + self.assertEqual( + str(e.exception), + "The task's prediction type (typing.Dict) and 'metrics.wer' metric's prediction type " + "() are different.", + ) + + def test_task_metrics_type_checking_with_inputs_outputs(self): operator = Task( inputs={"input": "str"}, outputs={"label": "str"}, @@ -23,6 +42,58 @@ def test_task_metrics_type_checking(self): "() are different.", ) + def test_task_missing_input_fields(self): + with self.assertRaises(ValueError) as e: + Task( + input_fields=None, + reference_fields={"label": "str"}, + prediction_type="str", + metrics=["metrics.wer", "metrics.rouge"], + ) + self.assertEqual( + str(e.exception), "Missing attribute in task: 'input_fields' not set." + ) + + def test_task_missing_reference_fields(self): + with self.assertRaises(ValueError) as e: + Task( + input_fields={"input": "int"}, + reference_fields=None, + prediction_type="str", + metrics=["metrics.wer", "metrics.rouge"], + ) + self.assertEqual( + str(e.exception), "Missing attribute in task: 'reference_fields' not set." + ) + + def test_conflicting_input_fields(self): + with self.assertRaises(ValueError) as e: + Task( + inputs={"input": "int"}, + input_fields={"input": "int"}, + reference_fields={"label": "str"}, + prediction_type="str", + metrics=["metrics.wer", "metrics.rouge"], + ) + self.assertEqual( + str(e.exception), + "Conflicting attributes: 'input_fields' cannot be set simultaneously with 'inputs'. Use only 'input_fields'", + ) + + def test_conflicting_output_fields(self): + with self.assertRaises(ValueError) as e: + Task( + input_fields={"input": "int"}, + reference_fields={"label": "str"}, + outputs={"label": "int"}, + prediction_type="str", + metrics=["metrics.wer", "metrics.rouge"], + ) + self.assertEqual( + str(e.exception), + "Conflicting attributes: 'reference_fields' cannot be set simultaneously with 'output'. Use only 'reference_fields'", + ) + def test_set_defaults(self): instances = [ {"input": "Input1", "input_type": "something", "label": 0, "labels": []}, @@ -30,8 +101,8 @@ def test_set_defaults(self): ] operator = Task( - inputs={"input": "str", "input_type": "str"}, - outputs={"label": "int", "labels": "List[int]"}, + input_fields={"input": "str", "input_type": "str"}, + reference_fields={"label": "int", "labels": "List[int]"}, prediction_type="Any", metrics=["metrics.accuracy"], defaults={"input_type": "text", "labels": [0, 1, 2]}, @@ -60,8 +131,8 @@ def test_set_defaults(self): def test_verify_defaults(self): operator = Task( - inputs={"input": "str"}, - outputs={"label": "int"}, + input_fields={"input": "str"}, + reference_fields={"label": "int"}, prediction_type="Any", metrics=["metrics.accuracy"], ) @@ -73,7 +144,7 @@ def test_verify_defaults(self): self.assertEqual( str(e.exception), f"If specified, all keys of the 'defaults' must refer to a chosen " - f"key in either 'inputs' or 'outputs'. However, the name '{default_name}' " + f"key in either 'input_fields' or 'reference_fields'. However, the name '{default_name}' " f"was provided which does not match any of the keys.", ) diff --git a/tests/library/test_templates.py b/tests/library/test_templates.py index 8ad32becb..d3fcb6a25 100644 --- a/tests/library/test_templates.py +++ b/tests/library/test_templates.py @@ -593,23 +593,26 @@ def test_multiple_choice_template(self): choices = ["True", "False"] inputs = [ { - "inputs": {"choices": choices, "text": "example A"}, - "outputs": {"choices": choices, "label": 0}, + "input_fields": {"choices": choices, "text": "example A"}, + "reference_fields": {"choices": choices, "label": 0}, }, { - "inputs": {"choices": choices, "text": "example A"}, - "outputs": {"choices": choices, "label": "False"}, + "input_fields": {"choices": choices, "text": "example A"}, + "reference_fields": {"choices": choices, "label": "False"}, }, { - "inputs": {"choices": ["True", "small"], "text": "example A"}, - "outputs": {"choices": ["True", "small"], "label": "small"}, + "input_fields": {"choices": ["True", "small"], "text": "example A"}, + "reference_fields": { + "choices": ["True", "small"], + "label": "small", + }, }, ] targets = [ { - "inputs": {"choices": choices, "text": "example A"}, - "outputs": { + "input_fields": {"choices": choices, "text": "example A"}, + "reference_fields": { "choices": choices, "label": 0, "options": [f"{first}", f"{second}"], @@ -621,8 +624,8 @@ def test_multiple_choice_template(self): "target_prefix": "", }, { - "inputs": {"choices": choices, "text": "example A"}, - "outputs": { + "input_fields": {"choices": choices, "text": "example A"}, + "reference_fields": { "choices": choices, "label": "False", "options": [f"{first}", f"{second}"], @@ -634,8 +637,8 @@ def test_multiple_choice_template(self): "target_prefix": "", }, { - "inputs": {"choices": ["True", "small"], "text": "example A"}, - "outputs": { + "input_fields": {"choices": ["True", "small"], "text": "example A"}, + "reference_fields": { "choices": ["True", "small"], "label": "small", "options": [f"{first}", f"{second}"], @@ -679,23 +682,26 @@ def test_multiple_choice_template_with_shuffle(self): inputs = [ { - "inputs": {"choices": ["True", "False"], "text": "example A"}, - "outputs": {"choices": ["True", "False"], "label": 0}, + "input_fields": {"choices": ["True", "False"], "text": "example A"}, + "reference_fields": {"choices": ["True", "False"], "label": 0}, }, { - "inputs": {"choices": ["True", "False"], "text": "example A"}, - "outputs": {"choices": ["True", "False"], "label": "False"}, + "input_fields": {"choices": ["True", "False"], "text": "example A"}, + "reference_fields": { + "choices": ["True", "False"], + "label": "False", + }, }, { - "inputs": {"choices": ["True", temp], "text": "example A"}, - "outputs": {"choices": ["True", temp], "label": temp}, + "input_fields": {"choices": ["True", temp], "text": "example A"}, + "reference_fields": {"choices": ["True", temp], "label": temp}, }, ] targets = [ { - "inputs": {"choices": ["True", "False"], "text": "example A"}, - "outputs": { + "input_fields": {"choices": ["True", "False"], "text": "example A"}, + "reference_fields": { "choices": ["True", "False"], "label": 0, "options": [f"{first}", f"{second}"], @@ -707,8 +713,8 @@ def test_multiple_choice_template_with_shuffle(self): "target_prefix": "", }, { - "inputs": {"choices": ["True", "False"], "text": "example A"}, - "outputs": { + "input_fields": {"choices": ["True", "False"], "text": "example A"}, + "reference_fields": { "choices": ["True", "False"], "label": 1, "options": [f"{first}", f"{second}"], @@ -720,8 +726,8 @@ def test_multiple_choice_template_with_shuffle(self): "target_prefix": "", }, { - "inputs": {"choices": [temp, "True"], "text": "example A"}, - "outputs": { + "input_fields": {"choices": [temp, "True"], "text": "example A"}, + "reference_fields": { "choices": [temp, "True"], "label": 0, "options": [f"{first}", f"{second}"],