Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use PyTorch instead of NumPy #167

Merged
merged 1 commit into from
Jul 6, 2023

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Jun 29, 2023

Closes #164. Closes #165.

@rlouf rlouf force-pushed the replace-numpy-with-pytorch branch from 687c7d4 to 6f253bc Compare June 29, 2023 15:56
@rlouf rlouf marked this pull request as ready for review June 29, 2023 15:57
@arunpatro
Copy link
Contributor

In general there are many places where tensors are not moved to the right device like is_finished. I couldn't fix all of them as it got messy. Perhaps you can test it on local if you have MPS device.

I just tested this branch vs main branch for CPU:

# model_name = "gpt2"
model_name = "togethercomputer/RedPajama-INCITE-Instruct-3B-v1"
prompt = "This is a prompt"

def test_time_outlines():
    import time

    import outlines.models as models
    from outlines.text import continuation

    now = time.time()
    model = models.transformers(model_name, 'cpu')
    sequence = continuation(model, max_tokens=100)(prompt, samples=10)
    print(f"Outlines: {time.time()-now:.2f}")


def test_time_hf():
    import time

    from transformers import AutoModelForCausalLM, AutoTokenizer

    now = time.time()
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    inputs = tokenizer([prompt], return_tensors="pt")
    out = model.generate(**inputs, do_sample=True, min_new_tokens=100, max_new_tokens=100, num_return_sequences=10, use_cache=False)
    print(out.shape)
    sequence = tokenizer.batch_decode(out)
    print(f"HuggingFace: {time.time()-now:.2f}")
   

HuggingFace: 100s
Outlines: 412s (this branch)
Outlines: 437s (main branch)

There is a slight advantage in only using torch

Tested on NVIDIA RTX A6000 with 14vCPUs

@brandonwillard
Copy link
Contributor

In general there are many places where tensors are not moved to the right device like is_finished. I couldn't fix all of them as it got messy. Perhaps you can test it on local if you have MPS device.

I just tested this branch vs main branch for CPU:

# model_name = "gpt2"
model_name = "togethercomputer/RedPajama-INCITE-Instruct-3B-v1"
prompt = "This is a prompt"

def test_time_outlines():
    import time

    import outlines.models as models
    from outlines.text import continuation

    now = time.time()
    model = models.transformers(model_name, 'cpu')
    sequence = continuation(model, max_tokens=100)(prompt, samples=10)
    print(f"Outlines: {time.time()-now:.2f}")


def test_time_hf():
    import time

    from transformers import AutoModelForCausalLM, AutoTokenizer

    now = time.time()
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    inputs = tokenizer([prompt], return_tensors="pt")
    out = model.generate(**inputs, do_sample=True, min_new_tokens=100, max_new_tokens=100, num_return_sequences=10, use_cache=False)
    print(out.shape)
    sequence = tokenizer.batch_decode(out)
    print(f"HuggingFace: {time.time()-now:.2f}")
   

HuggingFace: 100s Outlines: 412s (this branch) Outlines: 437s (main branch)

There is a slight advantage in only using torch

Tested on NVIDIA RTX A6000 with 14vCPUs

Can we get some high-level (line)profiling measures instead? That would be a much more direct approach.

@arunpatro
Copy link
Contributor

Can we get some high-level (line)profiling measures instead? That would be a much more direct approach.

What does line profiling mean? Do you mean to inspect inside the model.generate or continuation(model, max_tokens=100)(prompt, samples=10) to check which tensor operations are bottlenecked?

@brandonwillard
Copy link
Contributor

brandonwillard commented Jun 29, 2023

Can we get some high-level (line)profiling measures instead? That would be a much more direct approach.

What does line profiling mean? Do you mean to inspect inside the model.generate or continuation(model, max_tokens=100)(prompt, samples=10) to check which tensor operations are bottlenecked?

Something like the output of line_profiler (limited to outlines source, of course) to start with, for example.

The standard Python profiler might also be a good next step, since we really just need to get an idea of exactly where to focus.

I'm assuming that such profiler output will help tell us (implicitly) exactly where/when those missed/poorly managed transfers are happening, as well as anything else we might not have guessed (e.g. we're measuring one-time cost operations that aren't priorities).

@rlouf
Copy link
Member Author

rlouf commented Jun 30, 2023

Also if you could seed the generation (so we can reproduce it) and normalise by the length of the longest sequence it would be great.

@arunpatro
Copy link
Contributor

arunpatro commented Jun 30, 2023

import torch
import numpy
import random

SEED = 42

def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    numpy.random.seed(seed)
    random.seed(seed)
    
    
model_name = "gpt2"
# model_name = "togethercomputer/RedPajama-INCITE-Instruct-3B-v1"
prompt = "This is a prompt"

def test_time_outlines():
    set_seed(SEED)
    import time

    import outlines.models as models
    from outlines.text import continuation

    now = time.time()
    model = models.transformers(model_name, 'cpu')
    sequence = continuation(model, max_tokens=100)(prompt, samples=10)
    print(sequence[0])
    print(f"Outlines: {time.time()-now:.2f}")


def test_time_hf():
    set_seed(SEED)
    import time

    from transformers import AutoModelForCausalLM, AutoTokenizer

    now = time.time()
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    inputs = tokenizer([prompt], return_tensors="pt")
    out = model.generate(**inputs, do_sample=True, min_new_tokens=100, max_new_tokens=100, num_return_sequences=10, use_cache=False)
    sequence = tokenizer.batch_decode(out)
    print(sequence[0])
    print(f"HuggingFace: {time.time()-now:.2f}")

Seeding reproduces results for HF but not for Outlines. I suspect this is because outlines instantiates torch.Generator() while calling a function, that doesn't take the seed. A simple fix is to avoid using rng at all, and just do torch.rand(...).

All sequences also produce exactly num_tokens because of setting min_new_tokens=100, max_new_tokens=100. Normalizing makes sense if we only time the generation process.

@arunpatro
Copy link
Contributor

arunpatro commented Jun 30, 2023

Not fully sure how you want me to do the line_profiling, maybe you can share a MWE? Here is something:

Code:

from line_profiler import LineProfiler
from outlines.text.generate.sequence import Sequence
profiler = LineProfiler()
profiler.add_function(test_time_outlines)
profiler.add_function(Sequence.step)
profiler.add_function(Sequence.expand_attention_mask)
profiler.add_function(Sequence.__call__)
profiler.run('test_time_outlines()')
profiler.print_stats()
Output:
File: [/home/ubuntu/rloufout/outlines/text/generate/sequence.py](https://vscode-remote+ssh-002dremote-002blambda.vscode-resource.vscode-cdn.net/home/ubuntu/rloufout/outlines/text/generate/sequence.py)
Function: step at line 33

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    33                                               def step(
    34                                                   self,
    35                                                   rng: torch.Generator,
    36                                                   token_ids: torch.LongTensor,
    37                                                   attention_mask: torch.LongTensor,
    38                                                   samples: int = 1,
    39                                               ) -> Tuple[torch.LongTensor, torch.FloatTensor]:
    40                                                   """Generate one or several tokens that complete the input sequence.
    41                                           
    42                                                   The sampling step consists in using a model to generate next-token
    43                                                   logits and then sample `samples`-many new tokens from a categorical
    44                                                   distribution parametrized by these logits.
    45                                           
    46                                                   Parameters
    47                                                   ----------
    48                                                   rng
    49                                                       NumPy random number Generator instance
    50                                                   token_ids
    51                                                       The token ids passed as an input to the model, of shape `batch_shape
    52                                                       + (num_tokens,)`, where `num_tokens` is the sequences' length.
    53                                                   samples
    54                                                       The number of continuations to sample from the next-token probability
    55                                                       distribution.
    56                                           
    57                                                   Returns
    58                                                   -------
    59                                                   A tuple with an array of shape `new_batch_shape + (num_tokens+1,)`that
    60                                                   contains the completed sequences (input token ids and generated token
    61                                                   ids) and an array of shape `new_batch_shape + (vocab_size,)` that
    62                                                   contains the next token probabilities.
    63                                                   `new_batch_shape` is computed by removing dimensions of size one in
    64                                                   `(samples,) + batch_shape`.
    65                                           
    66                                                   """
    67       100     740350.0   7403.5      0.0          num_input_dims = token_ids.ndim
    68       100 24723072973.0 247230729.7     99.4          probs = self.model(token_ids, attention_mask)
    69                                           
    70                                                   # Sample `samples`-many new tokens
    71       100  136680682.0 1366806.8      0.5          next_token_ids = vectorized_random_choice(rng, probs, samples)
    72                                           
    73                                                   # Add the missing `num_tokens` and `num_sample` dimensions
    74       100    1160340.0  11603.4      0.0          next_token_ids = torch.unsqueeze(next_token_ids, -1)
    75       100     854830.0   8548.3      0.0          token_ids = torch.unsqueeze(token_ids, 0)
    76                                           
    77                                                   # Expand the input `token_ids` array to be able to concatenate several
    78                                                   # samples.
    79        99     666521.0   6732.5      0.0          if samples > 1:
    80         1       8030.0   8030.0      0.0              repetitions = (samples,) + (1,) * num_input_dims
    81         1     191150.0 191150.0      0.0              token_ids = torch.tile(token_ids, repetitions)
    82         1     756511.0 756511.0      0.0              probs = torch.tile(probs, repetitions)
    83                                           
    84       100    4581333.0  45813.3      0.0          token_ids = torch.concatenate([token_ids, next_token_ids], axis=-1)
    85                                           
    86                                                   # Merge sample and batch dimensions by removing dimensions of length
    87                                                   # 1. The shape of the resulting arrays is `new_batch_shape + (num_tokens,)`
    88                                                   # and `new_batch_shape + (vocab_size,)` respectively.
    89       100    2205500.0  22055.0      0.0          token_ids = torch.atleast_2d(token_ids.squeeze())
    90       100    1473930.0  14739.3      0.0          probs = torch.atleast_2d(probs.squeeze())
    91                                           
    92       100     654513.0   6545.1      0.0          return token_ids, probs

Total time: 0.0064726 s
File: [/home/ubuntu/rloufout/outlines/text/generate/sequence.py](https://vscode-remote+ssh-002dremote-002blambda.vscode-resource.vscode-cdn.net/home/ubuntu/rloufout/outlines/text/generate/sequence.py)
Function: expand_attention_mask at line 94

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    94                                               def expand_attention_mask(
    95                                                   self, attention_mask: torch.LongTensor
    96                                               ) -> torch.LongTensor:
    97                                                   """Expand the attention mask after the last completion."""
    98       100     919580.0   9195.8     14.2          batch_shape = attention_mask.shape[:-1]
    99       100    2206791.0  22067.9     34.1          attention_mask = torch.concatenate(
   100       100    2017451.0  20174.5     31.2              [attention_mask, torch.broadcast_to(torch.tensor([1]), batch_shape + (1,))],
   101       100     684470.0   6844.7     10.6              axis=-1,
   102                                                   )
   103       100     644310.0   6443.1     10.0          return attention_mask

Total time: 24.9459 s
File: [/home/ubuntu/rloufout/outlines/text/generate/sequence.py](https://vscode-remote+ssh-002dremote-002blambda.vscode-resource.vscode-cdn.net/home/ubuntu/rloufout/outlines/text/generate/sequence.py)
Function: __call__ at line 155

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   155                                               def __call__(
   156                                                   self,
   157                                                   prompt: Union[str, List[str]],
   158                                                   samples: int = 1,
   159                                                   rng: torch.Generator = torch.Generator(),
   160                                               ) -> Union[str, List[str]]:
   161                                                   """Generate a new sequence given a prompt.
   162                                           
   163                                                   Parameters
   164                                                   ----------
   165                                                   prompt
   166                                                       The input prompt.
   167                                                   samples
   168                                                       The number of samples to generate for each prompt.
   169                                           
   170                                                   Returns
   171                                                   -------
   172                                                   The full sequence that contains the prompts and the generated string.
   173                                           
   174                                                   """
   175         1     298490.0 298490.0      0.0          token_ids, attention_mask = self.model.tokenizer.encode(prompt)
   176         1       7620.0   7620.0      0.0          num_prompt_tokens = token_ids.shape[-1]
   177                                           
   178         1       6640.0   6640.0      0.0          if samples > 1:
   179         1  224162180.0 224162180.0      0.9              token_ids, _ = self.step(rng, token_ids, attention_mask, samples)
   180         1      74730.0  74730.0      0.0              is_finished = self.is_finished(token_ids)
   181                                           
   182         1       7130.0   7130.0      0.0              num_batch_dims = token_ids.ndim - 1
   183         1       6640.0   6640.0      0.0              repetitions = (samples,) + (1,) * num_batch_dims
   184         1      22210.0  22210.0      0.0              attention_mask = torch.tile(attention_mask, repetitions)
   185         1     137590.0 137590.0      0.0              attention_mask = self.expand_attention_mask(attention_mask)
   186                                                   else:
   187                                                       batch_shape = token_ids.shape[:-1]
   188                                                       is_finished = torch.zeros(batch_shape, dtype=torch.bool)
   189                                           
   190       100     729560.0   7295.6      0.0          while True:
   191       100     792760.0   7927.6      0.0              num_generated_tokens = token_ids.shape[-1] - num_prompt_tokens
   192        99    2308561.0  23318.8      0.0              if torch.all(is_finished) or num_generated_tokens == self.max_tokens:
   193         1       6640.0   6640.0      0.0                  break
   194                                           
   195        99    2936623.0  29662.9      0.0              token_ids_unfinished = token_ids[~is_finished]
   196        99    2217671.0  22400.7      0.0              attention_mask_unfinished = attention_mask[~is_finished]
   197        99 24667614933.0 249167827.6     98.9              token_ids_unfinished, _ = self.step(
   198        99     611311.0   6174.9      0.0                  rng, token_ids_unfinished, attention_mask_unfinished
   199                                                       )
   200                                           
   201        99   15343259.0 154982.4      0.1              token_ids = self.update_token_ids(
   202        99     689170.0   6961.3      0.0                  is_finished, token_ids, token_ids_unfinished
   203                                                       )
   204        99   19206879.0 194008.9      0.1              attention_mask = self.expand_attention_mask(attention_mask)
   205        99    7651544.0  77288.3      0.0              is_finished[~is_finished] = self.is_finished(token_ids_unfinished).flatten()
   206                                           
   207         1     995861.0 995861.0      0.0          result = self.model.tokenizer.decode(token_ids)
   208         1      22280.0  22280.0      0.0          result = self.postprocess_completions(result)
   209                                           
   210         1       7610.0   7610.0      0.0          if len(result) == 1:
   211                                                       return result[0]
   212                                           
   213         1       6560.0   6560.0      0.0          return result

Total time: 26.8777 s
File: [/tmp/ipykernel_81140/1788441446.py](https://vscode-remote+ssh-002dremote-002blambda.vscode-resource.vscode-cdn.net/tmp/ipykernel_81140/1788441446.py)
Function: test_time_outlines at line 20

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================

@rlouf
Copy link
Member Author

rlouf commented Jul 1, 2023

That's perfect, thanks! It looks like 99% of the time is spent in self.model (which is what we'd expect), so we need to figure out what's happening there and/or if the bottleneck is memory transfer. You can take a look at scalene for that.

In the meantime I'm going through the code line by line making sure that all tensors are on the same device.

@rlouf
Copy link
Member Author

rlouf commented Jul 3, 2023

I re-ran the line_profiler locally, first with samples=1:

import outlines.models as models
import outlines.text as text
from line_profiler import LineProfiler
from outlines.text.generate.sequence import Sequence
from outlines.text.generate.continuation import Continuation


model_name = "togethercomputer/RedPajama-INCITE-Instruct-3B-v1"
prompt = "This is a prompt"


def outlines_run():
    model = models.transformers(model_name, device="cuda")
    a = text.generate.continuation(model, max_tokens=100)(prompt)


profiler = LineProfiler()
profiler.add_function(outlines_run)
profiler.add_function(Sequence.step.__wrapped__)
profiler.add_function(Sequence.expand_attention_mask.__wrapped__)
profiler.add_function(Sequence.update_token_ids.__wrapped__)
profiler.add_function(Sequence.__call__.__wrapped__)
profiler.add_function(Continuation.is_finished)
profiler.run('outlines_run()')
profiler.print_stats()

And the results are similar to what you obtained:

Full output
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    34                                               @torch.no_grad()
    35                                               def step(
    36                                                   self,
    37                                                   rng: torch.Generator,
    38                                                   token_ids: torch.LongTensor,
    39                                                   attention_mask: torch.LongTensor,
    40                                                   samples: int = 1,
    41                                               ) -> Tuple[torch.LongTensor, torch.FloatTensor]:
    42                                                   """Generate one or several tokens that complete the input sequence.
    43                                           
    44                                                   The sampling step consists in using a model to generate next-token
    45                                                   logits and then sample `samples`-many new tokens from a categorical
    46                                                   distribution parametrized by these logits.
    47                                           
    48                                                   Parameters
    49                                                   ----------
    50                                                   rng
    51                                                       NumPy random number Generator instance
    52                                                   token_ids
    53                                                       The token ids passed as an input to the model, of shape `batch_shape
    54                                                       + (num_tokens,)`, where `num_tokens` is the sequences' length.
    55                                                   samples
    56                                                       The number of continuations to sample from the next-token probability
    57                                                       distribution.
    58                                           
    59                                                   Returns
    60                                                   -------
    61                                                   A tuple with an array of shape `new_batch_shape + (num_tokens+1,)`that
    62                                                   contains the completed sequences (input token ids and generated token
    63                                                   ids) and an array of shape `new_batch_shape + (vocab_size,)` that
    64                                                   contains the next token probabilities.
    65                                                   `new_batch_shape` is computed by removing dimensions of size one in
    66                                                   `(samples,) + batch_shape`.
    67                                           
    68                                                   """
    69       100      66675.0    666.8      0.0          num_input_dims = token_ids.ndim
    70       100 10274542947.0 102745429.5     99.9          probs = self.model(token_ids, attention_mask)
    71                                           
    72                                                   # Sample `samples`-many new tokens
    73       100    9679926.0  96799.3      0.1          next_token_ids = vectorized_random_choice(rng, probs, samples)
    74                                           
    75                                                   # Add the missing `num_tokens` and `num_sample` dimensions
    76       100     272250.0   2722.5      0.0          next_token_ids = torch.unsqueeze(next_token_ids, -1)
    77       100     142457.0   1424.6      0.0          token_ids = torch.unsqueeze(token_ids, 0)
    78                                           
    79                                                   # Expand the input `token_ids` array to be able to concatenate several
    80                                                   # samples.
    81       100      29935.0    299.4      0.0          if samples > 1:
    82                                                       repetitions = (samples,) + (1,) * num_input_dims
    83                                                       token_ids = torch.tile(token_ids, repetitions)
    84                                                       probs = torch.tile(probs, repetitions)
    85                                           
    86       100    1563555.0  15635.5      0.0          token_ids = torch.concatenate([token_ids, next_token_ids], axis=-1)
    87                                           
    88                                                   # Merge sample and batch dimensions by removing dimensions of length
    89                                                   # 1. The shape of the resulting arrays is `new_batch_shape + (num_tokens,)`
    90                                                   # and `new_batch_shape + (vocab_size,)` respectively.
    91       100     809526.0   8095.3      0.0          token_ids = torch.atleast_2d(token_ids.squeeze())
    92       100     501089.0   5010.9      0.0          probs = torch.atleast_2d(probs.squeeze())
    93                                           
    94       100      21946.0    219.5      0.0          return token_ids, probs

Total time: 0.00414091 s
File: /home/remi/projects/normal/outlines/outlines/text/generate/sequence.py
Function: expand_attention_mask at line 96

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    96                                               @torch.no_grad()
    97                                               def expand_attention_mask(
    98                                                   self, attention_mask: torch.LongTensor
    99                                               ) -> torch.LongTensor:
   100                                                   """Expand the attention mask after the last completion."""
   101       100     186026.0   1860.3      4.5          batch_shape = attention_mask.shape[:-1]
   102       100    1343915.0  13439.1     32.5          attention_mask = torch.concatenate(
   103       100      21965.0    219.7      0.5              [
   104       100      15429.0    154.3      0.4                  attention_mask,
   105       100     399853.0   3998.5      9.7                  torch.broadcast_to(
   106       100    2142039.0  21420.4     51.7                      torch.tensor([1], device=self.device), batch_shape + (1,)
   107                                                           ),
   108                                                       ],
   109       100      14390.0    143.9      0.3              axis=-1,
   110                                                   )
   111       100      17288.0    172.9      0.4          return attention_mask

Total time: 0.551257 s
File: /home/remi/projects/normal/outlines/outlines/text/generate/sequence.py
Function: update_token_ids at line 113

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   113                                               @torch.no_grad()
   114                                               def update_token_ids(
   115                                                   self,
   116                                                   is_finished: torch.BoolTensor,
   117                                                   token_ids: torch.LongTensor,
   118                                                   token_ids_unfinished: torch.LongTensor,
   119                                               ) -> torch.LongTensor:
   120                                                   """Update the array of token ids after the last completion.
   121                                           
   122                                                   We only generate new tokens for the sequences that are not finished. We thus
   123                                                   update the array with the new tokens, and append pad tokens to the finished
   124                                                   sequences.
   125                                           
   126                                                   Parameters
   127                                                   ----------
   128                                                   is_finished
   129                                                       Boolean array that indicates which sequences are finished.
   130                                                   token_ids
   131                                                       Array that contains the sequences before the generation's last step.
   132                                                   token_ids_unfinished
   133                                                       Array that contains the sequences of the unfinished sequences
   134                                                       after the generation's last step.
   135                                           
   136                                                   Returns
   137                                                   -------
   138                                                   An array that contains the updated array that contains the sequences. We append
   139                                                   pad tokens to the finished sequences.
   140                                           
   141                                                   """
   142       100     167455.0   1674.5      0.0          batch_shape = token_ids.shape[:-1]
   143       100      59216.0    592.2      0.0          num_tokens = token_ids.shape[-1]
   144       100     850508.0   8505.1      0.2          new_token_ids = torch.empty(
   145       100     138924.0   1389.2      0.0              batch_shape + (num_tokens + 1,), dtype=torch.int64, device=self.device
   146                                                   )
   147       100  537280326.0 5372803.3     97.5          token_ids_finished = token_ids[is_finished]
   148       100     156720.0   1567.2      0.0          batch_shape_finished = token_ids_finished.shape[:-1]
   149       100     677681.0   6776.8      0.1          token_ids_finished = torch.concatenate(
   150       100      20005.0    200.1      0.0              [
   151       100      15914.0    159.1      0.0                  token_ids_finished,
   152       100     529063.0   5290.6      0.1                  torch.broadcast_to(
   153       100    2356605.0  23566.0      0.4                      torch.tensor(
   154       100     143268.0   1432.7      0.0                          [self.model.tokenizer.pad_token_id], device=self.device
   155                                                               ),
   156       100      75104.0    751.0      0.0                      batch_shape_finished + (1,),
   157                                                           ),
   158                                                       ],
   159       100      16499.0    165.0      0.0              axis=-1,
   160                                                   )
   161                                           
   162       100    6001185.0  60011.8      1.1          new_token_ids[~is_finished] = token_ids_unfinished
   163       100    2742972.0  27429.7      0.5          new_token_ids[is_finished] = token_ids_finished
   164                                           
   165       100      25665.0    256.6      0.0          return new_token_ids

Total time: 10.8702 s
File: /home/remi/projects/normal/outlines/outlines/text/generate/sequence.py
Function: __call__ at line 167

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   167                                               @torch.no_grad()
   168                                               def __call__(
   169                                                   self,
   170                                                   prompt: Union[str, List[str]],
   171                                                   samples: int = 1,
   172                                                   rng: Optional[torch.Generator] = None,
   173                                               ) -> Union[str, List[str]]:
   174                                                   """Generate a new sequence given a prompt.
   175                                           
   176                                                   Parameters
   177                                                   ----------
   178                                                   prompt
   179                                                       The input prompt.
   180                                                   samples
   181                                                       The number of samples to generate for each prompt.
   182                                           
   183                                                   Returns
   184                                                   -------
   185                                                   The full sequence that contains the prompts and the generated string.
   186                                           
   187                                                   """
   188         1    1962430.0 1962430.0      0.0          token_ids, attention_mask = self.model.tokenizer.encode(prompt)
   189                                           
   190         1      71424.0  71424.0      0.0          token_ids = token_ids.to(self.device)
   191         1      10630.0  10630.0      0.0          attention_mask = attention_mask.to(self.device)
   192                                           
   193         1        161.0    161.0      0.0          if rng is None:
   194         1      17423.0  17423.0      0.0              rng = torch.Generator(device=self.device)
   195                                           
   196         1        761.0    761.0      0.0          num_prompt_tokens = token_ids.shape[-1]
   197                                           
   198         1        240.0    240.0      0.0          if samples > 1:
   199                                                       token_ids, _ = self.step(rng, token_ids, attention_mask, samples)
   200                                                       is_finished = self.is_finished(token_ids)
   201                                           
   202                                                       num_batch_dims = token_ids.ndim - 1
   203                                                       repetitions = (samples,) + (1,) * num_batch_dims
   204                                                       attention_mask = torch.tile(attention_mask, repetitions)
   205                                                       attention_mask = self.expand_attention_mask(attention_mask)
   206                                                   else:
   207         1       2033.0   2033.0      0.0              batch_shape = token_ids.shape[:-1]
   208         1     103334.0 103334.0      0.0              is_finished = torch.zeros(batch_shape, dtype=torch.bool, device=self.device)
   209                                           
   210                                                   while True:
   211       101     134723.0   1333.9      0.0              num_generated_tokens = token_ids.shape[-1] - num_prompt_tokens
   212       100    2695081.0  26950.8      0.0              if torch.all(is_finished) or num_generated_tokens == self.max_tokens:
   213         1        230.0    230.0      0.0                  break
   214                                           
   215       100    4638624.0  46386.2      0.0              token_ids_unfinished = token_ids[~is_finished]
   216       100    3841868.0  38418.7      0.0              attention_mask_unfinished = attention_mask[~is_finished]
   217       100 10289776263.0 102897762.6     94.7              token_ids_unfinished, _ = self.step(
   218       100      17093.0    170.9      0.0                  rng, token_ids_unfinished, attention_mask_unfinished
   219                                                       )
   220                                           
   221       100  553695694.0 5536956.9      5.1              token_ids = self.update_token_ids(
   222       100      17300.0    173.0      0.0                  is_finished, token_ids, token_ids_unfinished
   223                                                       )
   224       100    5447549.0  54475.5      0.1              attention_mask = self.expand_attention_mask(attention_mask)
   225       100    7548630.0  75486.3      0.1              is_finished[~is_finished] = self.is_finished(token_ids_unfinished).flatten()
   226                                           
   227         1     204724.0 204724.0      0.0          result = self.model.tokenizer.decode(token_ids)
   228         1       3406.0   3406.0      0.0          result = self.postprocess_completions(result)
   229                                           
   230         1        541.0    541.0      0.0          if len(result) == 1:
   231         1        271.0    271.0      0.0              return result[0]
   232                                           
   233                                                   return result

The interesting lines are below, where we see that Sequence.update_token_ids takes 5% of the time:

   217       100 10289776263.0 102897762.6     94.7              token_ids_unfinished, _ = self.step(
   218       100      17093.0    170.9      0.0                  rng, token_ids_unfinished, attention_mask_unfinished
   219                                                       )
   220                                           
   221       100  553695694.0 5536956.9      5.1              token_ids = self.update_token_ids(
   222       100      17300.0    173.0      0.0                  is_finished, token_ids, token_ids_unfinished
   223                                                       )

And most of the time is spent on this line:

   147       100  537280326.0 5372803.3     97.5          token_ids_finished = token_ids[is_finished]

which tells us that the boolean indexing that we currently use is very inefficient, and likely explains the difference that we see (at least on GPU) with transformers.

The difference with transformers when taking several samples is even worse (I time 56s for outlines and 21s for transformers on average over 10 runs), and my hunch is that this has something to do with moving memory around.

A first approach is to simplify the code completely to remove boolean indexing and do something equivalent to transformers (feed all inputs ids to the model at every step), get at least to parity, and take it from there.

@rlouf
Copy link
Member Author

rlouf commented Jul 4, 2023

For some reason use_cache=False does not work as intended in the transformers library for "togethercomputer/RedPajama-INCITE-Instruct-3B-v1", although it does for gpt2. This explains @arunpatro's and my observation that outlines is faster than transformers for GPT2 but much slower otherwise.

I added the following lines in transformers.generation.utils.py before the model is called to make sure the cache is not used:

model_inputs.pop("past_key_values")
model_inputs["input_ids"] = input_ids

On average over 10 runs I find the following runtime for 10 sampled sequences on my Quadro RTX5000:

  • transformers: 59s pm 0.3s
  • outlines: 59s pm 0.3s

So it's a wash. SImilarly for a single sampled sequence, both libraries take 27s on average over 10 runs.

@rlouf rlouf force-pushed the replace-numpy-with-pytorch branch from 46d9c89 to 3a3eae7 Compare July 6, 2023 12:11
@rlouf rlouf requested a review from arunpatro July 6, 2023 12:40
@rlouf
Copy link
Member Author

rlouf commented Jul 6, 2023

This is good for review again.

@rlouf rlouf merged commit 293826c into outlines-dev:main Jul 6, 2023
4 checks passed
@rlouf rlouf deleted the replace-numpy-with-pytorch branch July 6, 2023 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants