Skip to content

Commit

Permalink
Add country types
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed May 2, 2024
1 parent 4b8cc6d commit 8c17bee
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 5 deletions.
8 changes: 5 additions & 3 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ def custom_format_fn(sequence: str) -> Any:

if isinstance(python_type, EnumMeta):
values = python_type.__members__.keys()
regex_str = "(" + "|".join(values) + ")"
format_fn = lambda x: str(x)
enum_regex_str: str = "(" + "|".join(values) + ")"

return regex_str, format_fn
def enum_format_fn(sequence: str) -> str:
return str(sequence)

return enum_regex_str, enum_format_fn

if python_type == float:

Expand Down
2 changes: 1 addition & 1 deletion outlines/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import airports, countries
from .isbn import ISBN
from .phone_numbers import PhoneNumber
from .zip_codes import ZipCode
from . import airports
24 changes: 24 additions & 0 deletions outlines/types/countries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Generate valid country codes and names."""
from enum import Enum

try:
import pycountry
except ImportError:
raise ImportError(
'The `countries` module requires "pycountry" to be installed. You can install it with "pip install pycountry"'
)

ALPHA_2_CODE = [(country.alpha_2, country.alpha_2) for country in pycountry.countries]
Alpha2 = Enum("Alpha_2", ALPHA_2_CODE) # type:ignore

ALPHA_3_CODE = [(country.alpha_3, country.alpha_3) for country in pycountry.countries]
Alpha3 = Enum("Alpha_2", ALPHA_3_CODE) # type:ignore

NUMERIC_CODE = [(country.numeric, country.numeric) for country in pycountry.countries]
Numeric = Enum("Numeric_code", NUMERIC_CODE) # type:ignore

NAME = [(country.name, country.name) for country in pycountry.countries]
Name = Enum("Name", NAME) # type:ignore

FLAG = [(country.flag, country.flag) for country in pycountry.countries]
Flag = Enum("Flag", FLAG) # type:ignore
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ test = [
"vllm",
"torch",
"transformers",
"pycountry",
"pyairports",
]
serve = [
"vllm>=0.3.0",
Expand Down Expand Up @@ -126,6 +128,7 @@ module = [
"vllm.*",
"uvicorn.*",
"fastapi.*",
"pycountry.*",
"pyairports.*",
]
ignore_missing_imports = true
Expand Down
11 changes: 10 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,19 @@ class Model(BaseModel):
[
(types.airports.IATA, "CDG", True),
(types.airports.IATA, "XXX", False),
(types.countries.Alpha2, "FR", True),
(types.countries.Alpha2, "XX", False),
(types.countries.Alpha3, "UKR", True),
(types.countries.Alpha3, "XXX", False),
(types.countries.Numeric, "004", True),
(types.countries.Numeric, "900", False),
(types.countries.Name, "Ukraine", True),
(types.countries.Name, "Wonderland", False),
(types.countries.Flag, "🇿🇼", True),
(types.countries.Flag, "🤗", False),
],
)
def test_type_enum(custom_type, test_string, should_match):

type_name = custom_type.__name__

class Model(BaseModel):
Expand Down

0 comments on commit 8c17bee

Please sign in to comment.