diff --git a/docs/reference/format.md b/docs/reference/format.md new file mode 100644 index 000000000..52b0dc5ec --- /dev/null +++ b/docs/reference/format.md @@ -0,0 +1,23 @@ +# Type constraints + +We can ask completions to be restricted to valid python types: + +```python +from outlines import models, generate + +model = models.transformers("mistralai/Mistral-7B-v0.1") +generator = generate.format(model, int) +answer = generator("When I was 6 my sister was half my age. Now I’m 70 how old is my sister?") +print(answer) +# 67 +``` + +The following types are currently available: + +- int +- float +- bool +- datetime.date +- datetime.time +- datetime.datetime +- We also provide [custom types](types.md) diff --git a/docs/reference/json.md b/docs/reference/json.md index b3c89fa90..3b5976f19 100644 --- a/docs/reference/json.md +++ b/docs/reference/json.md @@ -46,6 +46,9 @@ print(result) `generation.json` computes an index that helps Outlines guide generation. This can take some time, but only needs to be done once. If you want to generate several times with the same schema make sure that you only call `generate.json` once. +!!! Tip "Custom types" + + Outlines provides [custom Pydantic types](types.md) so you do not have to write regular expressions for common types, such as phone numbers or zip codes. ## Using a JSON Schema diff --git a/docs/reference/types.md b/docs/reference/types.md index a82d521ae..82197ae5f 100644 --- a/docs/reference/types.md +++ b/docs/reference/types.md @@ -1,22 +1,49 @@ -# Type constraints +# Custom types -We can ask completions to be restricted to valid python types: +Outlines provides custom Pydantic types so you can focus on your use case rather than on writing regular expressions: + +- Using `outlines.types.ZipCode` will generate valid US Zip(+4) codes. +- Using `outlines.types.PhoneNumber` will generate valid US phone numbers. + +You can use these types in Pydantic schemas for JSON-structured generation: + +```python +from pydantic import BaseModel + +from outlines import models, generate, types + + +class Client(BaseModel): + name: str + phone_number: types.PhoneNumber + zip_code: types.ZipCode + + +model = models.transformers("mistralai/Mistral-7B-v0.1") +generator = generate.json(model, Client) +result = generator( + "Create a client profile with the fields name, phone_number and zip_code" +) +print(result) +# name='Tommy' phone_number='129-896-5501' zip_code='50766' +``` + +Or simply with `outlines.generate.format`: ```python -from outlines import models, generate +from pydantic import BaseModel + +from outlines import models, generate, types + model = models.transformers("mistralai/Mistral-7B-v0.1") -generator = generate.format(model, int) -answer = generator("When I was 6 my sister was half my age. Now I’m 70 how old is my sister?") -print(answer) -# 67 +generator = generate.format(model, types.PhoneNumber) +result = generator( + "Return a US Phone number: " +) +print(result) +# 334-253-2630 ``` -The following types are currently available: -- int -- float -- bool -- datetime.date -- datetime.time -- datetime.datetime +We plan on adding many more custom types. If you have found yourself writing regular expressions to generate fields of a given type, or if you could benefit from more specific types don't hesite to [submit a PR](https://github.com/outlines-dev/outlines/pulls) or [open an issue](https://github.com/outlines-dev/outlines/issues/new/choose). diff --git a/mkdocs.yml b/mkdocs.yml index cf1bd9163..ab8bd4dd0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -109,13 +109,14 @@ nav: - Structured generation: - Classification: reference/choices.md - Regex: reference/regex.md - - Type constraints: reference/types.md + - Type constraints: reference/format.md - JSON (function calling): reference/json.md - JSON mode: reference/json_mode.md - Grammar: reference/cfg.md - Custom FSM operations: reference/custom_fsm_ops.md - Utilities: - Serve with vLLM: reference/serve/vllm.md + - Custom types: reference/types.md - Prompt templating: reference/prompting.md - Outlines functions: reference/functions.md - Models: diff --git a/outlines/__init__.py b/outlines/__init__.py index 4fe24e4ef..3eb6a2f94 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -2,6 +2,7 @@ import outlines.generate import outlines.grammars import outlines.models +import outlines.types from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache from outlines.function import Function diff --git a/outlines/fsm/types.py b/outlines/fsm/types.py index bcf091854..929b48e87 100644 --- a/outlines/fsm/types.py +++ b/outlines/fsm/types.py @@ -1,6 +1,8 @@ import datetime from typing import Protocol, Tuple, Type, Union +from typing_extensions import _AnnotatedAlias, get_args + INTEGER = r"[+-]?(0|[1-9][0-9]*)" BOOLEAN = "(True|False)" FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?" @@ -17,6 +19,16 @@ def __call__( def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]: + # If it is a custom type + if isinstance(python_type, _AnnotatedAlias): + json_schema = get_args(python_type)[1].json_schema + type_class = get_args(python_type)[0] + + regex_str = json_schema["pattern"] + format_fn = lambda x: type_class(x) + + return regex_str, format_fn + if python_type == float: def float_format_fn(sequence: str) -> float: diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py new file mode 100644 index 000000000..f25e20846 --- /dev/null +++ b/outlines/types/__init__.py @@ -0,0 +1,2 @@ +from .phone_numbers import PhoneNumber +from .zip_codes import ZipCode diff --git a/outlines/types/phone_numbers.py b/outlines/types/phone_numbers.py new file mode 100644 index 000000000..0b27c7890 --- /dev/null +++ b/outlines/types/phone_numbers.py @@ -0,0 +1,16 @@ +"""Phone number types. + +We currently only support US phone numbers. We can however imagine having custom types +for each country, for instance leveraging the `phonenumbers` library. + +""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}" + + +PhoneNumber = Annotated[ + str, + WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}), +] diff --git a/outlines/types/zip_codes.py b/outlines/types/zip_codes.py new file mode 100644 index 000000000..981efdd03 --- /dev/null +++ b/outlines/types/zip_codes.py @@ -0,0 +1,13 @@ +"""Zip code types. + +We currently only support US Zip Codes. + +""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +# This matches Zip and Zip+4 codes +US_ZIP_CODE = r"\d{5}(?:-\d{4})?" + + +ZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})] diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 000000000..928dde723 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,34 @@ +import re + +import pytest +from pydantic import BaseModel + +from outlines import types +from outlines.fsm.types import python_types_to_regex + + +@pytest.mark.parametrize( + "custom_type,test_string,should_match", + [ + (types.PhoneNumber, "12", False), + (types.PhoneNumber, "(123) 123-1234", True), + (types.PhoneNumber, "123-123-1234", True), + (types.ZipCode, "12", False), + (types.ZipCode, "12345", True), + (types.ZipCode, "12345-1234", True), + ], +) +def test_phone_number(custom_type, test_string, should_match): + class Model(BaseModel): + attr: custom_type + + schema = Model.model_json_schema() + assert schema["properties"]["attr"]["type"] == "string" + regex_str = schema["properties"]["attr"]["pattern"] + does_match = re.match(regex_str, test_string) is not None + assert does_match is should_match + + regex_str, format_fn = python_types_to_regex(custom_type) + assert isinstance(format_fn(1), str) + does_match = re.match(regex_str, test_string) is not None + assert does_match is should_match