Skip to content

Commit

Permalink
allow imports list for executequery and filterbyquery
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed Jan 31, 2024
1 parent e15f4ad commit 3c6c458
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from collections import Counter
from copy import deepcopy
from dataclasses import field
from importlib import import_module
from itertools import zip_longest
from random import Random
from typing import (
Expand Down Expand Up @@ -1222,6 +1223,7 @@ class FilterByQuery(SingleStreamOperator):
Args:
query (str): a condition over fields of the instance, to be processed by python's eval()
imports_list (List(str)): names of imports needed for the eval of the query (e.g. 're', 'json')
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
Examples:
Expand All @@ -1233,12 +1235,17 @@ class FilterByQuery(SingleStreamOperator):
"""

query: str
imports_list: List[str] = []
error_on_filtered_all: bool = True

def prepare(self):
self.globals = {name: import_module(name) for name in self.imports_list}
super().prepare()

def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
yielded = False
for instance in stream:
if eval(self.query, None, instance):
if eval(self.query, self.globals, instance):
yielded = True
yield instance

Expand All @@ -1256,6 +1263,7 @@ class ExecuteQuery(StreamInstanceOperator):
Args:
query (str): an expression to be evaluated over the fields of the instance
to_field (str): the field where the result is to be stored into
imports_list (List(str)): names of imports needed for the eval of the query (e.g. 're', 'json')
Examples:
When instance {"a": 2, "b": 3} is process-ed by operator
Expand All @@ -1270,11 +1278,16 @@ class ExecuteQuery(StreamInstanceOperator):

query: str
to_field: str
imports_list: List[str] = []

def prepare(self):
self.globals = {name: import_module(name) for name in self.imports_list}
super().prepare()

def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
instance[self.to_field] = eval(self.query, None, instance)
instance[self.to_field] = eval(self.query, self.globals, instance)
return instance


Expand Down
17 changes: 17 additions & 0 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,23 @@ def test_execute_query(self):
str(ve.exception),
)

inputs = [{"json_string": '{"A":"a_value", "B":"b_value"}'}]
operator = ExecuteQuery(
query='json.loads(json_string)["A"]', imports_list=["json"], to_field="c"
)
self.assertEqual("a_value", operator.process(inputs[0])["c"])

pattern = "[0-9]+"
string = "Account Number - 12345, Amount - 586.32"
repl = "NN"
inputs = [{"pattern": pattern, "string": string, "repl": repl}]
operator = ExecuteQuery(
query="re.sub(pattern, repl, string)", imports_list=["re"], to_field="c"
)
self.assertEqual(
"Account Number - NN, Amount - NN.NN", operator.process(inputs[0])["c"]
)

def test_intersect(self):
inputs = [
{"label": ["a", "b"]},
Expand Down

0 comments on commit 3c6c458

Please sign in to comment.