Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use less problematic whitespace token #916

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/reference/json.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ print(result)

!!! Note "JSON and whitespaces"

By default Outlines lets model choose the number of linebreaks and white spaces used to structure the JSON. Small models tend to struggle with this, in which case we recommend to set the value of the parameter `whitespace_pattern` to the empty string:
By default Outlines prevents the model from generating json with syntactic newlines, tabs, or multiple spaces. The default `whitespace_pattern` is `r"[ ]?"`. Small models tend to enter an infinite repetition loop if the `whitespace_pattern` allows infinite spacing. If you would like to allow the model to generate multiple tabs, newlines, and spaces, you can set the whitespace pattern as follows:

```python
generator = generate.json(model, User, whitespace_pattern="")
generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*")
```

!!! Note "Performance"
Expand Down
2 changes: 1 addition & 1 deletion outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
BOOLEAN = r"(true|false)"
NULL = r"null"
WHITESPACE = r"[\n ]*"
WHITESPACE = r"[ ]?"

type_to_regex = {
"string": STRING,
Expand Down
45 changes: 15 additions & 30 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def test_match_number(pattern, does_match):
"properties": {"count": {"title": "Count", "type": "integer"}},
"required": ["count"],
},
'\\{[\\n ]*"count"[\\n ]*:[\\n ]*(-)?(0|[1-9][0-9]*)[\\n ]*\\}',
[('{\n "count": 100\n}', True)],
'\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}',
[('{ "count": 100 }', True)],
),
# array
(
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_match_number(pattern, does_match):
rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""",
[
("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True),
("""{ "test_dict":{"foo":"bar"\n}}""", True),
("""{ "test_dict":{"foo":"bar" }}""", True),
("""{ "test_dict":{}}""", True),
("""{ "WRONG_KEY":{}}""", False),
("""{ "test_dict":{"wrong_type" 1}}""", False),
Expand Down Expand Up @@ -369,8 +369,8 @@ def test_match_number(pattern, does_match):
},
"required": ["fuzz"],
},
f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}',
[('{\n "fuzz": {\n "spam": 100\n }\n}', True)],
f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}',
[('{ "fuzz": { "spam": 100 }}', True)],
),
# Schema with a reference
(
Expand All @@ -384,7 +384,7 @@ def test_match_number(pattern, does_match):
},
"required": ["user_id", "name", "a"],
},
f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}',
f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}',
[('{"user_id": 100, "name": "John", "a": "Marc"}', True)],
),
(
Expand All @@ -399,7 +399,7 @@ def test_match_number(pattern, does_match):
},
"required": ["user_id", "name", "name2"],
},
f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}',
f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}',
[('{"user_id": 100, "name": "John", "name2": "Marc"}', True)],
),
(
Expand Down Expand Up @@ -441,7 +441,7 @@ def test_match_number(pattern, does_match):
}
},
},
f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"address"[\\n ]*:[\\n ]*\\{{[\\n ]*"city"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}[\\n ]*\\}}',
f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}',
[
(
'{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}',
Expand All @@ -462,7 +462,7 @@ def test_match_number(pattern, does_match):
"title": "Character",
"type": "object",
},
f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"weapon"[\\n ]*:[\\n ]*({STRING}|null))?[\\n ]*\\}}',
f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}',
[
('{ "name" : "Player" }', True),
('{ "name" : "Player", "weapon" : "sword" }', True),
Expand All @@ -482,7 +482,7 @@ def test_match_number(pattern, does_match):
"title": "Character",
"type": "object",
},
f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}',
f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}',
[
('{ "name" : "Player" , "weapon" : "sword" }', True),
(
Expand All @@ -506,7 +506,7 @@ def test_match_number(pattern, does_match):
"title": "Character",
"type": "object",
},
f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"armor"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}',
f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}',
[
(
'{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }',
Expand All @@ -530,7 +530,7 @@ def test_match_number(pattern, does_match):
"title": "Character",
"type": "object",
},
f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}',
f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}',
[
('{ "name" : "Player" }', True),
('{ "name" : "Player", "age" : 10, "strength" : 10 }', True),
Expand Down Expand Up @@ -710,19 +710,6 @@ def test_format(schema, regex, examples):
('{"time":20:20:39Z}', False), # missing quotes for value
],
),
# Unconstrained Object
(
{
"title": "Foo",
"type": "object",
},
[
("{}", True),
('{"a": 1, "b": null}', True),
('{"a": {"z": {"g": 4}}, "b": null}', True),
("1234", False), # not an object
],
),
],
)
def test_format_without_regex(schema, examples):
Expand All @@ -737,7 +724,7 @@ def test_format_without_regex(schema, examples):
assert match is None


@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]?", "abc"])
@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"])
def test_json_schema_custom_whitespace_pattern(whitespace_pattern):
"""assert whitespace_pattern setting respected"""

Expand All @@ -759,13 +746,11 @@ class MockModel(BaseModel):
)
mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}"""

match_default_ws = re.fullmatch(pattern, mock_result_mult_ws)
match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws)
if whitespace_pattern is None:
assert match_default_ws
else:
assert match_default_ws is None

assert re.fullmatch(pattern, mock_result_maybe_ws)
assert re.fullmatch(pattern, mock_result_mult_ws)


def test_one_of_doesnt_produce_illegal_lookaround():
Expand Down
Loading