Skip to content

Commit

Permalink
Allow users to pass custom whitespace pattern for JSON-structured gen…
Browse files Browse the repository at this point in the history
…eration
  • Loading branch information
Andrew Lapp authored and rlouf committed Feb 10, 2024
1 parent 7c71199 commit 9c74d7c
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 35 deletions.
12 changes: 10 additions & 2 deletions docs/reference/json.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,21 @@ class User(BaseModel):
id: int


model = models.transformers("mistralai/Mistral-7B")
model = models.transformers("mistralai/Mistral-7B-v0.1")
generator = text.generate.json(model, User)
result = generator("Create a user profile with the fields name, last_name and id")
print(result)
# User(name="John", last_name="Doe", id=11)
```

!!! warning "JSON and whitespaces"

By default Outlines lets model choose the number of linebreaks and white spaces used to structure the JSON. Small models tend to struggle with this, in which case we recommend to set the value of the parameter `whitespace_pattern` to the empty string:

```python
generator = text.generate.json(model, User, whitespace_pattern="")
```

## From a function's signature

Outlines can infer the structure of the output from the signature of a function. The result is a dictionary, and can be passed directly to the function using the usual dictionary expansion syntax `**`:
Expand All @@ -44,7 +52,7 @@ from outlines import text
def add(a: int, b: int):
return a + b

model = models.transformers("mistralai/Mistral-7B")
model = models.transformers("mistralai/Mistral-7B-v0.1")
generator = text.generate.json(model, add)
result = generator("Return two integers named a and b respectively. a is odd and b even.")

Expand Down
68 changes: 45 additions & 23 deletions outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import json
import re
from typing import Callable, Union
from typing import Callable, Optional, Union

from jsonschema.protocols import Validator
from pydantic import BaseModel, create_model
Expand Down Expand Up @@ -38,7 +38,9 @@
}


def build_regex_from_object(object: Union[str, Callable, BaseModel]):
def build_regex_from_object(
object: Union[str, Callable, BaseModel], whitespace_pattern: Optional[str] = None
):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.
Expand All @@ -54,6 +56,9 @@ def build_regex_from_object(object: Union[str, Callable, BaseModel]):
----------
schema
A string that represents a JSON Schema.
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
Returns
-------
Expand Down Expand Up @@ -83,10 +88,12 @@ def build_regex_from_object(object: Union[str, Callable, BaseModel]):
resolver = registry.resolver()

content = schema.contents
return to_regex(resolver, content)
return to_regex(resolver, content, whitespace_pattern)


def to_regex(resolver: Resolver, instance: dict):
def to_regex(
resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None
):
"""Translate a JSON Schema instance into a regex that validates the schema.
Note
Expand All @@ -105,8 +112,15 @@ def to_regex(resolver: Resolver, instance: dict):
An object that resolves references to other instances within a schema
instance
The instance to translate
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
"""

# set whitespace pattern
if whitespace_pattern is None:
whitespace_pattern = WHITESPACE

if "properties" in instance:
regex = ""
regex += r"\{"
Expand All @@ -120,12 +134,12 @@ def to_regex(resolver: Resolver, instance: dict):
if any(is_required):
last_required_pos = max([i for i, value in enumerate(is_required) if value])
for i, (name, value) in enumerate(properties.items()):
subregex = f'{WHITESPACE}"{name}"{WHITESPACE}:{WHITESPACE}'
subregex += to_regex(resolver, value)
subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}'
subregex += to_regex(resolver, value, whitespace_pattern)
if i < last_required_pos:
subregex = f"{subregex}{WHITESPACE},"
subregex = f"{subregex}{whitespace_pattern},"
elif i > last_required_pos:
subregex = f"{WHITESPACE},{subregex}"
subregex = f"{whitespace_pattern},{subregex}"
regex += subregex if is_required[i] else f"({subregex})?"
# If no property is required, we have to create a possible pattern for each property in which
# it's the last one necessarilly present. Then, we add the others as optional before and after
Expand All @@ -134,41 +148,47 @@ def to_regex(resolver: Resolver, instance: dict):
else:
property_subregexes = []
for i, (name, value) in enumerate(properties.items()):
subregex = f'{WHITESPACE}"{name}"{WHITESPACE}:{WHITESPACE}'
subregex += to_regex(resolver, value)
subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}'
subregex += to_regex(resolver, value, whitespace_pattern)
property_subregexes.append(subregex)
possible_patterns = []
for i in range(len(property_subregexes)):
pattern = ""
for subregex in property_subregexes[:i]:
pattern += f"({subregex}{WHITESPACE},)?"
pattern += f"({subregex}{whitespace_pattern},)?"
pattern += property_subregexes[i]
for subregex in property_subregexes[i + 1 :]:
pattern += f"({WHITESPACE},{subregex})?"
pattern += f"({whitespace_pattern},{subregex})?"
possible_patterns.append(pattern)
regex += f"({'|'.join(possible_patterns)})?"

regex += f"{WHITESPACE}" + r"\}"
regex += f"{whitespace_pattern}" + r"\}"

return regex

# To validate against allOf, the given data must be valid against all of the
# given subschemas.
elif "allOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["allOf"]]
subregexes = [
to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"]
]
subregexes_str = [f"{subregex}" for subregex in subregexes]
return rf"({''.join(subregexes_str)})"

# To validate against `anyOf`, the given data must be valid against
# any (one or more) of the given subschemas.
elif "anyOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["anyOf"]]
subregexes = [
to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"]
]
return rf"({'|'.join(subregexes)})"

# To validate against oneOf, the given data must be valid against exactly
# one of the given subschemas.
elif "oneOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["oneOf"]]
subregexes = [
to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"]
]

xor_patterns = []
# json schema validation ensured there is no overlapping schemas in oneOf
Expand All @@ -195,7 +215,7 @@ def to_regex(resolver: Resolver, instance: dict):
elif "$ref" in instance:
path = f"{instance['$ref']}"
instance = resolver.lookup(path).contents
return to_regex(resolver, instance)
return to_regex(resolver, instance, whitespace_pattern)

# The type keyword may either be a string or an array:
# - If it's a string, it is the name of one of the basic types.
Expand Down Expand Up @@ -254,14 +274,14 @@ def to_regex(resolver: Resolver, instance: dict):
num_repeats = rf"{{{max(min_items - 1, 0)},}}"
else:
if max_items < 1:
return rf"\[{WHITESPACE}\]"
return rf"\[{whitespace_pattern}\]"
num_repeats = rf"{{{max(min_items - 1, 0)},{max_items - 1}}}"

allow_empty = "?" if min_items == 0 else ""

if "items" in instance:
items_regex = to_regex(resolver, instance["items"])
return rf"\[{WHITESPACE}(({items_regex})(,{WHITESPACE}({items_regex})){num_repeats}){allow_empty}{WHITESPACE}\]"
items_regex = to_regex(resolver, instance["items"], whitespace_pattern)
return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]"
else:
# Here we need to make the choice to exclude generating list of objects
# if the specification of the object is not given, even though a JSON
Expand All @@ -273,8 +293,8 @@ def to_regex(resolver: Resolver, instance: dict):
{"type": "integer"},
{"type": "string"},
]
regexes = [to_regex(resolver, t) for t in types]
return rf"\[{WHITESPACE}({'|'.join(regexes)})(,{WHITESPACE}({'|'.join(regexes)})){num_repeats}){allow_empty}{WHITESPACE}\]"
regexes = [to_regex(resolver, t, whitespace_pattern) for t in types]
return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}){allow_empty}{whitespace_pattern}\]"

elif instance_type == "boolean":
return type_to_regex["boolean"]
Expand All @@ -287,7 +307,9 @@ def to_regex(resolver: Resolver, instance: dict):
# if the specification of the object is not give, even though a JSON
# object that contains an object here would be valid under the specification.
regexes = [
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
to_regex(resolver, {"type": t}, whitespace_pattern)
for t in instance_type
if t != "object"
]
return rf"({'|'.join(regexes)})"

Expand Down
38 changes: 33 additions & 5 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json as pyjson
from functools import singledispatch
from typing import Callable, Union
from typing import Callable, Optional, Union

from pydantic import BaseModel

Expand All @@ -14,21 +14,49 @@

@singledispatch
def json(
model, schema_object: Union[str, object, Callable], sampler: Sampler = multinomial()
model,
schema_object: Union[str, object, Callable],
sampler: Sampler = multinomial(),
whitespace_pattern: Optional[str] = None,
) -> SequenceGenerator:
"""
Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
Parameters
----------
model:
An instance of `Transformer` that represents a model from the
`transformers` library.
schema_object:
The JSON Schema to generate data for. Can be a JSON string, a Pydantic model, or a callable
that returns a JSON schema.
max_tokens:
The maximum number of tokens to generate.
sampler:
The sampling algorithm to use to generate token ids from the logits
distribution.
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
Returns
-------
A `SequenceGenerator` instance that generates text constrained by the schema_object and
transforms the result if BaseModel is used.
"""
if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
regex_str = build_regex_from_object(schema)
regex_str = build_regex_from_object(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: schema_object.parse_raw(x)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_object(schema)
regex_str = build_regex_from_object(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
elif isinstance(schema_object, str):
schema = schema_object
regex_str = build_regex_from_object(schema)
regex_str = build_regex_from_object(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
else:
Expand Down
12 changes: 7 additions & 5 deletions outlines/serve/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import math
from collections import defaultdict
from typing import DefaultDict, List
from typing import DefaultDict, List, Optional

import torch

Expand Down Expand Up @@ -105,18 +105,20 @@ def convert_token_to_string(token: str) -> str:


class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema, llm):
"""Compile the FSM that drives the JSON-structured generation.
def __init__(self, schema, llm, whitespace_pattern: Optional[str] = None):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
llm
An instance of `vllm.LLM`
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
"""
if isinstance(schema, dict):
schema = json.dumps(schema)
regex_string = build_regex_from_object(schema)
regex_string = build_regex_from_object(schema, whitespace_pattern)
super().__init__(regex_string, llm)
29 changes: 29 additions & 0 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,32 @@ def test_format(schema, regex, examples):
assert match.span() == (0, len(string))
else:
assert match is None


@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]?", "abc"])
def test_json_schema_custom_whitespace_pattern(whitespace_pattern):
"""assert whitespace_pattern setting respected"""

class MockModel(BaseModel):
foo: int
bar: str

# assert any ws pattern can be used
if whitespace_pattern == "abc":
build_regex_from_object(MockModel, whitespace_pattern)
return

pattern = build_regex_from_object(MockModel, whitespace_pattern)

mock_result_mult_ws = (
"""{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}"""
)
mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}"""

match_default_ws = re.fullmatch(pattern, mock_result_mult_ws)
if whitespace_pattern is None:
assert match_default_ws
else:
assert match_default_ws is None

assert re.fullmatch(pattern, mock_result_maybe_ws)
24 changes: 24 additions & 0 deletions tests/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,30 @@ def test_transformers_logits_vocab_size():
assert sequence == "False"


def test_transformers_json_custom_ws():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name, device="cpu")
prompt = "Output some JSON with newlines" # try to force model to use newlines

schema = """{
"title": "spam",
"type": "object",
"properties": {
"foo" : {"type": "integer"},
"bar": {"type": "integer"}
},
"required": ["foo", "bar"]
}
"""

rng = torch.Generator()
rng.manual_seed(0)

generator = generate.json(model, schema, whitespace_pattern=r"[ ]?")
generator.format_sequence = lambda x: x # patch to return raw text
assert "\n" not in generator(prompt, max_tokens=500, rng=rng)


def test_transformers_reduced_vocabulary_caching():
tokenizer = TransformerTokenizer("gpt2")
tokenizer2 = TransformerTokenizer("gpt2")
Expand Down

0 comments on commit 9c74d7c

Please sign in to comment.