From 143ed13c80ed8f889e79fbbd7563062efdf0e13b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 6 Nov 2023 18:34:31 +0100 Subject: [PATCH 1/6] Improve support for JSON schema We use the `referencing` library to dereference fields in the JSON Schema, which simplifies the codebase a lot and prevents reference errors. We also support combination of `minLength` and `maxLength` as well as the `pattern` keyword. --- environment.yml | 1 + outlines/text/json_schema.py | 337 ++++++--------------- pyproject.toml | 2 + tests/text/test_json_schema.py | 527 +++++++++------------------------ 4 files changed, 221 insertions(+), 646 deletions(-) diff --git a/environment.yml b/environment.yml index b0b3b62cf..14629af3e 100644 --- a/environment.yml +++ b/environment.yml @@ -16,6 +16,7 @@ dependencies: - scipy - pytest - pre-commit + - referencing - transformers - pip - pip: diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index b9d6a84c8..fe1578071 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -1,7 +1,9 @@ -import itertools import json import re -from typing import Callable, Dict + +from referencing import Registry, Resource +from referencing._core import Resolver +from referencing.jsonschema import DRAFT202012 STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)' STRING = f'"{STRING_INNER}*"' @@ -33,276 +35,109 @@ def build_regex_from_schema(schema: str): A string that contains a regular expression that matches any JSON object that follows the schema. - """ - schedule = build_schedule_from_schema(schema) - - regex = "" - for step in schedule: - regex += match_step_to_regex(step) - - return regex - - -def _ref_resolver(schema: Dict) -> Callable[[str], Dict]: - cache: Dict[str, Dict] = dict() - - if "$id" in schema: - cache[schema["$id"]] = schema - - if "$defs" in schema: - for definition, annotation in schema["$defs"].items(): - cache[f"#/$defs/{definition}"] = annotation - - if "$id" in annotation: - cache[annotation["$id"]] = annotation - - def resolver(reference: str) -> Dict: - """Resolve a $ref reference in the context of the top-level schema.""" - - if reference in cache: - return cache[reference] - - path = reference.split("/") - - # Navigate through the top-level schema based on the path - subschema = schema - - if path[0] != "#": - raise ValueError(f"Unable to resolve reference: {reference}") - - for step in path[1:]: - if step in subschema: - subschema = subschema[step] - else: - raise ValueError(f"Unable to resolve reference: {reference}") - - cache[reference] = subschema - return subschema - - return resolver - - -def build_schedule_from_schema(schema: str): - """Turn a JSON schema into a regex that matches any JSON object that follows - this schema. - - JSON Schema is a declarative language that allows to annotate JSON documents - with types and descriptions. These schemas can be generated from any Python - datastructure that has type annotation: namedtuples, dataclasses, Pydantic - models. And by ensuring that the generation respects the schema we ensure - that the output can be parsed into these objects. - This function parses the provided schema and builds a generation schedule which - mixes deterministic generation (fixed strings), and sampling with constraints. - - Parameters - ---------- - schema - A string that represents a JSON Schema. - - Returns - ------- - A generation schedule. A list of strings that represent the JSON - schema's structure and regular expression that define the structure of - the fields. - - References - ---------- - .. [0] JSON Schema. https://json-schema.org/ - """ schema = json.loads(schema) - schema = expand_json_schema(schema, resolver=_ref_resolver(schema)) - schedule = build_schedule_from_instance(schema) - - # Concatenate adjacent strings - reduced_schedule = [ - x - for cls, grp in itertools.groupby(schedule, type) - for x in (("".join(grp),) if cls is str else grp) - ] - - return reduced_schedule - - -def expand_item_json_schema(expanded_property: Dict, resolver: Callable[[str], Dict]): - """Recursively expand "$ref"s in "item"s.""" - if "items" not in expanded_property.keys(): - return - elif "$ref" in expanded_property["items"]: - expanded_property["items"] = expand_json_schema( - resolver(expanded_property["items"]["$ref"]), resolver - ) - else: - expand_item_json_schema(expanded_property["items"], resolver) - - -def expand_json_schema( - raw_schema: Dict, - resolver: Callable[[str], Dict], -): - """Replace references by their value in the JSON Schema. - - This recursively follows the references to other schemas in case - of nested models. Other schemas that may exist at a higher level - within the overall schema may be referenced via the `$ref` keyword - according to the JSON Schema specification. - - - Parameters - --------- - raw_schema - The raw JSON schema as a Python dictionary, possibly with definitions - and references. - resolver - A function that takes a reference and returns the corresponding schema - or subschema from the currently scoped top-level schema. - - Returns - ------- - A dictionary that represents the flattened equivalent of the input - JSON schema. - - """ - expanded_properties = {} - - if "properties" in raw_schema: - if "$id" in raw_schema: - # see https://json-schema.org/understanding-json-schema/structuring#bundling - resolver = _ref_resolver(raw_schema) - - for name, value in raw_schema["properties"].items(): - if "$ref" in value: # if item is a single element - expanded_properties[name] = expand_json_schema( - resolver(value["$ref"]), resolver - ) - elif "type" in value and value["type"] == "array": # if item is a list - expanded_properties[name] = value - - if "$ref" in value["items"] or ( - "type" in value["items"] and value["items"]["type"] == "array" - ): - expand_item_json_schema(expanded_properties[name], resolver) - else: - expanded_properties[name]["items"] = value["items"] - - else: - expanded_properties[name] = value - - return { - **({"title": raw_schema["title"]} if "title" in raw_schema else {}), - "type": raw_schema["type"], - "properties": expanded_properties, - } - - else: - return raw_schema + # Build reference resolver + schema = Resource(contents=schema, specification=DRAFT202012) + uri = schema.id() if schema.id() is not None else "" + registry = Registry().with_resource(uri=uri, resource=schema) + resolver = registry.resolver() + content = schema.contents + regex = to_regex(resolver, content) + return regex -def build_schedule_from_instance(instance: Dict): - """Build a generation schedule from a instance. - This recursively follows the references to other instances. +def to_regex(resolver: Resolver, instance: dict): + whitespace = r"[\n ]*" - Parameters - ---------- - instance - An instance, can be the JSON schema itself. - indent - The current indentation level - - Returns - ------- - A generation schedule for the instance, a list of strings that represent - the structure of the JSON schema and dictionaries that contain the - instance definition. - - """ - schedule = [] if "properties" in instance: - schedule.append(r"\{") - schedule += build_schedule_from_instance(instance["properties"]) - schedule.append(r"\}") - else: - for i, (name, annotation) in enumerate(instance.items()): - whitespace = r"[\n ]*" - schedule.append(f'{whitespace}"{name}"{whitespace}:{whitespace}') + regex = "" + regex += r"\{" + for i, (name, value) in enumerate(instance["properties"].items()): + regex += f'{whitespace}"{name}"{whitespace}:{whitespace}' + regex += to_regex(resolver, value) - if "anyOf" in annotation: - schedule.append(annotation) - elif annotation["type"] == "object": - schedule += build_schedule_from_instance(annotation) - else: - schedule.append(annotation) + # No comma after the last key-value pair in JSON + if i < len(instance["properties"]) - 1: + regex += f"{whitespace}," - # We cannot add commas after the last key-value pair in JSON - if i == len(instance) - 1: - schedule.append(whitespace) - else: - schedule.append(f"{whitespace},") + regex += f"{whitespace}" + r"\}" - return schedule + return regex + elif "oneOf" in instance: + print(instance) -def match_step_to_regex(step): - """Translate an element of a JSON schema to a regex that defines its content. + elif "allOf" in instance: + print(instance) - Parameters - ---------- - step: - A string that represents the schema's structure, or a dictionary - that represents a field in the schema. + elif "anyOf" in instance: + subregexes = [to_regex(resolver, t) for t in instance["anyOf"]] + return rf"({'|'.join(subregexes)})" - Returns - ------- - A string that represents a regular expression that defines the value of the - schedule's step. - - """ - if isinstance(step, str): - return step - - if isinstance(step, dict): - keys = set(step.keys()) - - if all(key in keys for key in ("enum", "type")) and step["type"] == "string": - choices = [f'"{re.escape(choice)}"' for choice in step["enum"]] + elif "enum" in instance: + if instance["type"] == "string": + choices = [f'"{re.escape(choice)}"' for choice in instance["enum"]] return f"({'|'.join(choices)})" - - elif "enum" in keys: - choices = [re.escape(str(choice)) for choice in step["enum"]] + else: + choices = [re.escape(str(choice)) for choice in instance["enum"]] return f"({'|'.join(choices)})" - elif all(key in keys for key in ("type", "items")) and step["type"] == "array": - item_regexes = match_step_to_regex(step["items"]) - return rf"\[({item_regexes})(,({item_regexes}))*\]" + elif "$ref" in instance: + path = f"{instance['$ref']}" + instance = resolver.lookup(path).contents + return to_regex(resolver, instance) + + elif "type" in instance: + type = instance["type"] + + if type == "string": + if "maxLength" in instance or "minLength" in instance: + max_length = instance.get("maxLength", "") + min_length = instance.get("minLength", "") + try: + if int(max_length) < int(min_length): + raise ValueError( + "maxLength must be greater than or equal to minLength" + ) + except ValueError: + pass + return f'"{STRING_INNER}{{{min_length},{max_length}}}"' + elif "pattern" in instance: + pattern = instance["pattern"] + if pattern[0] == "^" and pattern[-1] == "$": + return rf'(^"{pattern[1:-1]}"$)' + else: + return rf'("{pattern}")' + else: + return type_to_regex["string"] + + elif type == "number": + return type_to_regex["number"] - elif "type" in keys and step["type"] == "object": - steps = build_schedule_from_schema(json.dumps(step)) - regex_str = "" - for step in steps: - regex_str += match_step_to_regex(step) - return regex_str + elif type == "integer": + return type_to_regex["integer"] - elif ( - all(key in keys for key in ("type", "maxLength")) - and step["type"] == "string" - ): - max_length = step["maxLength"] - return f'"{STRING_INNER}{{,{max_length}}}"' + elif type == "array": + items_regex = to_regex(resolver, instance["items"]) + return rf"\[({items_regex})(,({items_regex}))*\]" - elif ( - all(key in keys for key in ("type", "minLength")) - and step["type"] == "string" - ): - min_length = step["minLength"] - return f'"{STRING_INNER}{{{min_length},}}"' + elif type == "boolean": + return type_to_regex["boolean"] - elif "type" in keys: - return type_to_regex[step["type"]] + elif type == "null": + return type_to_regex["null"] - elif "anyOf" in keys: - regexes = [match_step_to_regex(choice) for choice in step["anyOf"]] - return rf"({'|'.join(regexes)})" + # elif isinstance(type, list): + # if "object" in type: + # expanded = to_regex(resolver, instance) + # return "" + # return "" - raise NotImplementedError + raise NotImplementedError( + f"""Could not translate the instance {instance} to a + regular expression. Make sure it is valid to the JSON Schema specification. If + it is, please open an issue on the Outlines repository""" + ) diff --git a/pyproject.toml b/pyproject.toml index c769baf17..6ab6cf61a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "torch", "numba", "joblib", + "referencing", ] dynamic = ["version"] @@ -92,6 +93,7 @@ module = [ "PIL.Image", "pydantic", "pytest", + "referencing.*", "scipy.*", "tenacity.*", "tiktoken.*", diff --git a/tests/text/test_json_schema.py b/tests/text/test_json_schema.py index 239effb61..0069ca841 100644 --- a/tests/text/test_json_schema.py +++ b/tests/text/test_json_schema.py @@ -1,7 +1,5 @@ import json import re -from enum import Enum -from typing import List, Optional, Union import pytest from pydantic import BaseModel, constr @@ -13,12 +11,12 @@ NUMBER, STRING, STRING_INNER, - build_schedule_from_schema, - match_step_to_regex, + build_regex_from_schema, + to_regex, ) -def test_pydantic_basic(): +def test_from_pydantic(): class User(BaseModel): user_id: int name: str @@ -28,382 +26,31 @@ class User(BaseModel): is_true: bool schema = json.dumps(User.model_json_schema()) - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', - {"title": "User Id", "type": "integer"}, - '[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*', - {"title": "Name", "type": "string"}, - '[\\n ]*,[\\n ]*"maxlength_name"[\\n ]*:[\\n ]*', - {"title": "Maxlength Name", "type": "string", "maxLength": 10}, - '[\\n ]*,[\\n ]*"minlength_name"[\\n ]*:[\\n ]*', - {"title": "Minlength Name", "type": "string", "minLength": 10}, - '[\\n ]*,[\\n ]*"value"[\\n ]*:[\\n ]*', - {"title": "Value", "type": "number"}, - '[\\n ]*,[\\n ]*"is_true"[\\n ]*:[\\n ]*', - {"title": "Is True", "type": "boolean"}, - "[\\n ]*\\}", - ] - - -def test_pydantic_optional(): - class Foo(BaseModel): - bar: Optional[str] - - schema = json.dumps(Foo.model_json_schema()) - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"bar"[\\n ]*:[\\n ]*', - {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Bar"}, - "[\\n ]*\\}", - ] - - -def test_pydantic_array(): - class User(BaseModel): - user_id: int - value: List[float] - - schema = json.dumps(User.model_json_schema()) - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', - {"title": "User Id", "type": "integer"}, - '[\\n ]*,[\\n ]*"value"[\\n ]*:[\\n ]*', - {"title": "Value", "type": "array", "items": {"type": "number"}}, - "[\\n ]*\\}", - ] - - -def test_pydantic_enum(): - class Name(str, Enum): - john = "John" - marc = "Marc" - michel = "Michel" - - class User(BaseModel): - user_id: int - name: Name - - schema = json.dumps(User.model_json_schema()) - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', - {"title": "User Id", "type": "integer"}, - '[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*', - { - "title": "Name", - "enum": ["John", "Marc", "Michel"], - "type": "string", - }, - "[\\n ]*\\}", - ] - - -def test_pydantic_nested(): - """Arbitrarily nested schema.""" - - class Fizz(BaseModel): - buzz: str - - class Foo(BaseModel): - count: int - size: Fizz - - class Bar(BaseModel): - apple: str - banana: str - - class Spam(BaseModel): - foo: Foo - bars: Bar - - # We need to a recursive function to parse nested schemas - schema = json.dumps(Spam.model_json_schema()) - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"foo"[\\n ]*:[\\n ]*\\{[\\n ]*"count"[\\n ]*:[\\n ]*', - {"title": "Count", "type": "integer"}, - '[\\n ]*,[\\n ]*"size"[\\n ]*:[\\n ]*\\{[\\n ]*"buzz"[\\n ]*:[\\n ]*', - {"title": "Buzz", "type": "string"}, - '[\\n ]*\\}[\\n ]*\\}[\\n ]*,[\\n ]*"bars"[\\n ]*:[\\n ]*\\{[\\n ]*"apple"[\\n ]*:[\\n ]*', - {"title": "Apple", "type": "string"}, - '[\\n ]*,[\\n ]*"banana"[\\n ]*:[\\n ]*', - {"title": "Banana", "type": "string"}, - "[\\n ]*\\}[\\n ]*\\}", - ] - - -def test_pydantic_list_object(): - class Foo(BaseModel): - count: int - - class Spam(BaseModel): - foo: List[Foo] - - # We need to a recursive function to parse nested schemas - schema = json.dumps(Spam.model_json_schema()) - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"foo"[\\n ]*:[\\n ]*', - { - "items": { - "title": "Foo", - "type": "object", - "properties": {"count": {"title": "Count", "type": "integer"}}, - }, - "title": "Foo", - "type": "array", - }, - "[\\n ]*\\}", - ] - - -def test_pydantic_recursive_list_object(): - class ItemModel(BaseModel): - name: str - - class ArrayModel1(BaseModel): - item_model_lists: List[List[ItemModel]] - - class ArrayModel2(BaseModel): - nums: List[List[int]] - - class ArrayModel3(BaseModel): - array_model_lists: List[List[ArrayModel1]] - - schema = json.dumps(ArrayModel1.model_json_schema()) - schedule = build_schedule_from_schema(schema) - array_model_1_schema = { - "items": { - "items": { - "title": "ItemModel", - "type": "object", - "properties": {"name": {"title": "Name", "type": "string"}}, - }, - "type": "array", - }, - "title": "Item Model Lists", - "type": "array", - } - assert schedule == [ - '\\{[\\n ]*"item_model_lists"[\\n ]*:[\\n ]*', - array_model_1_schema, - "[\\n ]*\\}", - ] - - schema = json.dumps(ArrayModel2.model_json_schema()) - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"nums"[\\n ]*:[\\n ]*', - { - "items": {"items": {"type": "integer"}, "type": "array"}, - "title": "Nums", - "type": "array", - }, - "[\\n ]*\\}", - ] - - schema = json.dumps(ArrayModel3.model_json_schema()) - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"array_model_lists"[\\n ]*:[\\n ]*', - { - "items": { - "items": { - "title": "ArrayModel1", - "type": "object", - "properties": {"item_model_lists": array_model_1_schema}, - }, - "type": "array", - }, - "title": "Array Model Lists", - "type": "array", - }, - "[\\n ]*\\}", - ] - - -def test_pydantic_union(): - """Schemas with Union types.""" - - class Spam(BaseModel): - foo: int - bar: Union[float, str] - - schema = json.dumps(Spam.model_json_schema()) - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"foo"[\\n ]*:[\\n ]*', - {"title": "Foo", "type": "integer"}, - '[\\n ]*,[\\n ]*"bar"[\\n ]*:[\\n ]*', - {"title": "Bar", "anyOf": [{"type": "number"}, {"type": "string"}]}, - "[\\n ]*\\}", - ] - - -def test_json_schema(): - schema = '{"title": "User", "type": "object", "properties": {"user_id": {"title": "User Id", "type": "integer"}, "name": {"title": "Name", "type": "string"}}, "required": ["user_id", "name"]}' - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', - {"title": "User Id", "type": "integer"}, - '[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*', - {"title": "Name", "type": "string"}, - "[\\n ]*\\}", - ] - - -def test_json_schema_no_titles(): - schema = '{"type": "object", "properties": {"user_id": {"type": "integer"}, "name": {"type": "string"}}, "required": ["user_id", "name"]}' - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', - {"type": "integer"}, - '[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*', - {"type": "string"}, - "[\\n ]*\\}", - ] - - -def test_json_schema_with_property_ref(): - schema = """{ - "title": "User", - "type": "object", - "properties": { - "user_id": {"title": "User Id", "type": "integer"}, - "name": {"title": "Name", "type": "string"}, - "a": {"$ref": "#/properties/name"}, - "b": {"$ref": "#/properties/name"}, - "c": {"$ref": "#/properties/name"} - }, - "required": ["user_id", "name"]} - """ - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', - {"title": "User Id", "type": "integer"}, - '[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*', - {"title": "Name", "type": "string"}, - '[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*', - {"title": "Name", "type": "string"}, - '[\\n ]*,[\\n ]*"b"[\\n ]*:[\\n ]*', - {"title": "Name", "type": "string"}, - '[\\n ]*,[\\n ]*"c"[\\n ]*:[\\n ]*', - {"title": "Name", "type": "string"}, - "[\\n ]*\\}", - ] - - -def test_json_schema_with_def_ref(): - schema = """{ - "title": "User", - "type": "object", - "$defs": { - "name": {"title": "Name2", "type": "string"} - }, - "properties": { - "user_id": {"title": "User Id", "type": "integer"}, - "name": {"title": "Name", "type": "string"}, - "name2": {"$ref": "#/$defs/name"} - }, - "required": ["user_id", "name"]} - """ - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*', - {"title": "User Id", "type": "integer"}, - '[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*', - {"title": "Name", "type": "string"}, - '[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*', - {"title": "Name2", "type": "string"}, - "[\\n ]*\\}", - ] - - -def test_json_schema_with_bundled_ref(): - schema = """{ - "$id": "https://example.com/schemas/customer", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "Customer", - "type": "object", - "properties": { - "first_name": { "type": "string" }, - "last_name": { "type": "string" }, - "shipping_address": { "$ref": "/schemas/address" }, - "billing_address": { "$ref": "/schemas/address" } - }, - "required": ["first_name", "last_name", "shipping_address", "billing_address"], - "$defs": { - "address": { - "title": "Address", - "$id": "/schemas/address", - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "street_address": { "type": "string" }, - "city": { "type": "string" }, - "state": { "$ref": "#/definitions/state" } - }, - "required": ["street_address", "city", "state"], - "definitions": { - "state": { "type": "object", "title": "State", "properties": { "name": { "type": "string" } }, "required": ["name"] } - } - } - } - }""" - schedule = build_schedule_from_schema(schema) - assert schedule == [ - '\\{[\\n ]*"first_name"[\\n ]*:[\\n ]*', - {"type": "string"}, - '[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*', - {"type": "string"}, - '[\\n ]*,[\\n ]*"shipping_address"[\\n ]*:[\\n ]*\\{[\\n ]*"street_address"[\\n ]*:[\\n ]*', - {"type": "string"}, - '[\\n ]*,[\\n ]*"city"[\\n ]*:[\\n ]*', - {"type": "string"}, - '[\\n ]*,[\\n ]*"state"[\\n ]*:[\\n ]*\\{[\\n ]*"name"[\\n ]*:[\\n ]*', - {"type": "string"}, - '[\\n ]*\\}[\\n ]*\\}[\\n ]*,[\\n ]*"billing_address"[\\n ]*:[\\n ]*\\{[\\n ]*"street_address"[\\n ]*:[\\n ]*', - {"type": "string"}, - '[\\n ]*,[\\n ]*"city"[\\n ]*:[\\n ]*', - {"type": "string"}, - '[\\n ]*,[\\n ]*"state"[\\n ]*:[\\n ]*\\{[\\n ]*"name"[\\n ]*:[\\n ]*', - {"type": "string"}, - "[\\n ]*\\}[\\n ]*\\}[\\n ]*\\}", - ] - - -class MockTokenizer: - pad_token_id = 0 - eos_token_id = 0 - - -class MockModel: - tokenizer = MockTokenizer() - device = "cpu" + schedule = build_regex_from_schema(schema) + assert isinstance(schedule, str) @pytest.mark.parametrize( "pattern,does_match", [ - ("0", True), - ("1", True), - ("-1", False), - ("01", False), - ("1.3", False), - ("t", False), + ({"integer": "0"}, True), + ({"integer": "1"}, True), + ({"integer": "-1"}, False), + ({"integer": "01"}, False), + ({"integer": "1.3"}, False), + ({"integer": "t"}, False), ], ) def test_match_integer(pattern, does_match): step = {"title": "Foo", "type": "integer"} - regex = match_step_to_regex(step) + regex = to_regex(None, step) assert regex == INTEGER - match = re.fullmatch(regex, pattern) + value = pattern["integer"] + match = re.fullmatch(regex, value) if does_match: - assert match[0] == pattern - assert match.span() == (0, len(pattern)) + assert match[0] == value + assert match.span() == (0, len(value)) else: assert match is None @@ -411,47 +58,64 @@ def test_match_integer(pattern, does_match): @pytest.mark.parametrize( "pattern,does_match", [ - ("1", True), - ("0", True), - ("01", False), - (".3", False), - ("1.3", True), - ("-1.3", True), - ("1.3e9", False), - ("1.3e+9", True), + ({"number": "1"}, True), + ({"number": "0"}, True), + ({"number": "01"}, False), + ({"number": ".3"}, False), + ({"number": "1.3"}, True), + ({"number": "-1.3"}, True), + ({"number": "1.3e9"}, False), + ({"number": "1.3e+9"}, True), ], ) def test_match_number(pattern, does_match): step = {"title": "Foo", "type": "number"} - regex = match_step_to_regex(step) + regex = to_regex(None, step) assert regex == NUMBER - match = re.fullmatch(regex, pattern) + value = pattern["number"] + match = re.fullmatch(regex, value) if does_match: - assert match[0] == pattern - assert match.span() == (0, len(pattern)) + assert match[0] == value + assert match.span() == (0, len(value)) else: assert match is None @pytest.mark.parametrize( - "step,regex,examples", + "schema,regex,examples", [ + # String ( {"title": "Foo", "type": "string"}, STRING, [("unquotedstring", False), ('"quoted_string"', True)], ), + # String with maximum length ( {"title": "Foo", "type": "string", "maxLength": 3}, f'"{STRING_INNER}{{,3}}"', [('"ab"', True), ('"a""', False), ('"abcd"', False)], ), + # String with minimum length ( {"title": "Foo", "type": "string", "minLength": 3}, f'"{STRING_INNER}{{3,}}"', [('"ab"', False), ('"abcd"', True), ('"abc""', False)], ), + # String with both minimum and maximum length + ( + {"title": "Foo", "type": "string", "minLength": 3, "maxLength": 5}, + f'"{STRING_INNER}{{3,5}}"', + [('"ab"', False), ('"abcd"', True), ('"abcdef""', False)], + ), + # String defined by a regular expression + ( + {"title": "Foo", "type": "string", "pattern": r"^[a-z]$"}, + r'(^"[a-z]"$)', + [('"a"', True), ('"1"', False)], + ), + # Boolean ( {"title": "Foo", "type": "boolean"}, BOOLEAN, @@ -462,6 +126,7 @@ def test_match_number(pattern, does_match): ("0", False), ], ), + # Null ( {"title": "Foo", "type": "null"}, NULL, @@ -471,31 +136,25 @@ def test_match_number(pattern, does_match): ("0", False), ], ), - ( - {"title": "Foo", "anyOf": [{"type": "string"}, {"type": "number"}]}, - f"({STRING}|{NUMBER})", - [ - ('"string"', True), - ('"st"ring"', False), - ("1000", True), - ("true", False), - ], - ), + # Enum string ( {"title": "Foo", "enum": ["Marc", "Jean"], "type": "string"}, '("Marc"|"Jean")', [('"Marc"', True), ('"Jean"', True), ('"John"', False)], ), + # Make sure strings are escaped ( {"title": "Foo", "enum": [".*", r"\s*"], "type": "string"}, r'("\.\*"|"\\s\*")', [('".*"', True), (r'"\s*"', True), (r'"\.\*"', False)], ), + # Enum integer ( {"title": "Foo", "enum": [0, 1], "type": "integer"}, "(0|1)", [("0", True), ("1", True), ("a", False)], ), + # integer ( { "title": "Foo", @@ -505,11 +164,13 @@ def test_match_number(pattern, does_match): '\\{[\\n ]*"count"[\\n ]*:[\\n ]*(0|[1-9][0-9]*)[\\n ]*\\}', [('{\n "count": 100\n}', True)], ), + # array ( {"title": "Foo", "type": "array", "items": {"type": "number"}}, rf"\[({NUMBER})(,({NUMBER}))*\]", [("[1e+9,1.3]", True)], ), + # anyOf ( { "title": "Foo", @@ -519,6 +180,7 @@ def test_match_number(pattern, does_match): r"\[(((true|false)|null))(,(((true|false)|null)))*\]", [("[true,null,false]", True)], ), + # Nested schema ( { "title": "Bar", @@ -534,11 +196,86 @@ def test_match_number(pattern, does_match): f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}', [('{\n "fuzz": {\n "spam": 100\n }\n}', True)], ), + # Schema with a reference + ( + { + "title": "User", + "type": "object", + "properties": { + "user_id": {"title": "User Id", "type": "integer"}, + "name": {"title": "Name", "type": "string"}, + "a": {"$ref": "#/properties/name"}, + }, + "required": ["user_id", "name"], + }, + f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + [('{"user_id": 100, "name": "John", "a": "Marc"}', True)], + ), + ( + { + "title": "User", + "type": "object", + "$defs": {"name": {"title": "Name2", "type": "string"}}, + "properties": { + "user_id": {"title": "User Id", "type": "integer"}, + "name": {"title": "Name", "type": "string"}, + "name2": {"$ref": "#/$defs/name"}, + }, + "required": ["user_id", "name"], + }, + f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)], + ), + ( + { + "$id": "customer", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Customer", + "type": "object", + "properties": { + "name": {"type": "string"}, + "last_name": {"type": "string"}, + "address": {"$ref": "customer#/$defs/address"}, + }, + "required": [ + "first_name", + "last_name", + "shipping_address", + "billing_address", + ], + "$defs": { + "address": { + "title": "Address", + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "city": {"type": "string"}, + }, + "required": ["street_address", "city", "state"], + "definitions": { + "state": { + "type": "object", + "title": "State", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + }, + } + }, + }, + f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"address"[\\n ]*:[\\n ]*\\{{[\\n ]*"city"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}[\\n ]*\\}}', + [ + ( + '{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}', + True, + ) + ], + ), ], ) -def test_match(step, regex, examples): - test_regex = match_step_to_regex(step) - +def test_match(schema, regex, examples): + schema = json.dumps(schema) + test_regex = build_regex_from_schema(schema) assert test_regex == regex for string, does_match in examples: From 39cf75512b480e56aeaf765a6276d868ff98ceff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 8 Nov 2023 14:34:31 +0100 Subject: [PATCH 2/6] Support `oneOf`, `anyOf` and `allOf` --- outlines/text/json_schema.py | 57 +++++++++++++++++++++++++++------- tests/text/test_json_schema.py | 25 ++++++++++++--- 2 files changed, 67 insertions(+), 15 deletions(-) diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index fe1578071..da7f4f377 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -1,3 +1,4 @@ +import itertools as it import json import re @@ -50,6 +51,25 @@ def build_regex_from_schema(schema: str): def to_regex(resolver: Resolver, instance: dict): + """Translate a JSON Schema instance into a regex that validates the schema. + + Note + ---- + Many features of JSON schema are missing: + - Support the fact that fields in an object are optional by default + - Handle `required` keyword + - Handle `additionalProperties` keyword + - Handle types defined as a list + - Handle constraints on numbers + - Handle special patterns: `date`, `uri`, etc. + + Parameters + ---------- + resolver + An object that resolves references to other instances within a schema + instance + The instance to translate + """ whitespace = r"[\n ]*" if "properties" in instance: @@ -67,16 +87,33 @@ def to_regex(resolver: Resolver, instance: dict): return regex - elif "oneOf" in instance: - print(instance) - + # To validate against allOf, the given data must be valid against all of the + # given subschemas. elif "allOf" in instance: - print(instance) + subregexes = [to_regex(resolver, t) for t in instance["allOf"]] + subregexes_str = [f"{subregex}" for subregex in subregexes] + return rf"({''.join(subregexes_str)})" + # To validate against `anyOf`, the given data must be valid against + # any (one or more) of the given subschemas. elif "anyOf" in instance: subregexes = [to_regex(resolver, t) for t in instance["anyOf"]] + combinations = [ + "(" + "".join(c) + ")" + for r in range(1, len(subregexes) + 1) + for c in it.permutations(subregexes, r) + ] + + return rf"({'|'.join(combinations)})" + + # To validate against oneOf, the given data must be valid against exactly + # one of the given subschemas. + elif "oneOf" in instance: + subregexes = [to_regex(resolver, t) for t in instance["oneOf"]] return rf"({'|'.join(subregexes)})" + # The enum keyword is used to restrict a value to a fixed set of values. It + # must be an array with at least one element, where each element is unique. elif "enum" in instance: if instance["type"] == "string": choices = [f'"{re.escape(choice)}"' for choice in instance["enum"]] @@ -90,9 +127,13 @@ def to_regex(resolver: Resolver, instance: dict): instance = resolver.lookup(path).contents return to_regex(resolver, instance) + # The type keyword may either be a string or an array: + # - If it's a string, it is the name of one of the basic types. + # - If it is an array, it must be an array of strings, where each string is + # the name of one of the basic types, and each element is unique. In this + # case, the JSON snippet is valid if it matches any of the given types. elif "type" in instance: type = instance["type"] - if type == "string": if "maxLength" in instance or "minLength" in instance: max_length = instance.get("maxLength", "") @@ -130,12 +171,6 @@ def to_regex(resolver: Resolver, instance: dict): elif type == "null": return type_to_regex["null"] - # elif isinstance(type, list): - # if "object" in type: - # expanded = to_regex(resolver, instance) - # return "" - # return "" - raise NotImplementedError( f"""Could not translate the instance {instance} to a regular expression. Make sure it is valid to the JSON Schema specification. If diff --git a/tests/text/test_json_schema.py b/tests/text/test_json_schema.py index 0069ca841..f3f2bc827 100644 --- a/tests/text/test_json_schema.py +++ b/tests/text/test_json_schema.py @@ -170,15 +170,32 @@ def test_match_number(pattern, does_match): rf"\[({NUMBER})(,({NUMBER}))*\]", [("[1e+9,1.3]", True)], ), + # oneOf + ( + { + "title": "Foo", + "oneOf": [{"type": "string"}, {"type": "number"}], + }, + rf"({STRING}|{NUMBER})", + [("12.3", True), ('"a"', True), ('1.3"a"', False)], + ), # anyOf ( { "title": "Foo", - "type": "array", - "items": {"anyOf": [{"type": "boolean"}, {"type": "null"}]}, + "anyOf": [{"type": "string"}, {"type": "integer"}], + }, + rf'(("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*")|((0|[1-9][0-9]*))|("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"(0|[1-9][0-9]*))|((0|[1-9][0-9]*)"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"))', + [("12", True), ('"a"', True), ('1"a"', True)], + ), + # allOf + ( + { + "title": "Foo", + "allOf": [{"type": "string"}, {"type": "integer"}], }, - r"\[(((true|false)|null))(,(((true|false)|null)))*\]", - [("[true,null,false]", True)], + rf"({STRING}{INTEGER})", + [('"a"1', True), ('"a"', False), ('"1"', False)], ), # Nested schema ( From e074a804f60710ec179e7bfd25e4bd61b1d00d0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 8 Nov 2023 17:23:01 +0100 Subject: [PATCH 3/6] Check the JSON Schema before translating it --- environment.yml | 1 + outlines/text/json_schema.py | 2 ++ pyproject.toml | 2 ++ 3 files changed, 5 insertions(+) diff --git a/environment.yml b/environment.yml index 14629af3e..56c58c5e0 100644 --- a/environment.yml +++ b/environment.yml @@ -17,6 +17,7 @@ dependencies: - pytest - pre-commit - referencing + - jsonschema - transformers - pip - pip: diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index da7f4f377..bac3e930d 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -2,6 +2,7 @@ import json import re +from jsonschema.protocols import Validator from referencing import Registry, Resource from referencing._core import Resolver from referencing.jsonschema import DRAFT202012 @@ -37,6 +38,7 @@ def build_regex_from_schema(schema: str): follows the schema. """ + Validator.check_schema(schema) schema = json.loads(schema) # Build reference resolver diff --git a/pyproject.toml b/pyproject.toml index 6ab6cf61a..7fc1dceba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "numba", "joblib", "referencing", + "jsonschema", ] dynamic = ["version"] @@ -86,6 +87,7 @@ exclude=["examples"] module = [ "jinja2", "joblib.*", + "jsonschema.*", "openai", "numpy.*", "perscache.*", From d2f1c9c4886e2d2bfd1a23b90c9be60049a7fde2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 8 Nov 2023 17:27:31 +0100 Subject: [PATCH 4/6] Support array of types and arrays without specified types --- outlines/text/json_schema.py | 28 ++++++++++++++++++++++++---- tests/text/test_json_schema.py | 2 +- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index bac3e930d..0c597b84d 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -48,8 +48,7 @@ def build_regex_from_schema(schema: str): resolver = registry.resolver() content = schema.contents - regex = to_regex(resolver, content) - return regex + return to_regex(resolver, content) def to_regex(resolver: Resolver, instance: dict): @@ -164,8 +163,22 @@ def to_regex(resolver: Resolver, instance: dict): return type_to_regex["integer"] elif type == "array": - items_regex = to_regex(resolver, instance["items"]) - return rf"\[({items_regex})(,({items_regex}))*\]" + if "items" in instance: + items_regex = to_regex(resolver, instance["items"]) + return rf"\[({items_regex})(,({items_regex}))*\]" + else: + # Here we need to make the choice to exclude generating list of objects + # if the specification of the object is not give, even though a JSON + # object that contains an object here would be valid under the specification. + types = [ + {"type": "boolean"}, + {"type": "null"}, + {"type": "number"}, + {"type": "integer"}, + {"type": "string"}, + ] + regexes = [to_regex(resolver, t) for t in types] + return rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)}))*\]" elif type == "boolean": return type_to_regex["boolean"] @@ -173,6 +186,13 @@ def to_regex(resolver: Resolver, instance: dict): elif type == "null": return type_to_regex["null"] + elif isinstance(type, list): + # Here we need to make the choice to exclude generating an object + # if the specification of the object is not give, even though a JSON + # object that contains an object here would be valid under the specification. + regexes = [to_regex(resolver, {"type": t}) for t in type if t != "object"] + return rf"({'|'.join(regexes)})" + raise NotImplementedError( f"""Could not translate the instance {instance} to a regular expression. Make sure it is valid to the JSON Schema specification. If diff --git a/tests/text/test_json_schema.py b/tests/text/test_json_schema.py index f3f2bc827..05255dc60 100644 --- a/tests/text/test_json_schema.py +++ b/tests/text/test_json_schema.py @@ -185,7 +185,7 @@ def test_match_number(pattern, does_match): "title": "Foo", "anyOf": [{"type": "string"}, {"type": "integer"}], }, - rf'(("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*")|((0|[1-9][0-9]*))|("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"(0|[1-9][0-9]*))|((0|[1-9][0-9]*)"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"))', + r'(("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*")|((0|[1-9][0-9]*))|("(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"(0|[1-9][0-9]*))|((0|[1-9][0-9]*)"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"))', [("12", True), ('"a"', True), ('1"a"', True)], ), # allOf From 2f877cfae77e6cfed99f17df6f9c05f8000cdc5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 8 Nov 2023 21:03:15 +0100 Subject: [PATCH 5/6] Support enums with different types --- outlines/text/json_schema.py | 37 +++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index 0c597b84d..614227531 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -63,6 +63,9 @@ def to_regex(resolver: Resolver, instance: dict): - Handle types defined as a list - Handle constraints on numbers - Handle special patterns: `date`, `uri`, etc. + - Handle optional fields (not in `required`) + + This does not support recursive definitions. Parameters ---------- @@ -116,12 +119,14 @@ def to_regex(resolver: Resolver, instance: dict): # The enum keyword is used to restrict a value to a fixed set of values. It # must be an array with at least one element, where each element is unique. elif "enum" in instance: - if instance["type"] == "string": - choices = [f'"{re.escape(choice)}"' for choice in instance["enum"]] - return f"({'|'.join(choices)})" - else: - choices = [re.escape(str(choice)) for choice in instance["enum"]] - return f"({'|'.join(choices)})" + choices = [] + for choice in instance["enum"]: + if type(choice) in [int, float, bool, None]: + choices.append(re.escape(str(choice))) + elif type(choice) == str: + choices.append(f'"{re.escape(choice)}"') + + return f"({'|'.join(choices)})" elif "$ref" in instance: path = f"{instance['$ref']}" @@ -134,8 +139,8 @@ def to_regex(resolver: Resolver, instance: dict): # the name of one of the basic types, and each element is unique. In this # case, the JSON snippet is valid if it matches any of the given types. elif "type" in instance: - type = instance["type"] - if type == "string": + instance_type = instance["type"] + if instance_type == "string": if "maxLength" in instance or "minLength" in instance: max_length = instance.get("maxLength", "") min_length = instance.get("minLength", "") @@ -156,13 +161,13 @@ def to_regex(resolver: Resolver, instance: dict): else: return type_to_regex["string"] - elif type == "number": + elif instance_type == "number": return type_to_regex["number"] - elif type == "integer": + elif instance_type == "integer": return type_to_regex["integer"] - elif type == "array": + elif instance_type == "array": if "items" in instance: items_regex = to_regex(resolver, instance["items"]) return rf"\[({items_regex})(,({items_regex}))*\]" @@ -180,17 +185,19 @@ def to_regex(resolver: Resolver, instance: dict): regexes = [to_regex(resolver, t) for t in types] return rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)}))*\]" - elif type == "boolean": + elif instance_type == "boolean": return type_to_regex["boolean"] - elif type == "null": + elif instance_type == "null": return type_to_regex["null"] - elif isinstance(type, list): + elif isinstance(instance_type, list): # Here we need to make the choice to exclude generating an object # if the specification of the object is not give, even though a JSON # object that contains an object here would be valid under the specification. - regexes = [to_regex(resolver, {"type": t}) for t in type if t != "object"] + regexes = [ + to_regex(resolver, {"type": t}) for t in instance_type if t != "object" + ] return rf"({'|'.join(regexes)})" raise NotImplementedError( From e2a9aa637c0e0417bba854873a89a63f5007c092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 10 Nov 2023 15:19:56 +0100 Subject: [PATCH 6/6] Support fixed-length arrays --- outlines/text/json_schema.py | 23 ++++++++++++++++------- tests/text/test_json_schema.py | 24 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index 614227531..f2fc351ed 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -142,16 +142,16 @@ def to_regex(resolver: Resolver, instance: dict): instance_type = instance["type"] if instance_type == "string": if "maxLength" in instance or "minLength" in instance: - max_length = instance.get("maxLength", "") - min_length = instance.get("minLength", "") + max_items = instance.get("maxLength", "") + min_items = instance.get("minLength", "") try: - if int(max_length) < int(min_length): + if int(max_items) < int(min_items): raise ValueError( "maxLength must be greater than or equal to minLength" ) except ValueError: pass - return f'"{STRING_INNER}{{{min_length},{max_length}}}"' + return f'"{STRING_INNER}{{{min_items},{max_items}}}"' elif "pattern" in instance: pattern = instance["pattern"] if pattern[0] == "^" and pattern[-1] == "$": @@ -168,12 +168,19 @@ def to_regex(resolver: Resolver, instance: dict): return type_to_regex["integer"] elif instance_type == "array": + min_items = instance.get("minItems", "0") + max_items = instance.get("maxItems", "") + if min_items == max_items: + num_repeats = "{" + str(int(min_items) - 1) + "}" + else: + num_repeats = "*" + if "items" in instance: items_regex = to_regex(resolver, instance["items"]) - return rf"\[({items_regex})(,({items_regex}))*\]" + return rf"\[({items_regex})(,({items_regex})){num_repeats}\]" else: # Here we need to make the choice to exclude generating list of objects - # if the specification of the object is not give, even though a JSON + # if the specification of the object is not given, even though a JSON # object that contains an object here would be valid under the specification. types = [ {"type": "boolean"}, @@ -183,7 +190,9 @@ def to_regex(resolver: Resolver, instance: dict): {"type": "string"}, ] regexes = [to_regex(resolver, t) for t in types] - return rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)}))*\]" + return ( + rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]" + ) elif instance_type == "boolean": return type_to_regex["boolean"] diff --git a/tests/text/test_json_schema.py b/tests/text/test_json_schema.py index 05255dc60..b84ac39fa 100644 --- a/tests/text/test_json_schema.py +++ b/tests/text/test_json_schema.py @@ -170,6 +170,30 @@ def test_match_number(pattern, does_match): rf"\[({NUMBER})(,({NUMBER}))*\]", [("[1e+9,1.3]", True)], ), + # array with a set length of 1 + ( + { + "title": "Foo", + "type": "array", + "items": {"type": "integer"}, + "minItems": 1, + "maxItems": 1, + }, + rf"\[({INTEGER})(,({INTEGER})){{0}}\]", + [("[1]", True), ("[1,2]", False), ('["a"]', False), ("[]", False)], + ), + # array with a set length greather than 1 + ( + { + "title": "Foo", + "type": "array", + "items": {"type": "integer"}, + "minItems": 3, + "maxItems": 3, + }, + rf"\[({INTEGER})(,({INTEGER})){{2}}\]", + [("[1]", False), ("[]", False), ("[1,2,3]", True), ("[1,2,3,4]", False)], + ), # oneOf ( {