From 1ddea46f6791ebc24b4e73c28cbef89d6b6cde86 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Tue, 19 Sep 2023 18:27:39 -0700 Subject: [PATCH] Use 'python' as the default protocol for structured mapping. This will impact `as_structured`/`lf.query`/`lf.parse`. PiperOrigin-RevId: 566808741 --- langfun/core/structured/nl2structure.py | 12 ++++----- langfun/core/structured/nl2structure_test.py | 26 +++++++++++--------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/langfun/core/structured/nl2structure.py b/langfun/core/structured/nl2structure.py index 20e3257..0d1b2c5 100644 --- a/langfun/core/structured/nl2structure.py +++ b/langfun/core/structured/nl2structure.py @@ -223,7 +223,7 @@ def parse( *, user_prompt: str | None = None, examples: list[mapping.MappingExample] | None = None, - protocol: schema_lib.SchemaProtocol = 'json', + protocol: schema_lib.SchemaProtocol = 'python', **kwargs, ) -> Any: """Parse a natural langugage message based on schema. @@ -271,7 +271,7 @@ class Flight(pg.Object): examples: An optional list of fewshot examples for helping parsing. If None, the default one-shot example will be added. protocol: The protocol for schema/value representation. Applicable values - are 'json' and 'python'. By default the JSON representation will be used. + are 'json' and 'python'. By default 'python' will be used.` **kwargs: Keyword arguments passed to the `lf.structured.ParseStructure` transform, e.g. `lm` for specifying the language model. @@ -309,7 +309,7 @@ def as_structured( default: Any = lf.message_transform.RAISE_IF_HAS_ERROR, examples: list[mapping.Mapping] | None = None, *, - protocol: schema_lib.SchemaProtocol = 'json', + protocol: schema_lib.SchemaProtocol = 'python', **kwargs, ): """Returns the structured representation of the message text. @@ -323,7 +323,7 @@ def as_structured( examples: An optional list of fewshot examples for helping parsing. If None, the default one-shot example will be added. protocol: The protocol for schema/value representation. Applicable values - are 'json' and 'python'. By default the JSON representation will be used. + are 'json' and 'python'. By default 'python' will be used. **kwargs: Additional keyword arguments that will be passed to `lf.structured.NaturalLanguageToStructure`. @@ -483,7 +483,7 @@ def query( default: Any = lf.message_transform.RAISE_IF_HAS_ERROR, *, examples: list[mapping.MappingExample] | None = None, - protocol: schema_lib.SchemaProtocol = 'json', + protocol: schema_lib.SchemaProtocol = 'python', **kwargs, ) -> Any: """Parse a natural langugage message based on schema. @@ -525,7 +525,7 @@ class Flight(pg.Object): examples: An optional list of fewshot examples for helping parsing. If None, the default one-shot example will be added. protocol: The protocol for schema/value representation. Applicable values - are 'json' and 'python'. By default the JSON representation will be used. + are 'json' and 'python'. By default `python` will be used. **kwargs: Keyword arguments passed to the `lf.structured.NaturalLanguageToStructureed` transform, e.g. `lm` for specifying the language model for structured parsing. diff --git a/langfun/core/structured/nl2structure_test.py b/langfun/core/structured/nl2structure_test.py index 1f2aad2..8857fa1 100644 --- a/langfun/core/structured/nl2structure_test.py +++ b/langfun/core/structured/nl2structure_test.py @@ -249,14 +249,12 @@ def test_bad_transform(self): def test_parse(self): lm = fake.StaticSequence(['1']) self.assertEqual( - nl2structure.parse('the answer is 1', int, lm=lm, protocol='python'), + nl2structure.parse('the answer is 1', int, lm=lm), 1 ) self.assertEqual( nl2structure.parse( - 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm, - protocol='python', - ), + 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm), 1, ) @@ -506,7 +504,7 @@ def test_transform(self): def test_bad_transform(self): with lf.context( - lm=fake.StaticSequence(['three', '3']), + lm=fake.StaticSequence(['three', '`3`']), override_attrs=True, ): with self.assertRaisesRegex( @@ -517,10 +515,14 @@ def test_bad_transform(self): def test_parse(self): lm = fake.StaticSequence(['{"result": 1}']) - self.assertEqual(nl2structure.parse('the answer is 1', int, lm=lm), 1) + self.assertEqual( + nl2structure.parse('the answer is 1', int, lm=lm, protocol='json'), + 1 + ) self.assertEqual( nl2structure.parse( - 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm + 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm, + protocol='json', ), 1, ) @@ -673,12 +675,11 @@ def test_bad_transform(self): mapping.MappingError, 'Cannot parse message text into structured output', ): - nl2structure.query('Compute 1 + 2', int, protocol='python') + nl2structure.query('Compute 1 + 2', int) def test_query(self): lm = fake.StaticSequence(['1']) - self.assertEqual( - nl2structure.query('what is 1 + 0', int, lm=lm, protocol='python'), 1) + self.assertEqual(nl2structure.query('what is 1 + 0', int, lm=lm), 1) class QueryStructureJsonTest(unittest.TestCase): @@ -870,11 +871,12 @@ def test_bad_transform(self): mapping.MappingError, 'Cannot parse message text into structured output', ): - nl2structure.query('Compute 1 + 2', int) + nl2structure.query('Compute 1 + 2', int, protocol='json') def test_query(self): lm = fake.StaticSequence(['{"result": 1}']) - self.assertEqual(nl2structure.query('what is 1 + 0', int, lm=lm), 1) + self.assertEqual( + nl2structure.query('what is 1 + 0', int, lm=lm, protocol='json'), 1) if __name__ == '__main__':