Skip to content

Commit

Permalink
Add HuggingFace's GPT2 model
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Mar 26, 2023
1 parent c38e00a commit 8edb5b6
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 6 deletions.
68 changes: 68 additions & 0 deletions outlines/text/models/hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Dict

from outlines.text.models.model import LanguageModel

try:
import jax
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
except ImportError:
raise ImportError(
"You need to install `transformers` and `flax` to run the GTP2 model."
)


class GPT2(LanguageModel):
def __init__(self):
"""Initialize the GPT2 model.
We use HuggingFace's Flax implementation of GPT2. This method will download
the model's weights if they are not yet cached on your machine.
# TODO: Download the pre-trained weight when the model is executed instead of
# when the graph is built.
"""
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
super().__init__()

def sample(self, prompt_tokens: Dict[str, jax.Array]) -> jax.Array:
"""Sample new tokens give the tokenized prompt.
Since HuggingFace's `generate` method returns the prompt along with the
generated token we need to truncate the returned array of tokens.
Parameters
----------
prompt_tokens
A dictionary that contains the ids of the tokens contained in the input
prompt and the input mask. This is the default output of HuggingFace's
tokenizers.
"""
returned_tokens = self.model.generate(
**prompt_tokens, do_sample=True, max_new_tokens=10
).sequences
new_tokens = returned_tokens[:, prompt_tokens["input_ids"].shape[1] + 1 :]
new_tokens = new_tokens.squeeze()

return new_tokens

def encode(self, sequence: str) -> Dict[str, jax.Array]:
"""Return a list of token ids from a text sequence.
Parameters
----------
sequence
The text sequence to tokenize.
Returns
-------
A dictionary that contains the token ids and the input mask.
"""
return self.tokenizer(sequence, return_tensors="jax")

def decode(self, ids: jax.Array) -> str:
"""Return a text sequence from a array of token ids."""
return self.tokenizer.decode(ids, skip_special_tokens=True)
27 changes: 24 additions & 3 deletions outlines/text/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,42 @@


class LanguageModel(Op):
"""An `Op` that produces a sample from a language model.
The output of language models in outlines is modeled as a random variable.
Therefore, calling a language model will return a random sequence (via
ancestral sampling) by default. Other decoding methods are constructed
as graph transformations.
"""

def __init__(self, name=None):
super().__init__()
self.name = name

def make_node(self, prompt):
prompt = as_string(prompt)
out = StringVariable()
if self.name is not None:
out.name = self.name

return Apply(self, [prompt], [out])

def perform(self, prompt):
return self.sample(prompt)
tokens = self.encode(prompt)
sampled_tokens = self.sample(tokens)
outputs = self.decode(sampled_tokens)
return (outputs,)

def sample(self, prompt):
return (f"2x{prompt}",)
def sample(self, tokens):
raise NotImplementedError

def logprob(self, prompt, context):
"""Return the log-probability of each token in the vocabulary given the
input prompt and the current context (previously generated tokens).
# TODO: Implement `logprob` as a graph transformation?
Parameters
----------
prompt
Expand Down
3 changes: 0 additions & 3 deletions requirements.txt

This file was deleted.

19 changes: 19 additions & 0 deletions tests/test_compile.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from outlines import compile, script, string


Expand Down Expand Up @@ -30,3 +32,20 @@ def test_compile_scripts():
o = script("This is a ${var}")(var=s)
out = compile([s], [o])
assert out("test") == "This is a test"


@pytest.mark.skip
def test_compile_hf():
import outlines
import outlines.text.models.hugging_face

gpt2 = outlines.text.models.hugging_face.GPT2()
o = script(
"""
Here is a good joke: ${joke}
And a random fact: ${fact}
"""
)(joke=gpt2, fact=gpt2)

fn = compile([], [o])
print(fn())

0 comments on commit 8edb5b6

Please sign in to comment.