diff --git a/.secrets.baseline b/.secrets.baseline index b529132ac..83526f3e7 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -3,7 +3,7 @@ "files": "^.secrets.baseline$", "lines": null }, - "generated_at": "2024-08-21T15:51:06Z", + "generated_at": "2024-09-21T12:22:01Z", "plugins_used": [ { "name": "AWSKeyDetector" @@ -82,7 +82,7 @@ "hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889", "is_secret": false, "is_verified": false, - "line_number": 1946, + "line_number": 1955, "type": "Hex High Entropy String", "verified_result": null } diff --git a/examples/run_generic_inference_engine.py b/examples/run_generic_inference_engine.py index bcd4d30ee..b234e6467 100644 --- a/examples/run_generic_inference_engine.py +++ b/examples/run_generic_inference_engine.py @@ -1,20 +1,52 @@ -from unitxt import get_logger, produce -from unitxt.inference import GenericInferenceEngine +from unitxt import get_logger, produce # Import necessary functions from unitxt +from unitxt.inference import GenericInferenceEngine # Import the inference engine class if __name__ == "__main__": - generic_engine = GenericInferenceEngine( - default="engines.ibm_gen_ai.llama_3_8b_instruct" + # Create an instance of the GenericInferenceEngine with a default engine. + # This means if no engine is specified during inference, it will default to this one. + generic_engine_with_default = GenericInferenceEngine( + default="engines.ibm_gen_ai.llama_3_70b_instruct" ) + + # Define the recipe for data processing and model selection. + # - card: Specifies the underlying data (from cards.almost_evil). + # - template: Selects the specific template within the card (from templates.qa.open.simple). + # - demos_pool_size and num_demos: Control the number of demonstration examples used (set to 0 here). recipe = "card=cards.almost_evil,template=templates.qa.open.simple,demos_pool_size=0,num_demos=0" + + # Create a list of instances (data points) for inference. + # Each instance has a "question" and its corresponding "answers". instances = [ - {"question": "How many days there are in a week", "answers": ["7"]}, { - "question": "If a ate an apple in the morning, and one in the evening, how many apples did I eat?", + "question": "How many days there are in a week, answer only with numerals", + "answers": ["7"], + }, + { + "question": "If a ate an apple in the morning, and one in the evening, what is the number of apples I have eaten?, answer only with numerals", "answers": ["2"], }, ] + + # Process the instances using the defined recipe. + # This likely formats the data according to the chosen card and template. dataset = produce(instances, recipe) - predictions = generic_engine.infer(dataset) + # Perform inference on the processed dataset using the engine with the default model. + predictions = generic_engine_with_default.infer(dataset) + get_logger().info(predictions) # Log the predictions + + # The following code block demonstrates how to use the GenericInferenceEngine without specifying a + # default engine. It expects the engine to be defined in the UNITXT_INFERENCE_ENGINE environment variable. + try: + # Attempt to create an instance without a default engine. + generic_engine_without_default = GenericInferenceEngine() - get_logger().info(predictions) + # Perform inference (will use the engine specified in the environment variable). + predictions = generic_engine_without_default.infer(dataset) + get_logger().info(predictions) # Log the predictions + except: + # Handle the case where the environment variable is not set. + get_logger().error( + "GenericInferenceEngine could not be initialized without a default since " + "UNITXT_INFERENCE_ENGINE environmental variable is not set." + ) diff --git a/prepare/cards/fin_qa.py b/prepare/cards/fin_qa.py index bb6341855..3840f8fc1 100644 --- a/prepare/cards/fin_qa.py +++ b/prepare/cards/fin_qa.py @@ -1,38 +1,35 @@ from unitxt.blocks import ( LoadHF, - SerializeTableAsIndexedRowMajor, TaskCard, TemplatesList, ) from unitxt.catalog import add_to_catalog -from unitxt.operators import CopyFields, FilterByExpression +from unitxt.operators import Copy, FilterByExpression from unitxt.struct_data_operators import MapTableListsToStdTableJSON from unitxt.task import Task from unitxt.templates import InputOutputTemplate from unitxt.test_utils.card import test_card +from unitxt.types import Table card = TaskCard( loader=LoadHF(path="ibm/finqa", streaming=False), preprocess_steps=[ FilterByExpression(expression="len(table) > 1"), - CopyFields(field_to_field=[["pre_text/0", "pre_text"]]), - CopyFields(field_to_field=[["post_text/0", "post_text"]]), - MapTableListsToStdTableJSON(field_to_field=[["table", "stdtable"]]), - SerializeTableAsIndexedRowMajor( - field_to_field=[["stdtable", "serialized_table"]] - ), + Copy(field="pre_text/0", to_field="pre_text"), + Copy(field="post_text/0", to_field="post_text"), + MapTableListsToStdTableJSON(field="table"), ], task=Task( inputs={ "pre_text": str, - "serialized_table": str, + "table": Table, "post_text": str, "question": str, }, outputs={"program_re": str, "answer": str}, prediction_type=str, metrics=["metrics.fin_qa_metric"], - augmentable_inputs=["pre_text", "serialized_table", "post_text", "question"], + augmentable_inputs=["pre_text", "table", "post_text", "question"], ), templates=TemplatesList( [ @@ -52,7 +49,7 @@ ["table-min", "table header", "number", "the minimum number of one table row"]] Answer with only the program, without any additional explanation. Pre-table text: {pre_text} - Table: {serialized_table} + Table: {table} Post-table text: {post_text} Question: {question} Program: diff --git a/prepare/cards/numeric_nlg.py b/prepare/cards/numeric_nlg.py index 7846b3696..4621ffd0a 100644 --- a/prepare/cards/numeric_nlg.py +++ b/prepare/cards/numeric_nlg.py @@ -2,16 +2,16 @@ LoadHF, MapHTMLTableToJSON, Rename, - SerializeTableAsMarkdown, Set, TaskCard, ) from unitxt.catalog import add_to_catalog +from unitxt.operators import Copy from unitxt.templates import TemplatesList from unitxt.test_utils.card import test_card card = TaskCard( - loader=LoadHF(path="kasnerz/numericnlg"), # TODO: load from github repo + loader=LoadHF(path="kasnerz/numericnlg"), preprocess_steps=[ Set( fields={ @@ -21,7 +21,7 @@ } ), MapHTMLTableToJSON(field="table_html_clean", to_field="table_out"), - SerializeTableAsMarkdown(field="table_out", to_field="input_a"), + Copy(field="table_out", to_field="input_a"), Rename(field="description", to_field="output"), Rename(field="caption", to_field="input_b"), ], diff --git a/prepare/cards/rag/end_to_end/__init__.py b/prepare/cards/rag/end_to_end/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/prepare/cards/scigen.py b/prepare/cards/scigen.py index 2a76a0ff5..7924d923d 100644 --- a/prepare/cards/scigen.py +++ b/prepare/cards/scigen.py @@ -3,7 +3,6 @@ ConstructTableFromRowsCols, LoadHF, Rename, - SerializeTableAsIndexedRowMajor, Set, TaskCard, ) @@ -16,9 +15,8 @@ FilterByCondition(values={"table_content_values": "[]"}, condition="ne"), ConstructTableFromRowsCols( fields=["table_column_names", "table_content_values"], - to_field="table", + to_field="input_a", ), - SerializeTableAsIndexedRowMajor(field_to_field=[["table", "input_a"]]), Rename(field_to_field={"table_caption": "input_b", "text": "output"}), Set( fields={ diff --git a/prepare/cards/tab_fact.py b/prepare/cards/tab_fact.py index 446fa95d7..48273c4a4 100644 --- a/prepare/cards/tab_fact.py +++ b/prepare/cards/tab_fact.py @@ -2,7 +2,6 @@ LoadHF, MapInstanceValues, Rename, - SerializeTableAsIndexedRowMajor, Set, TaskCard, ) @@ -16,8 +15,7 @@ path="ibm/tab_fact", streaming=False, data_classification_policy=["public"] ), preprocess_steps=[ - SerializeTableAsIndexedRowMajor(field_to_field=[["table", "table_serialized"]]), - Rename(field_to_field={"table_serialized": "text_a", "statement": "text_b"}), + Rename(field_to_field={"table": "text_a", "statement": "text_b"}), MapInstanceValues(mappers={"label": {"0": "refuted", "1": "entailed"}}), Set( fields={ diff --git a/prepare/cards/turl_col_type.py b/prepare/cards/turl_col_type.py index 8ed468c23..0aae7172a 100644 --- a/prepare/cards/turl_col_type.py +++ b/prepare/cards/turl_col_type.py @@ -3,29 +3,25 @@ from unitxt.blocks import ( InputOutputTemplate, LoadHF, - SerializeTableAsIndexedRowMajor, Task, TaskCard, TemplatesList, ) from unitxt.catalog import add_to_catalog from unitxt.test_utils.card import test_card +from unitxt.types import Table card = TaskCard( loader=LoadHF( path="ibm/turl_table_col_type", - streaming=False, data_classification_policy=["public"], ), - preprocess_steps=[ - SerializeTableAsIndexedRowMajor(field_to_field=[["table", "table_lin"]]) - ], task=Task( input_fields={ "page_title": str, "section_title": str, "table_caption": str, - "table_lin": str, + "table": Table, "vocab": List[str], "colname": str, }, @@ -41,7 +37,7 @@ [ InputOutputTemplate( input_format=""" - This is a column type annotation task. The goal of this task is to choose the correct types for one selected column of the given input table from the given candidate types. The Wikipedia page, section and table caption (if any) provide important information for choosing the correct column types. \nPage Title: {page_title} \nSection Title: {section_title} \nTable caption: {table_caption} \nTable: \n{table_lin} \nSelected Column: {colname} \nCandidate Types: {vocab} \nOutput only the correct column types for this column (column name: {colname}) from the candidate types. + This is a column type annotation task. The goal of this task is to choose the correct types for one selected column of the given input table from the given candidate types. The Wikipedia page, section and table caption (if any) provide important information for choosing the correct column types. \nPage Title: {page_title} \nSection Title: {section_title} \nTable caption: {table_caption} \nTable: \n{table} \nSelected Column: {colname} \nCandidate Types: {vocab} \nOutput only the correct column types for this column (column name: {colname}) from the candidate types. """.strip(), output_format="{annotations}", postprocessors=["processors.to_list_by_comma"], diff --git a/prepare/cards/wikitq.py b/prepare/cards/wikitq.py index 59be1131d..38150ea2b 100644 --- a/prepare/cards/wikitq.py +++ b/prepare/cards/wikitq.py @@ -1,10 +1,9 @@ from unitxt.blocks import ( LoadHF, - SerializeTableAsIndexedRowMajor, - Set, TaskCard, ) from unitxt.catalog import add_to_catalog +from unitxt.operators import Copy, Set from unitxt.templates import MultiReferenceTemplate, TemplatesList from unitxt.test_utils.card import test_card @@ -15,10 +14,7 @@ ), preprocess_steps=[ Set({"context_type": "table"}), - ## truncate only if needed as it can impact evaluation results. - # TruncateTableCells(max_length=15, table="table", text_output="answers"), - # TruncateTableRows(field="table", rows_to_keep=50), - SerializeTableAsIndexedRowMajor(field_to_field=[["table", "context"]]), + Copy(field="table", to_field="context"), ], task="tasks.qa.with_context.extractive[metrics=[metrics.f1_strings, metrics.unsorted_list_exact_match]]", templates=TemplatesList( diff --git a/prepare/engines/ibm_wml/llama3.py b/prepare/engines/ibm_wml/llama3.py new file mode 100644 index 000000000..587e01327 --- /dev/null +++ b/prepare/engines/ibm_wml/llama3.py @@ -0,0 +1,11 @@ +from unitxt.catalog import add_to_catalog +from unitxt.inference import WMLInferenceEngine + +model_list = ["meta-llama/llama-3-70b-instruct"] + +for model in model_list: + model_label = model.split("/")[1].replace("-", "_").replace(".", ",").lower() + inference_model = WMLInferenceEngine( + model_name=model, max_new_tokens=2048, random_seed=42 + ) + add_to_catalog(inference_model, f"engines.ibm_wml.{model_label}", overwrite=True) diff --git a/prepare/engines/ollama/__init__.py b/prepare/engines/ollama/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/prepare/metrics/llm_as_judge/pairwise_rating/llama_3_arena_hard_template.py b/prepare/metrics/llm_as_judge/pairwise_rating/llama_3_arena_hard_template.py new file mode 100644 index 000000000..c16a6e18b --- /dev/null +++ b/prepare/metrics/llm_as_judge/pairwise_rating/llama_3_arena_hard_template.py @@ -0,0 +1,62 @@ +from unitxt import add_to_catalog +from unitxt.inference import ( + GenericInferenceEngine, + IbmGenAiInferenceEngine, + WMLInferenceEngine, +) +from unitxt.llm_as_judge import LLMAsJudge + +model_list = ["meta-llama/llama-3-8b-instruct", "meta-llama/llama-3-70b-instruct"] +format = "formats.llama3_instruct" +templates = [ + "templates.response_assessment.pairwise_comparative_rating.arena_hard", + "templates.response_assessment.pairwise_comparative_rating.arena_hard_with_shuffling", +] + +inference_engines = [ + ("ibm_wml", WMLInferenceEngine), + ("ibm_genai", IbmGenAiInferenceEngine), + ("generic_engine", GenericInferenceEngine), +] + + +for template in templates: + task = "pairwise_comparative_rating.single_turn" + + for model_id in model_list: + for inference_engine_name, inference_engine in inference_engines: + if ( + inference_engine_name == "ibm_wml" + and model_id == "meta-llama/llama-3-8b-instruct" + ): + continue # currently not supported + + # if inference engine is generic, these configurations will be defined when it is saved to the catalog + if inference_engine_name != "generic_engine": + inference_model = inference_engine( + model_name=model_id, max_new_tokens=2048, random_seed=42 + ) + else: + inference_model = inference_engine( + default="engines.ibm_gen_ai.llama_3_70b_instruct" + ) + + model_label = ( + model_id.split("/")[1].replace("-", "_").replace(".", ",").lower() + ) + model_label = f"{model_label}_{inference_engine_name}" + template_label = template.split(".")[-1] + metric_label = f"{model_label}_template_{template_label}" + metric = LLMAsJudge( + inference_model=inference_model, + template=template, + task=task, + format=format, + main_score=metric_label, + ) + + add_to_catalog( + metric, + f"metrics.llm_as_judge.pairwise_comparative_rating.{model_label}_template_{template_label}", + overwrite=True, + ) diff --git a/prepare/metrics/llm_as_judge/pairwise_rating/llama_3_ibm_genai_arena_hard_template.py b/prepare/metrics/llm_as_judge/pairwise_rating/llama_3_ibm_genai_arena_hard_template.py deleted file mode 100644 index 856f1b1d4..000000000 --- a/prepare/metrics/llm_as_judge/pairwise_rating/llama_3_ibm_genai_arena_hard_template.py +++ /dev/null @@ -1,36 +0,0 @@ -from unitxt import add_to_catalog -from unitxt.inference import ( - IbmGenAiInferenceEngine, -) -from unitxt.llm_as_judge import LLMAsJudge - -model_list = ["meta-llama/llama-3-8b-instruct", "meta-llama/llama-3-70b-instruct"] -format = "formats.llama3_instruct" -templates = [ - "templates.response_assessment.pairwise_comparative_rating.arena_hard", - "templates.response_assessment.pairwise_comparative_rating.arena_hard_with_shuffling", -] -for template in templates: - task = "pairwise_comparative_rating.single_turn" - - for model_id in model_list: - inference_model = IbmGenAiInferenceEngine( - model_name=model_id, max_new_tokens=2048, random_seed=42 - ) - model_label = model_id.split("/")[1].replace("-", "_").replace(".", ",").lower() - model_label = f"{model_label}_ibm_genai" - template_label = template.split(".")[-1] - metric_label = f"{model_label}_template_{template_label}" - metric = LLMAsJudge( - inference_model=inference_model, - template=template, - task=task, - format=format, - main_score=metric_label, - ) - - add_to_catalog( - metric, - f"metrics.llm_as_judge.pairwise_comparative_rating.{model_label}_template_{template_label}", - overwrite=True, - ) diff --git a/prepare/tasks/classification.py b/prepare/tasks/classification.py index c6741c691..dcbd30714 100644 --- a/prepare/tasks/classification.py +++ b/prepare/tasks/classification.py @@ -1,7 +1,8 @@ -from typing import List +from typing import List, Union -from unitxt.blocks import Task from unitxt.catalog import add_to_catalog +from unitxt.task import Task +from unitxt.types import Audio, Dialog, Image, Table, Text add_to_catalog( Task( @@ -79,7 +80,7 @@ add_to_catalog( Task( input_fields={ - "text_a": str, + "text_a": Union[Text, Image, Audio, Table, Dialog], "text_a_type": str, "text_b": str, "text_b_type": str, diff --git a/prepare/tasks/generation.py b/prepare/tasks/generation.py index efeaed30f..5d12bd0a9 100644 --- a/prepare/tasks/generation.py +++ b/prepare/tasks/generation.py @@ -1,5 +1,8 @@ +from typing import Union + from unitxt.blocks import Task from unitxt.catalog import add_to_catalog +from unitxt.types import Audio, Dialog, Image, Table, Text add_to_catalog( Task( @@ -35,9 +38,9 @@ add_to_catalog( Task( input_fields={ - "input_a": str, + "input_a": Union[Text, Image, Audio, Table, Dialog], "type_of_input_a": str, - "input_b": str, + "input_b": Union[Text, Image, Audio, Table, Dialog], "type_of_input_b": str, "type_of_output": str, }, diff --git a/prepare/tasks/qa/tasks.py b/prepare/tasks/qa/tasks.py index 62287ce5c..a65c928e6 100644 --- a/prepare/tasks/qa/tasks.py +++ b/prepare/tasks/qa/tasks.py @@ -6,7 +6,11 @@ add_to_catalog( Task( - input_fields={"context": Text, "context_type": str, "question": str}, + input_fields={ + "context": Union[Text, Table, Dialog], + "context_type": str, + "question": str, + }, reference_fields={"answers": List[str]}, prediction_type=str, metrics=["metrics.squad"], diff --git a/prepare/tasks/rag/__init__.py b/prepare/tasks/rag/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/prepare/templates/classification/classification.py b/prepare/templates/classification/classification.py index d760bf8cd..ae08fde8d 100644 --- a/prepare/templates/classification/classification.py +++ b/prepare/templates/classification/classification.py @@ -115,6 +115,22 @@ overwrite=True, ) +add_to_catalog( + TemplatesList( + [ + # "templates.classification.multi_class.default", + "templates.classification.multi_class.instruction", + # "templates.classification.multi_class.title", + # "templates.classification.multi_class.empty", + "templates.classification.multi_class.instruct_question_selects", + "templates.classification.multi_class.instruct_question_select_i_think", + # "templates.classification.multi_class.instruct_select_question", + ] + ), + "templates.classification.multi_class.blue_bench", + overwrite=True, +) + # Multi label diff --git a/prepare/templates/completion/multiple_choice/templates.py b/prepare/templates/completion/multiple_choice/templates.py index 551531b61..869d497d5 100644 --- a/prepare/templates/completion/multiple_choice/templates.py +++ b/prepare/templates/completion/multiple_choice/templates.py @@ -72,3 +72,16 @@ "templates.completion.multiple_choice.all", overwrite=True, ) + +add_to_catalog( + TemplatesList( + [ + "templates.completion.multiple_choice.simple", + "templates.completion.multiple_choice.enumerated", + "templates.completion.multiple_choice.standard", + # "templates.completion.multiple_choice.title", + ] + ), + "templates.completion.multiple_choice.blue_bench", + overwrite=True, +) diff --git a/prepare/templates/qa/multiple_choice/templates.py b/prepare/templates/qa/multiple_choice/templates.py index ebc23be7a..9298bdd1d 100644 --- a/prepare/templates/qa/multiple_choice/templates.py +++ b/prepare/templates/qa/multiple_choice/templates.py @@ -476,3 +476,38 @@ def remove_duplicates(input_list): "templates.qa.multiple_choice.with_topic.all", overwrite=True, ) + +add_to_catalog( + TemplatesList( + [ + "templates.qa.multiple_choice.with_topic.mmlu", + "templates.qa.multiple_choice.with_topic.helm", + "templates.qa.multiple_choice.with_topic.lm_eval_harness", + ] + ), + "templates.qa.multiple_choice.with_topic.blue_bench", + overwrite=True, +) + +add_to_catalog( + TemplatesList( + [ + "templates.qa.multiple_choice.open.helm", + "templates.qa.multiple_choice.open.lm_eval_harness", + "templates.qa.multiple_choice.open.mmlu", + ] + ), + "templates.qa.multiple_choice.open.blue_bench", + overwrite=True, +) +add_to_catalog( + TemplatesList( + [ + "templates.qa.multiple_choice.with_context.lm_eval_harness", + "templates.qa.multiple_choice.with_context.no_intro.helm", + "templates.qa.multiple_choice.with_context.no_intro.mmlu", + ] + ), + "templates.qa.multiple_choice.with_context.blue_bench", + overwrite=True, +) diff --git a/prepare/templates/rag/response_generation.py b/prepare/templates/rag/response_generation.py index efde6c9a5..de17aded1 100644 --- a/prepare/templates/rag/response_generation.py +++ b/prepare/templates/rag/response_generation.py @@ -1,6 +1,7 @@ from unitxt import add_to_catalog from unitxt.templates import ( MultiReferenceTemplate, + TemplatesList, ) add_to_catalog( @@ -45,3 +46,15 @@ "templates.rag.response_generation.answer_based_on_context_inverted", overwrite=True, ) + +add_to_catalog( + TemplatesList( + [ + "templates.rag.response_generation.please_respond", + "templates.rag.response_generation.please_respond_chat", + "templates.rag.response_generation.answer_based_on_context", + ] + ), + "templates.rag.response_generation.blue_bench", + overwrite=True, +) diff --git a/prepare/templates/summarization/abstractive.py b/prepare/templates/summarization/abstractive.py index dec71f714..7f10ed9d7 100644 --- a/prepare/templates/summarization/abstractive.py +++ b/prepare/templates/summarization/abstractive.py @@ -190,3 +190,18 @@ "templates.summarization.abstractive.all", overwrite=True, ) + + +add_to_catalog( + TemplatesList( + [ + "templates.summarization.abstractive.instruct_full", + "templates.summarization.abstractive.instruct_one_sentence", + "templates.summarization.abstractive.instruct_passive", + "templates.summarization.abstractive.instruct_write_succinct", + "templates.summarization.abstractive.instruct_tldr", + ] + ), + "templates.summarization.abstractive.blue_bench", + overwrite=True, +) diff --git a/prepare/templates/translation/directed.py b/prepare/templates/translation/directed.py index 7381b6ba2..7216986c4 100644 --- a/prepare/templates/translation/directed.py +++ b/prepare/templates/translation/directed.py @@ -76,3 +76,18 @@ "templates.translation.directed.all", overwrite=True, ) + +add_to_catalog( + TemplatesList( + [ + "templates.translation.directed.simple", + "templates.translation.directed.formal", + "templates.translation.directed.casual", + # "templates.translation.directed.playful", + # "templates.translation.directed.instructional", + # "templates.translation.directed.title", + ] + ), + "templates.translation.directed.blue_bench", + overwrite=True, +) diff --git a/src/unitxt/catalog/cards/fin_qa.json b/src/unitxt/catalog/cards/fin_qa.json index 72e2765e1..682c8592c 100644 --- a/src/unitxt/catalog/cards/fin_qa.json +++ b/src/unitxt/catalog/cards/fin_qa.json @@ -11,47 +11,25 @@ "expression": "len(table) > 1" }, { - "__type__": "copy_fields", - "field_to_field": [ - [ - "pre_text/0", - "pre_text" - ] - ] + "__type__": "copy", + "field": "pre_text/0", + "to_field": "pre_text" }, { - "__type__": "copy_fields", - "field_to_field": [ - [ - "post_text/0", - "post_text" - ] - ] + "__type__": "copy", + "field": "post_text/0", + "to_field": "post_text" }, { "__type__": "map_table_lists_to_std_table_json", - "field_to_field": [ - [ - "table", - "stdtable" - ] - ] - }, - { - "__type__": "serialize_table_as_indexed_row_major", - "field_to_field": [ - [ - "stdtable", - "serialized_table" - ] - ] + "field": "table" } ], "task": { "__type__": "task", "inputs": { "pre_text": "str", - "serialized_table": "str", + "table": "Table", "post_text": "str", "question": "str" }, @@ -65,7 +43,7 @@ ], "augmentable_inputs": [ "pre_text", - "serialized_table", + "table", "post_text", "question" ] @@ -75,7 +53,7 @@ "items": [ { "__type__": "input_output_template", - "input_format": "Presented with a financial report consisting of textual contents and a structured table, given a question, generate the reasoning program in the domain specific language (DSL) that will be executed to get the answer. \nThe DSL consists of mathematical operations and table operations as executable programs. The program consists of a sequence of operations. Each operation takes a list of arguments. \nThere are 6 mathematical operations: add, subtract, multiply, divide, greater, exp, and 4 table aggregation operations table-max, table-min, table-sum, table-average, that apply aggregation operations on table rows. The mathematical operations take arguments of either numbers from the given reports, or a numerical result from a previous step.\nThe table operations take arguments of table row names. We use the special token #n to denote the result from the nth step. \nFor example, in the example \"divide(9413, 20.01), divide(8249, 9.48), subtract(#0, #1)\", the program consists of 3 steps; The first and the second division steps take arguments from the table and the text, respectively, then the third step subtracts the results from the two previous steps.\n Definitions of all operations:\n [[\"Name\", \"Arguments\", \"Output\", \"Description\"],\n [\"add\", \"number1, number2\", \"number\", \"add two numbers: number1 + number2\"],\n [\"subtract\", \"number1, number2\", \"number\", \"subtract two numbers: number1 - number2\"],\n [\"multiply\", \"number1, number2\", \"number\", \"multiply two numbers: number1 * number2\"],\n [\"divide\", \"number1, number2\", \"number\", \"multiply two numbers: number1 / number2\"],\n [\"exp\", \"number1, number2\", \"number\", \"exponential: number1 ^ number2\"],\n [\"greater\", \"number1, number2\", \"bool\", \"comparison: number1 > number2\"],\n [\"table-sum\", \"table header\", \"number\", \"the summation of one table row\"],\n [\"table-average\", \"table header\", \"number\", \"the average of one table row\"],\n [\"table-max\", \"table header\", \"number\", \"the maximum number of one table row\"],\n [\"table-min\", \"table header\", \"number\", \"the minimum number of one table row\"]]\n Answer with only the program, without any additional explanation.\n Pre-table text: {pre_text}\n Table: {serialized_table}\n Post-table text: {post_text}\n Question: {question}\n Program:\n ", + "input_format": "Presented with a financial report consisting of textual contents and a structured table, given a question, generate the reasoning program in the domain specific language (DSL) that will be executed to get the answer. \nThe DSL consists of mathematical operations and table operations as executable programs. The program consists of a sequence of operations. Each operation takes a list of arguments. \nThere are 6 mathematical operations: add, subtract, multiply, divide, greater, exp, and 4 table aggregation operations table-max, table-min, table-sum, table-average, that apply aggregation operations on table rows. The mathematical operations take arguments of either numbers from the given reports, or a numerical result from a previous step.\nThe table operations take arguments of table row names. We use the special token #n to denote the result from the nth step. \nFor example, in the example \"divide(9413, 20.01), divide(8249, 9.48), subtract(#0, #1)\", the program consists of 3 steps; The first and the second division steps take arguments from the table and the text, respectively, then the third step subtracts the results from the two previous steps.\n Definitions of all operations:\n [[\"Name\", \"Arguments\", \"Output\", \"Description\"],\n [\"add\", \"number1, number2\", \"number\", \"add two numbers: number1 + number2\"],\n [\"subtract\", \"number1, number2\", \"number\", \"subtract two numbers: number1 - number2\"],\n [\"multiply\", \"number1, number2\", \"number\", \"multiply two numbers: number1 * number2\"],\n [\"divide\", \"number1, number2\", \"number\", \"multiply two numbers: number1 / number2\"],\n [\"exp\", \"number1, number2\", \"number\", \"exponential: number1 ^ number2\"],\n [\"greater\", \"number1, number2\", \"bool\", \"comparison: number1 > number2\"],\n [\"table-sum\", \"table header\", \"number\", \"the summation of one table row\"],\n [\"table-average\", \"table header\", \"number\", \"the average of one table row\"],\n [\"table-max\", \"table header\", \"number\", \"the maximum number of one table row\"],\n [\"table-min\", \"table header\", \"number\", \"the minimum number of one table row\"]]\n Answer with only the program, without any additional explanation.\n Pre-table text: {pre_text}\n Table: {table}\n Post-table text: {post_text}\n Question: {question}\n Program:\n ", "output_format": "{program_re}", "postprocessors": [] } diff --git a/src/unitxt/catalog/cards/numeric_nlg.json b/src/unitxt/catalog/cards/numeric_nlg.json index b4506d0a9..59b3b2dee 100644 --- a/src/unitxt/catalog/cards/numeric_nlg.json +++ b/src/unitxt/catalog/cards/numeric_nlg.json @@ -19,7 +19,7 @@ "to_field": "table_out" }, { - "__type__": "serialize_table_as_markdown", + "__type__": "copy", "field": "table_out", "to_field": "input_a" }, diff --git a/src/unitxt/catalog/cards/scigen.json b/src/unitxt/catalog/cards/scigen.json index 306d13b6b..509079bd6 100644 --- a/src/unitxt/catalog/cards/scigen.json +++ b/src/unitxt/catalog/cards/scigen.json @@ -21,16 +21,7 @@ "table_column_names", "table_content_values" ], - "to_field": "table" - }, - { - "__type__": "serialize_table_as_indexed_row_major", - "field_to_field": [ - [ - "table", - "input_a" - ] - ] + "to_field": "input_a" }, { "__type__": "rename", diff --git a/src/unitxt/catalog/cards/tab_fact.json b/src/unitxt/catalog/cards/tab_fact.json index b574d8cb7..1739bbf4b 100644 --- a/src/unitxt/catalog/cards/tab_fact.json +++ b/src/unitxt/catalog/cards/tab_fact.json @@ -9,19 +9,10 @@ ] }, "preprocess_steps": [ - { - "__type__": "serialize_table_as_indexed_row_major", - "field_to_field": [ - [ - "table", - "table_serialized" - ] - ] - }, { "__type__": "rename", "field_to_field": { - "table_serialized": "text_a", + "table": "text_a", "statement": "text_b" } }, diff --git a/src/unitxt/catalog/cards/turl_col_type.json b/src/unitxt/catalog/cards/turl_col_type.json index 445fceef4..69b195ea7 100644 --- a/src/unitxt/catalog/cards/turl_col_type.json +++ b/src/unitxt/catalog/cards/turl_col_type.json @@ -3,29 +3,17 @@ "loader": { "__type__": "load_hf", "path": "ibm/turl_table_col_type", - "streaming": false, "data_classification_policy": [ "public" ] }, - "preprocess_steps": [ - { - "__type__": "serialize_table_as_indexed_row_major", - "field_to_field": [ - [ - "table", - "table_lin" - ] - ] - } - ], "task": { "__type__": "task", "input_fields": { "page_title": "str", "section_title": "str", "table_caption": "str", - "table_lin": "str", + "table": "Table", "vocab": "List[str]", "colname": "str" }, @@ -44,7 +32,7 @@ "items": [ { "__type__": "input_output_template", - "input_format": "This is a column type annotation task. The goal of this task is to choose the correct types for one selected column of the given input table from the given candidate types. The Wikipedia page, section and table caption (if any) provide important information for choosing the correct column types. \nPage Title: {page_title} \nSection Title: {section_title} \nTable caption: {table_caption} \nTable: \n{table_lin} \nSelected Column: {colname} \nCandidate Types: {vocab} \nOutput only the correct column types for this column (column name: {colname}) from the candidate types.", + "input_format": "This is a column type annotation task. The goal of this task is to choose the correct types for one selected column of the given input table from the given candidate types. The Wikipedia page, section and table caption (if any) provide important information for choosing the correct column types. \nPage Title: {page_title} \nSection Title: {section_title} \nTable caption: {table_caption} \nTable: \n{table} \nSelected Column: {colname} \nCandidate Types: {vocab} \nOutput only the correct column types for this column (column name: {colname}) from the candidate types.", "output_format": "{annotations}", "postprocessors": [ "processors.to_list_by_comma" diff --git a/src/unitxt/catalog/cards/wikitq.json b/src/unitxt/catalog/cards/wikitq.json index 9774b5c8d..2113de3e0 100644 --- a/src/unitxt/catalog/cards/wikitq.json +++ b/src/unitxt/catalog/cards/wikitq.json @@ -16,13 +16,9 @@ } }, { - "__type__": "serialize_table_as_indexed_row_major", - "field_to_field": [ - [ - "table", - "context" - ] - ] + "__type__": "copy", + "field": "table", + "to_field": "context" } ], "task": "tasks.qa.with_context.extractive[metrics=[metrics.f1_strings, metrics.unsorted_list_exact_match]]", diff --git a/src/unitxt/catalog/engines/ibm_wml/llama_3_70b_instruct.json b/src/unitxt/catalog/engines/ibm_wml/llama_3_70b_instruct.json new file mode 100644 index 000000000..471a32eaf --- /dev/null +++ b/src/unitxt/catalog/engines/ibm_wml/llama_3_70b_instruct.json @@ -0,0 +1,6 @@ +{ + "__type__": "wml_inference_engine", + "model_name": "meta-llama/llama-3-70b-instruct", + "max_new_tokens": 2048, + "random_seed": 42 +} diff --git a/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_generic_engine_template_arena_hard.json b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_generic_engine_template_arena_hard.json new file mode 100644 index 000000000..5dcfaec43 --- /dev/null +++ b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_generic_engine_template_arena_hard.json @@ -0,0 +1,11 @@ +{ + "__type__": "llm_as_judge", + "inference_model": { + "__type__": "generic_inference_engine", + "default": "engines.ibm_gen_ai.llama_3_70b_instruct" + }, + "template": "templates.response_assessment.pairwise_comparative_rating.arena_hard", + "task": "pairwise_comparative_rating.single_turn", + "format": "formats.llama3_instruct", + "main_score": "llama_3_70b_instruct_generic_engine_template_arena_hard" +} diff --git a/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_generic_engine_template_arena_hard_with_shuffling.json b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_generic_engine_template_arena_hard_with_shuffling.json new file mode 100644 index 000000000..448f6cb95 --- /dev/null +++ b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_generic_engine_template_arena_hard_with_shuffling.json @@ -0,0 +1,11 @@ +{ + "__type__": "llm_as_judge", + "inference_model": { + "__type__": "generic_inference_engine", + "default": "engines.ibm_gen_ai.llama_3_70b_instruct" + }, + "template": "templates.response_assessment.pairwise_comparative_rating.arena_hard_with_shuffling", + "task": "pairwise_comparative_rating.single_turn", + "format": "formats.llama3_instruct", + "main_score": "llama_3_70b_instruct_generic_engine_template_arena_hard_with_shuffling" +} diff --git a/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_ibm_wml_template_arena_hard.json b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_ibm_wml_template_arena_hard.json new file mode 100644 index 000000000..e07e00f40 --- /dev/null +++ b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_ibm_wml_template_arena_hard.json @@ -0,0 +1,13 @@ +{ + "__type__": "llm_as_judge", + "inference_model": { + "__type__": "wml_inference_engine", + "model_name": "meta-llama/llama-3-70b-instruct", + "max_new_tokens": 2048, + "random_seed": 42 + }, + "template": "templates.response_assessment.pairwise_comparative_rating.arena_hard", + "task": "pairwise_comparative_rating.single_turn", + "format": "formats.llama3_instruct", + "main_score": "llama_3_70b_instruct_ibm_wml_template_arena_hard" +} diff --git a/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_ibm_wml_template_arena_hard_with_shuffling.json b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_ibm_wml_template_arena_hard_with_shuffling.json new file mode 100644 index 000000000..873afcbf9 --- /dev/null +++ b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_70b_instruct_ibm_wml_template_arena_hard_with_shuffling.json @@ -0,0 +1,13 @@ +{ + "__type__": "llm_as_judge", + "inference_model": { + "__type__": "wml_inference_engine", + "model_name": "meta-llama/llama-3-70b-instruct", + "max_new_tokens": 2048, + "random_seed": 42 + }, + "template": "templates.response_assessment.pairwise_comparative_rating.arena_hard_with_shuffling", + "task": "pairwise_comparative_rating.single_turn", + "format": "formats.llama3_instruct", + "main_score": "llama_3_70b_instruct_ibm_wml_template_arena_hard_with_shuffling" +} diff --git a/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_8b_instruct_generic_engine_template_arena_hard.json b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_8b_instruct_generic_engine_template_arena_hard.json new file mode 100644 index 000000000..b00903295 --- /dev/null +++ b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_8b_instruct_generic_engine_template_arena_hard.json @@ -0,0 +1,11 @@ +{ + "__type__": "llm_as_judge", + "inference_model": { + "__type__": "generic_inference_engine", + "default": "engines.ibm_gen_ai.llama_3_70b_instruct" + }, + "template": "templates.response_assessment.pairwise_comparative_rating.arena_hard", + "task": "pairwise_comparative_rating.single_turn", + "format": "formats.llama3_instruct", + "main_score": "llama_3_8b_instruct_generic_engine_template_arena_hard" +} diff --git a/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_8b_instruct_generic_engine_template_arena_hard_with_shuffling.json b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_8b_instruct_generic_engine_template_arena_hard_with_shuffling.json new file mode 100644 index 000000000..901e8e41f --- /dev/null +++ b/src/unitxt/catalog/metrics/llm_as_judge/pairwise_comparative_rating/llama_3_8b_instruct_generic_engine_template_arena_hard_with_shuffling.json @@ -0,0 +1,11 @@ +{ + "__type__": "llm_as_judge", + "inference_model": { + "__type__": "generic_inference_engine", + "default": "engines.ibm_gen_ai.llama_3_70b_instruct" + }, + "template": "templates.response_assessment.pairwise_comparative_rating.arena_hard_with_shuffling", + "task": "pairwise_comparative_rating.single_turn", + "format": "formats.llama3_instruct", + "main_score": "llama_3_8b_instruct_generic_engine_template_arena_hard_with_shuffling" +} diff --git a/src/unitxt/catalog/tasks/classification/multi_class/relation.json b/src/unitxt/catalog/tasks/classification/multi_class/relation.json index 24e9ffe3c..55d559132 100644 --- a/src/unitxt/catalog/tasks/classification/multi_class/relation.json +++ b/src/unitxt/catalog/tasks/classification/multi_class/relation.json @@ -1,7 +1,7 @@ { "__type__": "task", "input_fields": { - "text_a": "str", + "text_a": "Union[Text, Image, Audio, Table, Dialog]", "text_a_type": "str", "text_b": "str", "text_b_type": "str", diff --git a/src/unitxt/catalog/tasks/generation/from_pair.json b/src/unitxt/catalog/tasks/generation/from_pair.json index d3f59b95d..8474b002b 100644 --- a/src/unitxt/catalog/tasks/generation/from_pair.json +++ b/src/unitxt/catalog/tasks/generation/from_pair.json @@ -1,9 +1,9 @@ { "__type__": "task", "input_fields": { - "input_a": "str", + "input_a": "Union[Text, Image, Audio, Table, Dialog]", "type_of_input_a": "str", - "input_b": "str", + "input_b": "Union[Text, Image, Audio, Table, Dialog]", "type_of_input_b": "str", "type_of_output": "str" }, diff --git a/src/unitxt/catalog/tasks/qa/with_context/extractive.json b/src/unitxt/catalog/tasks/qa/with_context/extractive.json index 456cdf461..b5fd216ab 100644 --- a/src/unitxt/catalog/tasks/qa/with_context/extractive.json +++ b/src/unitxt/catalog/tasks/qa/with_context/extractive.json @@ -1,7 +1,7 @@ { "__type__": "task", "input_fields": { - "context": "Text", + "context": "Union[Text, Table, Dialog]", "context_type": "str", "question": "str" }, diff --git a/src/unitxt/catalog/templates/classification/multi_class/blue_bench.json b/src/unitxt/catalog/templates/classification/multi_class/blue_bench.json new file mode 100644 index 000000000..beb2f428e --- /dev/null +++ b/src/unitxt/catalog/templates/classification/multi_class/blue_bench.json @@ -0,0 +1,8 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.classification.multi_class.instruction", + "templates.classification.multi_class.instruct_question_selects", + "templates.classification.multi_class.instruct_question_select_i_think" + ] +} diff --git a/src/unitxt/catalog/templates/completion/multiple_choice/blue_bench.json b/src/unitxt/catalog/templates/completion/multiple_choice/blue_bench.json new file mode 100644 index 000000000..b393e80e3 --- /dev/null +++ b/src/unitxt/catalog/templates/completion/multiple_choice/blue_bench.json @@ -0,0 +1,8 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.completion.multiple_choice.simple", + "templates.completion.multiple_choice.enumerated", + "templates.completion.multiple_choice.standard" + ] +} diff --git a/src/unitxt/catalog/templates/qa/multiple_choice/open/blue_bench.json b/src/unitxt/catalog/templates/qa/multiple_choice/open/blue_bench.json new file mode 100644 index 000000000..d66873a22 --- /dev/null +++ b/src/unitxt/catalog/templates/qa/multiple_choice/open/blue_bench.json @@ -0,0 +1,8 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.qa.multiple_choice.open.helm", + "templates.qa.multiple_choice.open.lm_eval_harness", + "templates.qa.multiple_choice.open.mmlu" + ] +} diff --git a/src/unitxt/catalog/templates/qa/multiple_choice/with_context/blue_bench.json b/src/unitxt/catalog/templates/qa/multiple_choice/with_context/blue_bench.json new file mode 100644 index 000000000..db6d0f224 --- /dev/null +++ b/src/unitxt/catalog/templates/qa/multiple_choice/with_context/blue_bench.json @@ -0,0 +1,8 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.qa.multiple_choice.with_context.lm_eval_harness", + "templates.qa.multiple_choice.with_context.no_intro.helm", + "templates.qa.multiple_choice.with_context.no_intro.mmlu" + ] +} diff --git a/src/unitxt/catalog/templates/qa/multiple_choice/with_topic/blue_bench.json b/src/unitxt/catalog/templates/qa/multiple_choice/with_topic/blue_bench.json new file mode 100644 index 000000000..7e882e7f6 --- /dev/null +++ b/src/unitxt/catalog/templates/qa/multiple_choice/with_topic/blue_bench.json @@ -0,0 +1,8 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.qa.multiple_choice.with_topic.mmlu", + "templates.qa.multiple_choice.with_topic.helm", + "templates.qa.multiple_choice.with_topic.lm_eval_harness" + ] +} diff --git a/src/unitxt/catalog/templates/rag/response_generation/blue_bench.json b/src/unitxt/catalog/templates/rag/response_generation/blue_bench.json new file mode 100644 index 000000000..011e8c792 --- /dev/null +++ b/src/unitxt/catalog/templates/rag/response_generation/blue_bench.json @@ -0,0 +1,8 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.rag.response_generation.please_respond", + "templates.rag.response_generation.please_respond_chat", + "templates.rag.response_generation.answer_based_on_context" + ] +} diff --git a/src/unitxt/catalog/templates/summarization/abstractive/blue_bench.json b/src/unitxt/catalog/templates/summarization/abstractive/blue_bench.json new file mode 100644 index 000000000..2864f97aa --- /dev/null +++ b/src/unitxt/catalog/templates/summarization/abstractive/blue_bench.json @@ -0,0 +1,10 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.summarization.abstractive.instruct_full", + "templates.summarization.abstractive.instruct_one_sentence", + "templates.summarization.abstractive.instruct_passive", + "templates.summarization.abstractive.instruct_write_succinct", + "templates.summarization.abstractive.instruct_tldr" + ] +} diff --git a/src/unitxt/catalog/templates/translation/directed/blue_bench.json b/src/unitxt/catalog/templates/translation/directed/blue_bench.json new file mode 100644 index 000000000..df9958090 --- /dev/null +++ b/src/unitxt/catalog/templates/translation/directed/blue_bench.json @@ -0,0 +1,8 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.translation.directed.simple", + "templates.translation.directed.formal", + "templates.translation.directed.casual" + ] +} diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 79eae04a6..aeb3a0a96 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -209,12 +209,20 @@ class IbmGenAiInferenceEngineParams(Artifact): class GenericInferenceEngine(InferenceEngine): - default: str = None + default: Optional[str] = None def prepare_engine(self): - if "inference_engine" in os.environ: - engine_reference = os.environ["inference_engine"] + if "UNITXT_INFERENCE_ENGINE" in os.environ: + engine_reference = os.environ["UNITXT_INFERENCE_ENGINE"] else: + assert self.default is not None, ( + "GenericInferenceEngine could not be initialized" + '\nThis is since both the "UNITXT_INFERENCE_ENGINE" environmental variable is not set and no default engine was not inputted.' + "\nFor example, you can fix it by setting" + "\nexport UNITXT_INFERENCE_ENGINE=engines.ibm_gen_ai.llama_3_70b_instruct" + "\nto your ~/.bashrc" + "\nor passing a similar required engine in the default argument" + ) engine_reference = self.default self.engine, _ = fetch_artifact(engine_reference) @@ -559,6 +567,7 @@ class WMLInferenceEngine( parameters (WMLInferenceEngineParams, optional): Instance of WMLInferenceEngineParams which defines inference parameters and their values. Deprecated attribute, please pass respective parameters directly to the WMLInferenceEngine class instead. + concurrency_limit (int): number of requests that will be sent in parallel, max is 10. Examples: from .api import load_dataset @@ -592,7 +601,7 @@ class WMLInferenceEngine( } data_classification_policy = ["public", "proprietary"] parameters: Optional[WMLInferenceEngineParams] = None - + concurrency_limit: int = 10 _client: Any = InternalField(default=None, name="WML client") def verify(self): @@ -663,10 +672,19 @@ def _infer(self, dataset): api_client=self._client, ) - return model.generate_text( - prompt=dataset["source"], - params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False), - ) + # the class was previously used with a dataset that is a single instance + dataset = dataset if isinstance(dataset, list) else [dataset] + + result = [ + model.generate_text( + prompt=instance["source"], + params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False), + ) + for instance in dataset + ] + + # the class was previously used with a dataset that is a single instance + return result[0] if not isinstance(dataset, list) else result class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin): diff --git a/src/unitxt/llm_as_judge.py b/src/unitxt/llm_as_judge.py index 2447c56c4..ef7d31e7a 100644 --- a/src/unitxt/llm_as_judge.py +++ b/src/unitxt/llm_as_judge.py @@ -181,9 +181,17 @@ def compute( results = [] for instance in outputs: if self.task == "pairwise_comparative_rating.single_turn": - is_model_b_the_baseline = ( - instance["task_data"]["model_b"] == "baseline_model" + import json + + # seems like the task data sometimes comes as a string, not a dict + # this fixes it + task_data = ( + json.loads(instance["task_data"]) + if isinstance(instance["task_data"], str) + else instance["task_data"] ) + + is_model_b_the_baseline = task_data["model_b"] == "baseline_model" if is_model_b_the_baseline: model_a_preference_score = instance["prediction"] else: diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 7edca62e5..bf4448fc6 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -421,7 +421,7 @@ def statistic(arr, axis, score_name=score_name): full_score_name = ci_score_prefix + score_name result[f"{full_score_name}_ci_low"] = ci.low result[f"{full_score_name}_ci_high"] = ci.high - if score_name == self.main_score: + if score_name == self.score_prefix + self.main_score: result["score_ci_low"] = ci.low result["score_ci_high"] = ci.high return result @@ -1183,7 +1183,11 @@ def compute_instance_scores( return instances def get_group_scores( - self, instances: List[dict], score_names: List[str], group_aggregation_func + self, + instances: List[dict], + score_names: List[str], + group_aggregation_func, + prepend_score_prefix: bool = True, ): """Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group. @@ -1193,6 +1197,8 @@ def get_group_scores( group_aggregation_func: Callable aggregation function accepting a list of numeric scores; or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value. callable function returns a single score for the group + prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False + if down the stream such a prepending is expected. Returns: List of dicts, each corresponding to a group of instances (defined by 'group_id'), @@ -1222,7 +1228,9 @@ def get_group_scores( ) for score_name in score_names: group_to_instance_scores[group_key][score_name][subgroup_type].append( - instance["score"]["instance"][score_name] + instance["score"]["instance"][ + (self.score_prefix if prepend_score_prefix else "") + score_name + ] ) # if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores @@ -1230,7 +1238,8 @@ def get_group_scores( { "score": { "instance": { - score_name: group_aggregation_func( + (self.score_prefix if prepend_score_prefix else "") + + score_name: group_aggregation_func( score_dict if uses_subgroups else score_dict[default_subgroup_name] @@ -1268,7 +1277,7 @@ def aggregation_function( group_aggregation_func=group_aggregation_func, ): group_scores = self.get_group_scores( - instances, [field_name], group_aggregation_func + instances, [field_name], group_aggregation_func, False ) return nan_mean( [group["score"]["instance"][field_name] for group in group_scores] diff --git a/src/unitxt/serializers.py b/src/unitxt/serializers.py index d73590c89..ab3a01290 100644 --- a/src/unitxt/serializers.py +++ b/src/unitxt/serializers.py @@ -1,11 +1,12 @@ import csv import io from abc import abstractmethod -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union +from .dataclass import AbstractField, Field from .operators import InstanceFieldOperator -from .type_utils import isoftype -from .types import Dialog, Image, Number, Table, Text +from .type_utils import isoftype, to_type_string +from .types import Dialog, Image, Number, Table class Serializer(InstanceFieldOperator): @@ -22,6 +23,17 @@ def serialize(self, value: Any, instance: Dict[str, Any]) -> str: return str(value) +class SingleTypeSerializer(InstanceFieldOperator): + serialized_type: object = AbstractField() + + def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str: + if not isoftype(value, self.serialized_type): + raise ValueError( + f"SingleTypeSerializer for type {self.serialized_type} should get this type. got {to_type_string(value)}" + ) + return self.serialize(value, instance) + + class DefaultListSerializer(Serializer): def serialize(self, value: Any, instance: Dict[str, Any]) -> str: if isinstance(value, list): @@ -29,13 +41,24 @@ def serialize(self, value: Any, instance: Dict[str, Any]) -> str: return str(value) -class DialogSerializer(Serializer): +class ListSerializer(SingleTypeSerializer): + serialized_type = list + + def serialize(self, value: Any, instance: Dict[str, Any]) -> str: + return ", ".join(str(item) for item in value) + + +class DialogSerializer(SingleTypeSerializer): + serialized_type = Dialog + def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str: # Convert the Dialog into a string representation, typically combining roles and content return "\n".join(f"{turn['role']}: {turn['content']}" for turn in value) -class NumberSerializer(Serializer): +class NumberSerializer(SingleTypeSerializer): + serialized_type = Number + def serialize(self, value: Number, instance: Dict[str, Any]) -> str: # Check if the value is an integer or a float if isinstance(value, int): @@ -47,6 +70,7 @@ def serialize(self, value: Number, instance: Dict[str, Any]) -> str: class NumberQuantizingSerializer(NumberSerializer): + serialized_type = Number quantum: Union[float, int] = 0.1 def serialize(self, value: Number, instance: Dict[str, Any]) -> str: @@ -58,7 +82,9 @@ def serialize(self, value: Number, instance: Dict[str, Any]) -> str: raise ValueError("Unsupported type for NumberSerializer") -class TableSerializer(Serializer): +class TableSerializer(SingleTypeSerializer): + serialized_type = Table + def serialize(self, value: Table, instance: Dict[str, Any]) -> str: output = io.StringIO() writer = csv.writer(output, lineterminator="\n") @@ -71,7 +97,9 @@ def serialize(self, value: Table, instance: Dict[str, Any]) -> str: return output.getvalue().strip() -class ImageSerializer(Serializer): +class ImageSerializer(SingleTypeSerializer): + serialized_type = Image + def serialize(self, value: Image, instance: Dict[str, Any]) -> str: if "media" not in instance: instance["media"] = {} @@ -83,31 +111,32 @@ def serialize(self, value: Image, instance: Dict[str, Any]) -> str: return value["image"] -class DynamicSerializer(Serializer): - image: Serializer = ImageSerializer() - number: Serializer = DefaultSerializer() - table: Serializer = TableSerializer() - dialog: Serializer = DialogSerializer() - text: Serializer = DefaultSerializer() - list: Serializer = DefaultSerializer() - - def serialize(self, value: Any, instance: Dict[str, Any]) -> Any: - if isoftype(value, Image): - return self.image.serialize(value, instance) +class MultiTypeSerializer(Serializer): + serializers: List[SingleTypeSerializer] = Field( + default_factory=lambda: [ + ImageSerializer(), + TableSerializer(), + DialogSerializer(), + ] + ) - if isoftype(value, Table): - return self.table.serialize(value, instance) + def verify(self): + super().verify() + self._verify_serializers(self.serializers) - if isoftype(value, Dialog) and len(value) > 0: - return self.dialog.serialize(value, instance) + def _verify_serializers(self, serializers): + if not isoftype(serializers, List[SingleTypeSerializer]): + raise ValueError( + "MultiTypeSerializer requires the list of serializers to be List[SingleTypeSerializer]." + ) - if isoftype(value, Text): - return self.text.serialize(value, instance) + def add_serializers(self, serializers: List[SingleTypeSerializer]): + self._verify_serializers(serializers) + self.serializers = serializers + self.serializers - if isoftype(value, Number): - return self.number.serialize(value, instance) - - if isinstance(value, list): - return self.list.serialize(value, instance) + def serialize(self, value: Any, instance: Dict[str, Any]) -> Any: + for serializer in self.serializers: + if isoftype(value, serializer.serialized_type): + return serializer.serialize(value, instance) return str(value) diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index 20204380b..2dad7fd2b 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -15,6 +15,7 @@ from .operators import Set, StreamRefiner from .recipe import Recipe from .schema import Finalize +from .serializers import SingleTypeSerializer from .settings_utils import get_constants from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit from .stream import MultiStream @@ -38,6 +39,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator): template: Union[Template, List[Template], TemplatesList] = None system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt) format: Format = Field(default_factory=SystemFormat) + serializer: Union[SingleTypeSerializer, List[SingleTypeSerializer]] = None # Additional parameters template_card_index: int = NonPositionalField(default=None) @@ -146,6 +148,11 @@ def verify(self): else: self.verify_template(self.template) + if self.serializer is not None: + if not isinstance(self.serializer, list): + self.serializer = [self.serializer] + self.template.serializer.add_serializers(self.serializer) + def prepare_refiners(self): self.train_refiner.max_instances = self.max_train_instances self.train_refiner.apply_to_streams = ["train"] diff --git a/src/unitxt/struct_data_operators.py b/src/unitxt/struct_data_operators.py index 4bcff6717..e111ea878 100644 --- a/src/unitxt/struct_data_operators.py +++ b/src/unitxt/struct_data_operators.py @@ -29,15 +29,62 @@ from .dict_utils import dict_get from .operators import FieldOperator, InstanceOperator +from .random_utils import new_random_generator +from .serializers import TableSerializer +from .types import Table from .utils import deepcopy -class SerializeTable(ABC, FieldOperator): +def shuffle_columns(table: Table, seed=0) -> Table: + # extract header & rows from the dictionary + header = table.get("header", []) + rows = table.get("rows", []) + # shuffle the indices first + indices = list(range(len(header))) + random_generator = new_random_generator({"table": table, "seed": seed}) + random_generator.shuffle(indices) + + # shuffle the header & rows based on that indices + shuffled_header = [header[i] for i in indices] + shuffled_rows = [[row[i] for i in indices] for row in rows] + + table["header"] = shuffled_header + table["rows"] = shuffled_rows + + return table + + +def shuffle_rows(table: Table, seed=0) -> Table: + # extract header & rows from the dictionary + rows = table.get("rows", []) + # shuffle rows + random_generator = new_random_generator({"table": table, "seed": seed}) + random_generator.shuffle(rows) + table["rows"] = rows + + return table + + +class SerializeTable(ABC, TableSerializer): """TableSerializer converts a given table into a flat sequence with special symbols. Output format varies depending on the chosen serializer. This abstract class defines structure of a typical table serializer that any concrete implementation should follow. """ + seed: int = 0 + shuffle_rows: bool = False + shuffle_columns: bool = False + + def serialize(self, value: Table, instance: Dict[str, Any]) -> str: + value = deepcopy(value) + if self.shuffle_columns: + value = shuffle_columns(table=value, seed=self.seed) + + if self.shuffle_rows: + value = shuffle_rows(table=value, seed=self.seed) + + return self.serialize_table(value) + # main method to serialize a table @abstractmethod def serialize_table(self, table_content: Dict) -> str: @@ -60,10 +107,6 @@ class SerializeTableAsIndexedRowMajor(SerializeTable): Format: col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ... """ - def process_value(self, table: Any) -> Any: - table_input = deepcopy(table) - return self.serialize_table(table_content=table_input) - # main method that processes a table # table_content must be in the presribed input format def serialize_table(self, table_content: Dict) -> str: @@ -111,10 +154,6 @@ class SerializeTableAsMarkdown(SerializeTable): ... """ - def process_value(self, table: Any) -> Any: - table_input = deepcopy(table) - return self.serialize_table(table_content=table_input) - # main method that serializes a table. # table_content must be in the presribed input format. def serialize_table(self, table_content: Dict) -> str: @@ -159,10 +198,6 @@ class SerializeTableAsDFLoader(SerializeTable): index=[0,1,2]) """ - def process_value(self, table: Any) -> Any: - table_input = deepcopy(table) - return self.serialize_table(table_content=table_input) - # main method that serializes a table. # table_content must be in the presribed input format. def serialize_table(self, table_content: Dict) -> str: @@ -199,10 +234,6 @@ class SerializeTableAsJson(SerializeTable): } """ - def process_value(self, table: Any) -> Any: - table_input = deepcopy(table) - return self.serialize_table(table_content=table_input) - # main method that serializes a table. # table_content must be in the presribed input format. def serialize_table(self, table_content: Dict) -> str: @@ -493,20 +524,7 @@ class ShuffleTableRows(FieldOperator): def process_value(self, table: Any) -> Any: table_input = deepcopy(table) - return self.shuffle_rows(table_content=table_input) - - # shuffles table rows randomly - def shuffle_rows(self, table_content: Dict) -> str: - # extract header & rows from the dictionary - header = table_content.get("header", []) - rows = table_content.get("rows", []) - assert header and rows, "Incorrect input table format" - - # shuffle rows - random.shuffle(rows) - table_content["rows"] = rows - - return table_content + return shuffle_rows(table_input) class ShuffleTableColumns(FieldOperator): @@ -527,27 +545,7 @@ class ShuffleTableColumns(FieldOperator): def process_value(self, table: Any) -> Any: table_input = deepcopy(table) - return self.shuffle_columns(table_content=table_input) - - # shuffles table columns randomly - def shuffle_columns(self, table_content: Dict) -> str: - # extract header & rows from the dictionary - header = table_content.get("header", []) - rows = table_content.get("rows", []) - assert header and rows, "Incorrect input table format" - - # shuffle the indices first - indices = list(range(len(header))) - random.shuffle(indices) # - - # shuffle the header & rows based on that indices - shuffled_header = [header[i] for i in indices] - shuffled_rows = [[row[i] for i in indices] for row in rows] - - table_content["header"] = shuffled_header - table_content["rows"] = shuffled_rows - - return table_content + return shuffle_columns(table_input) class LoadJson(FieldOperator): diff --git a/src/unitxt/templates.py b/src/unitxt/templates.py index 4fc557aab..e68023077 100644 --- a/src/unitxt/templates.py +++ b/src/unitxt/templates.py @@ -11,10 +11,13 @@ from .operator import InstanceOperator from .random_utils import new_random_generator from .serializers import ( - DefaultListSerializer, - DynamicSerializer, + DialogSerializer, + ImageSerializer, + ListSerializer, + MultiTypeSerializer, NumberQuantizingSerializer, Serializer, + TableSerializer, ) from .settings_utils import get_constants from .type_utils import isoftype @@ -53,7 +56,14 @@ class Template(InstanceOperator): target_prefix: str = NonPositionalField(default="") title_fields: List[str] = NonPositionalField(default_factory=list) serializer: Serializer = NonPositionalField( - default_factory=lambda: DynamicSerializer(list=DefaultListSerializer()) + default_factory=lambda: MultiTypeSerializer( + serializers=[ + ImageSerializer(), + TableSerializer(), + DialogSerializer(), + ListSerializer(), + ] + ) ) def input_fields_to_instruction_and_target_prefix(self, input_fields): @@ -702,14 +712,16 @@ def reference_fields_to_target_and_references( class OutputQuantizingTemplate(InputOutputTemplate): - serializer: DynamicSerializer = NonPositionalField( - default_factory=DynamicSerializer + serializer: MultiTypeSerializer = NonPositionalField( + default_factory=MultiTypeSerializer ) quantum: Union[float, int] = 0.1 def prepare(self): super().prepare() - self.serializer.number = NumberQuantizingSerializer(quantum=self.quantum) + self.serializer.add_serializers( + [NumberQuantizingSerializer(quantum=self.quantum)] + ) class MultiLabelTemplate(InputOutputTemplate): @@ -737,7 +749,7 @@ def preprocess_reference_fields( class MultiReferenceTemplate(InputOutputTemplate): references_field: str = "references" random_reference: bool = False - serializer: Serializer = NonPositionalField(default_factory=DynamicSerializer) + serializer: Serializer = NonPositionalField(default_factory=MultiTypeSerializer) def serialize( self, data: Dict[str, Any], instance: Dict[str, Any] diff --git a/src/unitxt/version.py b/src/unitxt/version.py index 3be18c317..667df30e2 100644 --- a/src/unitxt/version.py +++ b/src/unitxt/version.py @@ -1 +1 @@ -version = "1.12.4" +version = "1.13.0" diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index cb1b44bf3..b8b261d7b 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -1062,17 +1062,19 @@ def test_grouped_instance_metrics(self): 0.08060156608173413, ] for metric, target in zip(accuracy_metrics, global_targets): - outputs = apply_metric( - metric=metric, - predictions=GROUPED_INSTANCE_PREDICTIONS, - references=GROUPED_INSTANCE_REFERENCES, - task_data=GROUPED_INSTANCE_ADDL_INPUTS, - ) - self.assertAlmostEqual( - target, - outputs[0]["score"]["global"]["score"], - msg=f"metric {metric.__class__.__name__} output {outputs[0]['score']['global']['score_name']} does not equal the expected value {target}", - ) + for score_prefix in ["my_", ""]: + metric.score_prefix = score_prefix + outputs = apply_metric( + metric=metric, + predictions=GROUPED_INSTANCE_PREDICTIONS, + references=GROUPED_INSTANCE_REFERENCES, + task_data=GROUPED_INSTANCE_ADDL_INPUTS, + ) + self.assertAlmostEqual( + target, + outputs[0]["score"]["global"]["score"], + msg=f"metric {metric.__class__.__name__} output {outputs[0]['score']['global']['score_name']} does not equal the expected value {target}", + ) def test_grouped_instance_metric_errors(self): """Test certain value and assertion error raises for grouped instance metrics (with group_mean reduction).""" @@ -1457,24 +1459,26 @@ def test_grouped_instance_metric_confidence_interval(self): ) # pass global dict because there are additional fields other than the main score - self._test_grouped_instance_confidence_interval( - metric=GroupMeanTokenOverlap(), - expected_global_result={ - "group_mean_recall": 0.525, - "group_mean_f1": 0.5083333333333333, - "score": 0.5083333333333333, - "score_name": "group_mean_f1", - "group_mean_precision": 0.5, - "group_mean_recall_ci_low": 0.25, - "group_mean_recall_ci_high": 0.7083333333333334, - "group_mean_f1_ci_low": 0.22302503471948287, - "group_mean_f1_ci_high": 0.6805555555555555, - "score_ci_low": 0.22302503471948287, - "score_ci_high": 0.6805555555555555, - "group_mean_precision_ci_low": 0.2095091529536007, - "group_mean_precision_ci_high": 0.6666666666666666, - }, - ) + for score_prefix in ["my_", ""]: + self._test_grouped_instance_confidence_interval( + metric=GroupMeanTokenOverlap(), + expected_global_result={ + f"group_mean_{score_prefix}recall": 0.525, + f"group_mean_{score_prefix}f1": 0.5083333333333333, + "score": 0.5083333333333333, + "score_name": f"group_mean_{score_prefix}f1", + f"group_mean_{score_prefix}precision": 0.5, + f"group_mean_{score_prefix}recall_ci_low": 0.25, + f"group_mean_{score_prefix}recall_ci_high": 0.7083333333333334, + f"group_mean_{score_prefix}f1_ci_low": 0.22302503471948287, + f"group_mean_{score_prefix}f1_ci_high": 0.6805555555555555, + "score_ci_low": 0.22302503471948287, + "score_ci_high": 0.6805555555555555, + f"group_mean_{score_prefix}precision_ci_low": 0.2095091529536007, + f"group_mean_{score_prefix}precision_ci_high": 0.6666666666666666, + }, + input_score_prefixes=[score_prefix], + ) def _test_grouped_instance_confidence_interval( self, @@ -1482,32 +1486,41 @@ def _test_grouped_instance_confidence_interval( expected_ci_low=0.0, expected_ci_high=1.0, expected_global_result=None, + input_score_prefixes=None, ): """Test the calculation of confidence intervals for a given metric with group_mean reduction.""" - outputs = apply_metric( - metric=metric, - predictions=GROUPED_INSTANCE_PREDICTIONS, - references=GROUPED_INSTANCE_REFERENCES, - task_data=GROUPED_INSTANCE_ADDL_INPUTS, - ) - # get first element of reduction_map values - reduction_params = next(iter(metric.reduction_map.values())) - prefix = "fixed_group" if reduction_params["agg_func"][2] else "group" - group_score_name = "_".join( - [ - prefix, - metric.reduction_map["group_mean"]["agg_func"][0], - metric.main_score, - ] - ) + input_expected_global_result_is_none = expected_global_result is None + # to remember between score_prefixes - if expected_global_result is None: - expected_global_result = { - f"{group_score_name}_ci_low": expected_ci_low, - f"{group_score_name}_ci_high": expected_ci_high, - "score_ci_low": expected_ci_low, - "score_ci_high": expected_ci_high, - } + for score_prefix in ( + ["my_", ""] if input_score_prefixes is None else input_score_prefixes + ): + metric.score_prefix = score_prefix + outputs = apply_metric( + metric=metric, + predictions=GROUPED_INSTANCE_PREDICTIONS, + references=GROUPED_INSTANCE_REFERENCES, + task_data=GROUPED_INSTANCE_ADDL_INPUTS, + ) + # get first element of reduction_map values + reduction_params = next(iter(metric.reduction_map.values())) + prefix = "fixed_group" if reduction_params["agg_func"][2] else "group" + group_score_name = "_".join( + [ + prefix, + metric.reduction_map["group_mean"]["agg_func"][0], + score_prefix, + metric.main_score, + ] + ).replace("__", "_") # for the case of empty score_prefix + + if input_expected_global_result_is_none: + expected_global_result = { + f"{group_score_name}_ci_low": expected_ci_low, + f"{group_score_name}_ci_high": expected_ci_high, + "score_ci_low": expected_ci_low, + "score_ci_high": expected_ci_high, + } global_result = outputs[0]["score"]["global"].copy() logger.info(global_result) diff --git a/tests/library/test_recipe.py b/tests/library/test_recipe.py index 841d87424..871013a4d 100644 --- a/tests/library/test_recipe.py +++ b/tests/library/test_recipe.py @@ -2,13 +2,19 @@ import copy import json import re +from typing import Any, Dict from unitxt import dataset_file from unitxt.artifact import fetch_artifact +from unitxt.card import TaskCard from unitxt.formats import SystemFormat +from unitxt.loaders import LoadFromDictionary +from unitxt.serializers import SingleTypeSerializer, TableSerializer from unitxt.standard import StandardRecipe, StandardRecipeWithIndexes +from unitxt.task import Task from unitxt.templates import InputOutputTemplate, TemplatesList from unitxt.text_utils import print_dict +from unitxt.types import Table from tests.utils import UnitxtTestCase @@ -724,3 +730,57 @@ def test_standard_recipe_with_a_missing_sampler(self): str(e.exception), "Unexpected None value for card.sampler. To use num_demos > 0, please set a sampler on the TaskCard.", ) + + def test_set_serializer_from_recipe(self): + instances = [ + { + "table": { + "header": ["col1", "col2"], + "rows": [["val1", "val2"], ["val3"], ["val4"]], + }, + "answer": "2", + }, + ] + + class MyTableSerializer(SingleTypeSerializer): + serialized_type = Table + + def serialize(self, value: Table, instance: Dict[str, Any]) -> str: + return str(value) + + task = Task( + input_fields={"table": Table}, + reference_fields={"answer": str}, + prediction_type=str, + metrics=["metrics.accuracy"], + ) + + template = InputOutputTemplate( + input_format="Solve: {table}\nAnswer: ", + output_format="{answer}", + postprocessors=[], + ) + + card = TaskCard( + loader=LoadFromDictionary(data={"train": instances}), + preprocess_steps=[], + task=task, + ) + + recipe = StandardRecipe( + card=card, + template=template, + serializer=TableSerializer(), + ) + result = next(iter(recipe()["train"]))["source"] + target = "Solve: col1,col2\nval1,val2\nval3\nval4\nAnswer: \n" + self.assertEqual(result, target) + + recipe = StandardRecipe( + card=card, + template=template, + serializer=MyTableSerializer(), + ) + result = next(iter(recipe()["train"]))["source"] + target = "Solve: {'header': ['col1', 'col2'], 'rows': [['val1', 'val2'], ['val3'], ['val4']]}\nAnswer: \n" + self.assertEqual(result, target) diff --git a/tests/library/test_serializers.py b/tests/library/test_serializers.py index 601f6fbda..2098d4711 100644 --- a/tests/library/test_serializers.py +++ b/tests/library/test_serializers.py @@ -1,7 +1,7 @@ from unitxt.serializers import ( DefaultSerializer, DialogSerializer, - DynamicSerializer, + MultiTypeSerializer, NumberQuantizingSerializer, NumberSerializer, TableSerializer, @@ -17,9 +17,9 @@ def setUp(self): self.dialog_serializer = DialogSerializer() self.number_serializer = NumberSerializer() self.table_serializer = TableSerializer() - self.custom_serializer = DynamicSerializer() - self.custom_serializer_with_number = DynamicSerializer( - number=NumberSerializer() + self.custom_serializer = MultiTypeSerializer() + self.custom_serializer_with_number = MultiTypeSerializer( + serializers=[NumberSerializer()] ) self.number_quantizing_serializer = NumberQuantizingSerializer(quantum=0.2) diff --git a/tests/library/test_struct_data_operators.py b/tests/library/test_struct_data_operators.py index 0b307a510..c7295c52a 100644 --- a/tests/library/test_struct_data_operators.py +++ b/tests/library/test_struct_data_operators.py @@ -58,6 +58,40 @@ def test_serializetable_markdown(self): tester=self, ) + def test_serializetable_markdown_with_shuffle(self): + inputs = [ + { + "table": { + "header": ["name", "age"], + "rows": [["Alex", "26"], ["Raj", "34"], ["Donald", "39"]], + } + } + ] + + serialized_str = "|age|name|\n|---|---|\n|39|Donald|\n|34|Raj|\n|26|Alex|" + + targets = [ + { + "table": { + "header": ["name", "age"], + "rows": [["Alex", "26"], ["Raj", "34"], ["Donald", "39"]], + }, + "serialized_table": serialized_str, + } + ] + + check_operator( + operator=SerializeTableAsMarkdown( + field_to_field={"table": "serialized_table"}, + shuffle_columns=True, + shuffle_rows=True, + seed=1, + ), + inputs=inputs, + targets=targets, + tester=self, + ) + def test_serializetable_indexedrowmajor(self): inputs = [ { @@ -397,19 +431,11 @@ def test_shuffle_table_rows(self): { "table": { "header": ["name", "age"], - "rows": [ - ["Donald", 39], - ["Raj", 34], - ["Alex", 21], - ], + "rows": [["Raj", 34], ["Alex", 21], ["Donald", 39]], } } ] - import random - - random.seed(123) - check_operator( operator=ShuffleTableRows(field="table"), inputs=inputs, diff --git a/utils/prepare_all_artifacts.py b/utils/prepare_all_artifacts.py index 5b20abad6..efb561a75 100644 --- a/utils/prepare_all_artifacts.py +++ b/utils/prepare_all_artifacts.py @@ -28,137 +28,158 @@ def import_module_from_file(file_path): return module -def prepare_artifacts_for_prepare_files(prepare_files): - failed_prepare_files = [] - prepare_exceptions = [] - for i, file in enumerate(prepare_files): - logger.info("*" * 100) - logger.info(f"* {i}/{len(prepare_files)}: {file}") - logger.info("*") - try: - import_module_from_file(file) - except Exception as e: - logger.info(f"Failed to prepare: {file}") - failed_prepare_files.append(file) - prepare_exceptions.append(e) - - return failed_prepare_files, prepare_exceptions - +# flake8: noqa: C901 +def main(): + catalog_dir = constants.catalog_dir + catalog_back_dir = catalog_dir + "_back" -def prepare_all_catalog_artifacts(catalog_dir): os.environ["UNITXT_USE_ONLY_LOCAL_CATALOGS"] = "True" os.environ["UNITXT_TEST_CARD_DISABLE"] = "True" os.environ["UNITXT_TEST_METRIC_DISABLE"] = "True" + os.environ["UNITXT_ALLOW_UNVERIFIED_CODE"] = "True" os.environ["UNITXT_SKIP_ARTIFACTS_PREPARE_AND_VERIFY"] = "True" logger.info("*" * 100) logger.info("*" * 100) - logger.info("* DELETING OLD FM_EVAL CATALOG *** ") - logger.info("deleting all files from 'src/unitxt/catalog'") - shutil.rmtree(catalog_dir, ignore_errors=True) + logger.info( + "Copying all files from 'src/unitxt/catalog' to a backup 'src/unitxt/catalog_back'" + ) + shutil.rmtree(catalog_back_dir, ignore_errors=True) + shutil.copytree(catalog_dir, catalog_back_dir) + + logger.critical("Starting to reprepare the catalog...") prepare_dir = os.path.join(Path(catalog_dir).parent.parent.parent, "prepare") prepare_files = sorted(glob.glob(f"{prepare_dir}/**/*.py", recursive=True)) - continue_preparing = True - iteration = 0 - while continue_preparing: - iteration += 1 - amount_of_prepare_files_before_iteration = len(prepare_files) - logger.info( - f"Iteration {iteration}: Preparing {amount_of_prepare_files_before_iteration} files" - ) - prepare_files, prepare_exceptions = prepare_artifacts_for_prepare_files( - prepare_files - ) - if ( - len(prepare_files) == 0 - or len(prepare_files) == amount_of_prepare_files_before_iteration - or iteration > 100 - ): - continue_preparing = False - logger.info( - f"Done preparing files. Failed to prepare {len(prepare_files)} files:" + failing_prepare_files = [] + prepare_files_generating_entries_not_in_the_catalog = [] + prepare_files_generating_entries_of_different_content_from_what_is_in_the_catalog = [] + current_catalog_files = glob.glob(f"{catalog_dir}/**/*.json", recursive=True) + initial_time = os.path.getmtime(catalog_dir) + for current_catalog_file in current_catalog_files: + if os.path.getmtime(current_catalog_file) > initial_time: + initial_time = os.path.getmtime(current_catalog_file) + # initial_time is the most recent modification time of any catalog file + next_border_time = initial_time + for i, prepare_file in enumerate(prepare_files): + logger.info("*" * 100) + logger.info(f"* {i}/{len(prepare_files)}: {prepare_file}") + logger.info("*") + border_time = next_border_time + try: + import_module_from_file(prepare_file) + current_catalog_files = glob.glob( + f"{catalog_dir}/**/*.json", recursive=True ) - for file, exception in zip(prepare_files, prepare_exceptions): - logger.info(f"Failed to prepare {file}. Exception: {exception}") + new_times = [] # modification times of catalog files changed by prepare_file + for current_catalog_file in current_catalog_files: + if ( + os.path.getmtime(current_catalog_file) > border_time + ): # current_catalog_file was just generated by prepare_file + new_times.append(os.path.getmtime(current_catalog_file)) + if not os.path.exists( + current_catalog_file.replace(catalog_dir, catalog_back_dir) + ): + # prepare_file generates a catalog file that is not a member of branch's original catalog + prepare_files_generating_entries_not_in_the_catalog.append( + prepare_file + ) + # return branch's catalog to its original state: + os.remove(current_catalog_file) + elif not filecmp.cmp( + current_catalog_file, + current_catalog_file.replace(catalog_dir, catalog_back_dir), + shallow=False, + ): + # prepare_file generates a catalog file that is different from the existing branch's catalog file of same name + prepare_files_generating_entries_of_different_content_from_what_is_in_the_catalog.append( + prepare_file + ) + # restore current_catalog_file from backup catalog. + shutil.copy( + current_catalog_file.replace(catalog_dir, catalog_back_dir), + current_catalog_file, + ) + # modification time of current_catalog_file is now - the time of copying + new_times.append(os.path.getmtime(current_catalog_file)) + + if new_times: + # several prepare files are all commented out, waiting for a fix + next_border_time = max(new_times) + except Exception as e: + logger.info(f"Failed to run prepare file: {prepare_file}") + failing_prepare_files.append((prepare_file, e)) -def compare_dirs(old, new): - dir_cmp = filecmp.dircmp(old, new) - diffs = [] - if ( - dir_cmp.diff_files - or dir_cmp.left_only - or dir_cmp.right_only - or dir_cmp.funny_files - ): - if dir_cmp.left_only: - diffs.extend( - [ - {"file": os.path.join(new, file), "diff": "old only"} - for file in dir_cmp.left_only - ] - ) - if dir_cmp.right_only: - diffs.extend( - [ - {"file": os.path.join(new, file), "diff": "new only"} - for file in dir_cmp.right_only - ] - ) - if dir_cmp.diff_files: - diffs.extend( - [ - {"file": os.path.join(new, file), "diff": "diff"} - for file in dir_cmp.diff_files - ] - ) - if dir_cmp.funny_files: - diffs.extend( - [ - {"file": os.path.join(new, file), "diff": "failed"} - for file in dir_cmp.funny_files - ] + # report errors discovered thus far + if failing_prepare_files: + logger.critical( + "Execution of the following prepare files failed for the following causes:" + ) + for prepare_file, e in failing_prepare_files: + logger.critical( + f"prepare file: '{prepare_file}' failed, throwing exception: '{e}'" ) - # Recursively compare subdirectories - for sub_dir, _ in dir_cmp.subdirs.items(): - diffs.extend( - compare_dirs(os.path.join(old, sub_dir), os.path.join(new, sub_dir)) + if prepare_files_generating_entries_not_in_the_catalog: + logger.critical( + "The following prepare files generated catalog files that are not included in the catalog. To fix: add the products of these prepare files to the catalog." + ) + prepare_files_generating_entries_not_in_the_catalog = sorted( + set(prepare_files_generating_entries_not_in_the_catalog) ) + for prepare_file in prepare_files_generating_entries_not_in_the_catalog: + logger.critical(f"{prepare_file}") - return diffs + if prepare_files_generating_entries_of_different_content_from_what_is_in_the_catalog: + prepare_files_generating_entries_of_different_content_from_what_is_in_the_catalog = sorted( + set( + prepare_files_generating_entries_of_different_content_from_what_is_in_the_catalog + ) + ) + logger.critical( + "The following prepare files generated catalog files of different contents from what is included in the (original branch's) catalog. To fix: update the branch's catalog files by the products of these prepare files." + ) + for prepare_file in prepare_files_generating_entries_of_different_content_from_what_is_in_the_catalog: + logger.critical(f"{prepare_file}") + # see if the branch's catalog contains any file that none of the branch's prepare file generates: + catalog_files_not_generated_by_any_prepare_file = [] + current_catalog_files = glob.glob(f"{catalog_dir}/**/*.json", recursive=True) + for current_catalog_file in current_catalog_files: + if ( + os.path.getmtime(current_catalog_file) > initial_time + ): # current_catalog_file was touched by a prepare file + continue + catalog_files_not_generated_by_any_prepare_file.append(current_catalog_file) + + if catalog_files_not_generated_by_any_prepare_file: + logger.critical( + "The following branch's catalog files are not generated by any of the branch's prepare files. To fix: remove them from the branch's catalog." + ) + for catalog_file in catalog_files_not_generated_by_any_prepare_file: + logger.critical(f"{catalog_file}") -def filter_known_diffs(diffs): - return [ - diff - for diff in diffs - if "news_category_classification_headline" - not in diff["file"] # in order to create we need Kaggle credentials - and "tablerow_classify" not in diff["file"] - ] # in order to create we need Kaggle credentials + # finally, restore branch's catalog, including modification times + shutil.rmtree(catalog_dir, ignore_errors=True) + shutil.copytree(catalog_back_dir, catalog_dir) + shutil.rmtree(catalog_back_dir, ignore_errors=True) + if failing_prepare_files: + raise RuntimeError( + "Checking consistency of branch's catalog against the total production of the branch's prepare files, we run each prepare file in turn, given the branch's catalog (which is needed as input by many of the prepare files). Some of the prepare files failed running. See details in the logs." + ) -def main(): - catalog_dir = constants.catalog_dir - catalog_back_dir = catalog_dir + "_back" - logger.info("move old catalog:") - try: - shutil.rmtree(catalog_back_dir) - except: - pass - shutil.move(catalog_dir, catalog_back_dir) - logger.critical("Starting reprepare catalog...") - prepare_all_catalog_artifacts(catalog_dir) - logger.critical("Comparing generated and old catalog...") - diffs = compare_dirs(new=catalog_dir, old=catalog_back_dir) - diffs = filter_known_diffs(diffs) - if diffs: - logger.critical("***** Directories has differences ******") - diffs.sort(key=lambda d: d["file"]) - for diff in diffs: - logger.critical(diff) - raise RuntimeError("Directories has differences") - logger.critical("Done. Catalog is consistent with prepare files") + if ( + catalog_files_not_generated_by_any_prepare_file + or prepare_files_generating_entries_not_in_the_catalog + or prepare_files_generating_entries_of_different_content_from_what_is_in_the_catalog + ): + raise RuntimeError( + "Branch's catalog is different from the total production of branch's prepare files. See details in the logs." + ) + + logger.critical( + "Done. Catalog is consistent with the total production of the prepare files." + ) if __name__ == "__main__":