Skip to content

Commit

Permalink
Merge branch 'main' into issue-838
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 authored May 17, 2024
2 parents 5babeee + 499d19d commit ea95069
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 19 deletions.
8 changes: 1 addition & 7 deletions outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,7 @@ def to_regex(
to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"]
]

xor_patterns = []
# json schema validation ensured there is no overlapping schemas in oneOf
for subregex in subregexes:
other_subregexes = filter(lambda r: r != subregex, subregexes)
other_subregexes_str = "|".join([f"{s}" for s in other_subregexes])
negative_lookahead = f"(?!.*({other_subregexes_str}))"
xor_patterns.append(f"({subregex}){negative_lookahead}")
xor_patterns = [f"(?:{subregex})" for subregex in subregexes]

return rf"({'|'.join(xor_patterns)})"

Expand Down
28 changes: 19 additions & 9 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,6 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
return
generated_token_ids = sequence.token_ids[:, -num_generated:]
generated_sequences = self.tokenizer.decode(generated_token_ids)
next_tokens = [
token[len(sequence) :] if not stop else ""
for token, sequence, stop in zip(
generated_sequences,
previously_generated_sequences,
is_stop_at_reached,
)
]
previously_generated_sequences = generated_sequences
if stop_sequences:
is_stop_at_reached = [
stop
Expand All @@ -360,6 +351,25 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
)
]

generated_sequences = [
self.format_sequence(
self.strip_stop_sequences(sequence, stop_sequences)
)
if stop
else sequence
for sequence, stop in zip(
generated_sequences, is_stop_at_reached
)
]
next_tokens = [
token[len(sequence) :]
for token, sequence, stop in zip(
generated_sequences,
previously_generated_sequences,
is_stop_at_reached,
)
]
previously_generated_sequences = generated_sequences
# We reshape the output to (batch_size, sample_size)
output: List[List[str]] = list()
for i in range(batch_size):
Expand Down
31 changes: 28 additions & 3 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import re
from typing import List
from typing import List, Literal, Union

import interegular
import pytest
from pydantic import BaseModel, constr
from pydantic import BaseModel, Field, constr

from outlines.fsm.json_schema import (
BOOLEAN,
Expand Down Expand Up @@ -323,7 +324,7 @@ def test_match_number(pattern, does_match):
"title": "Foo",
"oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}],
},
rf"(({STRING})(?!.*({NUMBER}|{BOOLEAN}))|({NUMBER})(?!.*({STRING}|{BOOLEAN}))|({BOOLEAN})(?!.*({STRING}|{NUMBER})))",
rf'((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))',
[
("12.3", True),
("true", True),
Expand Down Expand Up @@ -752,3 +753,27 @@ class MockModel(BaseModel):
assert match_default_ws is None
assert re.fullmatch(pattern, mock_result_maybe_ws)
def test_one_of_doesnt_produce_illegal_lookaround():
"""Reproduces failure in https://github.com/outlines-dev/outlines/issues/823"""

class Cat(BaseModel):
pet_type: Literal["cat"]
meows: int

class Dog(BaseModel):
pet_type: Literal["dog"]
barks: float

class Model(BaseModel):
pet: Union[Cat, Dog] = Field(..., discriminator="pet_type")
n: int

json_schema = Model.schema_json()

json_schema = Model.schema_json()
pattern = build_regex_from_schema(json_schema, whitespace_pattern=None)

# check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm()
interegular.parse_pattern(pattern).to_fsm()

0 comments on commit ea95069

Please sign in to comment.