diff --git a/prepare/cards/multidoc2dial.py b/prepare/cards/multidoc2dial.py index b35690035..75f516240 100644 --- a/prepare/cards/multidoc2dial.py +++ b/prepare/cards/multidoc2dial.py @@ -1,6 +1,6 @@ from src.unitxt.blocks import LoadHF, TaskCard from src.unitxt.catalog import add_to_catalog -from src.unitxt.operators import ExecuteQuery, ListFieldValues, RenameFields +from src.unitxt.operators import ExecuteExpression, ListFieldValues, RenameFields from src.unitxt.test_utils.card import test_card card_abstractive = TaskCard( @@ -11,7 +11,7 @@ use_query=True, ), ListFieldValues(fields=["utterance"], to_field="answer"), - ExecuteQuery(query="question.split('[SEP]')[0]", to_field="question"), + ExecuteExpression(expression="question.split('[SEP]')[0]", to_field="question"), ], task="tasks.qa.contextual.abstractive", templates="templates.qa.contextual.all", @@ -25,7 +25,7 @@ use_query=True, ), ListFieldValues(fields=["relevant_context"], to_field="answer"), - ExecuteQuery(query="question.split('[SEP]')[0]", to_field="question"), + ExecuteExpression(expression="question.split('[SEP]')[0]", to_field="question"), ], task="tasks.qa.contextual.extractive", templates="templates.qa.contextual.all", diff --git a/src/unitxt/catalog/cards/multidoc2dial/abstractive.json b/src/unitxt/catalog/cards/multidoc2dial/abstractive.json index 5b53f4c91..ccda88e78 100644 --- a/src/unitxt/catalog/cards/multidoc2dial/abstractive.json +++ b/src/unitxt/catalog/cards/multidoc2dial/abstractive.json @@ -20,8 +20,8 @@ "to_field": "answer" }, { - "type": "execute_query", - "query": "question.split('[SEP]')[0]", + "type": "execute_expression", + "expression": "question.split('[SEP]')[0]", "to_field": "question" } ], diff --git a/src/unitxt/catalog/cards/multidoc2dial/extractive.json b/src/unitxt/catalog/cards/multidoc2dial/extractive.json index e270e144f..5920d57d5 100644 --- a/src/unitxt/catalog/cards/multidoc2dial/extractive.json +++ b/src/unitxt/catalog/cards/multidoc2dial/extractive.json @@ -20,8 +20,8 @@ "to_field": "answer" }, { - "type": "execute_query", - "query": "question.split('[SEP]')[0]", + "type": "execute_expression", + "expression": "question.split('[SEP]')[0]", "to_field": "question" } ], diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index cf1b7718f..104ba8541 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -41,6 +41,7 @@ from collections import Counter from copy import deepcopy from dataclasses import field +from importlib import import_module from itertools import zip_longest from random import Random from typing import ( @@ -58,7 +59,7 @@ import requests from .artifact import Artifact, fetch_artifact -from .dataclass import NonPositionalField +from .dataclass import NonPositionalField, OptionalField from .dict_utils import dict_delete, dict_get, dict_set, is_subpath from .operator import ( MultiStream, @@ -1094,7 +1095,7 @@ def get_artifact(cls, artifact_identifier: str) -> Artifact: return cls.cache[artifact_identifier] -class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin): +class ApplyOperatorsField(StreamInstanceOperator): """Applies value operators to each instance in a stream based on specified fields. Args: @@ -1215,30 +1216,55 @@ def _is_required(self, instance: dict) -> bool: return True -class FilterByQuery(SingleStreamOperator): +class ComputeExpressionMixin(Artifact): + """Computes an expression expressed over fields of an instance. + + Args: + expression (str): the expression, in terms of names of fields of an instance + imports_list (List[str]): list of names of imports needed for the evaluation of the expression + """ + + expression: str + imports_list: List[str] = OptionalField(default_factory=list) + + def prepare(self): + # can not do the imports here, because object does not pickle with imports + self.globs = {} + self.to_import = True + + def compute_expression(self, instance: dict) -> Any: + if self.to_import: + for module_name in self.imports_list: + self.globs[module_name] = import_module(module_name) + self.to_import = False + + return eval(self.expression, self.globs, instance) + + +class FilterByExpression(SingleStreamOperator, ComputeExpressionMixin): """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated. Raises an error if a field participating in the specified condition is missing from the instance Args: - query (str): a condition over fields of the instance, to be processed by python's eval() + expression (str): a condition over fields of the instance, to be processed by python's eval() + imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json') error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True. Examples: - FilterByQuery(query = "a > 4") will yield only instances where "a">4 - FilterByQuery(query = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5 - FilterByQuery(query = "a in [4, 8]") will yield only instances where "a" is 4 or 8 - FilterByQuery(query = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8 + FilterByExpression(expression = "a > 4") will yield only instances where "a">4 + FilterByExpression(expression = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5 + FilterByExpression(expression = "a in [4, 8]") will yield only instances where "a" is 4 or 8 + FilterByExpression(expression = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8 """ - query: str error_on_filtered_all: bool = True def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator: yielded = False for instance in stream: - if eval(self.query, None, instance): + if self.compute_expression(instance): yielded = True yield instance @@ -1248,33 +1274,33 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato ) -class ExecuteQuery(StreamInstanceOperator): - """Compute an expression (query), expressed as a string to be eval-uated, over the instance's fields, and store the result in field to_field. +class ExecuteExpression(StreamInstanceOperator, ComputeExpressionMixin): + """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field. Raises an error if a field mentioned in the query is missing from the instance. Args: - query (str): an expression to be evaluated over the fields of the instance + expression (str): an expression to be evaluated over the fields of the instance to_field (str): the field where the result is to be stored into + imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json') Examples: When instance {"a": 2, "b": 3} is process-ed by operator - ExecuteQuery(query="a+b", to_field = "c") + ExecuteExpression(expression="a+b", to_field = "c") the result is {"a": 2, "b": 3, "c": 5} When instance {"a": "hello", "b": "world"} is process-ed by operator - ExecuteQuery(query = "a+' '+b", to_field = "c") + ExecuteExpression(expression = "a+' '+b", to_field = "c") the result is {"a": "hello", "b": "world", "c": "hello world"} """ - query: str to_field: str def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: - instance[self.to_field] = eval(self.query, None, instance) + instance[self.to_field] = self.compute_expression(instance) return instance diff --git a/src/unitxt/test_utils/card.py b/src/unitxt/test_utils/card.py index 09afcf5ce..61b19c061 100644 --- a/src/unitxt/test_utils/card.py +++ b/src/unitxt/test_utils/card.py @@ -140,6 +140,7 @@ def print_recipe_output( return examples +# flake8: noqa: C901 def test_with_eval( card, debug=False, @@ -150,12 +151,22 @@ def test_with_eval( ): if type(card.templates) is TemplatesDict: for template_card_index in card.templates.keys(): + # restore the operators on the card, so that they are fresh for second invocation of + # StandardRecipe as they are for the first one + if card.preprocess_steps is not None: + for step in card.preprocess_steps: + step.prepare() examples = load_examples_from_standard_recipe( card, template_card_index=template_card_index, debug=debug, **kwargs ) else: num_templates = len(card.templates) for template_card_index in range(0, num_templates): + # restore the operators on the card, so that they are fresh for second invocation of + # StandardRecipe as they are for the first one + if card.preprocess_steps is not None: + for step in card.preprocess_steps: + step.prepare() examples = load_examples_from_standard_recipe( card, template_card_index=template_card_index, debug=debug, **kwargs ) diff --git a/tests/test_operators.py b/tests/test_operators.py index 2b0513a0a..947b8c8a6 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -19,12 +19,12 @@ DeterministicBalancer, DivideAllFieldsBy, EncodeLabels, - ExecuteQuery, + ExecuteExpression, ExtractFieldValues, ExtractMostCommonFieldValues, FieldOperator, FilterByCondition, - FilterByQuery, + FilterByExpression, FlattenInstances, FromIterables, IndexOf, @@ -277,7 +277,7 @@ def test_filter_by_values_with_required_values(self): tester=self, ) check_operator( - operator=FilterByQuery(query="a == 1 and b == 3"), + operator=FilterByExpression(expression="a == 1 and b == 3"), inputs=inputs, targets=targets, tester=self, @@ -291,7 +291,7 @@ def test_filter_by_values_with_required_values(self): tester=self, ) check_operator_exception( - operator=FilterByQuery(query="c == 5"), + operator=FilterByExpression(expression="c == 5"), inputs=inputs, exception_text="name 'c' is not defined", tester=self, @@ -311,7 +311,7 @@ def test_filter_by_condition_ne(self): tester=self, ) check_operator( - operator=FilterByQuery(query="a != 1 and b != 2"), + operator=FilterByExpression(expression="a != 1 and b != 2"), inputs=inputs, targets=targets, tester=self, @@ -331,7 +331,7 @@ def test_filter_by_condition_gt(self): tester=self, ) check_operator( - operator=FilterByQuery(query="a>1"), + operator=FilterByExpression(expression="a>1"), inputs=inputs, targets=targets, tester=self, @@ -359,7 +359,7 @@ def test_filter_by_condition_not_in(self): tester=self, ) check_operator( - operator=FilterByQuery(query="b not in [3, 4]"), + operator=FilterByExpression(expression="b not in [3, 4]"), inputs=inputs, targets=targets, tester=self, @@ -385,8 +385,8 @@ def test_filter_by_condition_not_in_multiple(self): tester=self, ) check_operator( - operator=FilterByQuery( - query="b not in [3, 4] and a not in [1]", + operator=FilterByExpression( + expression="b not in [3, 4] and a not in [1]", error_on_filtered_all=False, ), inputs=inputs, @@ -394,12 +394,12 @@ def test_filter_by_condition_not_in_multiple(self): tester=self, ) check_operator_exception( - operator=FilterByQuery( - query="b not in [3, 4] and a not in [1]", + operator=FilterByExpression( + expression="b not in [3, 4] and a not in [1]", error_on_filtered_all=True, ), inputs=inputs, - exception_text="FilterByQuery filtered out every instance in stream 'test'. If this is intended set error_on_filtered_all=False", + exception_text="FilterByExpression filtered out every instance in stream 'test'. If this is intended set error_on_filtered_all=False", tester=self, ) @@ -422,7 +422,7 @@ def test_filter_by_condition_in(self): tester=self, ) check_operator( - operator=FilterByQuery(query="b in [3, 4]"), + operator=FilterByExpression(expression="b in [3, 4]"), inputs=inputs, targets=targets, tester=self, @@ -452,7 +452,7 @@ def test_filter_by_condition_in(self): ) with self.assertRaises(Exception) as ne: check_operator( - operator=FilterByQuery(query="c in ['5']"), + operator=FilterByExpression(expression="c in ['5']"), inputs=inputs, targets=targets, tester=self, @@ -475,16 +475,16 @@ def test_filter_by_condition_error_when_the_entire_stream_is_filtered(self): "FilterByCondition filtered out every instance in stream 'test'. If this is intended set error_on_filtered_all=False", ) - def test_execute_query(self): + def test_execute_expression(self): inputs = [{"a": 2, "b": 3}] - operator = ExecuteQuery(query="a+b", to_field="c") + operator = ExecuteExpression(to_field="c", expression="a+b") targets = [{"a": 2, "b": 3, "c": 5}] check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) inputs = [{"a": "hello", "b": "world"}] - operator = ExecuteQuery(query="a+' '+b", to_field="c") + operator = ExecuteExpression(expression="a+' '+b", to_field="c") targets = [{"a": "hello", "b": "world", "c": "hello world"}] check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) - operator = ExecuteQuery(query="f'{a} {b}'", to_field="c") + operator = ExecuteExpression(expression="f'{a} {b}'", to_field="c") check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) with self.assertRaises(ValueError) as ve: check_operator( @@ -494,10 +494,31 @@ def test_execute_query(self): tester=self, ) self.assertEqual( - "Error processing instance '0' from stream 'test' in ExecuteQuery due to: name 'a' is not defined", + "Error processing instance '0' from stream 'test' in ExecuteExpression due to: name 'a' is not defined", str(ve.exception), ) + inputs = [{"json_string": '{"A":"a_value", "B":"b_value"}'}] + operator = ExecuteExpression( + expression='json.loads(json_string)["A"]', + imports_list=["json"], + to_field="c", + ) + self.assertEqual("a_value", operator.process(inputs[0])["c"]) + + pattern = "[0-9]+" + string = "Account Number - 12345, Amount - 586.32" + repl = "NN" + inputs = [{"pattern": pattern, "string": string, "repl": repl}] + operator = ExecuteExpression( + expression="re.sub(pattern, repl, string)", + imports_list=["re"], + to_field="c", + ) + self.assertEqual( + "Account Number - NN, Amount - NN.NN", operator.process(inputs[0])["c"] + ) + def test_intersect(self): inputs = [ {"label": ["a", "b"]},