diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index d96597d4c..2c53fd240 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -195,13 +195,7 @@ def to_regex( to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] ] - xor_patterns = [] - # json schema validation ensured there is no overlapping schemas in oneOf - for subregex in subregexes: - other_subregexes = filter(lambda r: r != subregex, subregexes) - other_subregexes_str = "|".join([f"{s}" for s in other_subregexes]) - negative_lookahead = f"(?!.*({other_subregexes_str}))" - xor_patterns.append(f"({subregex}){negative_lookahead}") + xor_patterns = [f"(?:{subregex})" for subregex in subregexes] return rf"({'|'.join(xor_patterns)})" diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 5b3ad9e39..b992f7aa5 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -1,9 +1,10 @@ import json import re -from typing import List +from typing import List, Literal, Union +import interegular import pytest -from pydantic import BaseModel, constr +from pydantic import BaseModel, Field, constr from outlines.fsm.json_schema import ( BOOLEAN, @@ -321,7 +322,7 @@ def test_match_number(pattern, does_match): "title": "Foo", "oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}], }, - rf"(({STRING})(?!.*({NUMBER}|{BOOLEAN}))|({NUMBER})(?!.*({STRING}|{BOOLEAN}))|({BOOLEAN})(?!.*({STRING}|{NUMBER})))", + rf'((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))', [ ("12.3", True), ("true", True), @@ -750,3 +751,27 @@ class MockModel(BaseModel): assert match_default_ws is None assert re.fullmatch(pattern, mock_result_maybe_ws) + + +def test_one_of_doesnt_produce_illegal_lookaround(): + """Reproduces failure in https://github.com/outlines-dev/outlines/issues/823""" + + class Cat(BaseModel): + pet_type: Literal["cat"] + meows: int + + class Dog(BaseModel): + pet_type: Literal["dog"] + barks: float + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(..., discriminator="pet_type") + n: int + + json_schema = Model.schema_json() + + json_schema = Model.schema_json() + pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) + + # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() + interegular.parse_pattern(pattern).to_fsm()