Skip to content

Commit

Permalink
Support date,time,date-time,uuid formats
Browse files Browse the repository at this point in the history
  • Loading branch information
aschwa authored and rlouf committed Jan 23, 2024
1 parent 8a0bafc commit fda4ce4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
26 changes: 26 additions & 0 deletions outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@
"null": NULL,
}

DATE_TIME = r"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"
DATE = r"(?:\d{4})-(?:0[1-9]|1[0-2])-(?:0[1-9]|[1-2][0-9]|3[0-1])"
TIME = r"(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\\.[0-9]+)?(Z)?"
UUID = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"

format_to_regex = {
"uuid": UUID,
"date-time": DATE_TIME,
"date": DATE,
"time": TIME,
}


def build_regex_from_object(object: Union[str, Callable, BaseModel]):
"""Turn a JSON schema into a regex that matches any JSON object that follows
Expand Down Expand Up @@ -210,6 +222,20 @@ def to_regex(resolver: Resolver, instance: dict):
return rf'(^"{pattern[1:-1]}"$)'
else:
return rf'("{pattern}")'
elif "format" in instance:
format = instance["format"]
if format == "date-time":
return format_to_regex["date-time"]
elif format == "uuid":
return format_to_regex["uuid"]
elif format == "date":
return format_to_regex["date"]
elif format == "time":
return format_to_regex["time"]
else:
raise NotImplementedError(
f"Format {format} is not supported by Outlines"
)
else:
return type_to_regex["string"]

Expand Down
76 changes: 76 additions & 0 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

from outlines.fsm.json_schema import (
BOOLEAN,
DATE,
DATE_TIME,
INTEGER,
NULL,
NUMBER,
STRING,
STRING_INNER,
TIME,
UUID,
build_regex_from_object,
get_schema_from_signature,
to_regex,
Expand Down Expand Up @@ -451,3 +455,75 @@ def test_match(schema, regex, examples):
assert match.span() == (0, len(string))
else:
assert match is None


@pytest.mark.parametrize(
"schema,regex,examples",
[
# UUID
(
{"title": "Foo", "type": "string", "format": "uuid"},
UUID,
[
("123e4567-e89b-12d3-a456-426614174000", True),
("123e4567-e89b-12d3-a456-42661417400", False),
("123e4567-e89b-12d3-a456-42661417400g", False),
("123e4567-e89b-12d3-a456-42661417400-", False),
("", False),
],
),
# DATE-TIME
(
{"title": "Foo", "type": "string", "format": "date-time"},
DATE_TIME,
[
("2018-11-13T20:20:39Z", True),
("2016-09-18T17:34:02.666Z", True),
("2008-05-11T15:30:00Z", True),
("2021-01-01T00:00:00", True),
("2022-01-10 07:19:30", False), # missing T
("2022-12-10T10-04-29", False), # incorrect separator
("2023-01-01", False),
],
),
# DATE
(
{"title": "Foo", "type": "string", "format": "date"},
DATE,
[
("2018-11-13", True),
("2016-09-18", True),
("2008-05-11", True),
("2015-13-01", False), # incorrect month
("2022-01", False), # missing day
("2022/12/01", False), # incorrect separator"
],
),
# TIME
(
{"title": "Foo", "type": "string", "format": "time"},
TIME,
[
("20:20:39Z", True),
("15:30:00Z", True),
("25:30:00", False), # incorrect hour
("15:30", False), # missing seconds
("15:30:00.000", False), # missing Z
("15-30-00", False), # incorrect separator
("15:30:00+01:00", False), # incorrect separator
],
),
],
)
def test_format(schema, regex, examples):
schema = json.dumps(schema)
test_regex = build_regex_from_object(schema)
assert test_regex == regex

for string, does_match in examples:
match = re.fullmatch(test_regex, string)
if does_match:
assert match[0] == string
assert match.span() == (0, len(string))
else:
assert match is None

0 comments on commit fda4ce4

Please sign in to comment.