From 3c6c458a21cd43547f90fe6155a695eab7f727fb Mon Sep 17 00:00:00 2001 From: dafnapension Date: Wed, 31 Jan 2024 21:38:51 +0200 Subject: [PATCH] allow imports list for executequery and filterbyquery Signed-off-by: dafnapension --- src/unitxt/operators.py | 17 +++++++++++++++-- tests/test_operators.py | 17 +++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index cf1b7718f8..a2e1eb8809 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -41,6 +41,7 @@ from collections import Counter from copy import deepcopy from dataclasses import field +from importlib import import_module from itertools import zip_longest from random import Random from typing import ( @@ -1222,6 +1223,7 @@ class FilterByQuery(SingleStreamOperator): Args: query (str): a condition over fields of the instance, to be processed by python's eval() + imports_list (List(str)): names of imports needed for the eval of the query (e.g. 're', 'json') error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True. Examples: @@ -1233,12 +1235,17 @@ class FilterByQuery(SingleStreamOperator): """ query: str + imports_list: List[str] = [] error_on_filtered_all: bool = True + def prepare(self): + self.globals = {name: import_module(name) for name in self.imports_list} + super().prepare() + def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator: yielded = False for instance in stream: - if eval(self.query, None, instance): + if eval(self.query, self.globals, instance): yielded = True yield instance @@ -1256,6 +1263,7 @@ class ExecuteQuery(StreamInstanceOperator): Args: query (str): an expression to be evaluated over the fields of the instance to_field (str): the field where the result is to be stored into + imports_list (List(str)): names of imports needed for the eval of the query (e.g. 're', 'json') Examples: When instance {"a": 2, "b": 3} is process-ed by operator @@ -1270,11 +1278,16 @@ class ExecuteQuery(StreamInstanceOperator): query: str to_field: str + imports_list: List[str] = [] + + def prepare(self): + self.globals = {name: import_module(name) for name in self.imports_list} + super().prepare() def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: - instance[self.to_field] = eval(self.query, None, instance) + instance[self.to_field] = eval(self.query, self.globals, instance) return instance diff --git a/tests/test_operators.py b/tests/test_operators.py index bf24a0118d..b069704541 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -499,6 +499,23 @@ def test_execute_query(self): str(ve.exception), ) + inputs = [{"json_string": '{"A":"a_value", "B":"b_value"}'}] + operator = ExecuteQuery( + query='json.loads(json_string)["A"]', imports_list=["json"], to_field="c" + ) + self.assertEqual("a_value", operator.process(inputs[0])["c"]) + + pattern = "[0-9]+" + string = "Account Number - 12345, Amount - 586.32" + repl = "NN" + inputs = [{"pattern": pattern, "string": string, "repl": repl}] + operator = ExecuteQuery( + query="re.sub(pattern, repl, string)", imports_list=["re"], to_field="c" + ) + self.assertEqual( + "Account Number - NN, Amount - NN.NN", operator.process(inputs[0])["c"] + ) + def test_intersect(self): inputs = [ {"label": ["a", "b"]},