Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds option for JSON schema optimization #863

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import json
import re
import warnings
from typing import Callable, Optional
from copy import deepcopy
from typing import Callable, List, Optional

from jsonschema.protocols import Validator
from pydantic import create_model
Expand Down Expand Up @@ -39,7 +40,11 @@
}


def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None):
def build_regex_from_schema(
schema: str,
whitespace_pattern: Optional[str] = None,
enable_schema_optimization: bool = False,
):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.

Expand All @@ -58,6 +63,12 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
enable_schema_optimization:
If True, this will speed up generation by not requiring optional keys to be
present in the output. This is especially useful for large schemas with many
optional keys. Note though that this further restricts the support
distribution. Thus, it is necessary to remove the optional keys from the
finetuning dataset as well if needed. Hence, we set this to False by default.

Returns
-------
Expand All @@ -81,9 +92,76 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non
resolver = registry.resolver()

content = schema.contents
if enable_schema_optimization:
content = optimize_schema(content)
return to_regex(resolver, content, whitespace_pattern)


def _is_null_type(instance: dict):
if "type" in instance and (instance["type"] == "null" or instance["type"] is None):
return True
if "const" in instance and (
instance["const"] == "null" or instance["const"] is None
):
return True
return False


def _has_null_type(instance_list: List[dict]):
for instance in instance_list:
if _is_null_type(instance):
return True
return False


def optimize_schema(instance):
instance_copy = deepcopy(instance)
if "$defs" in instance_copy:
instance_copy["$defs"] = {
key: optimize_schema(subinstance)
for key, subinstance in instance_copy["$defs"].items()
}
if "properties" in instance_copy:
new_optional_keys = set()
keys_to_remove = set()
for key, subinstance in instance_copy["properties"].items():
subinstance = optimize_schema(subinstance)
if "type" in subinstance:
subinstance_type = subinstance["type"]
if subinstance_type == "null":
keys_to_remove.add(key)
elif (
subinstance_type == "array" and subinstance.get("minItems", 0) == 0
):
new_optional_keys.add(key)
elif "anyOf" in subinstance and _has_null_type(subinstance["anyOf"]):
any_of_list = subinstance.pop("anyOf")
filtered_any_of_list = list(
filter(lambda d: not _is_null_type(d), any_of_list)
)
if len(filtered_any_of_list) == 0:
keys_to_remove.add(key)
elif len(filtered_any_of_list) == 1:
subinstance = {**subinstance, **filtered_any_of_list[0]}
instance_copy["properties"][key] = subinstance
new_optional_keys.add(key)
else:
subinstance["anyOf"] = filtered_any_of_list
new_optional_keys.add(key)
if "required" in instance_copy:
instance_copy["required"] = [
key
for key in instance_copy["required"]
if key not in new_optional_keys and key not in keys_to_remove
]
instance_copy["properties"] = {
key: value
for key, value in instance_copy["properties"].items()
if key not in keys_to_remove
}
return instance_copy


def _get_num_items_pattern(min_items, max_items, whitespace_pattern):
# Helper function for arrays and objects
min_items = int(min_items or 0)
Expand Down
21 changes: 17 additions & 4 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def json(
schema_object: Union[str, object, Callable],
sampler: Sampler = multinomial(),
whitespace_pattern: Optional[str] = None,
enable_schema_optimization: bool = False,
) -> SequenceGenerator:
"""
Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
Expand All @@ -33,9 +34,15 @@ def json(
sampler:
The sampling algorithm to use to generate token ids from the logits
distribution.
whitespace_pattern
whitespace_pattern:
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
enable_schema_optimization:
If True, this will speed up generation by not requiring optional keys to be
present in the output. This is especially useful for large schemas with many
optional keys. Note though that this further restricts the support
distribution. Thus, it is necessary to remove the optional keys from the
finetuning dataset as well if needed. Hence, we set this to False by default.

Returns
-------
Expand All @@ -45,17 +52,23 @@ def json(
"""
if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
regex_str = build_regex_from_schema(schema, whitespace_pattern)
regex_str = build_regex_from_schema(
schema, whitespace_pattern, enable_schema_optimization
)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: schema_object.parse_raw(x)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_schema(schema, whitespace_pattern)
regex_str = build_regex_from_schema(
schema, whitespace_pattern, enable_schema_optimization
)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
elif isinstance(schema_object, str):
schema = schema_object
regex_str = build_regex_from_schema(schema, whitespace_pattern)
regex_str = build_regex_from_schema(
schema, whitespace_pattern, enable_schema_optimization
)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
else:
Expand Down
11 changes: 10 additions & 1 deletion outlines/integrations/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
schema: Union[dict, Type[BaseModel], str],
llm: "Llama",
whitespace_pattern: Optional[str] = None,
enable_schema_optimization: bool = False,
):
"""Compile the FSM that drives the JSON-guided generation.

Expand All @@ -184,9 +185,17 @@ def __init__(
Pattern to use for JSON syntactic whitespace (doesn't impact string
literals). For example, to allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
enable_schema_optimization
If True, this will speed up generation by not requiring optional keys to be
present in the output. This is especially useful for large schemas with many
optional keys. Note though that this further restricts the support
distribution. Thus, it is necessary to remove the optional keys from the
finetuning dataset as well if needed. Hence, we set this to False by default.
"""
schema_str = convert_json_schema_to_str(json_schema=schema)
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
regex_string = build_regex_from_schema(
schema_str, whitespace_pattern, enable_schema_optimization
)
super().__init__(regex_string=regex_string, llm=llm)


Expand Down
11 changes: 10 additions & 1 deletion outlines/integrations/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
schema: Union[dict, Type[BaseModel], str],
tokenizer_or_pipe: Union[PreTrainedTokenizerBase, Pipeline],
whitespace_pattern: Optional[str] = None,
enable_schema_optimization: bool = False,
):
"""Compile the FSM that drives the JSON-guided generation.

Expand All @@ -153,7 +154,15 @@ def __init__(
Pattern to use for JSON syntactic whitespace (doesn't impact string
literals). For example, to allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
enable_schema_optimization:
If True, this will speed up generation by not requiring optional keys to be
present in the output. This is especially useful for large schemas with many
optional keys. Note though that this further restricts the support
distribution. Thus, it is necessary to remove the optional keys from the
finetuning dataset as well if needed. Hence, we set this to False by default.
"""
schema_str = convert_json_schema_to_str(json_schema=schema)
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
regex_string = build_regex_from_schema(
schema_str, whitespace_pattern, enable_schema_optimization
)
super().__init__(regex_string=regex_string, tokenizer_or_pipe=tokenizer_or_pipe)
11 changes: 10 additions & 1 deletion outlines/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
schema: Union[dict, Type[BaseModel], str],
llm: "LLM",
whitespace_pattern: Optional[str] = None,
enable_schema_optimization: bool = False,
):
"""Compile the FSM that drives the JSON-guided generation.

Expand All @@ -145,7 +146,15 @@ def __init__(
Pattern to use for JSON syntactic whitespace (doesn't impact string
literals). For example, to allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
enable_schema_optimization:
If True, this will speed up generation by not requiring optional keys to be
present in the output. This is especially useful for large schemas with many
optional keys. Note though that this further restricts the support
distribution. Thus, it is necessary to remove the optional keys from the
finetuning dataset as well if needed. Hence, we set this to False by default.
"""
schema_str = convert_json_schema_to_str(json_schema=schema)
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
regex_string = build_regex_from_schema(
schema_str, whitespace_pattern, enable_schema_optimization
)
super().__init__(regex_string=regex_string, llm=llm)
11 changes: 10 additions & 1 deletion outlines/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,17 @@ async def generate(request: Request) -> Response:

json_schema = request_dict.pop("schema", None)
regex_string = request_dict.pop("regex", None)
whitespace_pattern = request_dict.pop("whitespace_pattern", None)
enable_schema_optimization = request_dict.pop("enable_schema_optimization", False)
if json_schema is not None:
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)]
logits_processors = [
JSONLogitsProcessor(
json_schema,
engine.engine,
whitespace_pattern,
enable_schema_optimization,
)
]
elif regex_string is not None:
logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)]
else:
Expand Down
Loading
Loading