Skip to content

Commit

Permalink
Restore FSM interface for backward compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Mar 12, 2024
1 parent 5d97ee1 commit d47bd6b
Show file tree
Hide file tree
Showing 2 changed files with 416 additions and 0 deletions.
69 changes: 69 additions & 0 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import warnings
from typing import TYPE_CHECKING, List, NewType

from outlines.fsm.guide import CFGGuide, RegexGuide, StopAtEOSGuide

if TYPE_CHECKING:
from outlines.models.tokenizer import Tokenizer

FSMState = NewType("FSMState", int)


class StopAtEosFSM(StopAtEOSGuide):
"""FSM to generate text until EOS has been generated."""

def __init__(self, tokenizer: "Tokenizer"):
warnings.warn(
UserWarning(
"The `StopAtTokenFSM` interface is deprecated and will be removed on 2024-06-01. Please use `StopAtEOSGuide` instead."
)
)
super().__init__(tokenizer)

def allowed_token_ids(self, state: FSMState) -> List[int]:
next_instruction = self.get_next_instruction(state)
return next_instruction.tokens

def next_state(self, state: FSMState, token_id: int) -> FSMState:
return FSMState(self.get_next_state(state, token_id))


class RegexFSM(RegexGuide):
"""FSM to generate text that is in the language of a regular expression."""

def __init__(self, regex_string: str, tokenizer):
warnings.warn(
UserWarning(
"The `RegexFSM` interface is deprecated and will be removed on 2024-06-01. Please use `RegexGuide` instead."
)
)
super().__init__(regex_string, tokenizer)

def allowed_token_ids(self, state: FSMState) -> List[int]:
next_instruction = self.get_next_instruction(state)
return next_instruction.tokens

def next_state(self, state: FSMState, token_id: int) -> FSMState:
return FSMState(self.get_next_state(state, token_id))


class CFGFSM(CFGGuide):
"""FSM to generate text that is in the language of a context-free grammar."""

def __init__(self, cfg_string: str, tokenizer):
warnings.warn(
UserWarning(
"The `CFGFSM` interface is deprecated and will be removed on 2024-06-01. Please use `CFGGuide` instead."
)
)
super().__init__(cfg_string, tokenizer)

def allowed_token_ids(self, state: FSMState) -> List[int]:
return self.get_next_instruction(state).tokens

def next_state(self, state: FSMState, token_id: int) -> FSMState:
return FSMState(self.get_next_state(state, token_id))

def copy(self) -> "CFGFSM":
"""Create a copy of the FSM."""
return CFGFSM(self.cfg_string, self.tokenizer)
Loading

0 comments on commit d47bd6b

Please sign in to comment.