Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Transformers model and Completion sequence generation #139

Merged
merged 5 commits into from
Jun 29, 2023

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Jun 9, 2023

Some notes on sequence generation

If we note $\boldsymbol{t} = (t_1, \dots, t_n)$ a sequence of $n$ tokens that represent a string, then we can define the next token $t_{n+1}$ in the sequence as the following random variable:

$$ \begin{align*} \boldsymbol{\alpha} &= \mathrm{LM}(\boldsymbol{t}, \boldsymbol{\theta})\\ t_{n+1} &\sim \mathrm{Categorical}({\boldsymbol{\alpha}}) \end{align*} $$

where $\boldsymbol{\theta}$ is the set of trained parameters. This random variable's support is the entire token vocabulary $\mathcal{V}$ (~ $10^3$ to $10^4$ tokens). In the following, the $\mathrm{LM}$ function will typically refer to a deep neural network trained on next-token-completion tasks, but it does not need to be and can be as simple as a function that returns a constant vector with equal values.

This random variable is the basic building block that allows us to sample text sequences. Representing these random variables explicitly is thus the first step in an effort to refactor the generation process in Outlines to make it more flexible, to allows a larger variety of sequence-generating processes and ways to sample them.

Using the same model call we can define many other random variables by manipulating the output logits $\boldsymbol{\alpha}$, randomly or deterministically. A particularly interesting example is when one applies a boolean mask $m$ to restrict the support of the distribution like so:

$$ \begin{align*} \boldsymbol{\alpha} &= \mathrm{LM}(\boldsymbol{t}, \boldsymbol{\theta})\\ \tilde{\boldsymbol{\alpha}} &= m \odot \boldsymbol{\alpha}\\ t_{n+1} &\sim \mathrm{Categorical}({\tilde{\boldsymbol{\alpha}}}) \end{align*} $$

This mask can encode some a priori knowledge about what the support should be. For instance:

  • We want the random variable to represent a digit;
  • We want to exclude the <EOS> (end-of-sequence) token from the support;
  • We want the random variable to match the [a-zA-Z] regular expression.

We can summarize the above in the following notation:

$$ t_{n+1} \sim \mathrm{Token}_m(\boldsymbol{t}) $$

where $\mathrm{Token}_m$ is a random variable with support $\mathcal{V} \backslash \left\{ i | m_i = 0 \right\}$ and parametrized by $\boldsymbol{t}$.

In practice

Generating a text sequence using language models requires:

  • A language model that returns next-token logits given an input sequence;
  • A model that describes how a sequence is to be generated

Models

In this PR we introduce the Transformers object which is responsible of initializing and calling the model. This implicitly defines the interface that models have with the rest of the library (a parent class will be added). In particular:

  • Models accept a NumPy array of token ids and return a NumPy array that contains the logits. The NumPy interface simplifies the rest of the library, and we make the assumption that transferring these arrays back-and-forth between the CPU and the device is not the limiting factor in performance.
  • Caching is handled at the model level. We need to cache and persist input ids to logit calls (for repeating call of the same workflow during development), and the model's k-v cache for repeated calls with overlapping prompts.

The k-v cache will require some careful thought, and will likely have to be customized to the particular class of models for which it is implemented (transformers, llama.cpp, etc.). My initial though is to build a trie that we query each time the model is called; this cache is attached to the model instance.

Another difficulty arises when the workflow uses different local models, since we cannot hold all of the models' weights in memory. We may need a process that supervises the models so that when a model is called we know whether it is currently loaded in memory.

We don't necessarily need to solve both these problems now, and they can be turned into issues.

Sequences

Sequences are objects that represent a sequence-generation model; when called with a prompt or a list of prompt they generate the sequence(s). Here we implement the simplest possible sequence, completion until an EOS token is found. The proposed API is as follows:

import outlines.models as models
import outlines.text as text

model = models.transformers("gpt2")
completion = text.completion(model)("say something")
new_completion = text.completion(model)(completion)

When sampling generations we can simply return a (list of) string(s), but for more advanced generation mechanisms we will need to return a state object that contains both the completion and the corresponding probabilities. This also means that we will need to unpack/repack the state when calling python functions in the middle of a chain.

Local model and API calls

Models based on API calls are less flexible; for instance we cannot shape the proposals completely (when we can) and don't have access to the logits. They have a different interface and will thus need to inherit from a different base class.

This also means that the Sequence implementation needs to be different; to make this transparent to the user we implement completion function which dispatches to Completion when the logits are available, and to a custom implementation when it is not. It can possibly fail when some generation constraints cannot be applied when calling the API.

@rlouf rlouf force-pushed the add-token-rv branch 2 times, most recently from 9fa796f to 137c64a Compare June 9, 2023 11:49
@rlouf rlouf marked this pull request as ready for review June 9, 2023 14:20
@rlouf rlouf force-pushed the add-token-rv branch 4 times, most recently from 361c022 to f73a41e Compare June 15, 2023 09:07
@rlouf rlouf changed the title Add LLMs random variable interface Add Transformers model and Completion sequence generation Jun 15, 2023
@rlouf rlouf force-pushed the add-token-rv branch 2 times, most recently from 76879da to cae171b Compare June 15, 2023 09:45
@rlouf
Copy link
Member Author

rlouf commented Jun 15, 2023

Left to implement:

Models

  • Add cache to the __call__ method and persist to disk
  • Add cache for kv values
  • Abstract tokenizer calls to unify interface and add TransformersTokenizer class

Completion

  • Add Sequence parent class, which contains the step and __call__ implementations
  • Test vectorized_choice
  • Test Sequence's step with mock model and tokenizers
  • Test Sequence's __call__ with mock model and tokenizers
  • Add samples argument to Sequence's __call__ method
  • Implement max_tokens constraints in Sequence
  • Test is_finished from Completion in isolation
  • Update the attention mask

@thomasahle
Copy link

The k-v cache will require some careful thought, and will likely have to be customized to the particular class of models for which it is implemented (transformers, llama.cpp, etc.). My initial though is to build a trie that we query each time the model is called; this cache is attached to the model instance.

It might be easier to allow the user to handle the caching. That way the model doesn't have to save all the text it has ever been prompted with. E.g.

prompt = "Here's a bunch of examples"
_text, kv_cache = model.generate(prompt, tokens=0, return_kv=True)  # Don't generating any text, just return the kv-cache
output1 = model.generate("", context=kv_cache)  # Use kv-cache from previous generation
output2 = model.generate("", context=kv_cache)  # Use kv-cache from original generation

I think this is a bit similar to how Hugging Face handles kv-caches right now.

It would also allow using a kv_cache that has been "learned" by fine-tuning, rather than representing an actual prefix.

At some point we'd like to give the model access to a vector db of "external" key-value pairs. I wonder if you are interested in having such a feature in outlines as well.

@rlouf
Copy link
Member Author

rlouf commented Jun 27, 2023

I couldn't find documentation for this code. Is there any way to generate the docs?

It's not documented yet, API might still change.

I find a large difference in native hf generation vs this api generation. Is it because of the constant device-cpu tensor-NDArray conversions?

Could be, or I'm doing something wrong with the arrays. Given the size of token id arrays I didn't think it would make a huge difference, but I will need benchmark this.

One question before I start benchmarking: did you set max_tokens to the same value as the one in transformersgenerate?

Can we also provide a warning message if no. of tokens is not specified, this would not terminate the while True loop, if a stopping_criterion is also not mentioned.

The default stopping criterion is EOS, which I think GPT2 does not have? It may be a good idea to set a default value for max_tokens or a warning as you suggested.

@rlouf
Copy link
Member Author

rlouf commented Jun 27, 2023

I timed Outlines and transformers using the following code. I include encoding/decoding for a fair comparison:

def test_time_outlines():
    import time
    import outlines.models as models
    from outlines.text.sequences.continuation import continuation

    now = time.time()
    model = models.transformers("gpt2")
    sequence = continuation(model, max_tokens=100)("This is a prompt")
    print(f"Outlines: {time.time()-now:.2f}")

def test_time_hf():
    import time
    from transformers import AutoModelForCausalLM, AutoTokenizer

    now = time.time()
    model = AutoModelForCausalLM.from_pretrained("gpt2")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    inputs = tokenizer(["This is a prompt"], return_tensors="pt")
    out = model.generate(**inputs, do_sample=True, max_new_tokens=100)
    sequence = tokenizer.batch_decode(out)
    print(f"HuggingFace: {time.time()-now:.2f}")

On CPU it is a wash, however setting samples=10 and num_return_sequences=10 respectively yields an order of magnitude difference. The reason for that is that we currently don't cache the KV values. Indeed, the following functions take the same time to run on outlines and using generate:

def test_time_outlines():
    import time

    import outlines.models as models
    from outlines.text.sequences.continuation import continuation

    now = time.time()
    model = models.transformers("gpt2")
    sequence = continuation(model, max_tokens=100)("This is a prompt", samples=10)
    print(f"Outlines: {time.time()-now:.2f}")


def test_time_hf():
    import time

    from transformers import AutoModelForCausalLM, AutoTokenizer

    now = time.time()
    model = AutoModelForCausalLM.from_pretrained("gpt2")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    inputs = tokenizer(["This is a prompt"], return_tensors="pt")
    out = model.generate(**inputs, do_sample=True, max_new_tokens=100, num_return_sequences=10, use_cache=False)
    sequence = tokenizer.batch_decode(out)
    print(f"HuggingFace: {time.time()-now:.2f}")

Passing the KV cache around is thus very important. However I think managing the cache won't change the design introduced here dramatically, but it would add to an already big PR. #150 already tracks this issue. Therefore I suggest we address this issue separately, if anything so it can be solved in parallel of #151 - #156 which are also blocked by this PR.

@arunpatro Would you mind comparing the runtimes on GPU with use_cache=False in transformers?

Update: Using torch.no_grad I find that (when cache is disabled in generate) generation in outlines is ~7% faster on average.

@arunpatro
Copy link
Contributor

@arunpatro Would you mind comparing the runtimes on GPU with use_cache=False in transformers?

Yeah surely.

Update: Using torch.no_grad I find that (when cache is disabled in generate) generation in outlines is ~7% faster on average.

What is the intuition here? How can outlines be faster than native, considering we do a lot of work on top? Are we avoiding extra boilerplate code that HF has inside their generate function?

@rlouf
Copy link
Member Author

rlouf commented Jun 27, 2023

What is the intuition here? How can outlines be faster than native, considering we do a lot of work on top? Are we avoiding extra boilerplate code that HF has inside their generate function?

We do less things than native in completion, I don't use the generate method models but call their forward method to get the logits and then do the sampling myself. On top of avoiding the boilerplate like you mentioned there are several other benefits:

  1. To draw several samples, transformers duplicates the input as many times as there are samples and then does the forward pass with this batched input. outlines run the forward pass with the single input and then draws several samples using the logits;
  2. When doing inference in batch transformers keeps feeding finished sequences to the model. Outlines only feeds the unfinished sequences.

The reason I decouple the model call / sampling this way is to be able to support different model providers and sampling methods (like SMC) in the future.

@arunpatro
Copy link
Contributor

arunpatro commented Jun 27, 2023

I see, that makes sense.

The reason I decouple the model call / sampling this way is to be able to support different model providers and sampling methods (like SMC) in the future.

I like this design decision, of de-coupling the model and the generation process.

Question 1: Why does outlines code not load the model in GPU automatically? According to the code I see, if device=None, it should be auto inferred right? Instead of device=None, maybe we can keep the default as device='auto', which would imply the meaning correctly.

Question 2: Why are we moving back and forth torch and numpy arrays? I modified this branch to only use torch.tensors and it can be drop in replacement (except MPS devices because torch.searchsorted is not implemented for MPS)

I also ran your test cases and I can confirm that outlines is faster on cuda for gpt2 however I cannot say the same for any other model.

# model_name = "gpt2"
model_name = "togethercomputer/RedPajama-INCITE-Instruct-3B-v1"
prompt = "This is a prompt"

def test_time_outlines():
    import time

    import outlines.models as models
    from outlines.text.sequences.continuation import continuation

    now = time.time()
    model = models.transformers(model_name)
    sequence = continuation(model, max_tokens=100)(prompt, samples=10)
    print(f"Outlines: {time.time()-now:.2f}")


def test_time_hf():
    import time

    from transformers import AutoModelForCausalLM, AutoTokenizer

    now = time.time()
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    inputs = tokenizer([prompt], return_tensors="pt")
    out = model.generate(**inputs, do_sample=True, max_new_tokens=100, num_return_sequences=10, use_cache=False)
    sequence = tokenizer.batch_decode(out)
    print(f"HuggingFace: {time.time()-now:.2f}")
    
  
def test_time_outlines_cuda():
    import time

    import outlines.models as models
    from outlines.text.sequences.continuation import continuation

    now = time.time()
    model = models.transformers(model_name, 'cuda')
    sequence = continuation(model, max_tokens=100)(prompt, samples=10)
    print(f"Outlines CUDA: {time.time()-now:.2f}")

 
def test_time_hf_cuda():
    import time

    from transformers import AutoModelForCausalLM, AutoTokenizer

    now = time.time()
    model = AutoModelForCausalLM.from_pretrained(model_name).to('cuda')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    inputs = tokenizer([prompt], return_tensors="pt")['input_ids'].to('cuda')
    out = model.generate(inputs, do_sample=True, max_new_tokens=100, num_return_sequences=10, use_cache=False)
    sequence = tokenizer.batch_decode(out)
    print(f"HuggingFace CUDA: {time.time()-now:.2f}")

The times show that HF is faster than Outlines on both cuda and cpu (for use_cache=False)

HuggingFace CUDA: 31.83
Outlines CUDA: 47.30
HuggingFace: 88.90
Outlines: 230.86

Tested on NVIDIA A100-SXM4-40GB with cuda=11.8

@rlouf
Copy link
Member Author

rlouf commented Jun 27, 2023

Thanks! Are those average times over several runs? I'm asking because with any other model than GPT2 the sequence is likely to terminate before max_tokens is reached (when EOS is found).

A possibility for the CUDA case is that with a bigger vocabulary we're paying for the memory transfer, which can be avoided by delegating next-token-sampling to the model class as well. This way we're only transferring a few tokens instead of the full logits. The reason for keeping NumPy is to keep as much generality as we can if we want to include llamacpp, JAX or TF transformers models and do logits manipulation to change the next-token proposal distribution. We need to evaluate the extent to which this is a performance bottleneck.

I don't have an explanation for the CPU case...

PS: do you need an A100 for a 3B model?

@arunpatro
Copy link
Contributor

Thanks! Are those average times over several runs? I'm asking because with any other model than GPT2 the sequence is likely to terminate before max_tokens is reached (when EOS is found).

No, these are not average numbers, but I ran this multiple times, and its similar.

PS: do you need an A100 for a 3B model?

Not really, but I am experimenting with upto 13B models, and sometimes I need a lot of GPU-RAM cuz of batchsizes (vLLM requirements, etc)

@rlouf
Copy link
Member Author

rlouf commented Jun 28, 2023

Question 1: Why does outlines code not load the model in GPU automatically? According to the code I see, if device=None, it should be auto inferred right? Instead of device=None, maybe we can keep the default as device='auto', which would imply the meaning correctly.

Because HF models are loaded on cpu by default, and you need to explicitly move them to another device. I'm following their conventions as much as I can. Do you think we should do differently?

Question 2: Why are we moving back and forth torch and numpy arrays? I modified this branch to only use torch.tensors and it can be drop in replacement (except MPS devices because torch.searchsorted is not implemented for MPS)

I have thought more about this, and there's probably no good reason to not be using PyTorch as the default. This would allow to keep a strict separation between model calls and logit manipulation without moving memory around. Arrays output by llamacpp can be converted into torch.Tensor at virtually no cost, and apparently so can JAX and TF arrays (although that's slightly more complicated).

Since switching to PyTorch requires a little more exploration and doesn't change the overall structure of the code introduced here I suggest we merge this PR and track this in #164. Wdyt @brandonwillard?

@brandonwillard
Copy link
Contributor

Since switching to PyTorch requires a little more exploration and doesn't change the overall structure of the code introduced here I suggest we merge this PR and track this in #164. Wdyt @brandonwillard?

Yeah, let's do that in a quick follow-up.

@arunpatro
Copy link
Contributor

Because HF models are loaded on cpu by default, and you need to explicitly move them to another device. I'm following their conventions as much as I can. Do you think we should do differently?

No, this is good. We should stick to HF wherever we can for defaults. The correct and fastest way to load model to device, is to use config.device_map = 'auto' which HF also expects as **kwargs.

Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor questions and comments; otherwise, I'm good with moving forward on this.

outlines/models/tokenizer.py Show resolved Hide resolved
outlines/models/tokenizer.py Show resolved Hide resolved
outlines/text/generate/sequence.py Show resolved Hide resolved
outlines/text/generate/sequence.py Outdated Show resolved Hide resolved
outlines/text/generate/sequence.py Outdated Show resolved Hide resolved
outlines/text/generate/sequence.py Outdated Show resolved Hide resolved
@rlouf rlouf force-pushed the add-token-rv branch 2 times, most recently from b487a10 to 24ef978 Compare June 29, 2023 06:17
@rlouf rlouf merged commit 759ce89 into outlines-dev:main Jun 29, 2023
4 checks passed
@rlouf rlouf deleted the add-token-rv branch June 29, 2023 07:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants