Skip to content

Commit

Permalink
Fix broken commits
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-grundmann committed Jun 14, 2024
1 parent d676bb9 commit 1385324
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions outlines/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,18 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
state=self._fsm_state[last_seq_id], token_id=last_token
)

allowed_tokens = self.fsm.get_next_instruction(
state=self._fsm_state[seq_id]
).tokens

cache_key = hash(tuple(allowed_tokens[:2048]))
if cache_key not in self.mask_cache:
mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
state_id = self._fsm_state[seq_id]
if state_id not in self.mask_cache:
allowed_tokens = self.fsm.get_next_instruction(
state=self._fsm_state[seq_id]
).tokens
mask = torch.full((scores.shape[-1],), -math.inf)
mask[allowed_tokens] = 0
self.mask_cache[cache_key] = mask
mask = mask.pin_memory()
self.mask_cache[state_id] = mask
else:
mask = self.mask_cache[cache_key]
mask = self.mask_cache[state_id]
mask = mask.to(device=scores.device, non_blocking=True)
biased_scores = scores + mask

return biased_scores
Expand Down

0 comments on commit 1385324

Please sign in to comment.