diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py index a70f0e3e8..6ed56d71b 100644 --- a/outlines/integrations/vllm.py +++ b/outlines/integrations/vllm.py @@ -100,9 +100,7 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: # Initialize the FSM state dictionary if the input_ids are empty, as this means # that the input_ids are the first tokens of the sequence. - if len(input_ids) == 0: - self._fsm_state = defaultdict(int) - else: + if len(input_ids) > 0: last_token = input_ids[-1] last_seq_id = hash(tuple(input_ids[:-1])) self._fsm_state[seq_id] = self.fsm.get_next_state(