From 972e957a0f1d62ba9d56a3be6b09e9fddef51469 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Sat, 4 May 2024 13:12:09 +0800 Subject: [PATCH 1/5] add schema optimization --- outlines/fsm/json_schema.py | 82 ++++++++++++++++++++++++++- outlines/generate/json.py | 21 +++++-- outlines/integrations/llamacpp.py | 11 +++- outlines/integrations/transformers.py | 11 +++- outlines/integrations/vllm.py | 11 +++- outlines/serve/serve.py | 11 +++- 6 files changed, 138 insertions(+), 9 deletions(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index c57cea7cd..c2185c226 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -2,6 +2,7 @@ import json import re import warnings +from copy import deepcopy from typing import Callable, Optional from jsonschema.protocols import Validator @@ -39,7 +40,11 @@ } -def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): +def build_regex_from_schema( + schema: str, + whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, +): """Turn a JSON schema into a regex that matches any JSON object that follows this schema. @@ -58,6 +63,12 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization: + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. Returns ------- @@ -81,9 +92,78 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non resolver = registry.resolver() content = schema.contents + if enable_schema_optimization: + content = optimize_schema(content) return to_regex(resolver, content, whitespace_pattern) +def is_null_type(instance: dict): + if "type" in instance and (instance["type"] == "null" or instance["type"] is None): + return True + if "const" in instance and ( + instance["const"] == "null" or instance["const"] is None + ): + return True + return False + + +def any_of_list_has_null_type(any_of_list: list[dict[str, str]]): + for subinstance in any_of_list: + if is_null_type(subinstance): + return True + return False + + +def optimize_schema(instance): + instance_copy = deepcopy(instance) + if "$defs" in instance_copy: + instance_copy["$defs"] = { + key: optimize_schema(subinstance) + for key, subinstance in instance_copy["$defs"].items() + } + if "properties" in instance_copy: + new_optional_keys = set() + keys_to_remove = set() + for key, subinstance in instance_copy["properties"].items(): + subinstance = optimize_schema(subinstance) + if "type" in subinstance: + subinstance_type = subinstance["type"] + if subinstance_type == "null": + keys_to_remove.add(key) + elif ( + subinstance_type == "array" and subinstance.get("minItems", 0) == 0 + ): + new_optional_keys.add(key) + elif "anyOf" in subinstance and any_of_list_has_null_type( + subinstance["anyOf"] + ): + any_of_list = subinstance.pop("anyOf") + filtered_any_of_list = list( + filter(lambda d: is_null_type(d), any_of_list) + ) + if len(filtered_any_of_list) == 0: + keys_to_remove.add(key) + elif len(filtered_any_of_list) == 1: + subinstance = {**subinstance, **filtered_any_of_list[0]} + instance_copy["properties"][key] = subinstance + new_optional_keys.add(key) + else: + subinstance["anyOf"] = filtered_any_of_list + new_optional_keys.add(key) + if "required" in instance_copy: + instance_copy["required"] = [ + key + for key in instance_copy["required"] + if key not in new_optional_keys and key not in keys_to_remove + ] + instance_copy["properties"] = { + key: value + for key, value in instance_copy["properties"].items() + if key not in keys_to_remove + } + return instance_copy + + def _get_num_items_pattern(min_items, max_items, whitespace_pattern): # Helper function for arrays and objects min_items = int(min_items or 0) diff --git a/outlines/generate/json.py b/outlines/generate/json.py index 3837f72b6..b18b958d4 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -18,6 +18,7 @@ def json( schema_object: Union[str, object, Callable], sampler: Sampler = multinomial(), whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, ) -> SequenceGenerator: """ Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. @@ -33,9 +34,15 @@ def json( sampler: The sampling algorithm to use to generate token ids from the logits distribution. - whitespace_pattern + whitespace_pattern: Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization: + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. Returns ------- @@ -45,17 +52,23 @@ def json( """ if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema( + schema, whitespace_pattern, enable_schema_optimization + ) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: schema_object.parse_raw(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema( + schema, whitespace_pattern, enable_schema_optimization + ) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) elif isinstance(schema_object, str): schema = schema_object - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema( + schema, whitespace_pattern, enable_schema_optimization + ) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) else: diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py index 4041c54fb..763dd3a1e 100644 --- a/outlines/integrations/llamacpp.py +++ b/outlines/integrations/llamacpp.py @@ -171,6 +171,7 @@ def __init__( schema: Union[dict, Type[BaseModel], str], llm: "Llama", whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, ): """Compile the FSM that drives the JSON-guided generation. @@ -184,9 +185,17 @@ def __init__( Pattern to use for JSON syntactic whitespace (doesn't impact string literals). For example, to allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. """ schema_str = convert_json_schema_to_str(json_schema=schema) - regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + regex_string = build_regex_from_schema( + schema_str, whitespace_pattern, enable_schema_optimization + ) super().__init__(regex_string=regex_string, llm=llm) diff --git a/outlines/integrations/transformers.py b/outlines/integrations/transformers.py index 7c1bafd22..c01d3c86c 100644 --- a/outlines/integrations/transformers.py +++ b/outlines/integrations/transformers.py @@ -140,6 +140,7 @@ def __init__( schema: Union[dict, Type[BaseModel], str], tokenizer_or_pipe: Union[PreTrainedTokenizerBase, Pipeline], whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, ): """Compile the FSM that drives the JSON-guided generation. @@ -153,7 +154,15 @@ def __init__( Pattern to use for JSON syntactic whitespace (doesn't impact string literals). For example, to allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization: + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. """ schema_str = convert_json_schema_to_str(json_schema=schema) - regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + regex_string = build_regex_from_schema( + schema_str, whitespace_pattern, enable_schema_optimization + ) super().__init__(regex_string=regex_string, tokenizer_or_pipe=tokenizer_or_pipe) diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py index 6ed56d71b..69cb2a8f8 100644 --- a/outlines/integrations/vllm.py +++ b/outlines/integrations/vllm.py @@ -132,6 +132,7 @@ def __init__( schema: Union[dict, Type[BaseModel], str], llm: "LLM", whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, ): """Compile the FSM that drives the JSON-guided generation. @@ -145,7 +146,15 @@ def __init__( Pattern to use for JSON syntactic whitespace (doesn't impact string literals). For example, to allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization: + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. """ schema_str = convert_json_schema_to_str(json_schema=schema) - regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + regex_string = build_regex_from_schema( + schema_str, whitespace_pattern, enable_schema_optimization + ) super().__init__(regex_string=regex_string, llm=llm) diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index fb8c80139..ef885988b 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -68,8 +68,17 @@ async def generate(request: Request) -> Response: json_schema = request_dict.pop("schema", None) regex_string = request_dict.pop("regex", None) + whitespace_pattern = request_dict.pop("whitespace_pattern", None) + enable_schema_optimization = request_dict.pop("enable_schema_optimization", False) if json_schema is not None: - logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] + logits_processors = [ + JSONLogitsProcessor( + json_schema, + engine.engine, + whitespace_pattern, + enable_schema_optimization, + ) + ] elif regex_string is not None: logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)] else: From 5b4ca39255cea275018926ad60dc94e22e4b6a14 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Sat, 4 May 2024 13:39:23 +0800 Subject: [PATCH 2/5] add test --- tests/fsm/test_json_schema.py | 172 ++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index edc061bec..50b97d6b8 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -20,6 +20,7 @@ WHITESPACE, build_regex_from_schema, get_schema_from_signature, + optimize_schema, to_regex, ) @@ -777,3 +778,174 @@ class Model(BaseModel): # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() interegular.parse_pattern(pattern).to_fsm() + + +@pytest.mark.parametrize( + "schema,expected_schema", + [ + # No optimizations possible + ( + { + "properties": {"field_a": {"title": "Field A", "type": "integer"}}, + "required": ["field_a"], + "title": "Test", + "type": "object", + }, + { + "properties": {"field_a": {"title": "Field A", "type": "integer"}}, + "required": ["field_a"], + "title": "Test", + "type": "object", + }, + ), + # Makes fields with null type in anyOf optional + # and removes null fields + ( + { + "properties": { + "field_a": {"title": "Field A", "type": "integer"}, + "field_b": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "Field B", + }, + "field_c": {"title": "Field C", "type": "null"}, + }, + "required": ["field_a", "field_b", "field_c"], + "title": "Test", + "type": "object", + }, + { + "properties": { + "field_a": {"title": "Field A", "type": "integer"}, + "field_b": {"title": "Field B", "type": "integer"}, + }, + "required": ["field_a"], + "title": "Test", + "type": "object", + }, + ), + # Multilevel example + ( + { + "$defs": { + "TestCell": { + "properties": { + "g": {"title": "G", "type": "integer"}, + "h": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "H", + }, + }, + "required": ["g", "h"], + "title": "TestCell", + "type": "object", + }, + "TestLineItem": { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "B", + }, + "c": {"title": "C", "type": "string"}, + "d": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "D", + }, + "e": {"title": "E", "type": "null"}, + "f": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "F", + }, + "i": {"$ref": "#/$defs/TestCell"}, + }, + "required": ["a", "b", "c", "d", "e", "f", "i"], + "title": "TestLineItem", + "type": "object", + }, + }, + "properties": { + "line_items": { + "anyOf": [ + { + "items": {"$ref": "#/$defs/TestLineItem"}, + "type": "array", + }, + {"type": "null"}, + ], + "title": "Line Items", + } + }, + "required": ["line_items"], + "title": "TestTable", + "type": "object", + }, + { + "$defs": { + "TestCell": { + "properties": { + "g": {"title": "G", "type": "integer"}, + "h": {"title": "H", "type": "string"}, + }, + "required": ["g"], + "title": "TestCell", + "type": "object", + }, + "TestLineItem": { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "integer"}, + "c": {"title": "C", "type": "string"}, + "d": {"title": "D", "type": "string"}, + "f": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "F", + }, + "i": {"$ref": "#/$defs/TestCell"}, + }, + "required": ["a", "c", "f", "i"], + "title": "TestLineItem", + "type": "object", + }, + }, + "properties": { + "line_items": { + "title": "Line Items", + "items": {"$ref": "#/$defs/TestLineItem"}, + "type": "array", + }, + }, + "required": [], + "title": "TestTable", + "type": "object", + }, + ), + # From function signature + ( + { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "B", + }, + }, + "required": ["a", "b"], + "title": "Arguments", + "type": "object", + }, + { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "integer"}, + }, + "required": ["a"], + "title": "Arguments", + "type": "object", + }, + ), + ], +) +def test_json_schema_optimization(schema: dict, expected_schema: dict): + optimized_schema = optimize_schema(schema) + assert optimized_schema == expected_schema From cb4606350097b0f1b1bdaef9c9b7515be50fc004 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Sat, 4 May 2024 13:51:13 +0800 Subject: [PATCH 3/5] improve names of util funcs --- outlines/fsm/json_schema.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index c2185c226..1e7e82fed 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -97,7 +97,7 @@ def build_regex_from_schema( return to_regex(resolver, content, whitespace_pattern) -def is_null_type(instance: dict): +def _is_null_type(instance: dict): if "type" in instance and (instance["type"] == "null" or instance["type"] is None): return True if "const" in instance and ( @@ -107,9 +107,9 @@ def is_null_type(instance: dict): return False -def any_of_list_has_null_type(any_of_list: list[dict[str, str]]): - for subinstance in any_of_list: - if is_null_type(subinstance): +def _has_null_type(instance_list: list[dict]): + for instance in instance_list: + if _is_null_type(instance): return True return False @@ -134,12 +134,10 @@ def optimize_schema(instance): subinstance_type == "array" and subinstance.get("minItems", 0) == 0 ): new_optional_keys.add(key) - elif "anyOf" in subinstance and any_of_list_has_null_type( - subinstance["anyOf"] - ): + elif "anyOf" in subinstance and _has_null_type(subinstance["anyOf"]): any_of_list = subinstance.pop("anyOf") filtered_any_of_list = list( - filter(lambda d: is_null_type(d), any_of_list) + filter(lambda d: _is_null_type(d), any_of_list) ) if len(filtered_any_of_list) == 0: keys_to_remove.add(key) From edf0320fd8758c98280348b9af09c33ae94e8d2e Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Sat, 4 May 2024 16:45:08 +0800 Subject: [PATCH 4/5] fix typing to support python versions < 3.9 --- outlines/fsm/json_schema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 1e7e82fed..7efcd2b5f 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -3,7 +3,7 @@ import re import warnings from copy import deepcopy -from typing import Callable, Optional +from typing import Callable, List, Optional from jsonschema.protocols import Validator from pydantic import create_model @@ -107,7 +107,7 @@ def _is_null_type(instance: dict): return False -def _has_null_type(instance_list: list[dict]): +def _has_null_type(instance_list: List[dict]): for instance in instance_list: if _is_null_type(instance): return True From dbf193ecc04a08b63d3c12bfd3766b09be3f5222 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Sat, 4 May 2024 19:24:53 +0800 Subject: [PATCH 5/5] bugfix: exclude null types in anyOf list --- outlines/fsm/json_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 7efcd2b5f..aa8e9f79b 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -137,7 +137,7 @@ def optimize_schema(instance): elif "anyOf" in subinstance and _has_null_type(subinstance["anyOf"]): any_of_list = subinstance.pop("anyOf") filtered_any_of_list = list( - filter(lambda d: _is_null_type(d), any_of_list) + filter(lambda d: not _is_null_type(d), any_of_list) ) if len(filtered_any_of_list) == 0: keys_to_remove.add(key)