diff --git a/docs/reference/json.md b/docs/reference/json.md index 3b5976f19..85e1a846a 100644 --- a/docs/reference/json.md +++ b/docs/reference/json.md @@ -36,10 +36,10 @@ print(result) !!! Note "JSON and whitespaces" - By default Outlines lets model choose the number of linebreaks and white spaces used to structure the JSON. Small models tend to struggle with this, in which case we recommend to set the value of the parameter `whitespace_pattern` to the empty string: + By default Outlines prevents the model from generating json with syntactic newlines, tabs, or multiple spaces. The default `whitespace_pattern` is `r"[ ]?"`. Small models tend to enter an infinite repetition loop if the `whitespace_pattern` allows infinite spacing. If you would like to allow the model to generate multiple tabs, newlines, and spaces, you can set the whitespace pattern as follows: ```python - generator = generate.json(model, User, whitespace_pattern="") + generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*") ``` !!! Note "Performance" diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index dbd2baa40..0e0d25bfc 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -16,7 +16,7 @@ NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" BOOLEAN = r"(true|false)" NULL = r"null" -WHITESPACE = r"[\n ]*" +WHITESPACE = r"[ ]?" type_to_regex = { "string": STRING, diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index b12f9576e..bc836ac8b 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -215,8 +215,8 @@ def test_match_number(pattern, does_match): "properties": {"count": {"title": "Count", "type": "integer"}}, "required": ["count"], }, - '\\{[\\n ]*"count"[\\n ]*:[\\n ]*(-)?(0|[1-9][0-9]*)[\\n ]*\\}', - [('{\n "count": 100\n}', True)], + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}', + [('{ "count": 100 }', True)], ), # array ( @@ -277,7 +277,7 @@ def test_match_number(pattern, does_match): rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", [ ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), - ("""{ "test_dict":{"foo":"bar"\n}}""", True), + ("""{ "test_dict":{"foo":"bar" }}""", True), ("""{ "test_dict":{}}""", True), ("""{ "WRONG_KEY":{}}""", False), ("""{ "test_dict":{"wrong_type" 1}}""", False), @@ -369,8 +369,8 @@ def test_match_number(pattern, does_match): }, "required": ["fuzz"], }, - f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}', - [('{\n "fuzz": {\n "spam": 100\n }\n}', True)], + f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}', + [('{ "fuzz": { "spam": 100 }}', True)], ), # Schema with a reference ( @@ -384,7 +384,7 @@ def test_match_number(pattern, does_match): }, "required": ["user_id", "name", "a"], }, - f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}', [('{"user_id": 100, "name": "John", "a": "Marc"}', True)], ), ( @@ -399,7 +399,7 @@ def test_match_number(pattern, does_match): }, "required": ["user_id", "name", "name2"], }, - f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}', [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)], ), ( @@ -441,7 +441,7 @@ def test_match_number(pattern, does_match): } }, }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"address"[\\n ]*:[\\n ]*\\{{[\\n ]*"city"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}', [ ( '{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}', @@ -462,7 +462,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"weapon"[\\n ]*:[\\n ]*({STRING}|null))?[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}', [ ('{ "name" : "Player" }', True), ('{ "name" : "Player", "weapon" : "sword" }', True), @@ -482,7 +482,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', [ ('{ "name" : "Player" , "weapon" : "sword" }', True), ( @@ -506,7 +506,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"armor"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}', [ ( '{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }', @@ -530,7 +530,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}', + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', [ ('{ "name" : "Player" }', True), ('{ "name" : "Player", "age" : 10, "strength" : 10 }', True), @@ -710,19 +710,6 @@ def test_format(schema, regex, examples): ('{"time":20:20:39Z}', False), # missing quotes for value ], ), - # Unconstrained Object - ( - { - "title": "Foo", - "type": "object", - }, - [ - ("{}", True), - ('{"a": 1, "b": null}', True), - ('{"a": {"z": {"g": 4}}, "b": null}', True), - ("1234", False), # not an object - ], - ), ], ) def test_format_without_regex(schema, examples): @@ -737,7 +724,7 @@ def test_format_without_regex(schema, examples): assert match is None -@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]?", "abc"]) +@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"]) def test_json_schema_custom_whitespace_pattern(whitespace_pattern): """assert whitespace_pattern setting respected""" @@ -759,13 +746,11 @@ class MockModel(BaseModel): ) mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}""" - match_default_ws = re.fullmatch(pattern, mock_result_mult_ws) + match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws) if whitespace_pattern is None: assert match_default_ws else: - assert match_default_ws is None - - assert re.fullmatch(pattern, mock_result_maybe_ws) + assert re.fullmatch(pattern, mock_result_mult_ws) def test_one_of_doesnt_produce_illegal_lookaround():