diff --git a/langfun/__init__.py b/langfun/__init__.py index 64b1d13..66dec9a 100644 --- a/langfun/__init__.py +++ b/langfun/__init__.py @@ -19,10 +19,14 @@ from langfun.core import * from langfun.core import structured +Schema = structured.Schema +MISSING = structured.MISSING +UNKNOWN = structured.UNKNOWN + parse = structured.parse query = structured.query describe = structured.describe - +complete = structured.complete from langfun.core import templates from langfun.core import transforms diff --git a/langfun/core/langfunc_test.py b/langfun/core/langfunc_test.py index 4643669..c1423fa 100644 --- a/langfun/core/langfunc_test.py +++ b/langfun/core/langfunc_test.py @@ -94,7 +94,7 @@ def test_call(self): "LangFunc(template_str='Hello', clean=True, returns=None, " 'lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0, ' 'max_tokens=1024, n=1, top_k=40, top_p=None, random_seed=None), ' - 'timeout=30.0, max_attempts=5, debug=False), input_transform=None, ' + 'timeout=120.0, max_attempts=5, debug=False), input_transform=None, ' 'output_transform=None)', ) diff --git a/langfun/core/language_model.py b/langfun/core/language_model.py index 48bae58..9a9fe11 100644 --- a/langfun/core/language_model.py +++ b/langfun/core/language_model.py @@ -93,7 +93,7 @@ class LanguageModel(component.Component): timeout: Annotated[ float | None, 'Timeout in seconds. If None, there is no timeout.' - ] = 30.0 + ] = 120.0 max_attempts: Annotated[ int, diff --git a/langfun/core/structured/__init__.py b/langfun/core/structured/__init__.py index afa43a4..e4b7c64 100644 --- a/langfun/core/structured/__init__.py +++ b/langfun/core/structured/__init__.py @@ -16,7 +16,20 @@ # pylint: disable=g-bad-import-order # pylint: disable=g-importing-member +from langfun.core.structured.schema import Missing +from langfun.core.structured.schema import MISSING +from langfun.core.structured.schema import Unknown +from langfun.core.structured.schema import UNKNOWN + from langfun.core.structured.schema import Schema +from langfun.core.structured.schema import SchemaProtocol +from langfun.core.structured.schema import schema_spec + +from langfun.core.structured.schema import class_dependencies +from langfun.core.structured.schema import class_definition +from langfun.core.structured.schema import class_definitions +from langfun.core.structured.schema import annotation +from langfun.core.structured.schema import structure_from_python from langfun.core.structured.schema import SchemaRepr from langfun.core.structured.schema import SchemaJsonRepr @@ -27,6 +40,7 @@ from langfun.core.structured.schema import schema_repr from langfun.core.structured.schema import value_repr + from langfun.core.structured.mapping import Mapping from langfun.core.structured.mapping import MappingExample from langfun.core.structured.mapping import MappingError @@ -45,5 +59,10 @@ from langfun.core.structured.structure2nl import DescribeStructure from langfun.core.structured.structure2nl import describe +from langfun.core.structured.structure2structure import StructureToStructure +from langfun.core.structured.structure2structure import CompleteStructure +from langfun.core.structured.structure2structure import complete + + # pylint: enable=g-importing-member # pylint: enable=g-bad-import-order diff --git a/langfun/core/structured/mapping.py b/langfun/core/structured/mapping.py index 300ecdf..8df144c 100644 --- a/langfun/core/structured/mapping.py +++ b/langfun/core/structured/mapping.py @@ -14,7 +14,7 @@ """Mapping interfaces.""" import io -from typing import Annotated, Any, Literal +from typing import Annotated import langfun.core as lf from langfun.core.structured import schema as schema_lib import pyglove as pg @@ -36,9 +36,6 @@ def __ne__(self, other): class MappingExample(lf.NaturalLanguageFormattable, lf.Component): """Mapping example between text, schema and structured value.""" - # Value marker for missing value in Mapping. - MISSING_VALUE = (pg.MISSING_VALUE,) - nl_context: Annotated[ str | None, ( @@ -78,29 +75,27 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component): ), ] = lf.contextual(default=None) - value: Annotated[ - Any, + value: pg.typing.Annotated[ + pg.typing.Any(transform=schema_lib.mark_missing), ( 'The structured representation for `nl_text` (or directly prompted ' 'from `nl_context`). It should align with schema.' '`value` could be used as input when we map a structured value to ' 'natural language, or as output when we map it reversely.' - ) - ] = MISSING_VALUE + ), + ] = schema_lib.MISSING def schema_str( - self, - protocol: Literal['json', 'python'] = 'json', - **kwargs) -> str: + self, protocol: schema_lib.SchemaProtocol = 'json', **kwargs + ) -> str: """Returns the string representation of schema based on protocol.""" if self.schema is None: return '' return self.schema.schema_str(protocol, **kwargs) def value_str( - self, - protocol: Literal['json', 'python'] = 'json', - **kwargs) -> str: + self, protocol: schema_lib.SchemaProtocol = 'json', **kwargs + ) -> str: """Returns the string representation of value based on protocol.""" return schema_lib.value_repr( protocol).repr(self.value, self.schema, **kwargs) @@ -122,7 +117,7 @@ def natural_language_format(self) -> str: result.write(lf.colored(self.schema_str(), color='red')) result.write('\n\n') - if MappingExample.MISSING_VALUE != self.value: + if schema_lib.MISSING != self.value: result.write(lf.colored('[VALUE]\n', styles=['bold'])) result.write(lf.colored(self.value_str(), color='blue')) return result.getvalue().strip() diff --git a/langfun/core/structured/mapping_test.py b/langfun/core/structured/mapping_test.py index 528d1e7..b0d9227 100644 --- a/langfun/core/structured/mapping_test.py +++ b/langfun/core/structured/mapping_test.py @@ -17,6 +17,7 @@ import unittest from langfun.core.structured import mapping +from langfun.core.structured import schema as schema_lib import pyglove as pg @@ -116,11 +117,11 @@ def test_str_no_schema(self): def test_str_no_value(self): self.assertEqual( - str(mapping.MappingExample( - 'Give the answer.', - '1 + 1 = 2', - mapping.MappingExample.MISSING_VALUE, - int)), + str( + mapping.MappingExample( + 'Give the answer.', '1 + 1 = 2', schema_lib.MISSING, int + ) + ), inspect.cleandoc(""" \x1b[1m[CONTEXT] \x1b[0m\x1b[35mGive the answer.\x1b[0m @@ -144,4 +145,3 @@ def test_serialization(self): if __name__ == '__main__': unittest.main() - diff --git a/langfun/core/structured/nl2structure.py b/langfun/core/structured/nl2structure.py index e93b19c..20e3257 100644 --- a/langfun/core/structured/nl2structure.py +++ b/langfun/core/structured/nl2structure.py @@ -80,7 +80,7 @@ class NaturalLanguageToStructure(mapping.Mapping): preamble: Annotated[ lf.LangFunc, - 'Preamble used for zeroshot jsonify.', + 'Preamble used for natural language-to-structure mapping.', ] nl_context_title: Annotated[ @@ -98,8 +98,8 @@ class NaturalLanguageToStructure(mapping.Mapping): value_title: Annotated[str, 'The section title for schema.'] protocol: Annotated[ - Literal['json', 'python'], - 'The protocol for representing the schema and value.' + schema_lib.SchemaProtocol, + 'The protocol for representing the schema and value.', ] @property @@ -223,7 +223,7 @@ def parse( *, user_prompt: str | None = None, examples: list[mapping.MappingExample] | None = None, - protocol: Literal['json', 'python'] = 'json', + protocol: schema_lib.SchemaProtocol = 'json', **kwargs, ) -> Any: """Parse a natural langugage message based on schema. @@ -293,7 +293,7 @@ class Flight(pg.Object): def _parse_structure_cls( - protocol: Literal['json', 'python'] + protocol: schema_lib.SchemaProtocol, ) -> Type[ParseStructure]: if protocol == 'json': return ParseStructureJson @@ -309,7 +309,7 @@ def as_structured( default: Any = lf.message_transform.RAISE_IF_HAS_ERROR, examples: list[mapping.Mapping] | None = None, *, - protocol: Literal['json', 'python'] = 'json', + protocol: schema_lib.SchemaProtocol = 'json', **kwargs, ): """Returns the structured representation of the message text. @@ -441,7 +441,7 @@ class QueryStructurePython(QueryStructure): def _query_structure_cls( - protocol: Literal['json', 'python'] + protocol: schema_lib.SchemaProtocol, ) -> Type[QueryStructure]: if protocol == 'json': return QueryStructureJson @@ -483,7 +483,7 @@ def query( default: Any = lf.message_transform.RAISE_IF_HAS_ERROR, *, examples: list[mapping.MappingExample] | None = None, - protocol: Literal['json', 'python'] = 'json', + protocol: schema_lib.SchemaProtocol = 'json', **kwargs, ) -> Any: """Parse a natural langugage message based on schema. @@ -528,8 +528,7 @@ class Flight(pg.Object): are 'json' and 'python'. By default the JSON representation will be used. **kwargs: Keyword arguments passed to the `lf.structured.NaturalLanguageToStructureed` transform, e.g. `lm` for - specifying the language model for structured parsing, `jsonify` for - customizing the prompt for structured parsing, and etc. + specifying the language model for structured parsing. Returns: The result based on the schema. diff --git a/langfun/core/structured/nl2structure_test.py b/langfun/core/structured/nl2structure_test.py index 67d5b3f..1f2aad2 100644 --- a/langfun/core/structured/nl2structure_test.py +++ b/langfun/core/structured/nl2structure_test.py @@ -41,7 +41,6 @@ def test_render_no_examples(self): m = lf.AIMessage('Bla bla bla 12 / 6 + 2 = 4.', result='12 / 6 + 2 = 4') m.source = lf.UserMessage('Compute 12 / 6 + 2.', tags=['lm-input']) - print(l.render(message=m).text) self.assertEqual( l.render(message=m).text, inspect.cleandoc(""" diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index 3f6f022..ffc5594 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -14,9 +14,10 @@ """Schema for structured data.""" import abc +import inspect import io import typing -from typing import Any, Literal, Type +from typing import Any, Literal, Sequence, Type, Union import langfun.core as lf from langfun.core import coding as lf_coding import pyglove as pg @@ -62,6 +63,9 @@ def _parse_node(v) -> pg.typing.ValueSpec: return _parse_node(value) +SchemaProtocol = Literal['json', 'python'] + + class Schema(lf.NaturalLanguageFormattable, pg.Object): """Base class for structured data schema.""" @@ -73,26 +77,19 @@ class Schema(lf.NaturalLanguageFormattable, pg.Object): ), ] - def schema_str( - self, - protocol: Literal['json', 'python'] = 'json', - **kwargs) -> str: + def schema_str(self, protocol: SchemaProtocol = 'json', **kwargs) -> str: """Returns the representation of the schema.""" return schema_repr(protocol).repr(self, **kwargs) def value_str( - self, - value: Any, - protocol: Literal['json', 'python'] = 'json', - **kwargs) -> str: + self, value: Any, protocol: SchemaProtocol = 'json', **kwargs + ) -> str: """Returns the representation of a structured value.""" return value_repr(protocol).repr(value, self, **kwargs) def parse( - self, - text: str, - protocol: Literal['json', 'python'] = 'json', - **kwargs) -> Any: + self, text: str, protocol: SchemaProtocol = 'json', **kwargs + ) -> Any: """Parse a LM generated text into a structured value.""" return self.spec.apply(value_repr(protocol).parse(text, self, **kwargs)) @@ -129,58 +126,8 @@ def _node(vs: pg.typing.ValueSpec) -> Any: def class_dependencies( self, include_subclasses: bool = True) -> list[Type[Any]]: - """Returns a set of dependent classes for a ValueSpec.""" - seen = set() - dependencies = [] - - def _add_dependency(cls_or_classes): - if isinstance(cls_or_classes, type): - cls_or_classes = [cls_or_classes] - for cls in cls_or_classes: - if cls not in seen: - dependencies.append(cls) - seen.add(cls) - - def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool): - if isinstance(vs, pg.typing.Object): - if issubclass(vs.cls, pg.Object) and vs.cls not in seen: - # Add base classes as dependencies. - for base_cls in vs.cls.__bases__: - # We only keep track of user-defined symbolic classes. - if issubclass( - base_cls, pg.Object - ) and not base_cls.__module__.startswith('pyglove'): - _fill_dependencies( - pg.typing.Object(base_cls), include_subclasses=False) - - # Add members as dependencies. - if hasattr(vs.cls, '__schema__'): - for field in vs.cls.__schema__.values(): - _fill_dependencies(field.value, include_subclasses) - _add_dependency(vs.cls) - - # Check subclasses if available. - if include_subclasses: - for _, cls in pg.object_utils.registered_types(): - if issubclass(cls, vs.cls) and cls not in dependencies: - _fill_dependencies( - pg.typing.Object(cls), include_subclasses=True) - - if isinstance(vs, pg.typing.List): - _fill_dependencies( - vs.element.value, include_subclasses) - elif isinstance(vs, pg.typing.Tuple): - for elem in vs.elements: - _fill_dependencies(elem.value, include_subclasses) - elif isinstance(vs, pg.typing.Dict) and vs.schema: - for v in vs.schema.values(): - _fill_dependencies(v, include_subclasses) - elif isinstance(vs, pg.typing.Union): - for v in vs.candidates: - _fill_dependencies(v, include_subclasses) - - _fill_dependencies(self.spec, include_subclasses) - return dependencies + """Returns a list of class dependencies for current schema.""" + return class_dependencies(self.spec, include_subclasses) @classmethod def from_value(cls, value) -> 'Schema': @@ -190,6 +137,103 @@ def from_value(cls, value) -> 'Schema': return cls(parse_value_spec(value)) +def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]: + """Returns a list of top level value specs from a symbolic value.""" + top_level_object_specs = [] + + def _collect_top_level_object_specs(k, v, p): + del k, p + if isinstance(v, pg.Object): + top_level_object_specs.append(pg.typing.Object(v.__class__)) + return pg.TraverseAction.CONTINUE + return pg.TraverseAction.ENTER + + pg.traverse(value, _collect_top_level_object_specs) + return top_level_object_specs + + +def class_dependencies( + value_or_spec: Union[ + pg.Symbolic, + Schema, + pg.typing.ValueSpec, + Type[pg.Object], + tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...], + ], + include_subclasses: bool = True, +) -> list[Type[Any]]: + """Returns a list of class dependencies from a value or specs.""" + if isinstance(value_or_spec, Schema): + return value_or_spec.class_dependencies(include_subclasses) + + if isinstance(value_or_spec, (pg.typing.ValueSpec, pg.symbolic.ObjectMeta)): + value_or_spec = (value_or_spec,) + + if isinstance(value_or_spec, tuple): + value_specs = [] + for v in value_or_spec: + if isinstance(v, pg.typing.ValueSpec): + value_specs.append(v) + elif inspect.isclass(v) and issubclass(v, pg.Object): + value_specs.append(pg.typing.Object(v)) + else: + raise TypeError(f'Unsupported spec type: {v!r}') + else: + value_specs = _top_level_object_specs_from_value(value_or_spec) + + seen = set() + dependencies = [] + + def _add_dependency(cls_or_classes): + if isinstance(cls_or_classes, type): + cls_or_classes = [cls_or_classes] + for cls in cls_or_classes: + if cls not in seen: + dependencies.append(cls) + seen.add(cls) + + def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool): + if isinstance(vs, pg.typing.Object): + if issubclass(vs.cls, pg.Object) and vs.cls not in seen: + # Add base classes as dependencies. + for base_cls in vs.cls.__bases__: + # We only keep track of user-defined symbolic classes. + if issubclass( + base_cls, pg.Object + ) and not base_cls.__module__.startswith('pyglove'): + _fill_dependencies( + pg.typing.Object(base_cls), include_subclasses=False + ) + + # Add members as dependencies. + if hasattr(vs.cls, '__schema__'): + for field in vs.cls.__schema__.values(): + _fill_dependencies(field.value, include_subclasses) + _add_dependency(vs.cls) + + # Check subclasses if available. + if include_subclasses: + for _, cls in pg.object_utils.registered_types(): + if issubclass(cls, vs.cls) and cls not in dependencies: + _fill_dependencies(pg.typing.Object(cls), include_subclasses=True) + + if isinstance(vs, pg.typing.List): + _fill_dependencies(vs.element.value, include_subclasses) + elif isinstance(vs, pg.typing.Tuple): + for elem in vs.elements: + _fill_dependencies(elem.value, include_subclasses) + elif isinstance(vs, pg.typing.Dict) and vs.schema: + for v in vs.schema.values(): + _fill_dependencies(v, include_subclasses) + elif isinstance(vs, pg.typing.Union): + for v in vs.candidates: + _fill_dependencies(v, include_subclasses) + + for value_spec in value_specs: + _fill_dependencies(value_spec, include_subclasses) + return dependencies + + def schema_spec(noneable: bool = False) -> pg.typing.ValueSpec: # pylint: disable=unused-argument if typing.TYPE_CHECKING: return Any @@ -221,142 +265,147 @@ def repr(self, schema: Schema) -> str: ret += f'\n\n```python\n{class_definition_str}```' return ret - def class_definitions( - self, - schema: Schema, - include_subclasses: bool = True) -> str | None: - """Returns a str for class definitions from a schema.""" - class_dependencies = schema.class_dependencies( - include_subclasses=include_subclasses) - if not class_dependencies: - return None - def_str = io.StringIO() - for i, cls in enumerate(class_dependencies): - if i > 0: - def_str.write('\n') - def_str.write(self.class_def(cls)) - return def_str.getvalue() + def class_definitions(self, schema: Schema) -> str | None: + deps = schema.class_dependencies(include_subclasses=True) + return class_definitions(deps) def result_definition(self, schema: Schema) -> str: - return self.annotate(schema.spec) - - def class_def(self, cls, strict: bool = False) -> str: - """Returns the Python class definition.""" - out = io.StringIO() - if not issubclass(cls, pg.Object): - raise TypeError( - 'Classes must be `pg.Object` subclasses to be used as schema. ' - f'Encountered: {cls}.' - ) - schema = cls.__schema__ - eligible_bases = [] - for base_cls in cls.__bases__: - if issubclass( - base_cls, pg.Symbolic - ) and not base_cls.__module__.startswith('pyglove'): - eligible_bases.append(base_cls.__name__) - if eligible_bases: - base_cls_str = ', '.join(eligible_bases) - out.write(f'class {cls.__name__}({base_cls_str}):\n') - else: - out.write(f'class {cls.__name__}:\n') - - if schema.fields: - for key, field in schema.items(): - if not isinstance(key, pg.typing.ConstStrKey): - raise TypeError( - 'Variable-length keyword arguments is not supported in ' - f'structured parsing or query. Encountered: {field}' - ) - out.write( - f' {field.key}: {self.annotate(field.value, strict=strict)}\n') - else: - out.write(' pass\n') - return out.getvalue() - - def annotate( - self, - vs: pg.typing.ValueSpec, - annotate_optional: bool = True, - strict: bool = False, - ) -> str: - """Returns the annotation string for a value spec.""" - if isinstance(vs, pg.typing.Any): - return 'Any' - elif isinstance(vs, pg.typing.Enum): - candidate_str = ', '.join([repr(v) for v in vs.values]) - return f'Literal[{candidate_str}]' - elif isinstance(vs, pg.typing.Union): - candidate_str = ', '.join( - [ - self.annotate(c, annotate_optional=False, strict=strict) - for c in vs.candidates - ] - ) - if vs.is_noneable: - candidate_str += ', None' - return f'Union[{candidate_str}]' - - if isinstance(vs, pg.typing.Bool): - x = 'bool' - elif isinstance(vs, pg.typing.Str): - if vs.regex is None: - x = 'str' - else: - if strict: - x = f"pg.typing.Str(regex='{vs.regex.pattern}')" - else: - x = f"str(regex='{vs.regex.pattern}')" - elif isinstance(vs, pg.typing.Number): - constraints = [] - min_label = 'min_value' if strict else 'min' - max_label = 'max_value' if strict else 'max' - if vs.min_value is not None: - constraints.append(f'{min_label}={vs.min_value}') - if vs.max_value is not None: - constraints.append(f'{max_label}={vs.max_value}') - x = 'int' if isinstance(vs, pg.typing.Int) else 'float' - if constraints: - if strict: - x = ( - 'pg.typing.Int' - if isinstance(vs, pg.typing.Int) - else 'pg.typing.Float' - ) - x += '(' + ', '.join(constraints) + ')' - elif isinstance(vs, pg.typing.Object): - x = vs.cls.__name__ - elif isinstance(vs, pg.typing.List): - item_str = self.annotate(vs.element.value, strict=strict) - x = f'list[{item_str}]' - elif isinstance(vs, pg.typing.Tuple): - elem_str = ', '.join( - [self.annotate(el.value, strict=strict) for el in vs.elements] - ) - x = f'tuple[{elem_str}]' - elif isinstance(vs, pg.typing.Dict): - kv_pairs = None - if vs.schema is not None: - kv_pairs = [ - (k, self.annotate(f.value, strict=strict)) - for k, f in vs.schema.items() - if isinstance(k, pg.typing.ConstStrKey) + return annotation(schema.spec) + + +def class_definitions( + classes: Sequence[Type[Any]], strict: bool = False, markdown: bool = False +) -> str | None: + """Returns a str for class definitions.""" + if not classes: + return None + def_str = io.StringIO() + for i, cls in enumerate(classes): + if i > 0: + def_str.write('\n') + def_str.write(class_definition(cls, strict)) + ret = def_str.getvalue() + if markdown and ret: + ret = f'```python\n{ret}```' + return ret + + +def class_definition(cls, strict: bool = False) -> str: + """Returns the Python class definition.""" + out = io.StringIO() + if not issubclass(cls, pg.Object): + raise TypeError( + 'Classes must be `pg.Object` subclasses to be used as schema. ' + f'Encountered: {cls}.' + ) + schema = cls.__schema__ + eligible_bases = [] + for base_cls in cls.__bases__: + if issubclass(base_cls, pg.Symbolic) and not base_cls.__module__.startswith( + 'pyglove' + ): + eligible_bases.append(base_cls.__name__) + if eligible_bases: + base_cls_str = ', '.join(eligible_bases) + out.write(f'class {cls.__name__}({base_cls_str}):\n') + else: + out.write(f'class {cls.__name__}:\n') + + if schema.fields: + for key, field in schema.items(): + if not isinstance(key, pg.typing.ConstStrKey): + raise TypeError( + 'Variable-length keyword arguments is not supported in ' + f'structured parsing or query. Encountered: {field}' + ) + out.write(f' {field.key}: {annotation(field.value, strict=strict)}\n') + else: + out.write(' pass\n') + return out.getvalue() + + +def annotation( + vs: pg.typing.ValueSpec, + annotate_optional: bool = True, + strict: bool = False, +) -> str: + """Returns the annotation string for a value spec.""" + if isinstance(vs, pg.typing.Any): + return 'Any' + elif isinstance(vs, pg.typing.Enum): + candidate_str = ', '.join([repr(v) for v in vs.values]) + return f'Literal[{candidate_str}]' + elif isinstance(vs, pg.typing.Union): + candidate_str = ', '.join( + [ + annotation(c, annotate_optional=False, strict=strict) + for c in vs.candidates ] - - if kv_pairs: - kv_str = ', '.join(f'\'{k}\': {v}' for k, v in kv_pairs) - x = '{' + kv_str + '}' - if strict: - x = f'pg.typing.Dict({x})' + ) + if vs.is_noneable: + candidate_str += ', None' + return f'Union[{candidate_str}]' + + if isinstance(vs, pg.typing.Bool): + x = 'bool' + elif isinstance(vs, pg.typing.Str): + if vs.regex is None: + x = 'str' + else: + if strict: + x = f"pg.typing.Str(regex='{vs.regex.pattern}')" else: - x = 'dict[str, Any]' - + x = f"str(regex='{vs.regex.pattern}')" + elif isinstance(vs, pg.typing.Number): + constraints = [] + min_label = 'min_value' if strict else 'min' + max_label = 'max_value' if strict else 'max' + if vs.min_value is not None: + constraints.append(f'{min_label}={vs.min_value}') + if vs.max_value is not None: + constraints.append(f'{max_label}={vs.max_value}') + x = 'int' if isinstance(vs, pg.typing.Int) else 'float' + if constraints: + if strict: + x = ( + 'pg.typing.Int' + if isinstance(vs, pg.typing.Int) + else 'pg.typing.Float' + ) + x += '(' + ', '.join(constraints) + ')' + elif isinstance(vs, pg.typing.Object): + x = vs.cls.__name__ + elif isinstance(vs, pg.typing.List): + item_str = annotation(vs.element.value, strict=strict) + x = f'list[{item_str}]' + elif isinstance(vs, pg.typing.Tuple): + elem_str = ', '.join( + [annotation(el.value, strict=strict) for el in vs.elements] + ) + x = f'tuple[{elem_str}]' + elif isinstance(vs, pg.typing.Dict): + kv_pairs = None + if vs.schema is not None: + kv_pairs = [ + (k, annotation(f.value, strict=strict)) + for k, f in vs.schema.items() + if isinstance(k, pg.typing.ConstStrKey) + ] + + if kv_pairs: + kv_str = ', '.join(f"'{k}': {v}" for k, v in kv_pairs) + x = '{' + kv_str + '}' + if strict: + x = f'pg.typing.Dict({x})' else: - raise TypeError(f'Unsupported value spec being used as schema: {vs}.') + x = 'dict[str, Any]' - if annotate_optional and vs.is_noneable: - x += ' | None' - return x + else: + raise TypeError(f'Unsupported value spec being used as schema: {vs}.') + + if annotate_optional and vs.is_noneable: + x += ' | None' + return x class SchemaJsonRepr(SchemaRepr): @@ -442,25 +491,41 @@ def repr(self, return f'```python\n{ object_code }\n```' return object_code - def parse(self, - text: str, - schema: Schema | None = None, - **kwargs) -> Any: - context = { - 'pg': pg, - 'Any': typing.Any, - 'List': typing.List, - 'Tuple': typing.Tuple, - 'Dict': typing.Dict, - 'Sequence': typing.Sequence, - 'Optional': typing.Optional, - 'Union': typing.Union, - } + def parse( + self, + text: str, + schema: Schema | None = None, + *, + additional_context: dict[str, Type[Any]] | None = None, + **kwargs, + ) -> Any: + """Parse a Python string into a structured object.""" + del kwargs + context = additional_context or {} if schema is not None: dependencies = schema.class_dependencies() context.update({d.__name__: d for d in dependencies}) - with lf_coding.context(**context): - return lf_coding.run(text)['__result__'] + return structure_from_python(text, **context) + + +def structure_from_python(code: str, **symbols) -> Any: + """Evaluates structure from Python code with access to symbols.""" + context = { + 'pg': pg, + 'Any': typing.Any, + 'List': typing.List, + 'Tuple': typing.Tuple, + 'Dict': typing.Dict, + 'Sequence': typing.Sequence, + 'Optional': typing.Optional, + 'Union': typing.Union, + # Special value markers. + 'UNKNOWN': UNKNOWN, + } + if symbols: + context.update(symbols) + with lf_coding.context(**context): + return lf_coding.run(code)['__result__'] class ValueJsonRepr(ValueRepr): @@ -532,7 +597,7 @@ def _cleanup_json(self, json_str: str) -> str: return cleaned.getvalue() -def schema_repr(protocol: Literal['json', 'python']) -> SchemaRepr: +def schema_repr(protocol: SchemaProtocol) -> SchemaRepr: """Gets a SchemaRepr object from protocol.""" if protocol == 'json': return SchemaJsonRepr() @@ -541,9 +606,92 @@ def schema_repr(protocol: Literal['json', 'python']) -> SchemaRepr: raise ValueError(f'Unsupported protocol: {protocol}.') -def value_repr(protocol: Literal['json', 'python']) -> ValueRepr: +def value_repr(protocol: SchemaProtocol) -> ValueRepr: if protocol == 'json': return ValueJsonRepr() elif protocol == 'python': return ValuePythonRepr() raise ValueError(f'Unsupported protocol: {protocol}.') + + +# +# Special value markers. +# + + +class Missing(pg.Object, pg.typing.CustomTyping): + """Value marker for a missing field. + + This class differs from pg.MISSING_VALUE in two aspects: + * When a field is assigned with lf.Missing(), it's considered non-partial. + * lf.Missing() could format the value spec as Python annotations that are + consistent with `lf.structured.Schema.schema_repr()`. + """ + + def _on_bound(self): + super()._on_bound() + self._value_spec = None + + @property + def value_spec(self) -> pg.ValueSpec | None: + """Returns the value spec that applies to the current missing value.""" + return self._value_spec + + def custom_apply( + self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs + ) -> tuple[bool, Any]: + self._value_spec = value_spec + return (False, self) + + def format(self, *args, **kwargs) -> str: + if self._value_spec is None: + return 'MISSING' + return f'MISSING({annotation(self._value_spec)})' + + @classmethod + def find_missing(cls, value: Any) -> dict[str, 'Missing']: + """Lists all missing values contained in the value.""" + missing = {} + + def _visit(k, v, p): + del p + if isinstance(v, Missing): + missing[k] = v + return pg.TraverseAction.ENTER + + pg.traverse(value, _visit) + return missing + + +MISSING = Missing() + + +def mark_missing(value: Any) -> Any: + """Replaces pg.MISSING within the value with lf.structured.Missing objects.""" + if isinstance(value, list): + value = pg.List(value) + elif isinstance(value, dict): + value = pg.Dict(value) + if isinstance(value, pg.Symbolic): + + def _mark_missing(k, v, p): + del k, p + if pg.MISSING_VALUE == v: + v = Missing() + return v + + return value.rebind(_mark_missing, raise_on_no_change=False) + return value + + +class Unknown(pg.Object, pg.typing.CustomTyping): + """Value marker for a field that LMs could not provide.""" + + def custom_apply(self, *args, **kwargs) -> tuple[bool, Any]: + return (False, self) + + def format(self, *args, **kwargs) -> str: + return 'UNKNOWN' + + +UNKNOWN = Unknown() diff --git a/langfun/core/structured/schema_test.py b/langfun/core/structured/schema_test.py index ff696ff..a9a03d9 100644 --- a/langfun/core/structured/schema_test.py +++ b/langfun/core/structured/schema_test.py @@ -160,6 +160,70 @@ def test_parse(self): schema.parse('1', protocol='text') +class ClassDependenciesTest(unittest.TestCase): + + def test_class_dependencies_from_specs(self): + class Foo(pg.Object): + x: int + + class Bar(pg.Object): + y: str + + class A(pg.Object): + foo: tuple[Foo, int] + + class X(pg.Object): + k: int + + class B(A): + bar: Bar + foo2: Foo | X + + self.assertEqual(schema_lib.class_dependencies(Foo), [Foo]) + + self.assertEqual( + schema_lib.class_dependencies((A,), include_subclasses=False), [Foo, A] + ) + + self.assertEqual( + schema_lib.class_dependencies(A, include_subclasses=True), + [Foo, A, Bar, X, B], + ) + + self.assertEqual( + schema_lib.class_dependencies(schema_lib.Schema(A)), [Foo, A, Bar, X, B] + ) + + self.assertEqual( + schema_lib.class_dependencies(pg.typing.Object(A)), [Foo, A, Bar, X, B] + ) + + with self.assertRaisesRegex(TypeError, 'Unsupported spec type'): + schema_lib.class_dependencies((Foo, 1)) + + def test_class_dependencies_from_value(self): + class Foo(pg.Object): + x: int + + class Bar(pg.Object): + y: str + + class A(pg.Object): + foo: tuple[Foo, int] + + class X(pg.Object): + k: int + + class B(A): + bar: Bar + foo2: Foo | X + + a = A(foo=(Foo(1), 0)) + self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, X, B]) + + self.assertEqual(schema_lib.class_dependencies(1), []) + + class SchemaPythonReprTest(unittest.TestCase): def assert_annotation( @@ -169,7 +233,7 @@ def assert_annotation( strict: bool = False, ) -> None: self.assertEqual( - schema_lib.SchemaPythonRepr().annotate(value_spec, strict=strict), + schema_lib.annotation(value_spec, strict=strict), expected_annotation, ) @@ -300,13 +364,13 @@ def test_annotation(self): self.assert_annotation(pg.typing.Any(), 'Any') self.assert_annotation(pg.typing.Any().noneable(), 'Any') - def test_class_def(self): + def test_class_definition(self): self.assertEqual( - schema_lib.SchemaPythonRepr().class_def(Activity), + schema_lib.class_definition(Activity), 'class Activity:\n description: str\n', ) self.assertEqual( - schema_lib.SchemaPythonRepr().class_def(Itinerary), + schema_lib.class_definition(Itinerary), inspect.cleandoc(""" class Itinerary: day: int(min=1) @@ -320,7 +384,7 @@ class A(pg.Object): pass self.assertEqual( - schema_lib.SchemaPythonRepr().class_def(A), + schema_lib.class_definition(A), 'class A:\n pass\n', ) @@ -329,7 +393,7 @@ class B: with self.assertRaisesRegex( TypeError, 'Classes must be `pg.Object` subclasses.*'): - schema_lib.SchemaPythonRepr().class_def(B) + schema_lib.class_definition(B) class C(pg.Object): x: str @@ -337,7 +401,7 @@ class C(pg.Object): with self.assertRaisesRegex( TypeError, 'Variable-length keyword arguments is not supported'): - schema_lib.SchemaPythonRepr().class_def(C) + schema_lib.class_definition(C) def test_repr(self): class Foo(pg.Object): @@ -551,5 +615,77 @@ def test_value_repr(self): schema_lib.value_repr('text') +class MissingTest(unittest.TestCase): + + def test_basics(self): + a = Itinerary( + day=1, + type=schema_lib.Missing(), + activities=schema_lib.Missing(), + hotel=schema_lib.Missing(), + ) + self.assertFalse(a.is_partial) + self.assertEqual(str(schema_lib.Missing()), 'MISSING') + self.assertEqual(str(a.type), "MISSING(Literal['daytime', 'nighttime'])") + self.assertEqual(str(a.activities), 'MISSING(list[Activity])') + self.assertEqual(str(a.hotel), "MISSING(str(regex='.*Hotel') | None)") + + def assert_missing(self, value, expected_missing): + value = schema_lib.mark_missing(value) + self.assertEqual(schema_lib.Missing.find_missing(value), expected_missing) + + def test_find_missing(self): + self.assert_missing( + Itinerary.partial(), + { + 'day': schema_lib.MISSING, + 'type': schema_lib.MISSING, + 'activities': schema_lib.MISSING, + }, + ) + + self.assert_missing( + Itinerary.partial( + day=1, type='daytime', activities=[Activity.partial()] + ), + { + 'activities[0].description': schema_lib.MISSING, + }, + ) + + def test_mark_missing(self): + class A(pg.Object): + x: typing.Any + + self.assertEqual(schema_lib.mark_missing(1), 1) + self.assertEqual( + schema_lib.mark_missing(pg.MISSING_VALUE), pg.MISSING_VALUE + ) + self.assertEqual( + schema_lib.mark_missing(A.partial(A.partial(A.partial()))), + A(A(A(schema_lib.MISSING))), + ) + self.assertEqual( + schema_lib.mark_missing(dict(a=A.partial())), + dict(a=A(schema_lib.MISSING)), + ) + self.assertEqual( + schema_lib.mark_missing([1, dict(a=A.partial())]), + [1, dict(a=A(schema_lib.MISSING))], + ) + + +class UnknownTest(unittest.TestCase): + + def test_basics(self): + class A(pg.Object): + x: int + + a = A(x=schema_lib.Unknown()) + self.assertFalse(a.is_partial) + self.assertEqual(a.x, schema_lib.UNKNOWN) + self.assertEqual(schema_lib.UNKNOWN, schema_lib.Unknown()) + + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/structured/structure2structure.py b/langfun/core/structured/structure2structure.py new file mode 100644 index 0000000..00ff514 --- /dev/null +++ b/langfun/core/structured/structure2structure.py @@ -0,0 +1,244 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Structure-to-structure mappings.""" + +from typing import Annotated, Any, Literal, Type + +import langfun.core as lf +from langfun.core.structured import mapping +from langfun.core.structured import schema as schema_lib +import pyglove as pg + + +class Pair(pg.Object): + """Value pair used for expressing structure-to-structure mapping.""" + + left: pg.typing.Annotated[ + pg.typing.Any(transform=schema_lib.mark_missing), 'The left-side value.' + ] + right: pg.typing.Annotated[ + pg.typing.Any(transform=schema_lib.mark_missing), 'The right-side value.' + ] + + +class StructureToStructure(mapping.Mapping): + """Base class for structure-to-structure mapping. + + {{ preamble }} + + {% if examples -%} + {% for example in examples -%} + {{ input_value_title }}: + {{ value_str(example.value.left) | indent(2, True) }} + + {%- if missing_type_dependencies(example.value) %} + + {{ type_definitions_title }}: + {{ type_definitions_str(example.value) | indent(2, True) }} + {%- endif %} + + {{ output_value_title }}: + {{ value_str(example.value.right) | indent(2, True) }} + + {% endfor %} + {% endif -%} + {{ input_value_title }}: + {{ value_str(input_value) | indent(2, True) }} + {%- if missing_type_dependencies(input_value) %} + + {{ type_definitions_title }}: + {{ type_definitions_str(input_value) | indent(2, True) }} + {%- endif %} + + {{ output_value_title }}: + """ + + default: Annotated[ + Any, + ( + 'The default value to use if mapping failed. ' + 'If unspecified, error will be raisen.' + ), + ] = lf.message_transform.RAISE_IF_HAS_ERROR + + preamble: Annotated[ + lf.LangFunc, + 'Preamble used for structure-to-structure mapping.', + ] + + type_definitions_title: Annotated[ + str, 'The section title for type definitions.' + ] = 'CLASS_DEFINITIONS' + + input_value_title: Annotated[str, 'The section title for input value.'] + output_value_title: Annotated[str, 'The section title for output value.'] + + def _on_bound(self): + super()._on_bound() + if self.examples: + for example in self.examples: + if not isinstance(example.value, Pair): + raise ValueError( + 'The value of example must be a `lf.structured.Pair` object. ' + f'Encountered: { example.value }.' + ) + + @property + def input_value(self) -> Any: + return schema_lib.mark_missing(self.message.result) + + def value_str(self, value: Any) -> str: + return schema_lib.value_repr('python').repr(value, compact=False) + + def missing_type_dependencies(self, value: Any) -> list[Type[Any]]: + value_specs = tuple( + [v.value_spec for v in schema_lib.Missing.find_missing(value).values()] + ) + return schema_lib.class_dependencies(value_specs, include_subclasses=True) + + def type_definitions_str(self, value: Any) -> str | None: + return schema_lib.class_definitions( + self.missing_type_dependencies(value), markdown=True + ) + + def _value_context(self): + classes = schema_lib.class_dependencies(self.input_value) + return {cls.__name__: cls for cls in classes} + + def transform_output(self, lm_output: lf.Message) -> lf.Message: + try: + result = schema_lib.value_repr('python').parse( + lm_output.text, additional_context=self._value_context() + ) + except Exception as e: # pylint: disable=broad-exception-caught + if self.default == lf.message_transform.RAISE_IF_HAS_ERROR: + raise mapping.MappingError( + 'Cannot parse message text into structured output. ' + f'Error={e}. Text={lm_output.text!r}.' + ) from e + result = self.default + lm_output.result = result + return lm_output + + +class CompleteStructure(StructureToStructure): + """Complete structure by filling the missing fields.""" + + preamble = lf.LangFunc(""" + Please generate the OUTPUT_OBJECT by completing the MISSING fields from the INPUT_OBJECT. + + INSTRUCTIONS: + 1. Each MISSING field contains a Python annotation, please fill the value based on the annotation. + 2. Classes for the MISSING fields are defined under CLASS_DEFINITIONS. + 3. If a MISSING field could not be filled, mark it as `UNKNOWN`. + """) + + # NOTE(daiyip): Set the input path of the transform to root, so this transform + # could access the input via the `message.result` field. + input_path = '' + + input_value_title = 'INPUT_OBJECT' + output_value_title = 'OUTPUT_OBJECT' + + +class _Date(pg.Object): + year: int + month: int + day: int + + +class _Country(pg.Object): + """Country.""" + + name: str + founding_date: _Date + continent: Literal[ + 'Africa', 'Asia', 'Europe', 'Oceania', 'North America', 'South America' + ] + population: int + hobby: str + + +def _default_complete_examples() -> list[mapping.MappingExample]: + return [ + mapping.MappingExample( + value=Pair( + left=_Country.partial(name='United States of America'), + right=_Country( + name='United States of America', + founding_date=_Date(year=1776, month=7, day=4), + continent='North America', + population=33_000_000, + hobby=schema_lib.UNKNOWN, + ), + ) + ) + ] + + +def complete( + value: pg.Symbolic, + default: Any = lf.message_transform.RAISE_IF_HAS_ERROR, + *, + examples: list[mapping.MappingExample] | None = None, + **kwargs, +) -> Any: + """Complete a symbolic value by filling its missing fields. + + Examples: + + ``` + class FlightDuration: + hours: int + minutes: int + + class Flight(pg.Object): + airline: str + flight_number: str + departure_airport_code: str + arrival_airport_code: str + departure_time: str + arrival_time: str + duration: FlightDuration + stops: int + price: float + + prompt = ''' + Information about flight UA2631. + ''' + + r = lf.query(prompt, Flight) + assert isinstance(r, Flight) + assert r.airline == 'United Airlines' + assert r.departure_airport_code == 'SFO' + assert r.duration.hour = 7 + ``` + + Args: + value: A symbolic value that may contain missing values. + default: The default value if parsing failed. If not specified, error will + be raised. + examples: An optional list of fewshot examples for helping parsing. If None, + the default one-shot example will be added. + **kwargs: Keyword arguments passed to the + `lf.structured.NaturalLanguageToStructureed` transform, e.g. `lm` for + specifying the language model for structured parsing. + + Returns: + The result based on the schema. + """ + if examples is None: + examples = _default_complete_examples() + t = CompleteStructure(default=default, examples=examples, **kwargs) + return t.transform(message=lf.UserMessage(text='', result=value)).result diff --git a/langfun/core/structured/structure2structure_test.py b/langfun/core/structured/structure2structure_test.py new file mode 100644 index 0000000..0134887 --- /dev/null +++ b/langfun/core/structured/structure2structure_test.py @@ -0,0 +1,440 @@ +# Copyright 2023 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for langfun.core.structured.structure2structure.""" + +import inspect +import unittest +import langfun.core as lf +from langfun.core.llms import fake +from langfun.core.structured import mapping +from langfun.core.structured import schema as schema_lib +from langfun.core.structured import structure2structure +import pyglove as pg + + +class Activity(pg.Object): + description: str + + +class Itinerary(pg.Object): + day: pg.typing.Int[1, None] + type: pg.typing.Enum['daytime', 'nighttime'] + activities: list[Activity] + hotel: pg.typing.Str['.*Hotel'] | None + + +class TripPlan(pg.Object): + place: str + itineraries: list[Itinerary] + + +class PairTest(unittest.TestCase): + + def test_partial(self): + p = structure2structure.Pair( + TripPlan.partial(place='San Francisco'), + TripPlan.partial(itineraries=[Itinerary.partial(day=1)]), + ) + self.assertEqual(p.left.itineraries, schema_lib.MISSING) + self.assertEqual(p.right.place, schema_lib.MISSING) + self.assertEqual(p.right.itineraries[0].activities, schema_lib.MISSING) + + +class CompleteStructureTest(unittest.TestCase): + + def test_render_no_examples(self): + l = structure2structure.CompleteStructure() + m = lf.UserMessage( + '', + result=TripPlan.partial( + place='San Francisco', + itineraries=[ + Itinerary.partial(day=1), + Itinerary.partial(day=2), + Itinerary.partial(day=3), + ], + ), + ) + + self.assertEqual( + l.render(message=m).text, + inspect.cleandoc(""" + Please generate the OUTPUT_OBJECT by completing the MISSING fields from the INPUT_OBJECT. + + INSTRUCTIONS: + 1. Each MISSING field contains a Python annotation, please fill the value based on the annotation. + 2. Classes for the MISSING fields are defined under CLASS_DEFINITIONS. + 3. If a MISSING field could not be filled, mark it as `UNKNOWN`. + + INPUT_OBJECT: + ```python + TripPlan( + place='San Francisco', + itineraries=[ + Itinerary( + day=1, + type=MISSING(Literal['daytime', 'nighttime']), + activities=MISSING(list[Activity]), + hotel=None + ), + Itinerary( + day=2, + type=MISSING(Literal['daytime', 'nighttime']), + activities=MISSING(list[Activity]), + hotel=None + ), + Itinerary( + day=3, + type=MISSING(Literal['daytime', 'nighttime']), + activities=MISSING(list[Activity]), + hotel=None + ) + ] + ) + ``` + + CLASS_DEFINITIONS: + ```python + class Activity: + description: str + ``` + + OUTPUT_OBJECT: + """), + ) + + def test_render_no_class_definitions(self): + l = structure2structure.CompleteStructure() + m = lf.UserMessage( + '', + result=TripPlan.partial( + place='San Francisco', + itineraries=[ + Itinerary.partial(day=1, activities=[Activity.partial()]), + Itinerary.partial(day=2, activities=[Activity.partial()]), + Itinerary.partial(day=3, activities=[Activity.partial()]), + ], + ), + ) + + self.assertEqual( + l.render(message=m).text, + inspect.cleandoc(""" + Please generate the OUTPUT_OBJECT by completing the MISSING fields from the INPUT_OBJECT. + + INSTRUCTIONS: + 1. Each MISSING field contains a Python annotation, please fill the value based on the annotation. + 2. Classes for the MISSING fields are defined under CLASS_DEFINITIONS. + 3. If a MISSING field could not be filled, mark it as `UNKNOWN`. + + INPUT_OBJECT: + ```python + TripPlan( + place='San Francisco', + itineraries=[ + Itinerary( + day=1, + type=MISSING(Literal['daytime', 'nighttime']), + activities=[ + Activity( + description=MISSING(str) + ) + ], + hotel=None + ), + Itinerary( + day=2, + type=MISSING(Literal['daytime', 'nighttime']), + activities=[ + Activity( + description=MISSING(str) + ) + ], + hotel=None + ), + Itinerary( + day=3, + type=MISSING(Literal['daytime', 'nighttime']), + activities=[ + Activity( + description=MISSING(str) + ) + ], + hotel=None + ) + ] + ) + ``` + + OUTPUT_OBJECT: + """), + ) + + def test_render_with_examples(self): + l = structure2structure.CompleteStructure( + examples=structure2structure._default_complete_examples() + ) + m = lf.UserMessage( + '', + result=TripPlan.partial( + place='San Francisco', + itineraries=[ + Itinerary.partial(day=1), + Itinerary.partial(day=2), + Itinerary.partial(day=3), + ], + ), + ) + + self.assertEqual( + l.render(message=m).text, + inspect.cleandoc(""" + Please generate the OUTPUT_OBJECT by completing the MISSING fields from the INPUT_OBJECT. + + INSTRUCTIONS: + 1. Each MISSING field contains a Python annotation, please fill the value based on the annotation. + 2. Classes for the MISSING fields are defined under CLASS_DEFINITIONS. + 3. If a MISSING field could not be filled, mark it as `UNKNOWN`. + + INPUT_OBJECT: + ```python + _Country( + name='United States of America', + founding_date=MISSING(_Date), + continent=MISSING(Literal['Africa', 'Asia', 'Europe', 'Oceania', 'North America', 'South America']), + population=MISSING(int), + hobby=MISSING(str) + ) + ``` + + CLASS_DEFINITIONS: + ```python + class _Date: + year: int + month: int + day: int + ``` + + OUTPUT_OBJECT: + ```python + _Country( + name='United States of America', + founding_date=_Date( + year=1776, + month=7, + day=4 + ), + continent='North America', + population=33000000, + hobby=UNKNOWN + ) + ``` + + + INPUT_OBJECT: + ```python + TripPlan( + place='San Francisco', + itineraries=[ + Itinerary( + day=1, + type=MISSING(Literal['daytime', 'nighttime']), + activities=MISSING(list[Activity]), + hotel=None + ), + Itinerary( + day=2, + type=MISSING(Literal['daytime', 'nighttime']), + activities=MISSING(list[Activity]), + hotel=None + ), + Itinerary( + day=3, + type=MISSING(Literal['daytime', 'nighttime']), + activities=MISSING(list[Activity]), + hotel=None + ) + ] + ) + ``` + + CLASS_DEFINITIONS: + ```python + class Activity: + description: str + ``` + + OUTPUT_OBJECT: + """), + ) + + def test_transform(self): + structured_response = inspect.cleandoc(""" + ```python + TripPlan( + place='San Francisco', + itineraries=[ + Itinerary( + day=1, + type='daytime', + activities=[ + Activity(description='Arrive in San Francisco and check into your hotel.'), + Activity(description='Take a walk around Fisherman\\'s Wharf and have dinner at one of the many seafood restaurants.'), + Activity(description='Visit Pier 39 and see the sea lions.'), + ], + hotel=None), + Itinerary( + day=2, + type='daytime', + activities=[ + Activity(description='Take a ferry to Alcatraz Island and tour the infamous prison.'), + Activity(description='Take a walk across the Golden Gate Bridge.'), + Activity(description='Visit the Japanese Tea Garden in Golden Gate Park.'), + ], + hotel=None), + Itinerary( + day=3, + type='daytime', + activities=[ + Activity(description='Visit the de Young Museum and see the collection of American art.'), + Activity(description='Visit the San Francisco Museum of Modern Art.'), + Activity(description='Take a cable car ride.'), + ], + hotel=None), + ] + ) + ``` + """) + + with lf.context( + lm=fake.StaticSequence( + [structured_response], + ), + override_attrs=True, + ): + r = structure2structure.complete( + TripPlan.partial( + place='San Francisco', + itineraries=[ + Itinerary.partial(day=1), + Itinerary.partial(day=2), + Itinerary.partial(day=3), + ], + ) + ) + # pylint: disable=line-too-long + self.assertEqual( + r, + TripPlan( + place='San Francisco', + itineraries=[ + Itinerary( + day=1, + type='daytime', + activities=[ + Activity( + description=( + 'Arrive in San Francisco and check into your' + ' hotel.' + ) + ), + Activity( + description=( + "Take a walk around Fisherman's Wharf and" + ' have dinner at one of the many seafood' + ' restaurants.' + ) + ), + Activity( + description='Visit Pier 39 and see the sea lions.' + ), + ], + hotel=None, + ), + Itinerary( + day=2, + type='daytime', + activities=[ + Activity( + description=( + 'Take a ferry to Alcatraz Island and tour the' + ' infamous prison.' + ) + ), + Activity( + description=( + 'Take a walk across the Golden Gate Bridge.' + ) + ), + Activity( + description=( + 'Visit the Japanese Tea Garden in Golden Gate' + ' Park.' + ) + ), + ], + hotel=None, + ), + Itinerary( + day=3, + type='daytime', + activities=[ + Activity( + description=( + 'Visit the de Young Museum and see the' + ' collection of American art.' + ) + ), + Activity( + description=( + 'Visit the San Francisco Museum of Modern' + ' Art.' + ) + ), + Activity(description='Take a cable car ride.'), + ], + hotel=None, + ), + ], + ), + ) + # pylint: enable=line-too-long + + def test_bad_init(self): + with self.assertRaisesRegex(ValueError, '.*must be.*Pair'): + structure2structure.CompleteStructure( + examples=[mapping.MappingExample(value=1)] + ) + + def test_bad_transform(self): + with lf.context( + lm=fake.StaticSequence(['Activity(description=1)']), + override_attrs=True, + ): + with self.assertRaisesRegex( + mapping.MappingError, + 'Cannot parse message text into structured output', + ): + structure2structure.complete(Activity.partial()) + + def test_default(self): + with lf.context( + lm=fake.StaticSequence(['Activity(description=1)']), + override_attrs=True, + ): + self.assertIsNone(structure2structure.complete(Activity.partial(), None)) + + +if __name__ == '__main__': + unittest.main()