From 9550b32d6482fddbb88b9ecce962340a47b113bb Mon Sep 17 00:00:00 2001 From: dafnapension Date: Mon, 26 Feb 2024 13:02:57 +0200 Subject: [PATCH] changed query systax of dicts: from ':' to '=' and simplified some Signed-off-by: dafnapension --- src/unitxt/artifact.py | 3 - src/unitxt/parsing_utils.py | 246 ++++++++++++++++++---------- tests/library/test_artifact.py | 23 ++- tests/library/test_parsing_utils.py | 30 +--- tests/library/test_query.py | 6 +- 5 files changed, 185 insertions(+), 123 deletions(-) diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index cdba2e5b4c..fa93af4217 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -11,7 +11,6 @@ from .dataclass import AbstractField, Dataclass, Field, InternalField, fields from .logging_utils import get_logger from .parsing_utils import ( - parse_key_equals_value_string_to_dict, separate_inside_and_outside_square_brackets, ) from .settings_utils import get_settings @@ -313,8 +312,6 @@ def get_artifactory_name_and_args( name: str, artifactories: Optional[List[Artifactory]] = None ): name, args = separate_inside_and_outside_square_brackets(name) - if args is not None: - args = parse_key_equals_value_string_to_dict(args) if artifactories is None: artifactories = list(Artifactories()) diff --git a/src/unitxt/parsing_utils.py b/src/unitxt/parsing_utils.py index d0709a8134..753386322b 100644 --- a/src/unitxt/parsing_utils.py +++ b/src/unitxt/parsing_utils.py @@ -1,50 +1,48 @@ -def separate_inside_and_outside_square_brackets(s): - """Separates the content inside and outside the first level of square brackets in a string. - - Allows text before the first bracket and nested brackets within the first level. Raises a ValueError for: - - Text following the closing bracket of the first bracket pair - - Unmatched brackets - - Multiple bracket pairs at the same level - - :param s: The input string to be parsed. - :return: A tuple (outside, inside) where 'outside' is the content outside the first level of square brackets, - and 'inside' is the content inside the first level of square brackets. If there are no brackets, - 'inside' will be None. - """ - start = s.find("[") - end = s.rfind("]") - - # Handle no brackets - if start == -1 and end == -1: - return s, None - - # Validate brackets - if start == -1 or end == -1 or start > end: - raise ValueError("Illegal structure: unmatched square brackets.") - - outside = s[:start] - inside = s[start + 1 : end] - after = s[end + 1 :] - - # Check for text after the closing bracket - if len(after.strip()) != 0: - raise ValueError( - "Illegal structure: text follows after the closing square bracket." - ) - - return outside, inside - - -# Formal definition of query: -# query -> assignment(=) (, assignment(=))* -# assignment(delimeter) -> name_value delimeter term -# term -> name_value | name_value[query] | [ term (, term)* ] | { assignment(:) (, assignment(:))* } +# An artifact is fetched from the catalog through the name by which it was added to that catalog. +# When first instantiated as an object (of class Artifact, or an extension thereof), values are set to the artifact's +# class arguments. +# When added to the catalog, the type (name of immediate class) along with the values of its arguments, are specified +# in the catalog record (json file) of that artifact, associated with the name given to that record upon adding to +# the catalog. +# When the artifact is later fetched from the catalog by that name, the values of its arguments are fetched as well, +# and used to instantiate the artifact into the object it was when added to the catalog. +# +# An addendum to this fetching process is a specification of alternative values to replace, upon fetching, (some of) +# the fetched values of the artifact's arguments, and to use the thus updated values for the instantiation following +# this fetching. +# +# The alternative arguments values, aka overwrites, are expressed as a string-ed series of key-value pairs, enclosed +# in square brackets, appended to the name of the artifact name. # -# a boolean parameter, return_dict, maintains a nested query: a query (potentially, indirectly) within another query, -# in the form of a string, so that the process of parsing cohere with the recursive process of fetching the artifact. +# Overall, the formal definition of a query string, by which an artifact is fetched and instantiated, is as follows: +# +# query -> name | name [overwrites] +# overwrites -> assignment (, assignment)* +# assignment -> name = term +# term -> [ term (, term)* ] | { assignment (, assignment)* } | name_value | query +# +# name_value starting at a given point in the query string is the longest substring of the query string, +# that starts at that point, and ends upon reaching the end of the query string, or one of these chars: [],:{}= +# spaces are allowed. +# name is a name_value that is not evaluated to an int, or float, or boolean. +# +# The following code processes a given query string, verifies that it conforms with the above format syntax, throwing +# exceptions otherwise, and returns a pair of: (a) artifact name in the catalog, and (b) a (potentially empty) +# dictionary whose keys are names of (some of) the class arguments of the artifact, associated with the alternative +# values to set to these arguments, upon the instantiation of the artifact as a response to this query. +# +# Note: the code does not verify that a name of an artifact's argument is indeed a name of an argument of that +# artifact. The instantiation process that follows the parsing will verify that. +# Also, if an alternative value of an argument is specified, in turn, as a query with overwrites, the conforming +# of that query with the above syntax is done when processing the major query of the artifact, but the parsing of the +# overwrites (in the argument's query) is delayed to the stage when the recursive instantiation of the major artifact +# reaches the instantiation of that argument. +# +# +from typing import Any, Tuple -def consume_name_val(instring: str) -> tuple: +def consume_name_val(instring: str) -> Tuple[Any, str]: name_val = "" for char in instring: if char in "[],:{}=": @@ -72,86 +70,104 @@ def consume_name_val(instring: str) -> tuple: return (name_val, instring) +def consume_name(instring: str) -> Tuple[Any, str]: + orig_instring = instring + (name_val, instring) = consume_name_val(instring) + if ( + name_val is None + or isinstance(name_val, (int, float, bool)) + or len(name_val) == 0 + ): + raise ValueError(f"malformed name at the beginning of: {orig_instring}") + return (name_val, instring) + + # flake8: noqa: C901 -def consume_term(instring: str, return_dict: bool) -> tuple: +def consume_term(instring: str) -> Tuple[Any, str]: orig_instring = instring if instring.startswith("["): toret = [] instring = instring[1:].strip() - (term, instring) = consume_term(instring, return_dict) + (term, instring) = consume_term(instring) toret.append(term) while instring.startswith(","): - (term, instring) = consume_term(instring[1:].strip(), return_dict) + (term, instring) = consume_term(instring[1:].strip()) toret.append(term) if not instring.startswith("]"): raise ValueError(f"malformed list in: {orig_instring}") instring = instring[1:].strip() - if not return_dict: - toret = orig_instring[: len(orig_instring) - len(instring)] return (toret, instring) if instring.startswith("{"): instring = instring[1:].strip() - (assignment, instring) = consume_assignment(instring, return_dict, ":") - toret = assignment - while instring.startswith(","): - (assignment, instring) = consume_assignment( - instring[1:].strip(), return_dict, ":" - ) - if return_dict: - toret = {**toret, **assignment} + (items, instring) = consume_overwrites(instring, "}") if not instring.startswith("}"): raise ValueError(f"malformed dict in: {orig_instring}") instring = instring[1:].strip() - if not return_dict: - toret = orig_instring[: len(orig_instring) - len(instring)] - return (toret, instring) + return (items, instring) (name_val, instring) = consume_name_val(instring) - if instring.startswith("["): - (quey, instring) = consume_query(instring[1:].strip(), False) + instring = instring.strip() + if not ( + name_val is None + or isinstance(name_val, (int, float, bool)) + or len(name_val) == 0 + ) and instring.startswith("["): + # term is a query with args + (overwrites, instring) = consume_overwrites(instring[1:].strip(), "]") + instring = instring.strip() if not instring.startswith("]"): - raise ValueError(f"malformed query in: {orig_instring}") + raise ValueError(f"malformed query as a term in: {orig_instring}") instring = instring[1:].strip() toret = orig_instring[: len(orig_instring) - len(instring)] + # argument's alternative value specified by query with overwrites. + # the parsing of that query is delayed, to be synchronizes with the recursive loading + # of the artifact from the catalog return (toret, instring) + return (name_val, instring) -def consume_assignment(instring: str, return_dict: bool, delimeter: str) -> tuple: +def consume_assignment(instring: str) -> Tuple[Any, str]: orig_instring = instring - (name_val, instring) = consume_name_val(instring) - if ( - name_val is None - or isinstance(name_val, (int, float, bool)) - or len(name_val) == 0 - ): - raise ValueError(f"malformed key in assignment that starts: {orig_instring}") - if not instring.startswith(delimeter): + (name, instring) = consume_name(instring) + + if not instring.startswith("="): raise ValueError(f"malformed assignment in: {orig_instring}") - (term, instring) = consume_term(instring[1:].strip(), return_dict) + (term, instring) = consume_term(instring[1:].strip()) if (term is None) or not (isinstance(term, (int, float, bool)) or len(term) > 0): - raise ValueError(f"malformed assignment in: {orig_instring}") - if return_dict: - return ({name_val: term}, instring) - toret = orig_instring[: len(orig_instring) - len(instring)] - return (toret, instring) + raise ValueError(f"malformed assigned value in: {orig_instring}") + return ({name: term}, instring) -def consume_query(instring: str, return_dict: bool) -> tuple: - orig_instring = instring - (toret, instring) = consume_assignment(instring.strip(), return_dict, "=") +def consume_overwrites(instring: str, valid_follower: str) -> Tuple[Any, str]: + if instring.startswith(valid_follower): + return ({}, instring) + (toret, instring) = consume_assignment(instring.strip()) while instring.startswith(","): instring = instring[1:].strip() - (assignment, instring) = consume_assignment(instring.strip(), return_dict, "=") - if return_dict: - toret = {**toret, **assignment} - else: - toret = orig_instring[: len(orig_instring) - len(instring)] + (assignment, instring) = consume_assignment(instring.strip()) + toret = {**toret, **assignment} return (toret, instring) -def parse_key_equals_value_string_to_dict(query: str) -> dict: +def consume_query(instring: str) -> Tuple[Tuple[str, any], str]: + orig_instring = instring + (name, instring) = consume_name(instring) + instring = instring.strip() + if len(instring) == 0 or not instring.startswith("["): + return ((name, None), instring) + + (overwrites, instring) = consume_overwrites(instring[1:], "]") + instring = instring.strip() + if len(instring) == 0 or not instring.startswith("]"): + raise ValueError( + f"malformed end of query: overwrites not closed by ] in: {orig_instring}" + ) + return ((name, overwrites), instring[1:].strip()) + + +def parse_key_equals_value_string_to_dict(args: str) -> dict: """Parses a query string of the form 'key1=value1,key2=value2,...' into a dictionary. The function converts numeric values into integers or floats as appropriate, and raises an @@ -160,8 +176,56 @@ def parse_key_equals_value_string_to_dict(query: str) -> dict: :param query: The query string to be parsed. :return: A dictionary with keys and values extracted from the query string, with spaces stripped from keys. """ - instring = query - qu, instring = consume_query(instring, True) + instring = args.strip() + if len(instring) == 0: + return {} + args_dict, instring = consume_overwrites(instring, " ") if len(instring.strip()) > 0: - raise ValueError(f"Illegal query structure in {query}") - return qu + raise ValueError(f"Illegal key-values structure in {args}") + return args_dict + + +def separate_inside_and_outside_square_brackets(s: str) -> Tuple[str, any]: + """Separates the content inside and outside the first level of square brackets in a string. + + Allows text before the first bracket and nested brackets within the first level. Raises a ValueError for: + - Text following the closing bracket of the first bracket pair + - Unmatched brackets + - Multiple bracket pairs at the same level + + :param s: The input string to be parsed. + :return: A tuple (outside, inside) where 'outside' is the content outside the first level of square brackets, + and 'inside' is the content inside the first level of square brackets. If there are no brackets, + 'inside' will be None. + """ + start = s.find("[") + end = s.rfind("]") + + # Handle no brackets + if start == -1 and end == -1: + return s, None + + # Validate brackets + if start == -1 or end == -1 or start > end: + raise ValueError("Illegal structure: unmatched square brackets.") + + after = s[end + 1 :] + + # Check for text after the closing bracket + if len(after.strip()) != 0: + raise ValueError( + "Illegal structure: text follows after the closing square bracket." + ) + + instring = s.strip() + orig_instring = instring + if "[" not in instring: + # no alternative values to artifact: consider the whole input string as an artifact name + return (instring, None) + # parse to identify artifact name and alternative values to artifact's arguments + (query, instring) = consume_query(instring) + if len(instring) > 0: + raise ValueError( + f"malformed end of query: excessive text following the ] that closes the overwrites in: {orig_instring}" + ) + return query diff --git a/tests/library/test_artifact.py b/tests/library/test_artifact.py index a69d0c4b9b..8c9f0f91d9 100644 --- a/tests/library/test_artifact.py +++ b/tests/library/test_artifact.py @@ -6,7 +6,7 @@ from src.unitxt.dataclass import UnexpectedArgumentError from src.unitxt.logging_utils import get_logger from src.unitxt.operator import SequentialOperator -from src.unitxt.operators import AddFields +from src.unitxt.operators import AddFields, RenameFields from src.unitxt.processors import StringOrNotString from src.unitxt.test_utils.catalog import temp_catalog from tests.utils import UnitxtTestCase @@ -82,11 +82,28 @@ def test_artifact_loading_with_overwrite_args_dict(self): "type_of_class": "topic", } ), - "test.for.dict", + "addfields.for.test.dict", catalog_path=catalog_path, ) + add_to_catalog( + RenameFields(field_to_field={"label_text": "label"}), + "renamefields.for.test.dict", + catalog_path=catalog_path, + ) + artifact = get_from_catalog( + "addfields.for.test.dict", + catalog_path=catalog_path, + ) + expected = { + "classes": ["war", "peace"], + "text_type": "text", + "type_of_class": "topic", + } + self.assertDictEqual(expected, artifact.fields) + + # with overwrite artifact = get_from_catalog( - "test.for.dict[fields={classes:[war_test, peace_test],text_type: text_test, type_of_class: topic_test}]", + "addfields.for.test.dict[fields={classes=[war_test, peace_test],text_type= text_test, type_of_class= topic_test}]", catalog_path=catalog_path, ) expected = { diff --git a/tests/library/test_parsing_utils.py b/tests/library/test_parsing_utils.py index d23a697867..919080b4ed 100644 --- a/tests/library/test_parsing_utils.py +++ b/tests/library/test_parsing_utils.py @@ -59,12 +59,12 @@ def test_parse_key_equals_value_string_to_dict_lists(self): expected = {"year": 2020, "score": [9.5, "nine", "and", "a half"], "count": 10} self.assertEqual(expected, parse_key_equals_value_string_to_dict(query)) - query = "year=2020,score=[9.5, artifactname[init=[nine, and, a, half], anotherinner=true]],count=10" + query = "year=2020,score=[9.5, artifactname[init=[nine, and, a, half], anotherinner=True]],count=10" expected = { "year": 2020, "score": [ 9.5, - "artifactname[init=[nine, and, a, half], anotherinner=true]", + "artifactname[init=[nine, and, a, half], anotherinner=True]", ], "count": 10, } @@ -96,37 +96,21 @@ def test_base_structure(self): def test_valid_structure(self): self.assertEqual( - separate_inside_and_outside_square_brackets("before[inside]"), - ("before", "inside"), + separate_inside_and_outside_square_brackets("before[inside=2]"), + ("before", {"inside": 2}), ) def test_valid_nested_structure(self): self.assertEqual( + ("before", {"inside_before": "inside_inside[iii=vvv]"}), separate_inside_and_outside_square_brackets( - "before[inside_before[inside_inside]]" + "before[inside_before=inside_inside[iii=vvv]]" ), - ("before", "inside_before[inside_inside]"), - ) - - def test_valid_nested_structure_with_broken_structre(self): - self.assertEqual( - separate_inside_and_outside_square_brackets( - "before[inside_before[inside_inside]" - ), - ("before", "inside_before[inside_inside"), - ) - - def test_valid_nested_structure_with_broken_structre_inside(self): - self.assertEqual( - separate_inside_and_outside_square_brackets( - "before[inside_a]between[inside_b]" - ), - ("before", "inside_a]between[inside_b"), ) def test_valid_empty_inside(self): self.assertEqual( - separate_inside_and_outside_square_brackets("before[]"), ("before", "") + ("before", {}), separate_inside_and_outside_square_brackets("before[]") ) def test_illegal_text_following_brackets(self): diff --git a/tests/library/test_query.py b/tests/library/test_query.py index 940ae806ae..3c1f3e88b9 100644 --- a/tests/library/test_query.py +++ b/tests/library/test_query.py @@ -16,9 +16,9 @@ def test_query_works(self): } self.assertDictEqual(parsed, target) - def test_empty_query_fail(self): - with self.assertRaises(ValueError): - parse("") + def test_empty_query_is_ok_but_fruitless_fail(self): + self.assertEqual({}, parse("")) + with self.assertRaises(ValueError): parse(",")