Skip to content

Commit

Permalink
Add integration test for JSON schema
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 19, 2023
1 parent eb748a1 commit dce9265
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
6 changes: 6 additions & 0 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,11 @@ def json(
regex_str = build_regex_from_object(schema)
generator = regex(model, regex_str, max_tokens, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
else:
raise ValueError(
f"Cannot parse schema {schema_object}. The schema must be either "
+ "a Pydantic object, a function or a string that contains the JSON "
+ "Schema specification"
)

return generator
4 changes: 2 additions & 2 deletions tests/generate/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

def test_sequence_generator_class():
class MockFSM:
def next_state(self, state, next_token_ids):
def next_state(self, state, next_token_ids, _):
return 4

def allowed_token_ids(self, _):
def allowed_token_ids(self, *_):
return [4]

def is_final_state(self, _, idx=0):
Expand Down
24 changes: 24 additions & 0 deletions tests/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,30 @@ class Spam(BaseModel):
assert len(result.spam) <= 10


def test_transformers_json_schema():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name, device="cpu")
prompt = "Output some JSON "

schema = """{
"title": "spam",
"type": "object",
"properties": {
"foo" : {"type": "integer"},
"bar": {"type": "string", "maxLength": 4}
}
}
"""

rng = torch.Generator()
rng.manual_seed(0) # make sure that `bar` is not an int

result = generate.json(model, schema, max_tokens=500)(prompt, rng=rng)
assert isinstance(result, dict)
assert isinstance(result["foo"], int)
assert isinstance(result["bar"], str)


def test_transformers_json_batch():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name, device="cpu")
Expand Down

0 comments on commit dce9265

Please sign in to comment.