Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build a JSON schema from a function's signature #355

Merged
merged 3 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion docs/reference/json.md
Original file line number Diff line number Diff line change
@@ -1 +1,55 @@
# JSON
# Make the LLM follow a JSON Schema

Outlines can make any open source model return a JSON object that follows a structure that is specified by the user. This is useful whenever we want the output of the model to be processed by code downstream: code does not understand natural language but rather the structured language it has been programmed to understand.

There are mostly two reasons why someone would want to get an output formatted as JSON from a LLM:

1. Parse the answer (e.g. with Pydantic), store it somewhere, return it to a user, etc.
2. Call a function with the result

Outlines has you covered in both cases! Indeed, to define the structure of the JSON you want the model to follow you can either provide a Pydantic model, or a function. No need to duplicate code!

## Using Pydantic

Outlines can infer the structure of the output from a Pydantic model. The result is an instance of the model that contains the values returned by the LLM:

```python
from pydantic import BaseModel

from outlines import models
from outlines import text


class User(BaseModel):
name: str
last_name: str
id: int


model = models.transformers("mistralai/Mistral-7B")
generator = text.generate.json(model, User)
result = generator("Create a user profile with the fields name, last_name and id")
print(result)
# User(name="John", last_name="Doe", id=11)
```

## From a function's signature

Outlines can infer the structure of the output from the signature of a function. The result is a dictionary, and can be passed directly to the function using the usual dictionary expansion syntax `**`:

```python
from outlines import models
from outlines import text

def concat(a: int, b: int):
return a + b

model = models.transformers("mistralai/Mistral-7B")
generator = text.generate.json(model, add)
result = generator("Return two integers named a and b respectively. a is odd and b even.")

print(add(**result))
# 3
```

A great advantage of passing functions directly to specify the structure is that the structure of the LLM will change with the function's definition. No need to change the code at several places!
5 changes: 1 addition & 4 deletions examples/dating_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,9 @@ def dating_profile_prompt(description: str, examples: list[Example]):
new_description = "I'm a laid-back lawyer who spends a lot of his free-time gaming. I work in a corporate office, but ended up here after the start-up I cofounded got acquired, so still play ping pong with my cool coworkers every day. I have a bar at home where I make cocktails, which is great for entertaining friends. I secretly like to wear suits and get a new one tailored every few months. I also like weddings because I get to wear those suits, and it's a good excuse for a date. I watch the latest series because I'm paying, with my hard-earned money, for every streaming service."

prompt = dating_profile_prompt(description=new_description, examples=samples)
profile = text.generate.json(model, DatingProfile)(prompt)
profile = text.generate.json(model, DatingProfile)(prompt) # type: ignore
print(profile)

parsed_profile = DatingProfile.model_validate_json(profile)
print(parsed_profile)

# Sample generated profiles
"""
{
Expand Down
40 changes: 27 additions & 13 deletions outlines/text/generate/regex.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json as pyjson
import math
from json import dumps
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union

import interegular
import torch
from pydantic import BaseModel

from outlines.text.fsm import create_fsm_index_tokenizer, make_deterministic_fsm
from outlines.text.generate.continuation import Continuation
from outlines.text.json_schema import build_regex_from_schema
from outlines.text.json_schema import build_regex_from_object, get_schema_from_signature

if TYPE_CHECKING:
from outlines.text.generate.sample import Sampler
Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(
final_states: Optional[Set[int]] = None,
states_to_token_maps: Optional[Dict[int, Dict[int, int]]] = None,
empty_token_ids: Optional[Set[int]] = None,
format_fn: Callable[[str], Union[BaseModel, dict, str]] = lambda x: x,
):
"""

Expand All @@ -73,6 +74,8 @@ def __init__(
corresponding FSM end states.
empty_token_ids
Pre-computed set of token ids for tokens that are empty strings.
format_fn
The function to apply to the generated JSON.

"""
super().__init__(model, max_tokens, sampler, stop)
Expand Down Expand Up @@ -113,6 +116,7 @@ def __init__(
self.mask_cache: Dict[Tuple[int, int], torch.LongTensor] = {}
self.regex_string = regex_string
self.allow_empty_tokens = allow_empty_tokens
self.format_fn = format_fn

def create_proposal(
self, generated_token_ids: torch.LongTensor, logits: torch.DoubleTensor
Expand Down Expand Up @@ -215,9 +219,10 @@ def _get_mask_for_state(

return mask

def postprocess_completions(self, completions: List[str]) -> List[str]:
def postprocess_completions(self, completions: List[str]):
self.last_fsm_states.clear()
return super().postprocess_completions(completions)
results: List[str] = super().postprocess_completions(completions)
return [self.format_fn(result) for result in results]


def regex(
Expand Down Expand Up @@ -386,25 +391,26 @@ def choice(

def json(
model,
schema: Union[str, BaseModel],
schema_object: Union[str, BaseModel, Callable],
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
allow_empty_tokens: bool = True,
):
) -> Union[dict, BaseModel]:
"""Generate a text sequence that follows a JSON schema or Pydantic model.

.. note:
Reuse instances of these guided generators whenever possible,
because constructing them has more overhead than generating
token sequences from them. See the docstring for `Regex`.
token sequences from them. See the docstring for `Regex`.

Parameters
---------
model
The language model to use to compute the next-token logits.
schema
The JSON schema or Pydantic model that guides the generation.
The JSON schema, Pydantic model or function (signature) that guides the
generation.
max_tokens
The maximum number of tokens to generate.
sampler
Expand All @@ -416,15 +422,23 @@ def json(
Allow sampling of tokens corresponding to empty strings.

"""
if isinstance(schema, type(BaseModel)):
schema = dumps(schema.model_json_schema())

regex_str = build_regex_from_schema(schema)
if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
format_fn = lambda x: schema_object.model_validate(pyjson.loads(x))
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
# TODO: Convert string fields to their respective types
format_fn = lambda x: pyjson.loads(x)
else:
format_fn = lambda x: x

regex_str = build_regex_from_object(schema)

return Regex(
model,
regex_str,
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
format_fn=format_fn,
)
53 changes: 48 additions & 5 deletions outlines/text/json_schema.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import inspect
import itertools as it
import json
import re
from typing import Callable, Union

from jsonschema.protocols import Validator
from pydantic import BaseModel, create_model
from referencing import Registry, Resource
from referencing._core import Resolver
from referencing.jsonschema import DRAFT202012
Expand All @@ -23,23 +26,43 @@
}


def build_regex_from_schema(schema: str):
def build_regex_from_object(object: Union[str, Callable, BaseModel]):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.

JSON Schema is a declarative language that allows to annotate JSON documents
with types and descriptions. These schemas can be generated from any Python
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
models. And by ensuring that the generation respects the schema we ensure
that the output can be parsed into these objects.
This function parses the provided schema and builds a generation schedule which
mixes deterministic generation (fixed strings), and sampling with constraints.

Parameters
----------
schema
A string that contains the JSON schema.
A string that represents a JSON Schema.

Returns
-------
A string that contains a regular expression that matches any JSON object that
follows the schema.
A generation schedule. A list of strings that represent the JSON
schema's structure and regular expression that define the structure of
the fields.

References
----------
.. [0] JSON Schema. https://json-schema.org/

"""

if isinstance(object, type(BaseModel)):
schema = object.model_json_schema()
elif callable(object):
schema = get_schema_from_signature(object)
else:
schema = json.loads(object)

Validator.check_schema(schema)
schema = json.loads(schema)

# Build reference resolver
schema = Resource(contents=schema, specification=DRAFT202012)
Expand Down Expand Up @@ -214,3 +237,23 @@ def to_regex(resolver: Resolver, instance: dict):
regular expression. Make sure it is valid to the JSON Schema specification. If
it is, please open an issue on the Outlines repository"""
)


def get_schema_from_signature(fn: Callable) -> str:
"""Turn a function signature into a JSON schema.

Every JSON object valid to the output JSON Schema can be passed
to `fn` using the ** unpacking syntax.

"""
signature = inspect.signature(fn)
arguments = {}
for name, arg in signature.parameters.items():
if arg.annotation == inspect._empty:
raise ValueError("Each argument must have a type annotation")
else:
arguments[name] = (arg.annotation, ...)

model = create_model("Arguments", **arguments)

return model.model_json_schema()
Loading