Skip to content

Commit

Permalink
Mask cache Performance Optimization for vllm (#939)
Browse files Browse the repository at this point in the history
## Problem
The current implementation allocates a mask for every token during
generation, which significantly impacts performance.

## Proposed Solution
To improve the performance, we can cache the mask on the device, as it
depends on the allowed tokens from the FSM. Additionally, limiting the
input to the hash function to the first 2k tokens results in a notable
speedup.

## Discussion
While using only the first 2k tokens for the hash may introduce
potential cache collisions, the likelihood of such collisions is very
low.

## TODO
- [x] Provide measurements of the performance impact

---------

Co-authored-by: pgrundmann <pgrundmann@bht-berlin.de>
  • Loading branch information
paul-grundmann and pgrundmann committed Jun 16, 2024
1 parent 61d1f43 commit 0c1935a
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions outlines/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import math
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, List, Optional, Type, Union
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Type, Union

import torch
from pydantic import BaseModel
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(self, regex_string: str, llm: "LLM"):
"`tokenizer` attribute or a `get_tokenizer` method."
)
tokenizer = adapt_tokenizer(tokenizer=tokenizer)
self.mask_cache: Dict[int, torch.Tensor] = {}
self.fsm = RegexGuide(regex_string, tokenizer)
self._fsm_state: DefaultDict[int, int] = defaultdict(int)

Expand Down Expand Up @@ -107,12 +108,18 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
state=self._fsm_state[last_seq_id], token_id=last_token
)

allowed_tokens = self.fsm.get_next_instruction(
state=self._fsm_state[seq_id]
).tokens

mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
mask[allowed_tokens] = 0
state_id = self._fsm_state[seq_id]
if state_id not in self.mask_cache:
allowed_tokens = self.fsm.get_next_instruction(
state=self._fsm_state[seq_id]
).tokens
mask = torch.full((scores.shape[-1],), -math.inf)
mask[allowed_tokens] = 0
mask = mask.pin_memory()
self.mask_cache[state_id] = mask
else:
mask = self.mask_cache[state_id]
mask = mask.to(device=scores.device, non_blocking=True)
biased_scores = scores + mask

return biased_scores
Expand Down

0 comments on commit 0c1935a

Please sign in to comment.