Skip to content

Commit

Permalink
Use 'python' as the default protocol for structured mapping.
Browse files Browse the repository at this point in the history
This will impact `as_structured`/`lf.query`/`lf.parse`.

PiperOrigin-RevId: 566808741
  • Loading branch information
daiyip authored and langfun authors committed Sep 20, 2023
1 parent 26e4f49 commit 1ddea46
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
12 changes: 6 additions & 6 deletions langfun/core/structured/nl2structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def parse(
*,
user_prompt: str | None = None,
examples: list[mapping.MappingExample] | None = None,
protocol: schema_lib.SchemaProtocol = 'json',
protocol: schema_lib.SchemaProtocol = 'python',
**kwargs,
) -> Any:
"""Parse a natural langugage message based on schema.
Expand Down Expand Up @@ -271,7 +271,7 @@ class Flight(pg.Object):
examples: An optional list of fewshot examples for helping parsing. If None,
the default one-shot example will be added.
protocol: The protocol for schema/value representation. Applicable values
are 'json' and 'python'. By default the JSON representation will be used.
are 'json' and 'python'. By default 'python' will be used.`
**kwargs: Keyword arguments passed to the `lf.structured.ParseStructure`
transform, e.g. `lm` for specifying the language model.
Expand Down Expand Up @@ -309,7 +309,7 @@ def as_structured(
default: Any = lf.message_transform.RAISE_IF_HAS_ERROR,
examples: list[mapping.Mapping] | None = None,
*,
protocol: schema_lib.SchemaProtocol = 'json',
protocol: schema_lib.SchemaProtocol = 'python',
**kwargs,
):
"""Returns the structured representation of the message text.
Expand All @@ -323,7 +323,7 @@ def as_structured(
examples: An optional list of fewshot examples for helping parsing. If None,
the default one-shot example will be added.
protocol: The protocol for schema/value representation. Applicable values
are 'json' and 'python'. By default the JSON representation will be used.
are 'json' and 'python'. By default 'python' will be used.
**kwargs: Additional keyword arguments that will be passed to
`lf.structured.NaturalLanguageToStructure`.
Expand Down Expand Up @@ -483,7 +483,7 @@ def query(
default: Any = lf.message_transform.RAISE_IF_HAS_ERROR,
*,
examples: list[mapping.MappingExample] | None = None,
protocol: schema_lib.SchemaProtocol = 'json',
protocol: schema_lib.SchemaProtocol = 'python',
**kwargs,
) -> Any:
"""Parse a natural langugage message based on schema.
Expand Down Expand Up @@ -525,7 +525,7 @@ class Flight(pg.Object):
examples: An optional list of fewshot examples for helping parsing. If None,
the default one-shot example will be added.
protocol: The protocol for schema/value representation. Applicable values
are 'json' and 'python'. By default the JSON representation will be used.
are 'json' and 'python'. By default `python` will be used.
**kwargs: Keyword arguments passed to the
`lf.structured.NaturalLanguageToStructureed` transform, e.g. `lm` for
specifying the language model for structured parsing.
Expand Down
26 changes: 14 additions & 12 deletions langfun/core/structured/nl2structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,12 @@ def test_bad_transform(self):
def test_parse(self):
lm = fake.StaticSequence(['1'])
self.assertEqual(
nl2structure.parse('the answer is 1', int, lm=lm, protocol='python'),
nl2structure.parse('the answer is 1', int, lm=lm),
1
)
self.assertEqual(
nl2structure.parse(
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
protocol='python',
),
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm),
1,
)

Expand Down Expand Up @@ -506,7 +504,7 @@ def test_transform(self):

def test_bad_transform(self):
with lf.context(
lm=fake.StaticSequence(['three', '3']),
lm=fake.StaticSequence(['three', '`3`']),
override_attrs=True,
):
with self.assertRaisesRegex(
Expand All @@ -517,10 +515,14 @@ def test_bad_transform(self):

def test_parse(self):
lm = fake.StaticSequence(['{"result": 1}'])
self.assertEqual(nl2structure.parse('the answer is 1', int, lm=lm), 1)
self.assertEqual(
nl2structure.parse('the answer is 1', int, lm=lm, protocol='json'),
1
)
self.assertEqual(
nl2structure.parse(
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
protocol='json',
),
1,
)
Expand Down Expand Up @@ -673,12 +675,11 @@ def test_bad_transform(self):
mapping.MappingError,
'Cannot parse message text into structured output',
):
nl2structure.query('Compute 1 + 2', int, protocol='python')
nl2structure.query('Compute 1 + 2', int)

def test_query(self):
lm = fake.StaticSequence(['1'])
self.assertEqual(
nl2structure.query('what is 1 + 0', int, lm=lm, protocol='python'), 1)
self.assertEqual(nl2structure.query('what is 1 + 0', int, lm=lm), 1)


class QueryStructureJsonTest(unittest.TestCase):
Expand Down Expand Up @@ -870,11 +871,12 @@ def test_bad_transform(self):
mapping.MappingError,
'Cannot parse message text into structured output',
):
nl2structure.query('Compute 1 + 2', int)
nl2structure.query('Compute 1 + 2', int, protocol='json')

def test_query(self):
lm = fake.StaticSequence(['{"result": 1}'])
self.assertEqual(nl2structure.query('what is 1 + 0', int, lm=lm), 1)
self.assertEqual(
nl2structure.query('what is 1 + 0', int, lm=lm, protocol='json'), 1)


if __name__ == '__main__':
Expand Down

0 comments on commit 1ddea46

Please sign in to comment.