diff --git a/docs/reference/json.md b/docs/reference/json.md index 02ed5a24a..5285aa24e 100644 --- a/docs/reference/json.md +++ b/docs/reference/json.md @@ -1 +1,55 @@ -# JSON +# Make the LLM follow a JSON Schema + +Outlines can make any open source model return a JSON object that follows a structure that is specified by the user. This is useful whenever we want the output of the model to be processed by code downstream: code does not understand natural language but rather the structured language it has been programmed to understand. + +There are mostly two reasons why someone would want to get an output formatted as JSON from a LLM: + +1. Parse the answer (e.g. with Pydantic), store it somewhere, return it to a user, etc. +2. Call a function with the result + +Outlines has you covered in both cases! Indeed, to define the structure of the JSON you want the model to follow you can either provide a Pydantic model, or a function. No need to duplicate code! + +## Using Pydantic + +Outlines can infer the structure of the output from a Pydantic model. The result is an instance of the model that contains the values returned by the LLM: + +```python +from pydantic import BaseModel + +from outlines import models +from outlines import text + + +class User(BaseModel): + name: str + last_name: str + id: int + + +model = models.transformers("mistralai/Mistral-7B") +generator = text.generate.json(model, User) +result = generator("Create a user profile with the fields name, last_name and id") +print(result) +# User(name="John", last_name="Doe", id=11) +``` + +## From a function's signature + +Outlines can infer the structure of the output from the signature of a function. The result is a dictionary, and can be passed directly to the function using the usual dictionary expansion syntax `**`: + +```python +from outlines import models +from outlines import text + +def concat(a: int, b: int): + return a + b + +model = models.transformers("mistralai/Mistral-7B") +generator = text.generate.json(model, add) +result = generator("Return two integers named a and b respectively. a is odd and b even.") + +print(add(**result)) +# 3 +``` + +A great advantage of passing functions directly to specify the structure is that the structure of the LLM will change with the function's definition. No need to change the code at several places! diff --git a/examples/dating_profile.py b/examples/dating_profile.py index 485dfa7dd..228f13993 100644 --- a/examples/dating_profile.py +++ b/examples/dating_profile.py @@ -121,12 +121,9 @@ def dating_profile_prompt(description: str, examples: list[Example]): new_description = "I'm a laid-back lawyer who spends a lot of his free-time gaming. I work in a corporate office, but ended up here after the start-up I cofounded got acquired, so still play ping pong with my cool coworkers every day. I have a bar at home where I make cocktails, which is great for entertaining friends. I secretly like to wear suits and get a new one tailored every few months. I also like weddings because I get to wear those suits, and it's a good excuse for a date. I watch the latest series because I'm paying, with my hard-earned money, for every streaming service." prompt = dating_profile_prompt(description=new_description, examples=samples) -profile = text.generate.json(model, DatingProfile)(prompt) +profile = text.generate.json(model, DatingProfile)(prompt) # type: ignore print(profile) -parsed_profile = DatingProfile.model_validate_json(profile) -print(parsed_profile) - # Sample generated profiles """ { diff --git a/outlines/text/generate/regex.py b/outlines/text/generate/regex.py index af87f9f51..24eabffc1 100644 --- a/outlines/text/generate/regex.py +++ b/outlines/text/generate/regex.py @@ -1,6 +1,6 @@ +import json as pyjson import math -from json import dumps -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union import interegular import torch @@ -8,7 +8,7 @@ from outlines.text.fsm import create_fsm_index_tokenizer, make_deterministic_fsm from outlines.text.generate.continuation import Continuation -from outlines.text.json_schema import build_regex_from_schema +from outlines.text.json_schema import build_regex_from_object, get_schema_from_signature if TYPE_CHECKING: from outlines.text.generate.sample import Sampler @@ -48,6 +48,7 @@ def __init__( final_states: Optional[Set[int]] = None, states_to_token_maps: Optional[Dict[int, Dict[int, int]]] = None, empty_token_ids: Optional[Set[int]] = None, + format_fn: Callable[[str], Union[BaseModel, dict, str]] = lambda x: x, ): """ @@ -73,6 +74,8 @@ def __init__( corresponding FSM end states. empty_token_ids Pre-computed set of token ids for tokens that are empty strings. + format_fn + The function to apply to the generated JSON. """ super().__init__(model, max_tokens, sampler, stop) @@ -113,6 +116,7 @@ def __init__( self.mask_cache: Dict[Tuple[int, int], torch.LongTensor] = {} self.regex_string = regex_string self.allow_empty_tokens = allow_empty_tokens + self.format_fn = format_fn def create_proposal( self, generated_token_ids: torch.LongTensor, logits: torch.DoubleTensor @@ -215,9 +219,10 @@ def _get_mask_for_state( return mask - def postprocess_completions(self, completions: List[str]) -> List[str]: + def postprocess_completions(self, completions: List[str]): self.last_fsm_states.clear() - return super().postprocess_completions(completions) + results: List[str] = super().postprocess_completions(completions) + return [self.format_fn(result) for result in results] def regex( @@ -386,25 +391,26 @@ def choice( def json( model, - schema: Union[str, BaseModel], + schema_object: Union[str, BaseModel, Callable], max_tokens: Optional[int] = None, *, sampler: Optional["Sampler"] = None, allow_empty_tokens: bool = True, -): +) -> Union[dict, BaseModel]: """Generate a text sequence that follows a JSON schema or Pydantic model. .. note: Reuse instances of these guided generators whenever possible, because constructing them has more overhead than generating - token sequences from them. See the docstring for `Regex`. + token sequences from them. See the docstring for `Regex`. Parameters --------- model The language model to use to compute the next-token logits. schema - The JSON schema or Pydantic model that guides the generation. + The JSON schema, Pydantic model or function (signature) that guides the + generation. max_tokens The maximum number of tokens to generate. sampler @@ -416,10 +422,17 @@ def json( Allow sampling of tokens corresponding to empty strings. """ - if isinstance(schema, type(BaseModel)): - schema = dumps(schema.model_json_schema()) - - regex_str = build_regex_from_schema(schema) + if isinstance(schema_object, type(BaseModel)): + schema = pyjson.dumps(schema_object.model_json_schema()) + format_fn = lambda x: schema_object.model_validate(pyjson.loads(x)) + elif callable(schema_object): + schema = pyjson.dumps(get_schema_from_signature(schema_object)) + # TODO: Convert string fields to their respective types + format_fn = lambda x: pyjson.loads(x) + else: + format_fn = lambda x: x + + regex_str = build_regex_from_object(schema) return Regex( model, @@ -427,4 +440,5 @@ def json( max_tokens, sampler=sampler, allow_empty_tokens=allow_empty_tokens, + format_fn=format_fn, ) diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index f2fc351ed..4044d225a 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -1,8 +1,11 @@ +import inspect import itertools as it import json import re +from typing import Callable, Union from jsonschema.protocols import Validator +from pydantic import BaseModel, create_model from referencing import Registry, Resource from referencing._core import Resolver from referencing.jsonschema import DRAFT202012 @@ -23,23 +26,43 @@ } -def build_regex_from_schema(schema: str): +def build_regex_from_object(object: Union[str, Callable, BaseModel]): """Turn a JSON schema into a regex that matches any JSON object that follows this schema. + JSON Schema is a declarative language that allows to annotate JSON documents + with types and descriptions. These schemas can be generated from any Python + datastructure that has type annotation: namedtuples, dataclasses, Pydantic + models. And by ensuring that the generation respects the schema we ensure + that the output can be parsed into these objects. + This function parses the provided schema and builds a generation schedule which + mixes deterministic generation (fixed strings), and sampling with constraints. + Parameters ---------- schema - A string that contains the JSON schema. + A string that represents a JSON Schema. Returns ------- - A string that contains a regular expression that matches any JSON object that - follows the schema. + A generation schedule. A list of strings that represent the JSON + schema's structure and regular expression that define the structure of + the fields. + + References + ---------- + .. [0] JSON Schema. https://json-schema.org/ """ + + if isinstance(object, type(BaseModel)): + schema = object.model_json_schema() + elif callable(object): + schema = get_schema_from_signature(object) + else: + schema = json.loads(object) + Validator.check_schema(schema) - schema = json.loads(schema) # Build reference resolver schema = Resource(contents=schema, specification=DRAFT202012) @@ -214,3 +237,23 @@ def to_regex(resolver: Resolver, instance: dict): regular expression. Make sure it is valid to the JSON Schema specification. If it is, please open an issue on the Outlines repository""" ) + + +def get_schema_from_signature(fn: Callable) -> str: + """Turn a function signature into a JSON schema. + + Every JSON object valid to the output JSON Schema can be passed + to `fn` using the ** unpacking syntax. + + """ + signature = inspect.signature(fn) + arguments = {} + for name, arg in signature.parameters.items(): + if arg.annotation == inspect._empty: + raise ValueError("Each argument must have a type annotation") + else: + arguments[name] = (arg.annotation, ...) + + model = create_model("Arguments", **arguments) + + return model.model_json_schema() diff --git a/tests/text/generate/test_integration_transfomers.py b/tests/text/generate/test_integration_transfomers.py index d54ca884d..18d5e9eb8 100644 --- a/tests/text/generate/test_integration_transfomers.py +++ b/tests/text/generate/test_integration_transfomers.py @@ -1,4 +1,3 @@ -import json import re from enum import Enum from typing import List, Union @@ -75,7 +74,7 @@ def test_transformers_integration_integer(): rng.manual_seed(0) model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") + model = models.transformers(model_name) prompt = "Write a short sentence" sequence = generate.integer(model, max_tokens=10)(prompt, rng=rng) @@ -88,7 +87,7 @@ def test_transformers_integration_integer_array(): rng.manual_seed(0) model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") + model = models.transformers(model_name) prompts = ["Give me a number", "And another one"] sequence = generate.integer(model, max_tokens=10)(prompts, rng=rng) assert isinstance(sequence, list) @@ -102,7 +101,7 @@ def test_transformers_integration_float(): rng.manual_seed(0) model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" - model = models.transformers(model_name, device="cpu") + model = models.transformers(model_name) prompt = "Write a short sentence" sequence = generate.float(model, max_tokens=10)(prompt, rng=rng) @@ -143,13 +142,13 @@ class Spam(BaseModel): rng = torch.Generator() rng.manual_seed(0) # make sure that `bar` is not an int - sequence = generate.json(model, Spam, max_tokens=1000)(prompt, rng=rng) - parsed = json.loads(sequence) - assert isinstance(parsed["foo"], int) - assert isinstance(parsed["bar"], int) - assert isinstance(parsed["spam"], str) - assert isinstance(parsed["fuzz"], bool) - assert len(parsed["spam"]) == 10 + result = generate.json(model, Spam, max_tokens=1000)(prompt, rng=rng) + assert isinstance(result, BaseModel) + assert isinstance(result.foo, int) + assert isinstance(result.bar, float) + assert isinstance(result.spam, str) + assert isinstance(result.fuzz, bool) + assert len(result.spam) == 10 def test_transformers_json_str_enum(): @@ -169,10 +168,10 @@ class User(BaseModel): user_id: int name: Name - sequence = generate.json(model, User)(prompt, rng=rng) - parsed = json.loads(sequence) - assert isinstance(parsed["user_id"], int) - assert parsed["name"] in ["John", "Marc", "Michel"] + result = generate.json(model, User)(prompt, rng=rng) + assert isinstance(result, BaseModel) + assert isinstance(result.user_id, int) + assert result.name in ["John", "Marc", "Michel"] def test_transformers_json_int_enum(): @@ -190,10 +189,10 @@ class Id(int, Enum): class User(BaseModel): user_id: Id - sequence = generate.json(model, User)(prompt, rng=rng) - parsed = json.loads(sequence) - assert isinstance(parsed["user_id"], int) - assert parsed["user_id"] in [1, 2] + result = generate.json(model, User)(prompt, rng=rng) + assert isinstance(result, BaseModel) + assert isinstance(result.user_id, int) + assert result.user_id in [1, 2] def test_transformers_json_array(): @@ -208,11 +207,11 @@ class User(BaseModel): rng = torch.Generator() rng.manual_seed(0) - sequence = generate.json(model, User)(prompt, rng=rng) - parsed = json.loads(sequence) - assert isinstance(parsed["user_id"], int) - assert isinstance(parsed["value"], list) - for value in parsed["value"]: + result = generate.json(model, User)(prompt, rng=rng) + assert isinstance(result, BaseModel) + assert isinstance(result.user_id, int) + assert isinstance(result.value, list) + for value in result.value: assert isinstance(value, float) or isinstance(value, int) @@ -229,14 +228,30 @@ class Spam(BaseModel): rng.manual_seed(4) sequence = generate.json(model, Spam, max_tokens=100)(prompt, rng=rng) - parsed = json.loads(sequence) + assert isinstance(sequence, BaseModel) assert ( - isinstance(parsed["bar"], int) - or isinstance(parsed["bar"], float) - or isinstance(parsed["bar"], str) + isinstance(sequence.bar, int) + or isinstance(sequence.bar, float) + or isinstance(sequence.bar, str) ) +def test_transformers_json_function(): + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name) + prompt = "Output arguments for the function" + + def function(foo: int, bar: List[int]): + return foo + sum(bar) + + rng = torch.Generator() + rng.manual_seed(4) + + sequence = generate.json(model, function, max_tokens=100)(prompt, rng=rng) + assert isinstance(sequence, dict) + assert isinstance(function(**sequence), int) + + def test_transformers_logits_vocab_size(): model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" model = models.transformers(model_name, device="cpu") diff --git a/tests/text/test_fsm.py b/tests/text/test_fsm.py index ce4a3647b..f10f0f816 100644 --- a/tests/text/test_fsm.py +++ b/tests/text/test_fsm.py @@ -429,7 +429,7 @@ def test_json_index_performance(): from pydantic import BaseModel, constr import outlines.models as models - from outlines.text.generate.regex import Regex, build_regex_from_schema + from outlines.text.generate.regex import Regex, build_regex_from_object class Weapon(str, Enum): sword = "sword" @@ -457,7 +457,7 @@ class Character(BaseModel): json_schema = json.dumps(Character.model_json_schema()) def build_regex(): - regex_str = build_regex_from_schema(json_schema) + regex_str = build_regex_from_object(json_schema) Regex(model, regex_str, 100) profiler = LineProfiler(create_fsm_index_end_to_end) diff --git a/tests/text/test_json_schema.py b/tests/text/test_json_schema.py index b84ac39fa..380392980 100644 --- a/tests/text/test_json_schema.py +++ b/tests/text/test_json_schema.py @@ -1,5 +1,6 @@ import json import re +from typing import List import pytest from pydantic import BaseModel, constr @@ -11,11 +12,32 @@ NUMBER, STRING, STRING_INNER, - build_regex_from_schema, + build_regex_from_object, + get_schema_from_signature, to_regex, ) +def test_function_basic(): + def test_function(foo: str, bar: List[int]): + ... + + result = get_schema_from_signature(test_function) + assert result["type"] == "object" + assert list(result["properties"].keys()) == ["foo", "bar"] + assert result["properties"]["foo"]["type"] == "string" + assert result["properties"]["bar"]["type"] == "array" + assert result["properties"]["bar"]["items"]["type"] == "integer" + + +def test_function_no_type(): + def test_function(foo, bar: List[int]): + ... + + with pytest.raises(ValueError): + get_schema_from_signature(test_function) + + def test_from_pydantic(): class User(BaseModel): user_id: int @@ -26,7 +48,7 @@ class User(BaseModel): is_true: bool schema = json.dumps(User.model_json_schema()) - schedule = build_regex_from_schema(schema) + schedule = build_regex_from_object(schema) assert isinstance(schedule, str) @@ -316,7 +338,7 @@ def test_match_number(pattern, does_match): ) def test_match(schema, regex, examples): schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) + test_regex = build_regex_from_object(schema) assert test_regex == regex for string, does_match in examples: