From 6ca67c58026a7480f237a972349aae60ab330cc3 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Wed, 31 Jan 2024 21:38:51 +0200 Subject: [PATCH 1/7] allow imports list for executequery and filterbyquery Signed-off-by: dafnapension --- src/unitxt/operators.py | 17 +++++++++++++++-- tests/test_operators.py | 17 +++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index cf1b7718f..a2e1eb880 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -41,6 +41,7 @@ from collections import Counter from copy import deepcopy from dataclasses import field +from importlib import import_module from itertools import zip_longest from random import Random from typing import ( @@ -1222,6 +1223,7 @@ class FilterByQuery(SingleStreamOperator): Args: query (str): a condition over fields of the instance, to be processed by python's eval() + imports_list (List(str)): names of imports needed for the eval of the query (e.g. 're', 'json') error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True. Examples: @@ -1233,12 +1235,17 @@ class FilterByQuery(SingleStreamOperator): """ query: str + imports_list: List[str] = [] error_on_filtered_all: bool = True + def prepare(self): + self.globals = {name: import_module(name) for name in self.imports_list} + super().prepare() + def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator: yielded = False for instance in stream: - if eval(self.query, None, instance): + if eval(self.query, self.globals, instance): yielded = True yield instance @@ -1256,6 +1263,7 @@ class ExecuteQuery(StreamInstanceOperator): Args: query (str): an expression to be evaluated over the fields of the instance to_field (str): the field where the result is to be stored into + imports_list (List(str)): names of imports needed for the eval of the query (e.g. 're', 'json') Examples: When instance {"a": 2, "b": 3} is process-ed by operator @@ -1270,11 +1278,16 @@ class ExecuteQuery(StreamInstanceOperator): query: str to_field: str + imports_list: List[str] = [] + + def prepare(self): + self.globals = {name: import_module(name) for name in self.imports_list} + super().prepare() def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: - instance[self.to_field] = eval(self.query, None, instance) + instance[self.to_field] = eval(self.query, self.globals, instance) return instance diff --git a/tests/test_operators.py b/tests/test_operators.py index 2b0513a0a..b1a36fa5e 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -498,6 +498,23 @@ def test_execute_query(self): str(ve.exception), ) + inputs = [{"json_string": '{"A":"a_value", "B":"b_value"}'}] + operator = ExecuteQuery( + query='json.loads(json_string)["A"]', imports_list=["json"], to_field="c" + ) + self.assertEqual("a_value", operator.process(inputs[0])["c"]) + + pattern = "[0-9]+" + string = "Account Number - 12345, Amount - 586.32" + repl = "NN" + inputs = [{"pattern": pattern, "string": string, "repl": repl}] + operator = ExecuteQuery( + query="re.sub(pattern, repl, string)", imports_list=["re"], to_field="c" + ) + self.assertEqual( + "Account Number - NN, Amount - NN.NN", operator.process(inputs[0])["c"] + ) + def test_intersect(self): inputs = [ {"label": ["a", "b"]}, From cc62ff89a24a6f65abddccd9a518bd1fa2fbb78b Mon Sep 17 00:00:00 2001 From: dafnapension Date: Thu, 1 Feb 2024 08:02:05 +0200 Subject: [PATCH 2/7] and now with Mixin, and rename Query to Expression Signed-off-by: dafnapension --- prepare/cards/multidoc2dial.py | 6 +- .../cards/multidoc2dial/abstractive.json | 4 +- .../cards/multidoc2dial/extractive.json | 4 +- src/unitxt/operators.py | 62 ++++++++++--------- tests/test_operators.py | 50 ++++++++------- 5 files changed, 68 insertions(+), 58 deletions(-) diff --git a/prepare/cards/multidoc2dial.py b/prepare/cards/multidoc2dial.py index b35690035..cbc7dedd8 100644 --- a/prepare/cards/multidoc2dial.py +++ b/prepare/cards/multidoc2dial.py @@ -1,6 +1,6 @@ from src.unitxt.blocks import LoadHF, TaskCard from src.unitxt.catalog import add_to_catalog -from src.unitxt.operators import ExecuteQuery, ListFieldValues, RenameFields +from src.unitxt.operators import AssignExpression, ListFieldValues, RenameFields from src.unitxt.test_utils.card import test_card card_abstractive = TaskCard( @@ -11,7 +11,7 @@ use_query=True, ), ListFieldValues(fields=["utterance"], to_field="answer"), - ExecuteQuery(query="question.split('[SEP]')[0]", to_field="question"), + AssignExpression(expression="question.split('[SEP]')[0]", to_field="question"), ], task="tasks.qa.contextual.abstractive", templates="templates.qa.contextual.all", @@ -25,7 +25,7 @@ use_query=True, ), ListFieldValues(fields=["relevant_context"], to_field="answer"), - ExecuteQuery(query="question.split('[SEP]')[0]", to_field="question"), + AssignExpression(expression="question.split('[SEP]')[0]", to_field="question"), ], task="tasks.qa.contextual.extractive", templates="templates.qa.contextual.all", diff --git a/src/unitxt/catalog/cards/multidoc2dial/abstractive.json b/src/unitxt/catalog/cards/multidoc2dial/abstractive.json index 5b53f4c91..2f875818d 100644 --- a/src/unitxt/catalog/cards/multidoc2dial/abstractive.json +++ b/src/unitxt/catalog/cards/multidoc2dial/abstractive.json @@ -20,8 +20,8 @@ "to_field": "answer" }, { - "type": "execute_query", - "query": "question.split('[SEP]')[0]", + "type": "assign_expression", + "expression": "question.split('[SEP]')[0]", "to_field": "question" } ], diff --git a/src/unitxt/catalog/cards/multidoc2dial/extractive.json b/src/unitxt/catalog/cards/multidoc2dial/extractive.json index e270e144f..befbfe3c8 100644 --- a/src/unitxt/catalog/cards/multidoc2dial/extractive.json +++ b/src/unitxt/catalog/cards/multidoc2dial/extractive.json @@ -20,8 +20,8 @@ "to_field": "answer" }, { - "type": "execute_query", - "query": "question.split('[SEP]')[0]", + "type": "assign_expression", + "expression": "question.split('[SEP]')[0]", "to_field": "question" } ], diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index a2e1eb880..114ade9fa 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -1095,7 +1095,7 @@ def get_artifact(cls, artifact_identifier: str) -> Artifact: return cls.cache[artifact_identifier] -class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin): +class ApplyOperatorsField(StreamInstanceOperator): """Applies value operators to each instance in a stream based on specified fields. Args: @@ -1216,36 +1216,48 @@ def _is_required(self, instance: dict) -> bool: return True -class FilterByQuery(SingleStreamOperator): +class ComputeExpressionMixin(Artifact): + """Computes an expression expressed over fields of an instance. + + Args: + expression (str): the expression, in terms of names of fields of an instance + imports_list (List[str]): list of names of imports needed for the evaluation of the expression + """ + + expression: str + imports_list: List[str] = [] + + def prepare(self): + self.globals = {name: import_module(name) for name in self.imports_list} + + def compute_expression(self, instance: dict) -> Any: + return eval(self.expression, self.globals, instance) + + +class FilterByExpression(SingleStreamOperator, ComputeExpressionMixin): """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated. Raises an error if a field participating in the specified condition is missing from the instance Args: - query (str): a condition over fields of the instance, to be processed by python's eval() - imports_list (List(str)): names of imports needed for the eval of the query (e.g. 're', 'json') + expression (str): a condition over fields of the instance, to be processed by python's eval() + imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json') error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True. Examples: - FilterByQuery(query = "a > 4") will yield only instances where "a">4 - FilterByQuery(query = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5 - FilterByQuery(query = "a in [4, 8]") will yield only instances where "a" is 4 or 8 - FilterByQuery(query = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8 + FilterByExpression(expression = "a > 4") will yield only instances where "a">4 + FilterByExpression(expression = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5 + FilterByExpression(expression = "a in [4, 8]") will yield only instances where "a" is 4 or 8 + FilterByExpression(expression = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8 """ - query: str - imports_list: List[str] = [] error_on_filtered_all: bool = True - def prepare(self): - self.globals = {name: import_module(name) for name in self.imports_list} - super().prepare() - def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator: yielded = False for instance in stream: - if eval(self.query, self.globals, instance): + if self.compute_expression(instance): yielded = True yield instance @@ -1255,39 +1267,33 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato ) -class ExecuteQuery(StreamInstanceOperator): - """Compute an expression (query), expressed as a string to be eval-uated, over the instance's fields, and store the result in field to_field. +class AssignExpression(StreamInstanceOperator, ComputeExpressionMixin): + """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field. Raises an error if a field mentioned in the query is missing from the instance. Args: - query (str): an expression to be evaluated over the fields of the instance + expression (str): an expression to be evaluated over the fields of the instance to_field (str): the field where the result is to be stored into - imports_list (List(str)): names of imports needed for the eval of the query (e.g. 're', 'json') + imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json') Examples: When instance {"a": 2, "b": 3} is process-ed by operator - ExecuteQuery(query="a+b", to_field = "c") + AssignExpression(expression="a+b", to_field = "c") the result is {"a": 2, "b": 3, "c": 5} When instance {"a": "hello", "b": "world"} is process-ed by operator - ExecuteQuery(query = "a+' '+b", to_field = "c") + AssignExpression(expression = "a+' '+b", to_field = "c") the result is {"a": "hello", "b": "world", "c": "hello world"} """ - query: str to_field: str - imports_list: List[str] = [] - - def prepare(self): - self.globals = {name: import_module(name) for name in self.imports_list} - super().prepare() def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: - instance[self.to_field] = eval(self.query, self.globals, instance) + instance[self.to_field] = self.compute_expression(instance) return instance diff --git a/tests/test_operators.py b/tests/test_operators.py index b1a36fa5e..63586cfe8 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -11,6 +11,7 @@ ApplyMetric, ApplyOperatorsField, ApplyStreamOperatorsField, + AssignExpression, Augmentor, AugmentPrefixSuffix, AugmentWhitespace, @@ -19,12 +20,11 @@ DeterministicBalancer, DivideAllFieldsBy, EncodeLabels, - ExecuteQuery, ExtractFieldValues, ExtractMostCommonFieldValues, FieldOperator, FilterByCondition, - FilterByQuery, + FilterByExpression, FlattenInstances, FromIterables, IndexOf, @@ -277,7 +277,7 @@ def test_filter_by_values_with_required_values(self): tester=self, ) check_operator( - operator=FilterByQuery(query="a == 1 and b == 3"), + operator=FilterByExpression(expression="a == 1 and b == 3"), inputs=inputs, targets=targets, tester=self, @@ -291,7 +291,7 @@ def test_filter_by_values_with_required_values(self): tester=self, ) check_operator_exception( - operator=FilterByQuery(query="c == 5"), + operator=FilterByExpression(expression="c == 5"), inputs=inputs, exception_text="name 'c' is not defined", tester=self, @@ -311,7 +311,7 @@ def test_filter_by_condition_ne(self): tester=self, ) check_operator( - operator=FilterByQuery(query="a != 1 and b != 2"), + operator=FilterByExpression(expression="a != 1 and b != 2"), inputs=inputs, targets=targets, tester=self, @@ -331,7 +331,7 @@ def test_filter_by_condition_gt(self): tester=self, ) check_operator( - operator=FilterByQuery(query="a>1"), + operator=FilterByExpression(expression="a>1"), inputs=inputs, targets=targets, tester=self, @@ -359,7 +359,7 @@ def test_filter_by_condition_not_in(self): tester=self, ) check_operator( - operator=FilterByQuery(query="b not in [3, 4]"), + operator=FilterByExpression(expression="b not in [3, 4]"), inputs=inputs, targets=targets, tester=self, @@ -385,8 +385,8 @@ def test_filter_by_condition_not_in_multiple(self): tester=self, ) check_operator( - operator=FilterByQuery( - query="b not in [3, 4] and a not in [1]", + operator=FilterByExpression( + expression="b not in [3, 4] and a not in [1]", error_on_filtered_all=False, ), inputs=inputs, @@ -394,12 +394,12 @@ def test_filter_by_condition_not_in_multiple(self): tester=self, ) check_operator_exception( - operator=FilterByQuery( - query="b not in [3, 4] and a not in [1]", + operator=FilterByExpression( + expression="b not in [3, 4] and a not in [1]", error_on_filtered_all=True, ), inputs=inputs, - exception_text="FilterByQuery filtered out every instance in stream 'test'. If this is intended set error_on_filtered_all=False", + exception_text="FilterByExpression filtered out every instance in stream 'test'. If this is intended set error_on_filtered_all=False", tester=self, ) @@ -422,7 +422,7 @@ def test_filter_by_condition_in(self): tester=self, ) check_operator( - operator=FilterByQuery(query="b in [3, 4]"), + operator=FilterByExpression(expression="b in [3, 4]"), inputs=inputs, targets=targets, tester=self, @@ -452,7 +452,7 @@ def test_filter_by_condition_in(self): ) with self.assertRaises(Exception) as ne: check_operator( - operator=FilterByQuery(query="c in ['5']"), + operator=FilterByExpression(expression="c in ['5']"), inputs=inputs, targets=targets, tester=self, @@ -475,16 +475,16 @@ def test_filter_by_condition_error_when_the_entire_stream_is_filtered(self): "FilterByCondition filtered out every instance in stream 'test'. If this is intended set error_on_filtered_all=False", ) - def test_execute_query(self): + def test_execute_expression(self): inputs = [{"a": 2, "b": 3}] - operator = ExecuteQuery(query="a+b", to_field="c") + operator = AssignExpression(to_field="c", expression="a+b") targets = [{"a": 2, "b": 3, "c": 5}] check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) inputs = [{"a": "hello", "b": "world"}] - operator = ExecuteQuery(query="a+' '+b", to_field="c") + operator = AssignExpression(expression="a+' '+b", to_field="c") targets = [{"a": "hello", "b": "world", "c": "hello world"}] check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) - operator = ExecuteQuery(query="f'{a} {b}'", to_field="c") + operator = AssignExpression(expression="f'{a} {b}'", to_field="c") check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) with self.assertRaises(ValueError) as ve: check_operator( @@ -494,13 +494,15 @@ def test_execute_query(self): tester=self, ) self.assertEqual( - "Error processing instance '0' from stream 'test' in ExecuteQuery due to: name 'a' is not defined", + "Error processing instance '0' from stream 'test' in AssignExpression due to: name 'a' is not defined", str(ve.exception), ) inputs = [{"json_string": '{"A":"a_value", "B":"b_value"}'}] - operator = ExecuteQuery( - query='json.loads(json_string)["A"]', imports_list=["json"], to_field="c" + operator = AssignExpression( + expression='json.loads(json_string)["A"]', + imports_list=["json"], + to_field="c", ) self.assertEqual("a_value", operator.process(inputs[0])["c"]) @@ -508,8 +510,10 @@ def test_execute_query(self): string = "Account Number - 12345, Amount - 586.32" repl = "NN" inputs = [{"pattern": pattern, "string": string, "repl": repl}] - operator = ExecuteQuery( - query="re.sub(pattern, repl, string)", imports_list=["re"], to_field="c" + operator = AssignExpression( + expression="re.sub(pattern, repl, string)", + imports_list=["re"], + to_field="c", ) self.assertEqual( "Account Number - NN, Amount - NN.NN", operator.process(inputs[0])["c"] From fc9dd30edd115129627c31e4440738b06e0866c1 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Thu, 1 Feb 2024 11:46:41 +0200 Subject: [PATCH 3/7] changed to OptionalField Signed-off-by: dafnapension --- src/unitxt/operators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index 114ade9fa..a52d780b0 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -59,7 +59,7 @@ import requests from .artifact import Artifact, fetch_artifact -from .dataclass import NonPositionalField +from .dataclass import NonPositionalField, OptionalField from .dict_utils import dict_delete, dict_get, dict_set, is_subpath from .operator import ( MultiStream, @@ -1225,7 +1225,7 @@ class ComputeExpressionMixin(Artifact): """ expression: str - imports_list: List[str] = [] + imports_list: List[str] = OptionalField(default_factory=list) def prepare(self): self.globals = {name: import_module(name) for name in self.imports_list} From 98b8b98086b623425bf13bbe55206d1d99defe2d Mon Sep 17 00:00:00 2001 From: dafnapension Date: Fri, 2 Feb 2024 18:38:09 +0200 Subject: [PATCH 4/7] get import_module out the init, as there, it prevents pickl-ing Signed-off-by: dafnapension --- src/unitxt/operators.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index a52d780b0..dd512b5ad 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -1228,10 +1228,17 @@ class ComputeExpressionMixin(Artifact): imports_list: List[str] = OptionalField(default_factory=list) def prepare(self): - self.globals = {name: import_module(name) for name in self.imports_list} + # can not do the imports here, because object does not pickle with imports + self.globs = {} + self.to_import = True def compute_expression(self, instance: dict) -> Any: - return eval(self.expression, self.globals, instance) + if self.to_import: + for module_name in self.imports_list: + self.globs[module_name] = import_module(module_name) + self.to_import = False + + return eval(self.expression, self.globs, instance) class FilterByExpression(SingleStreamOperator, ComputeExpressionMixin): From f78f7880ef5ef3ad8631ad29046e9b22a79d0ddf Mon Sep 17 00:00:00 2001 From: dafnapension Date: Sun, 4 Feb 2024 12:14:44 +0200 Subject: [PATCH 5/7] renamed AssignExpression -> ExecuteExpression Signed-off-by: dafnapension --- prepare/cards/multidoc2dial.py | 6 +++--- .../catalog/cards/multidoc2dial/abstractive.json | 2 +- .../catalog/cards/multidoc2dial/extractive.json | 2 +- src/unitxt/operators.py | 6 +++--- tests/test_operators.py | 14 +++++++------- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/prepare/cards/multidoc2dial.py b/prepare/cards/multidoc2dial.py index cbc7dedd8..75f516240 100644 --- a/prepare/cards/multidoc2dial.py +++ b/prepare/cards/multidoc2dial.py @@ -1,6 +1,6 @@ from src.unitxt.blocks import LoadHF, TaskCard from src.unitxt.catalog import add_to_catalog -from src.unitxt.operators import AssignExpression, ListFieldValues, RenameFields +from src.unitxt.operators import ExecuteExpression, ListFieldValues, RenameFields from src.unitxt.test_utils.card import test_card card_abstractive = TaskCard( @@ -11,7 +11,7 @@ use_query=True, ), ListFieldValues(fields=["utterance"], to_field="answer"), - AssignExpression(expression="question.split('[SEP]')[0]", to_field="question"), + ExecuteExpression(expression="question.split('[SEP]')[0]", to_field="question"), ], task="tasks.qa.contextual.abstractive", templates="templates.qa.contextual.all", @@ -25,7 +25,7 @@ use_query=True, ), ListFieldValues(fields=["relevant_context"], to_field="answer"), - AssignExpression(expression="question.split('[SEP]')[0]", to_field="question"), + ExecuteExpression(expression="question.split('[SEP]')[0]", to_field="question"), ], task="tasks.qa.contextual.extractive", templates="templates.qa.contextual.all", diff --git a/src/unitxt/catalog/cards/multidoc2dial/abstractive.json b/src/unitxt/catalog/cards/multidoc2dial/abstractive.json index 2f875818d..ccda88e78 100644 --- a/src/unitxt/catalog/cards/multidoc2dial/abstractive.json +++ b/src/unitxt/catalog/cards/multidoc2dial/abstractive.json @@ -20,7 +20,7 @@ "to_field": "answer" }, { - "type": "assign_expression", + "type": "execute_expression", "expression": "question.split('[SEP]')[0]", "to_field": "question" } diff --git a/src/unitxt/catalog/cards/multidoc2dial/extractive.json b/src/unitxt/catalog/cards/multidoc2dial/extractive.json index befbfe3c8..5920d57d5 100644 --- a/src/unitxt/catalog/cards/multidoc2dial/extractive.json +++ b/src/unitxt/catalog/cards/multidoc2dial/extractive.json @@ -20,7 +20,7 @@ "to_field": "answer" }, { - "type": "assign_expression", + "type": "execute_expression", "expression": "question.split('[SEP]')[0]", "to_field": "question" } diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index dd512b5ad..104ba8541 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -1274,7 +1274,7 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato ) -class AssignExpression(StreamInstanceOperator, ComputeExpressionMixin): +class ExecuteExpression(StreamInstanceOperator, ComputeExpressionMixin): """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field. Raises an error if a field mentioned in the query is missing from the instance. @@ -1286,11 +1286,11 @@ class AssignExpression(StreamInstanceOperator, ComputeExpressionMixin): Examples: When instance {"a": 2, "b": 3} is process-ed by operator - AssignExpression(expression="a+b", to_field = "c") + ExecuteExpression(expression="a+b", to_field = "c") the result is {"a": 2, "b": 3, "c": 5} When instance {"a": "hello", "b": "world"} is process-ed by operator - AssignExpression(expression = "a+' '+b", to_field = "c") + ExecuteExpression(expression = "a+' '+b", to_field = "c") the result is {"a": "hello", "b": "world", "c": "hello world"} """ diff --git a/tests/test_operators.py b/tests/test_operators.py index 63586cfe8..947b8c8a6 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -11,7 +11,6 @@ ApplyMetric, ApplyOperatorsField, ApplyStreamOperatorsField, - AssignExpression, Augmentor, AugmentPrefixSuffix, AugmentWhitespace, @@ -20,6 +19,7 @@ DeterministicBalancer, DivideAllFieldsBy, EncodeLabels, + ExecuteExpression, ExtractFieldValues, ExtractMostCommonFieldValues, FieldOperator, @@ -477,14 +477,14 @@ def test_filter_by_condition_error_when_the_entire_stream_is_filtered(self): def test_execute_expression(self): inputs = [{"a": 2, "b": 3}] - operator = AssignExpression(to_field="c", expression="a+b") + operator = ExecuteExpression(to_field="c", expression="a+b") targets = [{"a": 2, "b": 3, "c": 5}] check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) inputs = [{"a": "hello", "b": "world"}] - operator = AssignExpression(expression="a+' '+b", to_field="c") + operator = ExecuteExpression(expression="a+' '+b", to_field="c") targets = [{"a": "hello", "b": "world", "c": "hello world"}] check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) - operator = AssignExpression(expression="f'{a} {b}'", to_field="c") + operator = ExecuteExpression(expression="f'{a} {b}'", to_field="c") check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) with self.assertRaises(ValueError) as ve: check_operator( @@ -494,12 +494,12 @@ def test_execute_expression(self): tester=self, ) self.assertEqual( - "Error processing instance '0' from stream 'test' in AssignExpression due to: name 'a' is not defined", + "Error processing instance '0' from stream 'test' in ExecuteExpression due to: name 'a' is not defined", str(ve.exception), ) inputs = [{"json_string": '{"A":"a_value", "B":"b_value"}'}] - operator = AssignExpression( + operator = ExecuteExpression( expression='json.loads(json_string)["A"]', imports_list=["json"], to_field="c", @@ -510,7 +510,7 @@ def test_execute_expression(self): string = "Account Number - 12345, Amount - 586.32" repl = "NN" inputs = [{"pattern": pattern, "string": string, "repl": repl}] - operator = AssignExpression( + operator = ExecuteExpression( expression="re.sub(pattern, repl, string)", imports_list=["re"], to_field="c", From 9518248266b960570e417676046910c151b0935b Mon Sep 17 00:00:00 2001 From: dafnapension Date: Sun, 4 Feb 2024 17:34:17 +0200 Subject: [PATCH 6/7] refresh operators before the second invocation of StandardRecipe Signed-off-by: dafnapension --- src/unitxt/test_utils/card.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/unitxt/test_utils/card.py b/src/unitxt/test_utils/card.py index 09afcf5ce..a1ac6fb5d 100644 --- a/src/unitxt/test_utils/card.py +++ b/src/unitxt/test_utils/card.py @@ -62,6 +62,12 @@ def load_examples_from_standard_recipe(card, template_card_index, debug, **kwarg if "loader_limit" not in kwargs: kwargs["loader_limit"] = 200 + # restore the operators on the card, so that they are fresh for second invocation of + # StandardRecipe as they are for the first one + if card.preprocess_steps is not None: + for step in card.preprocess_steps: + step.prepare() + recipe = StandardRecipe( card=card, template_card_index=template_card_index, **kwargs ) From f389889d4572df01da623ad7d15d1f9fbf5ac841 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Sun, 4 Feb 2024 18:57:51 +0200 Subject: [PATCH 7/7] trying flake8 Signed-off-by: dafnapension --- src/unitxt/test_utils/card.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/unitxt/test_utils/card.py b/src/unitxt/test_utils/card.py index a1ac6fb5d..61b19c061 100644 --- a/src/unitxt/test_utils/card.py +++ b/src/unitxt/test_utils/card.py @@ -62,12 +62,6 @@ def load_examples_from_standard_recipe(card, template_card_index, debug, **kwarg if "loader_limit" not in kwargs: kwargs["loader_limit"] = 200 - # restore the operators on the card, so that they are fresh for second invocation of - # StandardRecipe as they are for the first one - if card.preprocess_steps is not None: - for step in card.preprocess_steps: - step.prepare() - recipe = StandardRecipe( card=card, template_card_index=template_card_index, **kwargs ) @@ -146,6 +140,7 @@ def print_recipe_output( return examples +# flake8: noqa: C901 def test_with_eval( card, debug=False, @@ -156,12 +151,22 @@ def test_with_eval( ): if type(card.templates) is TemplatesDict: for template_card_index in card.templates.keys(): + # restore the operators on the card, so that they are fresh for second invocation of + # StandardRecipe as they are for the first one + if card.preprocess_steps is not None: + for step in card.preprocess_steps: + step.prepare() examples = load_examples_from_standard_recipe( card, template_card_index=template_card_index, debug=debug, **kwargs ) else: num_templates = len(card.templates) for template_card_index in range(0, num_templates): + # restore the operators on the card, so that they are fresh for second invocation of + # StandardRecipe as they are for the first one + if card.preprocess_steps is not None: + for step in card.preprocess_steps: + step.prepare() examples = load_examples_from_standard_recipe( card, template_card_index=template_card_index, debug=debug, **kwargs )