Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more custom types #857

Merged
merged 3 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions docs/reference/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,18 @@

Outlines provides custom Pydantic types so you can focus on your use case rather than on writing regular expressions:

- Using `outlines.types.ZipCode` will generate valid US Zip(+4) codes.
- Using `outlines.types.PhoneNumber` will generate valid US phone numbers.

| Category | Type | Import | Description |
|:--------:|:----:|:-------|:------------|
| Zip code | US | `outlines.types.ZipCode` | Generate US Zip(+4) codes |
| Phone number | US | `outlines.types.PhoneNumber` | Generate valid US phone numbers |
| ISBN | 10 & 13 | `outlines.types.ISBN` | There is no guarantee that the [check digit][wiki-isbn] will be correct |
| Airport | IATA | `outlines.types.airports.IATA` | Valid [airport IATA codes][wiki-airport-iata] |
| Country | alpha-2 code | `outlines.types.airports.Alpha2` | Valid [country alpha-2 codes][wiki-country-alpha-2] |
| | alpha-3 code | `outlines.types.countries.Alpha3` | Valid [country alpha-3 codes][wiki-country-alpha-3] |
| | numeric code | `outlines.types.countries.Numeric` | Valid [country numeric codes][wiki-country-numeric] |
| | name | `outlines.types.countries.Name` | Valid country names |
| | flag | `outlines.types.countries.Flag` | Valid flag emojis |

You can use these types in Pydantic schemas for JSON-structured generation:

Expand Down Expand Up @@ -47,3 +57,10 @@ print(result)


We plan on adding many more custom types. If you have found yourself writing regular expressions to generate fields of a given type, or if you could benefit from more specific types don't hesite to [submit a PR](https://github.com/outlines-dev/outlines/pulls) or [open an issue](https://github.com/outlines-dev/outlines/issues/new/choose).


[wiki-isbn]: https://en.wikipedia.org/wiki/ISBN#Check_digits
[wiki-airport-iata]: https://en.wikipedia.org/wiki/IATA_airport_code
[wiki-country-alpha-2]: https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2
[wiki-country-alpha-3]: https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3
[wiki-country-numeric]: https://en.wikipedia.org/wiki/ISO_3166-1_numeric
24 changes: 17 additions & 7 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from typing import Protocol, Tuple, Type, Union
from enum import EnumMeta
from typing import Any, Protocol, Tuple, Type

from typing_extensions import _AnnotatedAlias, get_args

Expand All @@ -12,9 +13,7 @@


class FormatFunction(Protocol):
def __call__(
self, sequence: str
) -> Union[int, float, bool, datetime.date, datetime.time, datetime.datetime]:
def __call__(self, sequence: str) -> Any:
...


Expand All @@ -24,10 +23,21 @@ def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
json_schema = get_args(python_type)[1].json_schema
type_class = get_args(python_type)[0]

regex_str = json_schema["pattern"]
format_fn = lambda x: type_class(x)
custom_regex_str = json_schema["pattern"]

return regex_str, format_fn
def custom_format_fn(sequence: str) -> Any:
return type_class(sequence)

return custom_regex_str, custom_format_fn

if isinstance(python_type, EnumMeta):
values = python_type.__members__.keys()
enum_regex_str: str = "(" + "|".join(values) + ")"

def enum_format_fn(sequence: str) -> str:
return str(sequence)

return enum_regex_str, enum_format_fn

if python_type == float:

Expand Down
2 changes: 2 additions & 0 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from . import airports, countries
from .isbn import ISBN
from .phone_numbers import PhoneNumber
from .zip_codes import ZipCode
16 changes: 16 additions & 0 deletions outlines/types/airports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Generate valid airport codes."""
from enum import Enum

try:
from pyairports.airports import AIRPORT_LIST
except ImportError:
raise ImportError(
'The `airports` module requires "pyairports" to be installed. You can install it with "pip install pyairports"'
)


AIRPORT_IATA_LIST = list(
{(airport[3], airport[3]) for airport in AIRPORT_LIST if airport[3] != ""}
)

IATA = Enum("Airport", AIRPORT_IATA_LIST) # type:ignore
24 changes: 24 additions & 0 deletions outlines/types/countries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Generate valid country codes and names."""
from enum import Enum

try:
import pycountry
except ImportError:
raise ImportError(
'The `countries` module requires "pycountry" to be installed. You can install it with "pip install pycountry"'
)

ALPHA_2_CODE = [(country.alpha_2, country.alpha_2) for country in pycountry.countries]
Alpha2 = Enum("Alpha_2", ALPHA_2_CODE) # type:ignore

ALPHA_3_CODE = [(country.alpha_3, country.alpha_3) for country in pycountry.countries]
Alpha3 = Enum("Alpha_2", ALPHA_3_CODE) # type:ignore

NUMERIC_CODE = [(country.numeric, country.numeric) for country in pycountry.countries]
Numeric = Enum("Numeric_code", NUMERIC_CODE) # type:ignore

NAME = [(country.name, country.name) for country in pycountry.countries]
Name = Enum("Name", NAME) # type:ignore

FLAG = [(country.flag, country.flag) for country in pycountry.countries]
Flag = Enum("Flag", FLAG) # type:ignore
12 changes: 12 additions & 0 deletions outlines/types/isbn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""ISBN type"""
from pydantic import WithJsonSchema
from typing_extensions import Annotated

# Matches any ISBN number. Note that this is not completely correct as not all
# 10 or 13 digits numbers are valid ISBNs. See https://en.wikipedia.org/wiki/ISBN
# Taken from O'Reilly's Regular Expression Cookbook:
# https://www.oreilly.com/library/view/regular-expressions-cookbook/9781449327453/ch04s13.html
# TODO: Can this be represented by a grammar or do we need semantic checks?
ISBN_REGEX = r"(?:ISBN(?:-1[03])?:? )?(?=[0-9X]{10}$|(?=(?:[0-9]+[- ]){3})[- 0-9X]{13}$|97[89][0-9]{10}$|(?=(?:[0-9]+[- ]){4})[- 0-9]{17}$)(?:97[89][- ]?)?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]"

ISBN = Annotated[str, WithJsonSchema({"type": "string", "pattern": ISBN_REGEX})]
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ test = [
"vllm",
"torch",
"transformers",
"pycountry",
"pyairports",
]
serve = [
"vllm>=0.3.0",
Expand Down Expand Up @@ -126,6 +128,8 @@ module = [
"vllm.*",
"uvicorn.*",
"fastapi.*",
"pycountry.*",
"pyairports.*",
]
ignore_missing_imports = true

Expand Down
44 changes: 43 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@
(types.ZipCode, "12", False),
(types.ZipCode, "12345", True),
(types.ZipCode, "12345-1234", True),
(types.ISBN, "ISBN 0-1-2-3-4-5", False),
(types.ISBN, "ISBN 978-0-596-52068-7", True),
# (types.ISBN, "ISBN 978-0-596-52068-1", True), wrong check digit
(types.ISBN, "ISBN-13: 978-0-596-52068-7", True),
(types.ISBN, "978 0 596 52068 7", True),
(types.ISBN, "9780596520687", True),
(types.ISBN, "ISBN-10: 0-596-52068-9", True),
(types.ISBN, "0-596-52068-9", True),
],
)
def test_phone_number(custom_type, test_string, should_match):
def test_type_regex(custom_type, test_string, should_match):
class Model(BaseModel):
attr: custom_type

Expand All @@ -32,3 +40,37 @@ class Model(BaseModel):
assert isinstance(format_fn(1), str)
does_match = re.match(regex_str, test_string) is not None
assert does_match is should_match


@pytest.mark.parametrize(
"custom_type,test_string,should_match",
[
(types.airports.IATA, "CDG", True),
(types.airports.IATA, "XXX", False),
(types.countries.Alpha2, "FR", True),
(types.countries.Alpha2, "XX", False),
(types.countries.Alpha3, "UKR", True),
(types.countries.Alpha3, "XXX", False),
(types.countries.Numeric, "004", True),
(types.countries.Numeric, "900", False),
(types.countries.Name, "Ukraine", True),
(types.countries.Name, "Wonderland", False),
(types.countries.Flag, "🇿🇼", True),
(types.countries.Flag, "🤗", False),
],
)
def test_type_enum(custom_type, test_string, should_match):
type_name = custom_type.__name__

class Model(BaseModel):
attr: custom_type

schema = Model.model_json_schema()
assert isinstance(schema["$defs"][type_name]["enum"], list)
does_match = test_string in schema["$defs"][type_name]["enum"]
assert does_match is should_match

regex_str, format_fn = python_types_to_regex(custom_type)
assert isinstance(format_fn(1), str)
does_match = re.match(regex_str, test_string) is not None
assert does_match is should_match
Loading