Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using context-free grammars to guide generation does not work #959

Open
border-b opened this issue Jun 12, 2024 · 3 comments · May be fixed by #1067 or lapp0/outlines#34
Open

Using context-free grammars to guide generation does not work #959

border-b opened this issue Jun 12, 2024 · 3 comments · May be fixed by #1067 or lapp0/outlines#34

Comments

@border-b
Copy link

Describe the issue as clearly as possible:

I was trying to run the example of using cfg to guide generation. But it seems there is some issue with CFGGuide. Just running the example without any change produces an error.

Steps/code to reproduce the bug:

import outlines

arithmetic_grammar = """
    ?start: expression

    ?expression: term (("+" | "-") term)*

    ?term: factor (("*" | "/") factor)*

    ?factor: NUMBER
           | "-" factor
           | "(" expression ")"

    %import common.NUMBER
"""

model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
generator = outlines.generate.cfg(model, arithmetic_grammar)
sequence = generator("Alice had 4 apples and Bob ate 2. Write an expression for Alice's apples:")

print(sequence)
# (8-2)

Expected result:

(8-2)

Error message:

Traceback (most recent call last)
Cell In[1], line 19
     17 model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
     18 generator = outlines.generate.cfg(model, arithmetic_grammar)
---> 19 sequence = generator("Alice had 4 apples and Bob ate 2. Write an expression for Alice's apples:")
     21 print(sequence)
     22 # (8-2)

File /usr/local/lib/python3.10/site-packages/outlines/generate/api.py:207, in SequenceGenerator.__call__(self, prompts, max_tokens, stop_at, rng)
    205 while True:
    206     try:
--> 207         last_state = next(states)
    208         if max_tokens or stop_sequences:
    209             token_ids = last_state.token_ids

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:92, in sequence_generator(model, sampler, fsms, token_ids, sequence_weights, attention_masks, fsm_states, rng)
     89 fsms = reorder_fsms(fsms, ancestors)
     90 fsm_states = reorder_fsm_states(fsm_states, ancestors)
---> 92 fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids)
     93 is_finished = is_generation_finished(fsms, fsm_states)
     95 if is_finished:

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:131, in get_next_fsm_states(fsms, fsm_states, next_token_ids)
    114 def get_next_fsm_states(
    115     fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor"
    116 ) -> List[int]:
    117     """
    118 
    119     Parameters
   (...)
    129 
    130     """
--> 131     return [
    132         fsm.get_next_state(fsm_state, int(token_id[0]))
    133         for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids)
    134     ]

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:132, in <listcomp>(.0)
    114 def get_next_fsm_states(
    115     fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor"
    116 ) -> List[int]:
    117     """
    118 
    119     Parameters
   (...)
    129 
    130     """
    131     return [
--> 132         fsm.get_next_state(fsm_state, int(token_id[0]))
    133         for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids)
    134     ]

File /usr/local/lib/python3.10/site-packages/outlines/fsm/guide.py:416, in CFGGuide.get_next_state(self, state, token_id)
    413     self.reset_state = False
    414     state = self.start_state
--> 416 return self.regex_fsm.get_next_state(state, token_id)

AttributeError: 'CFGGuide' object has no attribute 'regex_fsm'

Outlines/Python version information:

Version information

``` huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid using `tokenizers` before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) 0.0.41 huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid using `tokenizers` before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) Python 3.10.8 (main, Dec 6 2022, 14:24:03) [GCC 10.2.1 20210110] huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... To disable this warning, you can either: - Avoid using `tokenizers` before the fork if possible - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) accelerate==0.30.1 aiohttp==3.8.3 aiosignal==1.3.1 aiostream==0.4.4 annotated-types==0.6.0 anyio==3.7.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asgiref==3.5.2 asttokens==2.2.1 async-lru==2.0.4 async-timeout==4.0.2 attrs==23.1.0 Babel==2.15.0 backcall==0.2.0 beautifulsoup4==4.12.3 bitsandbytes==0.43.1 bleach==6.1.0 bytecode==0.14.2 cattrs==23.1.2 certifi==2023.5.7 cffi==1.16.0 charset-normalizer==2.1.1 click==8.1.3 cloudpickle==2.0.0 comm==0.2.2 commonmark==0.9.1 datasets==2.19.1 ddsketch==2.0.4 ddtrace==1.5.2 debugpy==1.8.1 decorator==5.1.1 defusedxml==0.7.1 dill==0.3.8 diskcache==5.6.3 docstring_parser==0.16 envier==0.4.0 exceptiongroup==1.1.1 executing==1.2.0 fastapi==0.88.0 fastjsonschema==2.19.1 fastprogress==1.0.0 filelock==3.14.0 fqdn==1.5.1 frozenlist==1.3.3 fsspec==2024.3.1 grpclib==0.4.3 h11==0.14.0 h2==4.1.0 hpack==4.0.0 httpcore==1.0.5 httpx==0.27.0 huggingface-hub==0.23.0 hyperframe==6.0.1 idna==3.4 importlib-metadata==4.8.1 interegular==0.3.3 ipykernel==6.29.4 ipython==8.14.0 ipywidgets==8.1.2 isoduration==20.11.0 jedi==0.18.2 Jinja2==3.1.4 json5==0.9.25 jsonpointer==2.4 jsonschema==4.22.0 jsonschema-specifications==2023.12.1 jupyter==1.0.0 jupyter-console==6.6.3 jupyter-events==0.10.0 jupyter-lsp==2.2.5 jupyter_client==8.6.1 jupyter_core==5.7.2 jupyter_server==2.14.0 jupyter_server_terminals==0.5.3 jupyterlab==4.1.8 jupyterlab_pygments==0.3.0 jupyterlab_server==2.27.1 jupyterlab_widgets==3.0.10 lark==1.1.9 llvmlite==0.42.0 MarkupSafe==2.1.5 matplotlib-inline==0.1.6 mistune==3.0.2 modal==0.62.139 mpmath==1.3.0 multidict==6.0.4 multiprocess==0.70.16 nbclient==0.10.0 nbconvert==7.16.4 nbformat==5.10.4 nest-asyncio==1.6.0 networkx==3.2.1 notebook==7.1.3 notebook_shim==0.2.4 numba==0.59.1 numpy==1.25.0 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.19.3 nvidia-nvjitlink-cu12==12.1.105 nvidia-nvtx-cu12==12.1.105 outlines==0.0.41 overrides==7.7.0 packaging==23.1 pandas==2.2.2 pandocfilters==1.5.1 parso==0.8.3 peft==0.10.0 pexpect==4.8.0 pickleshare==0.7.5 pillow==10.2.0 platformdirs==4.2.1 prometheus_client==0.20.0 prompt-toolkit==3.0.38 protobuf==3.20.3 psutil==5.9.8 ptyprocess==0.7.0 pure-eval==0.2.2 pyarrow==16.0.0 pyarrow-hotfix==0.6 pycparser==2.22 pydantic==2.7.1 pydantic_core==2.18.2 Pygments==2.15.1 pyrsistent==0.19.3 python-dateutil==2.9.0.post0 python-json-logger==2.0.7 python-multipart==0.0.6 pytz==2024.1 PyYAML==6.0.1 pyzmq==26.0.3 qtconsole==5.5.2 QtPy==2.4.1 referencing==0.35.1 regex==2024.5.10 requests==2.31.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==12.3.0 rpds-py==0.18.1 safetensors==0.4.3 Send2Trash==1.8.3 sentencepiece==0.2.0 shtab==1.7.1 six==1.16.0 sniffio==1.3.0 soupsieve==2.5 stack-data==0.6.2 starlette==0.22.0 sympy==1.12 tblib==1.7.0 tenacity==8.2.2 terminado==0.18.1 tinycss2==1.3.0 tokenizers==0.19.1 toml==0.10.2 tomli==2.0.1 torch==2.2.2+cu121 torchaudio==2.2.2+cu121 torchvision==0.17.2+cu121 tornado==6.4 tqdm==4.66.4 traitlets==5.9.0 transformers==4.40.2 triton==2.2.0 trl==0.8.6 typeguard==4.0.0 typer==0.6.1 types-certifi==2021.10.8.3 types-python-dateutil==2.9.0.20240316 types-toml==0.10.4 typing_extensions==4.9.0 tyro==0.8.4 tzdata==2024.1 unsloth @ git+https://github.com/unslothai/unsloth.git@47ffd39abd02338e8a5f226d0f529347fb7e5f89 uri-template==1.3.0 urllib3==2.2.1 wcwidth==0.2.6 webcolors==1.13 webencodings==0.5.1 websocket-client==1.8.0 widgetsnbextension==4.0.10 xformers==0.0.25.post1 xmltodict==0.13.0 xxhash==3.4.1 yarl==1.9.2 zipp==3.15.0 ```

Context for the issue:

I was trying to test the output with a custom grammar, but the provided example fails to generate any output.

@border-b border-b added the bug label Jun 12, 2024
@rlouf
Copy link
Member

rlouf commented Jun 12, 2024

Can you upgrade outlines for 0.0.43 and try again?

@border-b
Copy link
Author

@rlouf Upgrading to 0.0.43 solves this error. But it generates another one:

Traceback (most recent call last)
File /usr/local/lib/python3.10/site-packages/lark/lexer.py:673, in ContextualLexer.lex(self, lexer_state, parser_state)
    672 last_token = lexer_state.last_token  # Save last_token. Calling root_lexer.next_token will change this to the wrong token
--> 673 token = self.root_lexer.next_token(lexer_state, parser_state)
    674 raise UnexpectedToken(token, e.allowed, state=parser_state, token_history=[last_token], terminals_by_name=self.root_lexer.terminals_by_name)

File /usr/local/lib/python3.10/site-packages/lark/lexer.py:598, in BasicLexer.next_token(self, lex_state, parser_state)
    597         allowed = {"<END-OF-FILE>"}
--> 598     raise UnexpectedCharacters(lex_state.text, line_ctr.char_pos, line_ctr.line, line_ctr.column,
    599                                allowed=allowed, token_history=lex_state.last_token and [lex_state.last_token],
    600                                state=parser_state, terminals_by_name=self.terminals_by_name)
    602 value, type_ = res

UnexpectedCharacters: No terminal matches 'e' in the current parser context, at line 1 col 193

98*0.5000000000000022*2.2204460492503131e-
                                        ^
Expected one of: 
	* RPAR
	* STAR
	* NUMBER
	* SLASH
	* PLUS
	* MINUS
	* LPAR

Previous tokens: Token('NUMBER', '2.2204460492503131')


During handling of the above exception, another exception occurred:

UnexpectedCharacters                      Traceback (most recent call last)
Cell In[2], line 19
     17 model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
     18 generator = outlines.generate.cfg(model, arithmetic_grammar)
---> 19 sequence = generator("Alice had 4 apples and Bob ate 2. Write an expression for Alice's apples:")
     21 print(sequence)
     22 # (8-2)

File /usr/local/lib/python3.10/site-packages/outlines/generate/api.py:207, in SequenceGenerator.__call__(self, prompts, max_tokens, stop_at, rng)
    205 while True:
    206     try:
--> 207         last_state = next(states)
    208         if max_tokens or stop_sequences:
    209             token_ids = last_state.token_ids

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:80, in sequence_generator(model, sampler, fsms, token_ids, sequence_weights, attention_masks, fsm_states, rng)
     75 except IndexError:  # Exceeding the context length
     76     raise ContextLengthExceededError(
     77         "The input length exceeds the context length of the model."
     78     )
---> 80 allowed_tokens = get_allowed_tokens(fsms, fsm_states)
     81 biased_logits = bias_logits(logits, allowed_tokens)
     82 next_token_ids, ancestors, sequence_weights = sampler(
     83     biased_logits, sequence_weights, rng
     84 )

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:155, in get_allowed_tokens(fsms, fsm_states)
    138 def get_allowed_tokens(
    139     fsms: List["Guide"], fsm_states: List[int]
    140 ) -> List[Optional[Iterable[int]]]:
    141     """Get the new instructions for each sequence from the finite-state machine.
    142 
    143     Parameters
   (...)
    153 
    154     """
--> 155     return [
    156         fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states)
    157     ]

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:156, in <listcomp>(.0)
    138 def get_allowed_tokens(
    139     fsms: List["Guide"], fsm_states: List[int]
    140 ) -> List[Optional[Iterable[int]]]:
    141     """Get the new instructions for each sequence from the finite-state machine.
    142 
    143     Parameters
   (...)
    153 
    154     """
    155     return [
--> 156         fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states)
    157     ]

File /usr/local/lib/python3.10/site-packages/outlines/fsm/guide.py:349, in CFGGuide.get_next_instruction(self, state)
    346         self.regex_fsm_last = proposer
    348 interactive = self.parser.parse_interactive(self.generation)
--> 349 interactive.exhaust_lexer()
    351 options = {self.terminal_regexps[x] for x in interactive.accepts()}
    352 # add %ignore terminals

File /usr/local/lib/python3.10/site-packages/lark/parsers/lalr_interactive_parser.py:52, in InteractiveParser.exhaust_lexer(self)
     47 def exhaust_lexer(self) -> List[Token]:
     48     """Try to feed the rest of the lexer state into the interactive parser.
     49 
     50     Note that this modifies the instance in place and does not feed an '$END' Token
     51     """
---> 52     return list(self.iter_parse())

File /usr/local/lib/python3.10/site-packages/lark/parsers/lalr_interactive_parser.py:43, in InteractiveParser.iter_parse(self)
     35 def iter_parse(self) -> Iterator[Token]:
     36     """Step through the different stages of the parse, by reading tokens from the lexer
     37     and feeding them to the parser, one per iteration.
     38 
   (...)
     41     When the parse is over, the resulting tree can be found in ``InteractiveParser.result``.
     42     """
---> 43     for token in self.lexer_thread.lex(self.parser_state):
     44         yield token
     45         self.result = self.feed_token(token)

File /usr/local/lib/python3.10/site-packages/lark/lexer.py:676, in ContextualLexer.lex(self, lexer_state, parser_state)
    674     raise UnexpectedToken(token, e.allowed, state=parser_state, token_history=[last_token], terminals_by_name=self.root_lexer.terminals_by_name)
    675 except UnexpectedCharacters:
--> 676     raise e

File /usr/local/lib/python3.10/site-packages/lark/lexer.py:665, in ContextualLexer.lex(self, lexer_state, parser_state)
    663     while True:
    664         lexer = self.lexers[parser_state.position]
--> 665         yield lexer.next_token(lexer_state, parser_state)
    666 except EOFError:
    667     pass

File /usr/local/lib/python3.10/site-packages/lark/lexer.py:598, in BasicLexer.next_token(self, lex_state, parser_state)
    596     if not allowed:
    597         allowed = {"<END-OF-FILE>"}
--> 598     raise UnexpectedCharacters(lex_state.text, line_ctr.char_pos, line_ctr.line, line_ctr.column,
    599                                allowed=allowed, token_history=lex_state.last_token and [lex_state.last_token],
    600                                state=parser_state, terminals_by_name=self.terminals_by_name)
    602 value, type_ = res
    604 ignored = type_ in self.ignore_types

UnexpectedCharacters: No terminal matches 'e' in the current parser context, at line 1 col 193

98*0.5000000000000022*2.2204460492503131e-
                                        ^
Expected one of: 
	* RPAR
	* STAR
	* SLASH
	* PLUS
	* MINUS

Previous tokens: Token('NUMBER', '2.2204460492503131')

It seems the function is executing now, but the output is not following the grammar?

@lapp0
Copy link
Collaborator

lapp0 commented Jun 12, 2024

This seems to be the same as the issue I ran into in #796

I'll be working on getting CFG and the parser in a good state over the coming weeks. You can track progress by subscribing to this issue: #684

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In Progress
3 participants