Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow imports list for executequery and filterbyquery and rename to ExecuteExpression and FilterByExpression #542

Merged
merged 7 commits into from
Feb 5, 2024
6 changes: 3 additions & 3 deletions prepare/cards/multidoc2dial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from src.unitxt.blocks import LoadHF, TaskCard
from src.unitxt.catalog import add_to_catalog
from src.unitxt.operators import ExecuteQuery, ListFieldValues, RenameFields
from src.unitxt.operators import ExecuteExpression, ListFieldValues, RenameFields
from src.unitxt.test_utils.card import test_card

card_abstractive = TaskCard(
Expand All @@ -11,7 +11,7 @@
use_query=True,
),
ListFieldValues(fields=["utterance"], to_field="answer"),
ExecuteQuery(query="question.split('[SEP]')[0]", to_field="question"),
ExecuteExpression(expression="question.split('[SEP]')[0]", to_field="question"),
],
task="tasks.qa.contextual.abstractive",
templates="templates.qa.contextual.all",
Expand All @@ -25,7 +25,7 @@
use_query=True,
),
ListFieldValues(fields=["relevant_context"], to_field="answer"),
ExecuteQuery(query="question.split('[SEP]')[0]", to_field="question"),
ExecuteExpression(expression="question.split('[SEP]')[0]", to_field="question"),
],
task="tasks.qa.contextual.extractive",
templates="templates.qa.contextual.all",
Expand Down
4 changes: 2 additions & 2 deletions src/unitxt/catalog/cards/multidoc2dial/abstractive.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
"to_field": "answer"
},
{
"type": "execute_query",
"query": "question.split('[SEP]')[0]",
"type": "execute_expression",
"expression": "question.split('[SEP]')[0]",
"to_field": "question"
}
],
Expand Down
4 changes: 2 additions & 2 deletions src/unitxt/catalog/cards/multidoc2dial/extractive.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
"to_field": "answer"
},
{
"type": "execute_query",
"query": "question.split('[SEP]')[0]",
"type": "execute_expression",
"expression": "question.split('[SEP]')[0]",
"to_field": "question"
}
],
Expand Down
60 changes: 43 additions & 17 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from collections import Counter
from copy import deepcopy
from dataclasses import field
from importlib import import_module
from itertools import zip_longest
from random import Random
from typing import (
Expand All @@ -58,7 +59,7 @@
import requests

from .artifact import Artifact, fetch_artifact
from .dataclass import NonPositionalField
from .dataclass import NonPositionalField, OptionalField
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
from .operator import (
MultiStream,
Expand Down Expand Up @@ -1094,7 +1095,7 @@ def get_artifact(cls, artifact_identifier: str) -> Artifact:
return cls.cache[artifact_identifier]


class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
class ApplyOperatorsField(StreamInstanceOperator):
"""Applies value operators to each instance in a stream based on specified fields.

Args:
Expand Down Expand Up @@ -1215,30 +1216,55 @@ def _is_required(self, instance: dict) -> bool:
return True


class FilterByQuery(SingleStreamOperator):
class ComputeExpressionMixin(Artifact):
"""Computes an expression expressed over fields of an instance.

Args:
expression (str): the expression, in terms of names of fields of an instance
imports_list (List[str]): list of names of imports needed for the evaluation of the expression
"""

expression: str
imports_list: List[str] = OptionalField(default_factory=list)

def prepare(self):
# can not do the imports here, because object does not pickle with imports
self.globs = {}
self.to_import = True

def compute_expression(self, instance: dict) -> Any:
if self.to_import:
for module_name in self.imports_list:
self.globs[module_name] = import_module(module_name)
self.to_import = False

return eval(self.expression, self.globs, instance)


class FilterByExpression(SingleStreamOperator, ComputeExpressionMixin):
"""Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.

Raises an error if a field participating in the specified condition is missing from the instance

Args:
query (str): a condition over fields of the instance, to be processed by python's eval()
expression (str): a condition over fields of the instance, to be processed by python's eval()
imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.

Examples:
FilterByQuery(query = "a > 4") will yield only instances where "a">4
FilterByQuery(query = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
FilterByQuery(query = "a in [4, 8]") will yield only instances where "a" is 4 or 8
FilterByQuery(query = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8
FilterByExpression(expression = "a > 4") will yield only instances where "a">4
FilterByExpression(expression = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
FilterByExpression(expression = "a in [4, 8]") will yield only instances where "a" is 4 or 8
FilterByExpression(expression = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8

"""

query: str
error_on_filtered_all: bool = True

def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
yielded = False
for instance in stream:
if eval(self.query, None, instance):
if self.compute_expression(instance):
yielded = True
yield instance

Expand All @@ -1248,33 +1274,33 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
)


class ExecuteQuery(StreamInstanceOperator):
"""Compute an expression (query), expressed as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
class ExecuteExpression(StreamInstanceOperator, ComputeExpressionMixin):
"""Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.

Raises an error if a field mentioned in the query is missing from the instance.

Args:
query (str): an expression to be evaluated over the fields of the instance
expression (str): an expression to be evaluated over the fields of the instance
to_field (str): the field where the result is to be stored into
imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')

Examples:
When instance {"a": 2, "b": 3} is process-ed by operator
ExecuteQuery(query="a+b", to_field = "c")
ExecuteExpression(expression="a+b", to_field = "c")
the result is {"a": 2, "b": 3, "c": 5}

When instance {"a": "hello", "b": "world"} is process-ed by operator
ExecuteQuery(query = "a+' '+b", to_field = "c")
ExecuteExpression(expression = "a+' '+b", to_field = "c")
the result is {"a": "hello", "b": "world", "c": "hello world"}

"""

query: str
to_field: str

def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
instance[self.to_field] = eval(self.query, None, instance)
instance[self.to_field] = self.compute_expression(instance)
return instance


Expand Down
11 changes: 11 additions & 0 deletions src/unitxt/test_utils/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def print_recipe_output(
return examples


# flake8: noqa: C901
def test_with_eval(
card,
debug=False,
Expand All @@ -150,12 +151,22 @@ def test_with_eval(
):
if type(card.templates) is TemplatesDict:
for template_card_index in card.templates.keys():
# restore the operators on the card, so that they are fresh for second invocation of
# StandardRecipe as they are for the first one
if card.preprocess_steps is not None:
for step in card.preprocess_steps:
step.prepare()
examples = load_examples_from_standard_recipe(
card, template_card_index=template_card_index, debug=debug, **kwargs
)
else:
num_templates = len(card.templates)
for template_card_index in range(0, num_templates):
# restore the operators on the card, so that they are fresh for second invocation of
# StandardRecipe as they are for the first one
if card.preprocess_steps is not None:
for step in card.preprocess_steps:
step.prepare()
examples = load_examples_from_standard_recipe(
card, template_card_index=template_card_index, debug=debug, **kwargs
)
Expand Down
59 changes: 40 additions & 19 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
DeterministicBalancer,
DivideAllFieldsBy,
EncodeLabels,
ExecuteQuery,
ExecuteExpression,
ExtractFieldValues,
ExtractMostCommonFieldValues,
FieldOperator,
FilterByCondition,
FilterByQuery,
FilterByExpression,
FlattenInstances,
FromIterables,
IndexOf,
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_filter_by_values_with_required_values(self):
tester=self,
)
check_operator(
operator=FilterByQuery(query="a == 1 and b == 3"),
operator=FilterByExpression(expression="a == 1 and b == 3"),
inputs=inputs,
targets=targets,
tester=self,
Expand All @@ -291,7 +291,7 @@ def test_filter_by_values_with_required_values(self):
tester=self,
)
check_operator_exception(
operator=FilterByQuery(query="c == 5"),
operator=FilterByExpression(expression="c == 5"),
inputs=inputs,
exception_text="name 'c' is not defined",
tester=self,
Expand All @@ -311,7 +311,7 @@ def test_filter_by_condition_ne(self):
tester=self,
)
check_operator(
operator=FilterByQuery(query="a != 1 and b != 2"),
operator=FilterByExpression(expression="a != 1 and b != 2"),
inputs=inputs,
targets=targets,
tester=self,
Expand All @@ -331,7 +331,7 @@ def test_filter_by_condition_gt(self):
tester=self,
)
check_operator(
operator=FilterByQuery(query="a>1"),
operator=FilterByExpression(expression="a>1"),
inputs=inputs,
targets=targets,
tester=self,
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_filter_by_condition_not_in(self):
tester=self,
)
check_operator(
operator=FilterByQuery(query="b not in [3, 4]"),
operator=FilterByExpression(expression="b not in [3, 4]"),
inputs=inputs,
targets=targets,
tester=self,
Expand All @@ -385,21 +385,21 @@ def test_filter_by_condition_not_in_multiple(self):
tester=self,
)
check_operator(
operator=FilterByQuery(
query="b not in [3, 4] and a not in [1]",
operator=FilterByExpression(
expression="b not in [3, 4] and a not in [1]",
error_on_filtered_all=False,
),
inputs=inputs,
targets=targets,
tester=self,
)
check_operator_exception(
operator=FilterByQuery(
query="b not in [3, 4] and a not in [1]",
operator=FilterByExpression(
expression="b not in [3, 4] and a not in [1]",
error_on_filtered_all=True,
),
inputs=inputs,
exception_text="FilterByQuery filtered out every instance in stream 'test'. If this is intended set error_on_filtered_all=False",
exception_text="FilterByExpression filtered out every instance in stream 'test'. If this is intended set error_on_filtered_all=False",
tester=self,
)

Expand All @@ -422,7 +422,7 @@ def test_filter_by_condition_in(self):
tester=self,
)
check_operator(
operator=FilterByQuery(query="b in [3, 4]"),
operator=FilterByExpression(expression="b in [3, 4]"),
inputs=inputs,
targets=targets,
tester=self,
Expand Down Expand Up @@ -452,7 +452,7 @@ def test_filter_by_condition_in(self):
)
with self.assertRaises(Exception) as ne:
check_operator(
operator=FilterByQuery(query="c in ['5']"),
operator=FilterByExpression(expression="c in ['5']"),
inputs=inputs,
targets=targets,
tester=self,
Expand All @@ -475,16 +475,16 @@ def test_filter_by_condition_error_when_the_entire_stream_is_filtered(self):
"FilterByCondition filtered out every instance in stream 'test'. If this is intended set error_on_filtered_all=False",
)

def test_execute_query(self):
def test_execute_expression(self):
inputs = [{"a": 2, "b": 3}]
operator = ExecuteQuery(query="a+b", to_field="c")
operator = ExecuteExpression(to_field="c", expression="a+b")
targets = [{"a": 2, "b": 3, "c": 5}]
check_operator(operator=operator, inputs=inputs, targets=targets, tester=self)
inputs = [{"a": "hello", "b": "world"}]
operator = ExecuteQuery(query="a+' '+b", to_field="c")
operator = ExecuteExpression(expression="a+' '+b", to_field="c")
targets = [{"a": "hello", "b": "world", "c": "hello world"}]
check_operator(operator=operator, inputs=inputs, targets=targets, tester=self)
operator = ExecuteQuery(query="f'{a} {b}'", to_field="c")
operator = ExecuteExpression(expression="f'{a} {b}'", to_field="c")
check_operator(operator=operator, inputs=inputs, targets=targets, tester=self)
with self.assertRaises(ValueError) as ve:
check_operator(
Expand All @@ -494,10 +494,31 @@ def test_execute_query(self):
tester=self,
)
self.assertEqual(
"Error processing instance '0' from stream 'test' in ExecuteQuery due to: name 'a' is not defined",
"Error processing instance '0' from stream 'test' in ExecuteExpression due to: name 'a' is not defined",
str(ve.exception),
)

inputs = [{"json_string": '{"A":"a_value", "B":"b_value"}'}]
operator = ExecuteExpression(
expression='json.loads(json_string)["A"]',
imports_list=["json"],
to_field="c",
)
self.assertEqual("a_value", operator.process(inputs[0])["c"])

pattern = "[0-9]+"
string = "Account Number - 12345, Amount - 586.32"
repl = "NN"
inputs = [{"pattern": pattern, "string": string, "repl": repl}]
operator = ExecuteExpression(
expression="re.sub(pattern, repl, string)",
imports_list=["re"],
to_field="c",
)
self.assertEqual(
"Account Number - NN, Amount - NN.NN", operator.process(inputs[0])["c"]
)

def test_intersect(self):
inputs = [
{"label": ["a", "b"]},
Expand Down
Loading