Skip to content

Commit

Permalink
changed query systax of dicts: from ':' to '=' and simplified some
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed Feb 28, 2024
1 parent 4163992 commit 9550b32
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 123 deletions.
3 changes: 0 additions & 3 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .dataclass import AbstractField, Dataclass, Field, InternalField, fields
from .logging_utils import get_logger
from .parsing_utils import (
parse_key_equals_value_string_to_dict,
separate_inside_and_outside_square_brackets,
)
from .settings_utils import get_settings
Expand Down Expand Up @@ -313,8 +312,6 @@ def get_artifactory_name_and_args(
name: str, artifactories: Optional[List[Artifactory]] = None
):
name, args = separate_inside_and_outside_square_brackets(name)
if args is not None:
args = parse_key_equals_value_string_to_dict(args)

if artifactories is None:
artifactories = list(Artifactories())
Expand Down
246 changes: 155 additions & 91 deletions src/unitxt/parsing_utils.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,48 @@
def separate_inside_and_outside_square_brackets(s):
"""Separates the content inside and outside the first level of square brackets in a string.
Allows text before the first bracket and nested brackets within the first level. Raises a ValueError for:
- Text following the closing bracket of the first bracket pair
- Unmatched brackets
- Multiple bracket pairs at the same level
:param s: The input string to be parsed.
:return: A tuple (outside, inside) where 'outside' is the content outside the first level of square brackets,
and 'inside' is the content inside the first level of square brackets. If there are no brackets,
'inside' will be None.
"""
start = s.find("[")
end = s.rfind("]")

# Handle no brackets
if start == -1 and end == -1:
return s, None

# Validate brackets
if start == -1 or end == -1 or start > end:
raise ValueError("Illegal structure: unmatched square brackets.")

outside = s[:start]
inside = s[start + 1 : end]
after = s[end + 1 :]

# Check for text after the closing bracket
if len(after.strip()) != 0:
raise ValueError(
"Illegal structure: text follows after the closing square bracket."
)

return outside, inside


# Formal definition of query:
# query -> assignment(=) (, assignment(=))*
# assignment(delimeter) -> name_value delimeter term
# term -> name_value | name_value[query] | [ term (, term)* ] | { assignment(:) (, assignment(:))* }
# An artifact is fetched from the catalog through the name by which it was added to that catalog.
# When first instantiated as an object (of class Artifact, or an extension thereof), values are set to the artifact's
# class arguments.
# When added to the catalog, the type (name of immediate class) along with the values of its arguments, are specified
# in the catalog record (json file) of that artifact, associated with the name given to that record upon adding to
# the catalog.
# When the artifact is later fetched from the catalog by that name, the values of its arguments are fetched as well,
# and used to instantiate the artifact into the object it was when added to the catalog.
#
# An addendum to this fetching process is a specification of alternative values to replace, upon fetching, (some of)
# the fetched values of the artifact's arguments, and to use the thus updated values for the instantiation following
# this fetching.
#
# The alternative arguments values, aka overwrites, are expressed as a string-ed series of key-value pairs, enclosed
# in square brackets, appended to the name of the artifact name.
#
# a boolean parameter, return_dict, maintains a nested query: a query (potentially, indirectly) within another query,
# in the form of a string, so that the process of parsing cohere with the recursive process of fetching the artifact.
# Overall, the formal definition of a query string, by which an artifact is fetched and instantiated, is as follows:
#
# query -> name | name [overwrites]
# overwrites -> assignment (, assignment)*
# assignment -> name = term
# term -> [ term (, term)* ] | { assignment (, assignment)* } | name_value | query
#
# name_value starting at a given point in the query string is the longest substring of the query string,
# that starts at that point, and ends upon reaching the end of the query string, or one of these chars: [],:{}=
# spaces are allowed.
# name is a name_value that is not evaluated to an int, or float, or boolean.
#
# The following code processes a given query string, verifies that it conforms with the above format syntax, throwing
# exceptions otherwise, and returns a pair of: (a) artifact name in the catalog, and (b) a (potentially empty)
# dictionary whose keys are names of (some of) the class arguments of the artifact, associated with the alternative
# values to set to these arguments, upon the instantiation of the artifact as a response to this query.
#
# Note: the code does not verify that a name of an artifact's argument is indeed a name of an argument of that
# artifact. The instantiation process that follows the parsing will verify that.
# Also, if an alternative value of an argument is specified, in turn, as a query with overwrites, the conforming
# of that query with the above syntax is done when processing the major query of the artifact, but the parsing of the
# overwrites (in the argument's query) is delayed to the stage when the recursive instantiation of the major artifact
# reaches the instantiation of that argument.
#
#
from typing import Any, Tuple


def consume_name_val(instring: str) -> tuple:
def consume_name_val(instring: str) -> Tuple[Any, str]:
name_val = ""
for char in instring:
if char in "[],:{}=":
Expand Down Expand Up @@ -72,86 +70,104 @@ def consume_name_val(instring: str) -> tuple:
return (name_val, instring)


def consume_name(instring: str) -> Tuple[Any, str]:
orig_instring = instring
(name_val, instring) = consume_name_val(instring)
if (
name_val is None
or isinstance(name_val, (int, float, bool))
or len(name_val) == 0
):
raise ValueError(f"malformed name at the beginning of: {orig_instring}")
return (name_val, instring)


# flake8: noqa: C901
def consume_term(instring: str, return_dict: bool) -> tuple:
def consume_term(instring: str) -> Tuple[Any, str]:
orig_instring = instring
if instring.startswith("["):
toret = []
instring = instring[1:].strip()
(term, instring) = consume_term(instring, return_dict)
(term, instring) = consume_term(instring)
toret.append(term)
while instring.startswith(","):
(term, instring) = consume_term(instring[1:].strip(), return_dict)
(term, instring) = consume_term(instring[1:].strip())
toret.append(term)
if not instring.startswith("]"):
raise ValueError(f"malformed list in: {orig_instring}")

Check warning on line 97 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L97

Added line #L97 was not covered by tests
instring = instring[1:].strip()
if not return_dict:
toret = orig_instring[: len(orig_instring) - len(instring)]
return (toret, instring)

if instring.startswith("{"):
instring = instring[1:].strip()
(assignment, instring) = consume_assignment(instring, return_dict, ":")
toret = assignment
while instring.startswith(","):
(assignment, instring) = consume_assignment(
instring[1:].strip(), return_dict, ":"
)
if return_dict:
toret = {**toret, **assignment}
(items, instring) = consume_overwrites(instring, "}")
if not instring.startswith("}"):
raise ValueError(f"malformed dict in: {orig_instring}")

Check warning on line 105 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L105

Added line #L105 was not covered by tests
instring = instring[1:].strip()
if not return_dict:
toret = orig_instring[: len(orig_instring) - len(instring)]
return (toret, instring)
return (items, instring)

(name_val, instring) = consume_name_val(instring)
if instring.startswith("["):
(quey, instring) = consume_query(instring[1:].strip(), False)
instring = instring.strip()
if not (
name_val is None
or isinstance(name_val, (int, float, bool))
or len(name_val) == 0
) and instring.startswith("["):
# term is a query with args
(overwrites, instring) = consume_overwrites(instring[1:].strip(), "]")
instring = instring.strip()
if not instring.startswith("]"):
raise ValueError(f"malformed query in: {orig_instring}")
raise ValueError(f"malformed query as a term in: {orig_instring}")

Check warning on line 120 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L120

Added line #L120 was not covered by tests
instring = instring[1:].strip()
toret = orig_instring[: len(orig_instring) - len(instring)]
# argument's alternative value specified by query with overwrites.
# the parsing of that query is delayed, to be synchronizes with the recursive loading
# of the artifact from the catalog
return (toret, instring)

return (name_val, instring)


def consume_assignment(instring: str, return_dict: bool, delimeter: str) -> tuple:
def consume_assignment(instring: str) -> Tuple[Any, str]:
orig_instring = instring
(name_val, instring) = consume_name_val(instring)
if (
name_val is None
or isinstance(name_val, (int, float, bool))
or len(name_val) == 0
):
raise ValueError(f"malformed key in assignment that starts: {orig_instring}")
if not instring.startswith(delimeter):
(name, instring) = consume_name(instring)

if not instring.startswith("="):
raise ValueError(f"malformed assignment in: {orig_instring}")
(term, instring) = consume_term(instring[1:].strip(), return_dict)
(term, instring) = consume_term(instring[1:].strip())
if (term is None) or not (isinstance(term, (int, float, bool)) or len(term) > 0):
raise ValueError(f"malformed assignment in: {orig_instring}")
if return_dict:
return ({name_val: term}, instring)
toret = orig_instring[: len(orig_instring) - len(instring)]
return (toret, instring)
raise ValueError(f"malformed assigned value in: {orig_instring}")
return ({name: term}, instring)


def consume_query(instring: str, return_dict: bool) -> tuple:
orig_instring = instring
(toret, instring) = consume_assignment(instring.strip(), return_dict, "=")
def consume_overwrites(instring: str, valid_follower: str) -> Tuple[Any, str]:
if instring.startswith(valid_follower):
return ({}, instring)
(toret, instring) = consume_assignment(instring.strip())
while instring.startswith(","):
instring = instring[1:].strip()
(assignment, instring) = consume_assignment(instring.strip(), return_dict, "=")
if return_dict:
toret = {**toret, **assignment}
else:
toret = orig_instring[: len(orig_instring) - len(instring)]
(assignment, instring) = consume_assignment(instring.strip())
toret = {**toret, **assignment}
return (toret, instring)


def parse_key_equals_value_string_to_dict(query: str) -> dict:
def consume_query(instring: str) -> Tuple[Tuple[str, any], str]:
orig_instring = instring
(name, instring) = consume_name(instring)
instring = instring.strip()
if len(instring) == 0 or not instring.startswith("["):
return ((name, None), instring)

Check warning on line 159 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L159

Added line #L159 was not covered by tests

(overwrites, instring) = consume_overwrites(instring[1:], "]")
instring = instring.strip()
if len(instring) == 0 or not instring.startswith("]"):
raise ValueError(

Check warning on line 164 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L164

Added line #L164 was not covered by tests
f"malformed end of query: overwrites not closed by ] in: {orig_instring}"
)
return ((name, overwrites), instring[1:].strip())


def parse_key_equals_value_string_to_dict(args: str) -> dict:
"""Parses a query string of the form 'key1=value1,key2=value2,...' into a dictionary.
The function converts numeric values into integers or floats as appropriate, and raises an
Expand All @@ -160,8 +176,56 @@ def parse_key_equals_value_string_to_dict(query: str) -> dict:
:param query: The query string to be parsed.
:return: A dictionary with keys and values extracted from the query string, with spaces stripped from keys.
"""
instring = query
qu, instring = consume_query(instring, True)
instring = args.strip()
if len(instring) == 0:
return {}
args_dict, instring = consume_overwrites(instring, " ")
if len(instring.strip()) > 0:
raise ValueError(f"Illegal query structure in {query}")
return qu
raise ValueError(f"Illegal key-values structure in {args}")
return args_dict


def separate_inside_and_outside_square_brackets(s: str) -> Tuple[str, any]:
"""Separates the content inside and outside the first level of square brackets in a string.
Allows text before the first bracket and nested brackets within the first level. Raises a ValueError for:
- Text following the closing bracket of the first bracket pair
- Unmatched brackets
- Multiple bracket pairs at the same level
:param s: The input string to be parsed.
:return: A tuple (outside, inside) where 'outside' is the content outside the first level of square brackets,
and 'inside' is the content inside the first level of square brackets. If there are no brackets,
'inside' will be None.
"""
start = s.find("[")
end = s.rfind("]")

# Handle no brackets
if start == -1 and end == -1:
return s, None

# Validate brackets
if start == -1 or end == -1 or start > end:
raise ValueError("Illegal structure: unmatched square brackets.")

after = s[end + 1 :]

# Check for text after the closing bracket
if len(after.strip()) != 0:
raise ValueError(
"Illegal structure: text follows after the closing square bracket."
)

instring = s.strip()
orig_instring = instring
if "[" not in instring:
# no alternative values to artifact: consider the whole input string as an artifact name
return (instring, None)

Check warning on line 224 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L224

Added line #L224 was not covered by tests
# parse to identify artifact name and alternative values to artifact's arguments
(query, instring) = consume_query(instring)
if len(instring) > 0:
raise ValueError(
f"malformed end of query: excessive text following the ] that closes the overwrites in: {orig_instring}"
)
return query
23 changes: 20 additions & 3 deletions tests/library/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from src.unitxt.dataclass import UnexpectedArgumentError
from src.unitxt.logging_utils import get_logger
from src.unitxt.operator import SequentialOperator
from src.unitxt.operators import AddFields
from src.unitxt.operators import AddFields, RenameFields
from src.unitxt.processors import StringOrNotString
from src.unitxt.test_utils.catalog import temp_catalog
from tests.utils import UnitxtTestCase
Expand Down Expand Up @@ -82,11 +82,28 @@ def test_artifact_loading_with_overwrite_args_dict(self):
"type_of_class": "topic",
}
),
"test.for.dict",
"addfields.for.test.dict",
catalog_path=catalog_path,
)
add_to_catalog(
RenameFields(field_to_field={"label_text": "label"}),
"renamefields.for.test.dict",
catalog_path=catalog_path,
)
artifact = get_from_catalog(
"addfields.for.test.dict",
catalog_path=catalog_path,
)
expected = {
"classes": ["war", "peace"],
"text_type": "text",
"type_of_class": "topic",
}
self.assertDictEqual(expected, artifact.fields)

# with overwrite
artifact = get_from_catalog(
"test.for.dict[fields={classes:[war_test, peace_test],text_type: text_test, type_of_class: topic_test}]",
"addfields.for.test.dict[fields={classes=[war_test, peace_test],text_type= text_test, type_of_class= topic_test}]",
catalog_path=catalog_path,
)
expected = {
Expand Down
Loading

0 comments on commit 9550b32

Please sign in to comment.