From e64cdf8f75add16d5d5765ee059f204db7fc02a6 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 May 2024 20:22:33 +0000 Subject: [PATCH 1/7] add args to prompt --- outlines/prompts.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/outlines/prompts.py b/outlines/prompts.py index b4e7288bb..375aa32ec 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -207,6 +207,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: env.filters["source"] = get_fn_source env.filters["signature"] = get_fn_signature env.filters["schema"] = get_schema + env.filters["args"] = get_fn_args jinja_template = env.from_string(cleaned_template) @@ -226,6 +227,23 @@ def get_fn_name(fn: Callable): return name +def get_fn_args(fn: Callable): + """Returns the arguments of a function with annotations and default values if provided.""" + if not callable(fn): + raise TypeError("The `args` filter only applies to callables.") + + arg_str_list = [] + signature = inspect.signature(fn) + for name, param in signature.parameters.items(): + annotation_str = f":{param.annotation.__name__}" if param.annotation != inspect.Parameter.empty else "" + default_str = f"={param.default}" if param.default != inspect.Parameter.empty else "" + arg_str = f"{name}{annotation_str}{default_str}" + arg_str_list.append(arg_str) + + arg_str = ', '.join(arg_str_list) + return arg_str + + def get_fn_description(fn: Callable): """Returns the first line of a callable's docstring.""" if not callable(fn): From 6c56c681fbbde8514a690fd07a33b46024aa3d30 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 May 2024 20:32:20 +0000 Subject: [PATCH 2/7] init tests --- tests/test_prompts.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 65eeb2022..efd7fd07b 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List import pytest from pydantic import BaseModel, Field @@ -252,3 +252,29 @@ def source_ppt(model): prompt = source_ppt(response) assert prompt == '{\n "one": "a description",\n "two": ""\n}' + + +def test_prompt_args(): + + def no_args(): + pass + + def with_args(x, y, z): + pass + + def with_annotations(x:bool, y:str, z: Dict[int, List[str]]): + pass + + def with_defaults(x=True, y="Hi", z={4:["I", "love", "outlines"]}): + pass + + def with_annotations_and_defaults(x:bool=True, y:str="Hi", z:Dict[int, List[str]]={4:["I", "love", "outlines"]}): + pass + + def with_all( + x1, y1, z1, + x2:bool, y2:str, z2: Dict[int, List[str]], + x3=True, y3="Hi", z3={4:["I", "love", "outlines"]}, + x4:bool=True, y4:str="Hi", z4:Dict[int, List[str]]={4:["I", "love", "outlines"]}, + ): + pass From d2e27d84a2522b955a9e7ac8b7b4df9aa293505c Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 May 2024 20:38:41 +0000 Subject: [PATCH 3/7] simplify --- outlines/prompts.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/outlines/prompts.py b/outlines/prompts.py index 375aa32ec..249b2e4df 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -234,12 +234,7 @@ def get_fn_args(fn: Callable): arg_str_list = [] signature = inspect.signature(fn) - for name, param in signature.parameters.items(): - annotation_str = f":{param.annotation.__name__}" if param.annotation != inspect.Parameter.empty else "" - default_str = f"={param.default}" if param.default != inspect.Parameter.empty else "" - arg_str = f"{name}{annotation_str}{default_str}" - arg_str_list.append(arg_str) - + arg_str_list = [str(param) for param in signature.parameters.values()] arg_str = ', '.join(arg_str_list) return arg_str From 2e4571ccf28aef2b3b960d22ad82e17300579142 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 May 2024 20:41:14 +0000 Subject: [PATCH 4/7] update tests --- tests/test_prompts.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_prompts.py b/tests/test_prompts.py index efd7fd07b..4be386e4f 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -278,3 +278,16 @@ def with_all( x4:bool=True, y4:str="Hi", z4:Dict[int, List[str]]={4:["I", "love", "outlines"]}, ): pass + + @outlines.prompt + def args_prompt(fn): + """args: {{ fn | args }}""" + + + args_prompt(no_args) + args_prompt(with_args) + args_prompt(with_annotations) + args_prompt(with_defaults) + args_prompt(with_annotations_and_defaults) + args_prompt(with_all) + From c7ce09cfd33126c60187005867da4f41dc8f55d9 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 May 2024 20:43:54 +0000 Subject: [PATCH 5/7] add asserts --- tests/test_prompts.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 4be386e4f..62e41b707 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -283,11 +283,9 @@ def with_all( def args_prompt(fn): """args: {{ fn | args }}""" - - args_prompt(no_args) - args_prompt(with_args) - args_prompt(with_annotations) - args_prompt(with_defaults) - args_prompt(with_annotations_and_defaults) - args_prompt(with_all) - + assert args_prompt(no_args) == "args: " + assert args_prompt(with_args) == "args: x, y, z" + assert args_prompt(with_annotations) == "args: x: bool, y: str, z: Dict[int, List[str]]" + assert args_prompt(with_defaults) == "args: x=True, y='Hi', z={4: ['I', 'love', 'outlines']}" + assert args_prompt(with_annotations_and_defaults) == "args: x: bool = True, y: str = 'Hi', z: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" + assert args_prompt(with_all) == "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" From 4aecc1188e77346fb9abc005b7c41154e933dd4d Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 May 2024 20:54:16 +0000 Subject: [PATCH 6/7] fix formatting --- outlines/prompts.py | 2 +- tests/test_prompts.py | 49 +++++++++++++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/outlines/prompts.py b/outlines/prompts.py index 249b2e4df..01e900c96 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -235,7 +235,7 @@ def get_fn_args(fn: Callable): arg_str_list = [] signature = inspect.signature(fn) arg_str_list = [str(param) for param in signature.parameters.values()] - arg_str = ', '.join(arg_str_list) + arg_str = ", ".join(arg_str_list) return arg_str diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 62e41b707..a0433c0e5 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -255,28 +255,39 @@ def source_ppt(model): def test_prompt_args(): - def no_args(): pass def with_args(x, y, z): pass - def with_annotations(x:bool, y:str, z: Dict[int, List[str]]): + def with_annotations(x: bool, y: str, z: Dict[int, List[str]]): pass - def with_defaults(x=True, y="Hi", z={4:["I", "love", "outlines"]}): + def with_defaults(x=True, y="Hi", z={4: ["I", "love", "outlines"]}): pass - def with_annotations_and_defaults(x:bool=True, y:str="Hi", z:Dict[int, List[str]]={4:["I", "love", "outlines"]}): + def with_annotations_and_defaults( + x: bool = True, + y: str = "Hi", + z: Dict[int, List[str]] = {4: ["I", "love", "outlines"]}, + ): pass def with_all( - x1, y1, z1, - x2:bool, y2:str, z2: Dict[int, List[str]], - x3=True, y3="Hi", z3={4:["I", "love", "outlines"]}, - x4:bool=True, y4:str="Hi", z4:Dict[int, List[str]]={4:["I", "love", "outlines"]}, - ): + x1, + y1, + z1, + x2: bool, + y2: str, + z2: Dict[int, List[str]], + x3=True, + y3="Hi", + z3={4: ["I", "love", "outlines"]}, + x4: bool = True, + y4: str = "Hi", + z4: Dict[int, List[str]] = {4: ["I", "love", "outlines"]}, + ): pass @outlines.prompt @@ -285,7 +296,19 @@ def args_prompt(fn): assert args_prompt(no_args) == "args: " assert args_prompt(with_args) == "args: x, y, z" - assert args_prompt(with_annotations) == "args: x: bool, y: str, z: Dict[int, List[str]]" - assert args_prompt(with_defaults) == "args: x=True, y='Hi', z={4: ['I', 'love', 'outlines']}" - assert args_prompt(with_annotations_and_defaults) == "args: x: bool = True, y: str = 'Hi', z: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" - assert args_prompt(with_all) == "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" + assert ( + args_prompt(with_annotations) + == "args: x: bool, y: str, z: Dict[int, List[str]]" + ) + assert ( + args_prompt(with_defaults) + == "args: x=True, y='Hi', z={4: ['I', 'love', 'outlines']}" + ) + assert ( + args_prompt(with_annotations_and_defaults) + == "args: x: bool = True, y: str = 'Hi', z: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" + ) + assert ( + args_prompt(with_all) + == "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" + ) From b2a51b36ac43d9a2c625893406f043c040230e14 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 17 May 2024 21:02:31 +0000 Subject: [PATCH 7/7] update docs --- docs/reference/prompting.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/prompting.md b/docs/reference/prompting.md index a7731ba0f..34860fce0 100644 --- a/docs/reference/prompting.md +++ b/docs/reference/prompting.md @@ -223,7 +223,7 @@ Several projects (e.g.[Toolformer](https://arxiv.org/abs/2302.04761), [ViperGPT] Can you do something? COMMANDS - 1. my_tool: Tool description, args: arg1:str, arg2:int + 1. my_tool: Tool description., args: arg1: str, arg2: int def my_tool(arg1: str, arg2: int): """Tool description.