Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dicts in overwrite args of fetched artifact #598

Merged
merged 8 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .dataclass import AbstractField, Dataclass, Field, InternalField, fields
from .logging_utils import get_logger
from .parsing_utils import (
parse_key_equals_value_string_to_dict,
separate_inside_and_outside_square_brackets,
)
from .settings_utils import get_settings
Expand Down Expand Up @@ -313,8 +312,6 @@ def get_artifactory_name_and_args(
name: str, artifactories: Optional[List[Artifactory]] = None
):
name, args = separate_inside_and_outside_square_brackets(name)
if args is not None:
args = parse_key_equals_value_string_to_dict(args)

if artifactories is None:
artifactories = list(Artifactories())
Expand Down
295 changes: 198 additions & 97 deletions src/unitxt/parsing_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,191 @@
def separate_inside_and_outside_square_brackets(s):
# An artifact is fetched from the catalog through the name by which it was added to that catalog.
# When first instantiated as an object (of class Artifact, or an extension thereof), values are set to the artifact's
# class arguments.
# When added to the catalog, the type (name of immediate class) along with the values of its arguments, are specified
# in the catalog record (json file) of that artifact, associated with the name given to that record upon adding to
# the catalog.
# When the artifact is later fetched from the catalog by that name, the values of its arguments are fetched as well,
# and used to instantiate the artifact into the object it was when added to the catalog.
#
# An addendum to this fetching process is a specification of alternative values to replace, upon fetching, (some of)
# the fetched values of the artifact's arguments, and to use the thus updated values for the instantiation following
# this fetching.
#
# The alternative arguments values, aka overwrites, are expressed as a string-ed series of key-value pairs, enclosed
# in square brackets, appended to the name of the artifact name.
#
# Overall, the formal definition of a query string, by which an artifact is fetched and instantiated, is as follows:
#
# query -> name | name [overwrites]
# overwrites -> assignment (, assignment)*
# assignment -> name = term
# term -> [ term (, term)* ] | { assignment (, assignment)* } | name_value | query
#
# name_value starting at a given point in the query string is the longest substring of the query string,
# that starts at that point, and ends upon reaching the end of the query string, or one of these chars: [],:{}=
# spaces are allowed.
# name is a name_value that is not evaluated to an int, or float, or boolean.
#
# The following code processes a given query string, verifies that it conforms with the above format syntax, throwing
# exceptions otherwise, and returns a pair of: (a) artifact name in the catalog, and (b) a (potentially empty)
# dictionary whose keys are names of (some of) the class arguments of the artifact, associated with the alternative
# values to set to these arguments, upon the instantiation of the artifact as a response to this query.
#
# Note: the code does not verify that a name of an artifact's argument is indeed a name of an argument of that
# artifact. The instantiation process that follows the parsing will verify that.
# Also, if an alternative value of an argument is specified, in turn, as a query with overwrites, the conforming
# of that query with the above syntax is done when processing the major query of the artifact, but the parsing of the
# overwrites (in the argument's query) is delayed to the stage when the recursive instantiation of the major artifact
# reaches the instantiation of that argument.
#
#
from typing import Any, Tuple


def consume_name_val(instring: str) -> Tuple[Any, str]:
name_val = ""
for char in instring:
if char in "[],:{}=":
break
name_val += char
instring = instring[len(name_val) :].strip()
name_val = name_val.strip()

if name_val == "True":
return (True, instring)
if name_val == "False":
return (False, instring)

sign = 1
if name_val.startswith("-"):
sign = -1
name_val = name_val[1:]
if name_val.isdigit():
return (sign * int(name_val), instring)
if name_val.replace(".", "", 1).isdigit() and name_val.count(".") < 2:
return (sign * float(name_val), instring)

if sign == -1:
name_val = "-" + name_val
return (name_val, instring)


def consume_name(instring: str) -> Tuple[Any, str]:
orig_instring = instring
(name_val, instring) = consume_name_val(instring)
if (
name_val is None
or isinstance(name_val, (int, float, bool))
or len(name_val) == 0
):
raise ValueError(f"malformed name at the beginning of: {orig_instring}")
return (name_val, instring)


# flake8: noqa: C901
def consume_term(instring: str) -> Tuple[Any, str]:
orig_instring = instring
if instring.startswith("["):
toret = []
instring = instring[1:].strip()
(term, instring) = consume_term(instring)
toret.append(term)
while instring.startswith(","):
(term, instring) = consume_term(instring[1:].strip())
toret.append(term)
if not instring.startswith("]"):
raise ValueError(f"malformed list in: {orig_instring}")

Check warning on line 97 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L97

Added line #L97 was not covered by tests
instring = instring[1:].strip()
return (toret, instring)

if instring.startswith("{"):
instring = instring[1:].strip()
(items, instring) = consume_overwrites(instring, "}")
if not instring.startswith("}"):
raise ValueError(f"malformed dict in: {orig_instring}")

Check warning on line 105 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L105

Added line #L105 was not covered by tests
instring = instring[1:].strip()
return (items, instring)

(name_val, instring) = consume_name_val(instring)
instring = instring.strip()
if not (
name_val is None
or isinstance(name_val, (int, float, bool))
or len(name_val) == 0
) and instring.startswith("["):
# term is a query with args
(overwrites, instring) = consume_overwrites(instring[1:].strip(), "]")
instring = instring.strip()
if not instring.startswith("]"):
raise ValueError(f"malformed query as a term in: {orig_instring}")

Check warning on line 120 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L120

Added line #L120 was not covered by tests
instring = instring[1:].strip()
toret = orig_instring[: len(orig_instring) - len(instring)]
# argument's alternative value specified by query with overwrites.
# the parsing of that query is delayed, to be synchronizes with the recursive loading
# of the artifact from the catalog
return (toret, instring)

return (name_val, instring)


def consume_assignment(instring: str) -> Tuple[Any, str]:
orig_instring = instring
(name, instring) = consume_name(instring)

if not instring.startswith("="):
raise ValueError(f"malformed assignment in: {orig_instring}")
(term, instring) = consume_term(instring[1:].strip())
if (term is None) or not (isinstance(term, (int, float, bool)) or len(term) > 0):
raise ValueError(f"malformed assigned value in: {orig_instring}")
return ({name: term}, instring)


def consume_overwrites(instring: str, valid_follower: str) -> Tuple[Any, str]:
if instring.startswith(valid_follower):
return ({}, instring)
(toret, instring) = consume_assignment(instring.strip())
while instring.startswith(","):
instring = instring[1:].strip()
(assignment, instring) = consume_assignment(instring.strip())
toret = {**toret, **assignment}
return (toret, instring)


def consume_query(instring: str) -> Tuple[Tuple[str, any], str]:
orig_instring = instring
(name, instring) = consume_name(instring)
instring = instring.strip()
if len(instring) == 0 or not instring.startswith("["):
return ((name, None), instring)

Check warning on line 159 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L159

Added line #L159 was not covered by tests

(overwrites, instring) = consume_overwrites(instring[1:], "]")
instring = instring.strip()
if len(instring) == 0 or not instring.startswith("]"):
raise ValueError(

Check warning on line 164 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L164

Added line #L164 was not covered by tests
f"malformed end of query: overwrites not closed by ] in: {orig_instring}"
)
return ((name, overwrites), instring[1:].strip())


def parse_key_equals_value_string_to_dict(args: str) -> dict:
"""Parses a query string of the form 'key1=value1,key2=value2,...' into a dictionary.

The function converts numeric values into integers or floats as appropriate, and raises an
exception if the query string is malformed or does not conform to the expected format.

:param query: The query string to be parsed.
:return: A dictionary with keys and values extracted from the query string, with spaces stripped from keys.
"""
instring = args.strip()
if len(instring) == 0:
return {}
args_dict, instring = consume_overwrites(instring, " ")
if len(instring.strip()) > 0:
raise ValueError(f"Illegal key-values structure in {args}")
return args_dict


def separate_inside_and_outside_square_brackets(s: str) -> Tuple[str, any]:
"""Separates the content inside and outside the first level of square brackets in a string.

Allows text before the first bracket and nested brackets within the first level. Raises a ValueError for:
Expand All @@ -22,8 +209,6 @@
if start == -1 or end == -1 or start > end:
raise ValueError("Illegal structure: unmatched square brackets.")

outside = s[:start]
inside = s[start + 1 : end]
after = s[end + 1 :]

# Check for text after the closing bracket
Expand All @@ -32,99 +217,15 @@
"Illegal structure: text follows after the closing square bracket."
)

return outside, inside


def parse_key_equals_value_string_to_dict(query: str):
"""Parses a query string of the form 'key1=value1,key2=value2,...' into a dictionary.

The function converts numeric values into integers or floats as appropriate, and raises an
exception if the query string is malformed or does not conform to the expected format.

:param query: The query string to be parsed.
:return: A dictionary with keys and values extracted from the query string, with spaces stripped from keys.
"""
result = {}
kvs = split_within_depth(query, dellimiter=",")
if len(kvs) == 0:
raise ValueError(
f'Illegal query: "{query}" should contain at least one assignment of the form: key1=value1,key2=value2'
)
for kv in kvs:
kv = kv.strip()
key_val = split_within_depth(kv, dellimiter="=")
if (
len(key_val) != 2
or len(key_val[0].strip()) == 0
or len(key_val[1].strip()) == 0
):
raise ValueError(
f'Illegal query: "{query}" with wrong assignment "{kv}" should be of the form: key=value.'
)
key, val = key_val[0].strip(), key_val[1].strip()
if val.isdigit():
result[key] = int(val)
elif val.replace(".", "", 1).isdigit() and val.count(".") < 2:
result[key] = float(val)
else:
try:
result[key] = parse_list_string(val)
except:
result[key] = val

return result


def split_within_depth(
s, dellimiter=",", depth=0, forbbiden_chars=None, level_start="[", level_end="]"
):
result = []
part = ""
current_depth = 0
for char in s:
if char == level_start:
current_depth += 1
part += char
elif char == level_end:
current_depth -= 1
part += char
elif (
forbbiden_chars is not None
and char in forbbiden_chars
and current_depth <= depth
):
raise ValueError("")
elif char == dellimiter and current_depth <= depth:
result.append(part)
part = ""
else:
part += char
if part:
result.append(part)
return result


def parse_list_string(s: str):
"""Parses a query string of the form 'val1,val2,...' into a list."""
start = s.find("[")
end = s.rfind("]")

# Handle no brackets
if start == -1 and end == -1:
return s

# Validate brackets
if start == -1 or end == -1 or start > end:
raise ValueError("Illegal structure: unmatched square brackets.")

before = s[:start].strip()
inside = s[start + 1 : end].strip()
after = s[end + 1 :].strip()

# Check for text after the closing bracket
if len(before) != 0 or len(after) != 0:
instring = s.strip()
orig_instring = instring
if "[" not in instring:
# no alternative values to artifact: consider the whole input string as an artifact name
return (instring, None)

Check warning on line 224 in src/unitxt/parsing_utils.py

View check run for this annotation

Codecov / codecov/patch

src/unitxt/parsing_utils.py#L224

Added line #L224 was not covered by tests
# parse to identify artifact name and alternative values to artifact's arguments
(query, instring) = consume_query(instring)
if len(instring) > 0:
raise ValueError(
"Illegal structure: text follows before or after the closing square bracket."
f"malformed end of query: excessive text following the ] that closes the overwrites in: {orig_instring}"
)
splitted = split_within_depth(inside.strip(), dellimiter=",", forbbiden_chars=["="])
return [s.strip() for s in splitted]
return query
42 changes: 42 additions & 0 deletions tests/library/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from src.unitxt.dataclass import UnexpectedArgumentError
from src.unitxt.logging_utils import get_logger
from src.unitxt.operator import SequentialOperator
from src.unitxt.operators import AddFields, RenameFields
from src.unitxt.processors import StringOrNotString
from src.unitxt.test_utils.catalog import temp_catalog
from tests.utils import UnitxtTestCase
Expand Down Expand Up @@ -70,3 +71,44 @@ def test_artifact_loading_with_overwrite_args_list(self):
)
artifact, _ = fetch_artifact(artifact_identifier)
self.assertEqual(artifact.metrics, ["metrics.rouge", "metrics.accuracy"])

def test_artifact_loading_with_overwrite_args_dict(self):
with temp_catalog() as catalog_path:
add_to_catalog(
AddFields(
fields={
"classes": ["war", "peace"],
"text_type": "text",
"type_of_class": "topic",
}
),
"addfields.for.test.dict",
catalog_path=catalog_path,
)
add_to_catalog(
RenameFields(field_to_field={"label_text": "label"}),
"renamefields.for.test.dict",
catalog_path=catalog_path,
)
artifact = get_from_catalog(
"addfields.for.test.dict",
catalog_path=catalog_path,
)
expected = {
"classes": ["war", "peace"],
"text_type": "text",
"type_of_class": "topic",
}
self.assertDictEqual(expected, artifact.fields)

# with overwrite
artifact = get_from_catalog(
"addfields.for.test.dict[fields={classes=[war_test, peace_test],text_type= text_test, type_of_class= topic_test}]",
catalog_path=catalog_path,
)
expected = {
"classes": ["war_test", "peace_test"],
"text_type": "text_test",
"type_of_class": "topic_test",
}
self.assertDictEqual(expected, artifact.fields)
Loading
Loading