Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Localize types #868

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions docs/reference/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

Outlines provides custom Pydantic types so you can focus on your use case rather than on writing regular expressions:


| Category | Type | Import | Description |
|:--------:|:----:|:-------|:------------|
| Zip code | US | `outlines.types.ZipCode` | Generate US Zip(+4) codes |
| Phone number | US | `outlines.types.PhoneNumber` | Generate valid US phone numbers |
| ISBN | 10 & 13 | `outlines.types.ISBN` | There is no guarantee that the [check digit][wiki-isbn] will be correct |
| Airport | IATA | `outlines.types.airports.IATA` | Valid [airport IATA codes][wiki-airport-iata] |
| Country | alpha-2 code | `outlines.types.airports.Alpha2` | Valid [country alpha-2 codes][wiki-country-alpha-2] |
Expand All @@ -15,18 +12,37 @@ Outlines provides custom Pydantic types so you can focus on your use case rather
| | name | `outlines.types.countries.Name` | Valid country names |
| | flag | `outlines.types.countries.Flag` | Valid flag emojis |

Some types require localization. We currently only support US types, but please don't hesitate to create localized versions of the different types and open a Pull Request. Localized types are specified using `types.locale` in the following way:

```python
from outlines import types

types.locale("us").ZipCode
types.locale("us").PhoneNumber
```

Here are the localized types that are currently available:

| Category | Locale | Import | Description |
|:--------:|:----:|:-------|:------------|
| Zip code | US | `ZipCode` | Generate US Zip(+4) codes |
| Phone number | US | `PhoneNumber` | Generate valid US phone numbers |


You can use these types in Pydantic schemas for JSON-structured generation:

```python
from pydantic import BaseModel

from outlines import models, generate, types

# Specify the locale for types
locale = types.locale("us")

class Client(BaseModel):
name: str
phone_number: types.PhoneNumber
zip_code: types.ZipCode
phone_number: locale.PhoneNumber
zip_code: locale.ZipCode


model = models.transformers("mistralai/Mistral-7B-v0.1")
Expand All @@ -47,7 +63,7 @@ from outlines import models, generate, types


model = models.transformers("mistralai/Mistral-7B-v0.1")
generator = generate.format(model, types.PhoneNumber)
generator = generate.format(model, types.locale("us").PhoneNumber)
result = generator(
"Return a US Phone number: "
)
Expand Down
3 changes: 1 addition & 2 deletions outlines/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import airports, countries
from .isbn import ISBN
from .phone_numbers import PhoneNumber
from .zip_codes import ZipCode
from .locales import locale
21 changes: 21 additions & 0 deletions outlines/types/locales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from outlines.types.phone_numbers import USPhoneNumber
from outlines.types.zip_codes import USZipCode


@dataclass
class US:
ZipCode = USZipCode
PhoneNumber = USPhoneNumber


def locale(locale_str: str):
locales = {"us": US}

if locale_str not in locales:
raise NotImplementedError(
f"The locale {locale_str} is not supported yet. Please don't hesitate to create custom types for you locale and open a Pull Request."
)

return locales[locale_str]
2 changes: 1 addition & 1 deletion outlines/types/phone_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}"


PhoneNumber = Annotated[
USPhoneNumber = Annotated[
str,
WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}),
]
2 changes: 1 addition & 1 deletion outlines/types/zip_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
US_ZIP_CODE = r"\d{5}(?:-\d{4})?"


ZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})]
USZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})]
33 changes: 27 additions & 6 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
@pytest.mark.parametrize(
"custom_type,test_string,should_match",
[
(types.PhoneNumber, "12", False),
(types.PhoneNumber, "(123) 123-1234", True),
(types.PhoneNumber, "123-123-1234", True),
(types.ZipCode, "12", False),
(types.ZipCode, "12345", True),
(types.ZipCode, "12345-1234", True),
(types.phone_numbers.USPhoneNumber, "12", False),
(types.phone_numbers.USPhoneNumber, "(123) 123-1234", True),
(types.phone_numbers.USPhoneNumber, "123-123-1234", True),
(types.zip_codes.USZipCode, "12", False),
(types.zip_codes.USZipCode, "12345", True),
(types.zip_codes.USZipCode, "12345-1234", True),
(types.ISBN, "ISBN 0-1-2-3-4-5", False),
(types.ISBN, "ISBN 978-0-596-52068-7", True),
# (types.ISBN, "ISBN 978-0-596-52068-1", True), wrong check digit
Expand All @@ -42,6 +42,27 @@ class Model(BaseModel):
assert does_match is should_match


def test_locale_not_implemented():
with pytest.raises(NotImplementedError):
types.locale("fr")


@pytest.mark.parametrize(
"locale_str,base_types,locale_types",
[
(
"us",
["ZipCode", "PhoneNumber"],
[types.zip_codes.USZipCode, types.phone_numbers.USPhoneNumber],
)
],
)
def test_locale(locale_str, base_types, locale_types):
for base_type, locale_type in zip(base_types, locale_types):
type = getattr(types.locale(locale_str), base_type)
assert type == locale_type


@pytest.mark.parametrize(
"custom_type,test_string,should_match",
[
Expand Down
Loading