Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LanguageModel Op and integration with HuggingFace's GPT2 implementation #30

Merged
merged 7 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Set up test environment
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install .[test]
- name: Run tests
run: |
pytest
1 change: 1 addition & 0 deletions outlines/text/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import LanguageModel
76 changes: 76 additions & 0 deletions outlines/text/models/hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import random
from typing import Dict

from outlines.text.models.model import LanguageModel

try:
import jax
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
except ImportError:
raise ImportError(
"You need to install `transformers` and `flax` to run the GTP2 model."
)


class GPT2(LanguageModel):
def __init__(self):
"""Initialize the GPT2 model.

We use HuggingFace's Flax implementation of GPT2. This method will download
the model's weights if they are not yet cached on your machine.

# TODO: Download the pre-trained weight when the model is executed instead of
# when the graph is built.

"""

random.seed()
self.seed = random.randint(0, 2**32)
super().__init__()

def sample(self, prompt_tokens: Dict[str, jax.Array]) -> jax.Array:
"""Sample new tokens give the tokenized prompt.

Since HuggingFace's `generate` method returns the prompt along with the
generated token we need to truncate the returned array of tokens.

Parameters
----------
prompt_tokens
A dictionary that contains the ids of the tokens contained in the input
prompt and the input mask. This is the default output of HuggingFace's
tokenizers.


"""
self.model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
returned_tokens = self.model.generate(
**prompt_tokens,
do_sample=True,
max_new_tokens=100,
prng_key=jax.random.PRNGKey(self.seed),
pad_token_id=self.tokenizer.eos_token_id,
).sequences
new_tokens = returned_tokens[:, prompt_tokens["input_ids"].shape[1] + 1 :]
new_tokens = new_tokens.squeeze()

return new_tokens

def encode(self, sequence: str) -> Dict[str, jax.Array]:
"""Return a list of token ids from a text sequence.

Parameters
----------
sequence
The text sequence to tokenize.

Returns
-------
A dictionary that contains the token ids and the input mask.
"""
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
return self.tokenizer(sequence, return_tensors="jax")

def decode(self, ids: jax.Array) -> str:
"""Return a text sequence from a array of token ids."""
return self.tokenizer.decode(ids, skip_special_tokens=True)
77 changes: 77 additions & 0 deletions outlines/text/models/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from outlines.graph import Apply, Op
from outlines.text.var import StringVariable, as_string


class LanguageModel(Op):
"""An `Op` that produces a sample from a language model.

The output of language models in outlines is modeled as a random variable.
Therefore, calling a language model will return a random sequence (via
ancestral sampling) by default. Other decoding methods are constructed
as graph transformations.

"""

def __init__(self, name=None):
super().__init__()
self.name = name

def make_node(self, prompt):
prompt = as_string(prompt)
out = StringVariable()
if self.name is not None:
out.name = self.name

return Apply(self, [prompt], [out])

def perform(self, prompt):
tokens = self.encode(prompt)
sampled_tokens = self.sample(tokens)
outputs = self.decode(sampled_tokens)
return (outputs,)

def sample(self, tokens):
raise NotImplementedError

def logprob(self, prompt, context):
"""Return the log-probability of each token in the vocabulary given the
input prompt and the current context (previously generated tokens).

# TODO: Implement `logprob` as a graph transformation?

Parameters
----------
prompt
The input to the language model, parameter of the distribution.
context
A sequence that contains the previously generated tokens that
are part of the context window. This sequence can be shorter
than the total sequence generated so far if the context length
has been reached.

Returns
-------
A sequence that represents the log-probability distribution over the
tokens.

"""
raise NotImplementedError

def encode(self, sequence: str):
"""Encode the given sequence.

Defaults to a pass-through so it does not have to be implemented by
subclasses that represent an integration to an API that take text as an
input.

"""
return sequence

def decode(self, ids) -> str:
"""Decode a list of ids to a string.

Defaults to a pass-through so it does not have to be implemented by
subclasses that represent an integration to an API that returns text.

"""
return ids
46 changes: 40 additions & 6 deletions outlines/text/script.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import textwrap
from functools import singledispatchmethod
from typing import Dict, Union

from mako import lexer
from mako.parsetree import Expression, Text

from outlines.graph import Op
from outlines.text.models import LanguageModel
from outlines.text.var import StringVariable, as_string


Expand All @@ -18,28 +20,60 @@ class Script:
"""

def __init__(self, script):
script = textwrap.dedent(script).lstrip().rstrip()
self.parsetree = lexer.Lexer(script).parse()
self.model_outputs = {}

def __call__(self, **inputs: Dict[str, Union[StringVariable, Op]]):
"""Create an Outlines graph from a Mako template.

When one calls a `Script` instance with arguments that represent
variables in the template, Outlines parses the template and iteratively
builds the graph it represents before returning it.

"""
nodes = self.parsetree.nodes
graph = self.parse_node(nodes[0], inputs)
graph = self.parse_node(nodes[0], inputs, "")
for node in self.parsetree.nodes[1:]:
graph = graph + self.parse_node(node, inputs)
graph = graph + self.parse_node(node, inputs, graph)

return graph

@singledispatchmethod
def parse_node(self, node, inputs):
def parse_node(self, node, inputs, graph):
raise NotImplementedError(f"Cannot transpile {node} to an Outlines graph.")

@parse_node.register(Text)
def parse_Text(self, node, inputs):
def parse_Text(self, node, inputs, graph):
"""Parse Mako's `Text` nodes.

`Text` nodes corresponds to `StringConstants` in Outline's language.

"""
return as_string(node.content)

@parse_node.register(Expression)
def parse_Expression(self, node, inputs):
def parse_Expression(self, node, inputs, graph):
"""Parse Mako's `Expression` nodes.

We first fetch the argument that the user passed to the `__call__`
method that corresponds to the current variable name. Then we check if
this argument has already been seen; if that's the case we assume the
user is referencing the output of a previously-run LM and add the
corresponding node.

"""
try:
return as_string(inputs[node.text])
user_input = inputs[node.text]
if isinstance(user_input, LanguageModel):
try:
return self.model_outputs[node.text]
except KeyError:
output = user_input(graph)
self.model_outputs[node.text] = output
return output
else:
return as_string(inputs[node.text])
except KeyError:
raise TypeError(
f"Prompt evaluation missing 1 required argument: '{node.text}'"
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,11 @@ dependencies = [
]
dynamic = ["version"]

[project.optional-dependencies]
test = [
"pre-commit",
"pytest"
]

[tool.setuptools_scm]
write_to = "outlines/_version.py"
3 changes: 0 additions & 3 deletions requirements.txt

This file was deleted.

19 changes: 19 additions & 0 deletions tests/test_compile.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from outlines import compile, script, string


Expand Down Expand Up @@ -30,3 +32,20 @@ def test_compile_scripts():
o = script("This is a ${var}")(var=s)
out = compile([s], [o])
assert out("test") == "This is a test"


@pytest.mark.skip
def test_compile_hf():
"""Move when we have found a better way to run these slow examples."""
import outlines
import outlines.text.models.hugging_face

gpt2 = outlines.text.models.hugging_face.GPT2()
o = script(
"""
Here is a good joke: ${joke}
And a random fact: ${fact}
"""
)(joke=gpt2, fact=gpt2)
fn = compile([], [o])
print(fn())
22 changes: 22 additions & 0 deletions tests/text/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from outlines import string
from outlines.text.models.model import LanguageModel


def test_initialize_model():
llm = LanguageModel(name="llm")

prompt = string()
out = llm(prompt)
assert isinstance(out.owner.op, LanguageModel)
assert out.owner.inputs[0] == prompt
assert out.name == "llm"


class MockLM(LanguageModel):
def sample(self, _):
return "test"


def test_sample():
llm = MockLM()
assert llm.perform("")[0] == "test"
29 changes: 29 additions & 0 deletions tests/text/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from outlines.text import script, string
from outlines.text.basic import Add
from outlines.text.models import LanguageModel
from outlines.text.var import StringConstant, StringVariable


Expand Down Expand Up @@ -47,3 +48,31 @@ def test_template_string_variable():
assert isinstance(t.owner.inputs[0], StringVariable)
assert isinstance(t.owner.inputs[1], StringConstant)
assert t.owner.inputs[1].value == " test"


class MockLanguageModel(LanguageModel):
def sample(self, prompt):
return f"2x{prompt}"


def test_template_language_model():
r"""Test the transpilation of scripts that contain one or
several `LanguageModel`\s.
"""

# Single occurence
lm = MockLanguageModel()
t = script("Test ${lm}")(lm=lm)
assert isinstance(t.owner.op, Add)
assert isinstance(t.owner.inputs[1].owner.op, LanguageModel)

lm_input = t.owner.inputs[1].owner.inputs[0].value
assert lm_input == "Test "

# The first reference to the lamguage model should
# execute decoding, the following ones be replaced
# by the result of this evaluation.
lm = MockLanguageModel(name="lm")
t = script("Test ${lm} more text ${lm}")(lm=lm)
assert isinstance(t.owner.inputs[1].owner.op, MockLanguageModel)
assert t.owner.inputs[1].owner.inputs[0].value == "Test "