Skip to content

Commit

Permalink
Introducelf.structured.complete with enhanced Python schema support.
Browse files Browse the repository at this point in the history
Usage:

```python
class Person:
  name: str
  birth_place: str
  birth_year: int

lf.complete(Person.partial('Larry Page'), lm=lf.llms.Gpt35())
```

This CL also:
- Update the default LM timeout from 30 to 120 seconds.
- `lf.MISSING` and `lf.UNKNOWN` for representing missing and unknown fields.

PiperOrigin-RevId: 566746300
  • Loading branch information
daiyip authored and langfun authors committed Sep 19, 2023
1 parent 03643b4 commit 0d572d0
Show file tree
Hide file tree
Showing 12 changed files with 1,239 additions and 255 deletions.
6 changes: 5 additions & 1 deletion langfun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
from langfun.core import *
from langfun.core import structured

Schema = structured.Schema
MISSING = structured.MISSING
UNKNOWN = structured.UNKNOWN

parse = structured.parse
query = structured.query
describe = structured.describe

complete = structured.complete

from langfun.core import templates
from langfun.core import transforms
Expand Down
2 changes: 1 addition & 1 deletion langfun/core/langfunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_call(self):
"LangFunc(template_str='Hello', clean=True, returns=None, "
'lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0, '
'max_tokens=1024, n=1, top_k=40, top_p=None, random_seed=None), '
'timeout=30.0, max_attempts=5, debug=False), input_transform=None, '
'timeout=120.0, max_attempts=5, debug=False), input_transform=None, '
'output_transform=None)',
)

Expand Down
2 changes: 1 addition & 1 deletion langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class LanguageModel(component.Component):

timeout: Annotated[
float | None, 'Timeout in seconds. If None, there is no timeout.'
] = 30.0
] = 120.0

max_attempts: Annotated[
int,
Expand Down
19 changes: 19 additions & 0 deletions langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,20 @@
# pylint: disable=g-bad-import-order
# pylint: disable=g-importing-member

from langfun.core.structured.schema import Missing
from langfun.core.structured.schema import MISSING
from langfun.core.structured.schema import Unknown
from langfun.core.structured.schema import UNKNOWN

from langfun.core.structured.schema import Schema
from langfun.core.structured.schema import SchemaProtocol
from langfun.core.structured.schema import schema_spec

from langfun.core.structured.schema import class_dependencies
from langfun.core.structured.schema import class_definition
from langfun.core.structured.schema import class_definitions
from langfun.core.structured.schema import annotation
from langfun.core.structured.schema import structure_from_python

from langfun.core.structured.schema import SchemaRepr
from langfun.core.structured.schema import SchemaJsonRepr
Expand All @@ -27,6 +40,7 @@
from langfun.core.structured.schema import schema_repr
from langfun.core.structured.schema import value_repr


from langfun.core.structured.mapping import Mapping
from langfun.core.structured.mapping import MappingExample
from langfun.core.structured.mapping import MappingError
Expand All @@ -45,5 +59,10 @@
from langfun.core.structured.structure2nl import DescribeStructure
from langfun.core.structured.structure2nl import describe

from langfun.core.structured.structure2structure import StructureToStructure
from langfun.core.structured.structure2structure import CompleteStructure
from langfun.core.structured.structure2structure import complete


# pylint: enable=g-importing-member
# pylint: enable=g-bad-import-order
25 changes: 10 additions & 15 deletions langfun/core/structured/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Mapping interfaces."""

import io
from typing import Annotated, Any, Literal
from typing import Annotated
import langfun.core as lf
from langfun.core.structured import schema as schema_lib
import pyglove as pg
Expand All @@ -36,9 +36,6 @@ def __ne__(self, other):
class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
"""Mapping example between text, schema and structured value."""

# Value marker for missing value in Mapping.
MISSING_VALUE = (pg.MISSING_VALUE,)

nl_context: Annotated[
str | None,
(
Expand Down Expand Up @@ -78,29 +75,27 @@ class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
),
] = lf.contextual(default=None)

value: Annotated[
Any,
value: pg.typing.Annotated[
pg.typing.Any(transform=schema_lib.mark_missing),
(
'The structured representation for `nl_text` (or directly prompted '
'from `nl_context`). It should align with schema.'
'`value` could be used as input when we map a structured value to '
'natural language, or as output when we map it reversely.'
)
] = MISSING_VALUE
),
] = schema_lib.MISSING

def schema_str(
self,
protocol: Literal['json', 'python'] = 'json',
**kwargs) -> str:
self, protocol: schema_lib.SchemaProtocol = 'json', **kwargs
) -> str:
"""Returns the string representation of schema based on protocol."""
if self.schema is None:
return ''
return self.schema.schema_str(protocol, **kwargs)

def value_str(
self,
protocol: Literal['json', 'python'] = 'json',
**kwargs) -> str:
self, protocol: schema_lib.SchemaProtocol = 'json', **kwargs
) -> str:
"""Returns the string representation of value based on protocol."""
return schema_lib.value_repr(
protocol).repr(self.value, self.schema, **kwargs)
Expand All @@ -122,7 +117,7 @@ def natural_language_format(self) -> str:
result.write(lf.colored(self.schema_str(), color='red'))
result.write('\n\n')

if MappingExample.MISSING_VALUE != self.value:
if schema_lib.MISSING != self.value:
result.write(lf.colored('[VALUE]\n', styles=['bold']))
result.write(lf.colored(self.value_str(), color='blue'))
return result.getvalue().strip()
Expand Down
12 changes: 6 additions & 6 deletions langfun/core/structured/mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

from langfun.core.structured import mapping
from langfun.core.structured import schema as schema_lib
import pyglove as pg


Expand Down Expand Up @@ -116,11 +117,11 @@ def test_str_no_schema(self):

def test_str_no_value(self):
self.assertEqual(
str(mapping.MappingExample(
'Give the answer.',
'1 + 1 = 2',
mapping.MappingExample.MISSING_VALUE,
int)),
str(
mapping.MappingExample(
'Give the answer.', '1 + 1 = 2', schema_lib.MISSING, int
)
),
inspect.cleandoc("""
\x1b[1m[CONTEXT]
\x1b[0m\x1b[35mGive the answer.\x1b[0m
Expand All @@ -144,4 +145,3 @@ def test_serialization(self):

if __name__ == '__main__':
unittest.main()

19 changes: 9 additions & 10 deletions langfun/core/structured/nl2structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class NaturalLanguageToStructure(mapping.Mapping):

preamble: Annotated[
lf.LangFunc,
'Preamble used for zeroshot jsonify.',
'Preamble used for natural language-to-structure mapping.',
]

nl_context_title: Annotated[
Expand All @@ -98,8 +98,8 @@ class NaturalLanguageToStructure(mapping.Mapping):
value_title: Annotated[str, 'The section title for schema.']

protocol: Annotated[
Literal['json', 'python'],
'The protocol for representing the schema and value.'
schema_lib.SchemaProtocol,
'The protocol for representing the schema and value.',
]

@property
Expand Down Expand Up @@ -223,7 +223,7 @@ def parse(
*,
user_prompt: str | None = None,
examples: list[mapping.MappingExample] | None = None,
protocol: Literal['json', 'python'] = 'json',
protocol: schema_lib.SchemaProtocol = 'json',
**kwargs,
) -> Any:
"""Parse a natural langugage message based on schema.
Expand Down Expand Up @@ -293,7 +293,7 @@ class Flight(pg.Object):


def _parse_structure_cls(
protocol: Literal['json', 'python']
protocol: schema_lib.SchemaProtocol,
) -> Type[ParseStructure]:
if protocol == 'json':
return ParseStructureJson
Expand All @@ -309,7 +309,7 @@ def as_structured(
default: Any = lf.message_transform.RAISE_IF_HAS_ERROR,
examples: list[mapping.Mapping] | None = None,
*,
protocol: Literal['json', 'python'] = 'json',
protocol: schema_lib.SchemaProtocol = 'json',
**kwargs,
):
"""Returns the structured representation of the message text.
Expand Down Expand Up @@ -441,7 +441,7 @@ class QueryStructurePython(QueryStructure):


def _query_structure_cls(
protocol: Literal['json', 'python']
protocol: schema_lib.SchemaProtocol,
) -> Type[QueryStructure]:
if protocol == 'json':
return QueryStructureJson
Expand Down Expand Up @@ -483,7 +483,7 @@ def query(
default: Any = lf.message_transform.RAISE_IF_HAS_ERROR,
*,
examples: list[mapping.MappingExample] | None = None,
protocol: Literal['json', 'python'] = 'json',
protocol: schema_lib.SchemaProtocol = 'json',
**kwargs,
) -> Any:
"""Parse a natural langugage message based on schema.
Expand Down Expand Up @@ -528,8 +528,7 @@ class Flight(pg.Object):
are 'json' and 'python'. By default the JSON representation will be used.
**kwargs: Keyword arguments passed to the
`lf.structured.NaturalLanguageToStructureed` transform, e.g. `lm` for
specifying the language model for structured parsing, `jsonify` for
customizing the prompt for structured parsing, and etc.
specifying the language model for structured parsing.
Returns:
The result based on the schema.
Expand Down
1 change: 0 additions & 1 deletion langfun/core/structured/nl2structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_render_no_examples(self):
m = lf.AIMessage('Bla bla bla 12 / 6 + 2 = 4.', result='12 / 6 + 2 = 4')
m.source = lf.UserMessage('Compute 12 / 6 + 2.', tags=['lm-input'])

print(l.render(message=m).text)
self.assertEqual(
l.render(message=m).text,
inspect.cleandoc("""
Expand Down
Loading

0 comments on commit 0d572d0

Please sign in to comment.