Skip to content

Commit

Permalink
Pass is_in and stop_at to the model call directly
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed May 2, 2023
1 parent bc832c5 commit 9e7bd5d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 64 deletions.
33 changes: 13 additions & 20 deletions examples/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,33 +47,26 @@ def search_wikipedia(query: str):
return ".".join(list(page.values())[0]["extract"].split(".")[:2])


mode_model = models.text_completion.openai(
"gpt-3.5-turbo", is_in=["Tho", "Act"], max_tokens=2
)
action_model = models.text_completion.openai(
"text-davinci-003", is_in=["Search", "Finish"], max_tokens=2
)
thought_model = models.text_completion.openai(
"text-davinci-003", stop_at=["\n"], max_tokens=128
)
subject_model = models.text_completion.openai(
"text-davinci-003", stop_at=["'"], max_tokens=128
)

prompt = build_reAct_prompt("Where is Apple Computers headquarted? ")
complete = models.text_completion.openai(
"gpt-3.5-turbo", max_tokens=128, temperature=1.0
)

for i in range(1, 10):
mode = mode_model(prompt)
mode = complete(prompt, is_in=["Tho", "Act"])
prompt = add_mode(i, mode, "", prompt)

if mode == "Tho":
prompt = add_mode(i, mode, "", prompt)
thought = thought_model(prompt)
thought = complete(prompt, stop_at="\n")
prompt += f"{thought}"
if mode == "Act":
prompt = add_mode(i, mode, "", prompt)
action = action_model(prompt)
elif mode == "Act":
action = complete(prompt, is_in=["Search", "Finish"])
prompt += f"{action} '"
subject = " ".join(subject_model(prompt).split()[:2])

subject = complete(prompt, stop_at=["'"]) # Apple Computers headquartered
subject = " ".join(subject.split()[:2])
prompt += f"{subject}'"

if action == "Search":
result = search_wikipedia(subject)
prompt = add_mode(i, "Obs", result, prompt)
Expand Down
83 changes: 43 additions & 40 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@

def OpenAITextCompletion(
model_name: str,
stop_at: Optional[List[str]] = None,
is_in: Optional[List[str]] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = 216,
temperature: Optional[float] = 1.0,
) -> Callable:
"""Create a function that will call the OpenAI conmpletion API.
Expand All @@ -37,10 +35,6 @@ def OpenAITextCompletion(
----------
model_name: str
The name of the model as listed in the OpenAI documentation.
stop_at
A list of tokens which, when found, stop the generation.
is_in
A list of strings among which the results will be chosen.
max_tokens
The maximum number of tokens to generate.
temperature
Expand All @@ -53,8 +47,6 @@ def OpenAITextCompletion(
"""

parameters = validate_completion_parameters(stop_at, is_in, max_tokens, temperature)

@error_handler
@memory.cache
def call_completion_api(
Expand All @@ -78,11 +70,24 @@ def call_completion_api(

return response

def generate(prompt: str) -> str:
response = call_completion_api(model_name, prompt, **parameters)
def generate(prompt: str, *, stop_at=None, is_in=None):
if stop_at is not None:
stop_at = tuple(stop_at)

if is_in is not None and stop_at is not None:
raise TypeError("You cannot set `is_in` and `stop_at` at the same time.")
elif is_in is not None:
return generate_choice(prompt, is_in)
else:
return generate_base(prompt, stop_at)

def generate_base(prompt: str, stop_at: Optional[Tuple[str]]) -> str:
response = call_completion_api(
model_name, prompt, stop_at, {}, max_tokens, temperature
)
return response["choices"][0]["text"]

def generate_choice(prompt: str) -> str:
def generate_choice(prompt: str, is_in: List[str]) -> str:
"""Generate a a sequence that must be one of many options.
We tokenize every choice, iterate over the token lists, create a mask
Expand All @@ -109,28 +114,21 @@ def generate_choice(prompt: str) -> str:
if len(mask) == 0:
break

parameters["logit_bias"] = mask
parameters["max_tokens"] = 1
response = call_completion_api(model_name, prompt, **parameters)
response = call_completion_api(
model_name, prompt, None, mask, 1, temperature
)
decoded.append(response["choices"][0]["text"])
prompt = prompt + "".join(decoded)

return "".join(decoded)

if is_in is not None:
return generate_choice
else:
return generate

return generate


def OpenAIChatCompletion(
model_name: str,
stop_at: Optional[List[str]] = None,
is_in: Optional[List[str]] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = 128,
temperature: Optional[float] = 1.0,
) -> Callable:
"""Create a function that will call the chat completion OpenAI API.
Expand All @@ -141,10 +139,6 @@ def OpenAIChatCompletion(
----------
model_name: str
The name of the model as listed in the OpenAI documentation.
stop_at
A list of tokens which, when found, stop the generation.
is_in
A list of strings among which the results will be chosen.
max_tokens
The maximum number of tokens to generate.
temperature
Expand All @@ -156,7 +150,6 @@ def OpenAIChatCompletion(
parameters when passed a prompt.
"""
parameters = validate_completion_parameters(stop_at, is_in, max_tokens, temperature)

@error_handler
@memory.cache
Expand All @@ -181,13 +174,26 @@ def call_chat_completion_api(

return response

def generate(query: str) -> str:
def generate(prompt: str, *, stop_at=None, is_in=None):
if stop_at is not None:
stop_at = tuple(stop_at)

if is_in is not None and stop_at is not None:
raise TypeError("You cannot set `is_in` and `stop_at` at the same time.")
elif is_in is not None:
return generate_choice(prompt, is_in)
else:
return generate_base(prompt, stop_at)

def generate_base(query: str, stop_at: Optional[Tuple[str]]) -> str:
messages = [{"role": "user", "content": query}]
response = call_chat_completion_api(model_name, messages, *parameters)
response = call_chat_completion_api(
model_name, messages, stop_at, {}, max_tokens, temperature
)
answer = response["choices"][0]["message"]["content"]
return answer

def generate_choice(prompt: str) -> str:
def generate_choice(prompt: str, is_in=List[str]) -> str:
"""Generate a a sequence that must be one of many options.
We tokenize every choice, iterate over the token lists, create a mask
Expand All @@ -214,19 +220,16 @@ def generate_choice(prompt: str) -> str:
if len(mask) == 0:
break

parameters["logit_bias"] = mask
parameters["max_tokens"] = 1
messages = [{"role": "user", "content": prompt}]
response = call_chat_completion_api(model_name, messages, **parameters)
response = call_chat_completion_api(
model_name, messages, None, mask, 1, temperature
)
decoded.append(response["choices"][0]["message"]["content"])
prompt = prompt + "".join(decoded)

return "".join(decoded)

if is_in is not None:
return generate_choice
else:
return generate
return generate


def validate_completion_parameters(
Expand Down
6 changes: 2 additions & 4 deletions outlines/text/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def completion(
except KeyError:
raise ValueError(f"The model provider {provider_name} is not available.")

llm = model_cls(
model_name, stop_at=stop_at, max_tokens=max_tokens, temperature=temperature
)
llm = model_cls(model_name, max_tokens=max_tokens, temperature=temperature)

def decorator(fn: Callable):
prompt_fn = text.prompt(fn)
Expand All @@ -98,7 +96,7 @@ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Tuple[str, str]:
"""
prompt = prompt_fn(*args, **kwargs)
result = llm(prompt)
result = llm(prompt, stop_at=stop_at)
return result, prompt + result

return wrapper
Expand Down

0 comments on commit 9e7bd5d

Please sign in to comment.