Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When consecutively entering the same text input, calling the RegexPrefixAllowedTokens instance results in an error at the second time. #1021

Closed
littlepenguin89106 opened this issue Jul 6, 2024 · 2 comments · Fixed by #966
Labels

Comments

@littlepenguin89106
Copy link

Describe the issue as clearly as possible:

For some reasons, I need to handle same the text input. I found that if I consecutively entering the same text input, I will break the assumption of __call__ in RegexPrefixAllowedTokens:

# If the prefix token IDs have changed we assume that we are dealing with a new
# sample and reset the FSM state
if input_ids[: len(self._prefix)] != self._prefix:
    self._fsm_state = defaultdict(int)
    self._prefix = input_ids
    seq_id = hash(tuple([]))

else:
    # Remove the prefix token IDs from the input token IDs, as the FSM should
    # only be applied to the generated tokens
    input_ids = input_ids[len(self._prefix) :]
    last_token = input_ids[-1]
    last_seq_id = hash(tuple(input_ids[:-1]))
    seq_id = hash(tuple(input_ids))
    self._fsm_state[seq_id] = self.fsm.get_next_state(
        state=self._fsm_state[last_seq_id], token_id=last_token
    )

and get list index out of range error in last_token = input_ids[-1]

Steps/code to reproduce the bug:

# Modified from example of transformers_integration.py

from pydantic import BaseModel
from transformers import pipeline

from outlines.integrations.transformers import JSONPrefixAllowedTokens


class Person(BaseModel):
    first_name: str
    surname: str


pipe = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1")
prefix_allowed_tokens_fn = JSONPrefixAllowedTokens(
    schema=Person, tokenizer_or_pipe=pipe, whitespace_pattern=r" ?"
)

input_text = ["He is Tom Jones"]

for i in range(2):
    results = pipe(
        input_text,
        return_full_text=False,
        do_sample=False,
        max_new_tokens=50,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    )

Expected result:

It should output the same result without any error at the second time.

Error message:

Traceback (most recent call last):
  File "/work/tvsrt0p1c/vqa/test.py", line 20, in <module>
    results = pipe(
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/pipelines/text_generation.py", line 240, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/pipelines/base.py", line 1223, in __call__
    outputs = list(final_iterator)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/pipelines/pt_utils.py", line 124, in __next__
    item = next(self.iterator)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/pipelines/pt_utils.py", line 125, in __next__
    processed = self.infer(item, **self.params)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/pipelines/base.py", line 1149, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/pipelines/text_generation.py", line 327, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1576, in generate
    result = self._greedy_search(
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2507, in _greedy_search
    next_tokens_scores = logits_processor(input_ids, next_token_logits)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 98, in __call__
    scores = processor(input_ids, scores)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 1247, in __call__
    prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
  File "/home/tvsrt0p1c/.local/lib/python3.10/site-packages/outlines/integrations/transformers.py", line 115, in __call__
    last_token = input_ids[-1]
IndexError: list index out of range

Outlines/Python version information:

Version information

``` 0.0.46 Python 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] aiohttp==3.9.5 aiosignal==1.3.1 annotated-types==0.7.0 asttokens==2.4.1 async-timeout==4.0.3 attrs==23.2.0 boltons @ file:///home/conda/feedstock_root/build_artifacts/boltons_1677499911949/work Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1693583441880/work certifi==2024.7.4 cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1671179353105/work charset-normalizer==3.3.2 cloudpickle==3.0.0 colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work conda==23.3.1 conda-libmamba-solver @ file:///home/conda/feedstock_root/build_artifacts/conda-libmamba-solver_1680508672016/work/src conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1691048088238/work conda_package_streaming @ file:///home/conda/feedstock_root/build_artifacts/conda-package-streaming_1691009212940/work cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography-split_1691444160254/work datasets==2.20.0 decorator==5.1.1 dill==0.3.8 diskcache==5.6.3 exceptiongroup==1.2.1 executing==2.0.1 filelock==3.15.4 frozenlist==1.4.1 fsspec==2024.5.0 huggingface-hub==0.23.4 idna==3.7 interegular==0.3.3 ipdb==0.13.13 ipython==8.26.0 jedi==0.19.1 Jinja2==3.1.4 jsonpatch @ file:///home/conda/feedstock_root/build_artifacts/jsonpatch_1632759296524/work jsonpointer==2.0 jsonschema==4.22.0 jsonschema-specifications==2023.12.1 lark==1.1.9 libmambapy @ file:///home/conda/feedstock_root/build_artifacts/mamba-split_1680791035685/work/libmambapy llvmlite==0.43.0 mamba @ file:///home/conda/feedstock_root/build_artifacts/mamba-split_1680791035685/work/mamba MarkupSafe==2.1.5 matplotlib-inline==0.1.7 mpmath==1.3.0 multidict==6.0.5 multiprocess==0.70.16 nest-asyncio==1.6.0 networkx==3.3 numba==0.60.0 numpy==1.24.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.5.82 nvidia-nvtx-cu12==12.1.105 outlines==0.0.46 packaging==24.1 pandas==2.2.2 parso==0.8.4 pexpect==4.9.0 pillow==10.3.0 pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1693086607691/work prompt_toolkit==3.0.47 ptyprocess==0.7.0 pure-eval==0.2.2 pyairports==2.1.1 pyarrow==16.1.0 pyarrow-hotfix==0.6 pycosat @ file:///home/conda/feedstock_root/build_artifacts/pycosat_1666836542287/work pycountry==24.6.1 pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work pydantic==2.8.2 pydantic_core==2.20.1 Pygments==2.18.0 pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1685514481738/work PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work python-dateutil==2.9.0.post0 pytz==2024.1 PyYAML==6.0.1 referencing==0.35.1 regex==2024.5.15 requests==2.32.3 rpds-py==0.18.1 ruamel.yaml @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml_1686993901728/work ruamel.yaml.clib @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml.clib_1670412719074/work safetensors==0.4.3 six==1.16.0 stack-data==0.6.3 sympy==1.12.1 tokenizers==0.19.1 tomli==2.0.1 toolz @ file:///home/conda/feedstock_root/build_artifacts/toolz_1657485559105/work torch==2.3.0 torchvision==0.18.0 tqdm==4.66.4 traitlets==5.14.3 transformers==4.40.2 triton==2.3.0 typing_extensions==4.12.2 tzdata==2024.1 urllib3==2.2.2 wcwidth==0.2.13 xxhash==3.4.1 yarl==1.9.4 zstandard @ file:///home/conda/feedstock_root/build_artifacts/zstandard_1667296087208/work ```

Context for the issue:

I think the assumption is not fully cover the usages, and may occur other unexpected errors.

@lapp0
Copy link
Collaborator

lapp0 commented Jul 13, 2024

Could you please install the branch from this PR and check whether it resolves your issue? #966

@ianrandman
Copy link

Could you please install the branch from this PR and check whether it resolves your issue? #966

Hi, I am also encountering this issue when on the main branch or on the suggested branch.

I am able to trigger the error two ways:

  • Call the pipe() multiple times with the same JSONPrefixAllowedTokens instance.
  • Call the pipe() with a multimodal model (e.g. LLaVa with more than one image as input with a single prompt. As far as I can tell, the internal implementation in transformers in this case is to process the images sequentially.

Both cases occur due to the second query resulting in input_ids == self._prefix on the first call to RegexPrefixAllowedTokens.__call__.

Here is the relevant code block:

# If the prefix token IDs have changed we assume that we are dealing with a new
# sample and reset the FSM state
if input_ids[: len(self._prefix)] != self._prefix:
self._fsm_state = defaultdict(int)
self._prefix = input_ids
seq_id = hash(tuple([]))
else:
# Remove the prefix token IDs from the input token IDs, as the FSM should
# only be applied to the generated tokens
input_ids = input_ids[len(self._prefix) :]
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
seq_id = hash(tuple(input_ids))
self._fsm_state[seq_id] = self.fsm.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token
)

I propose changing the first conditional:

# If the prefix token IDs have changed, or they are equal to the entirety of the input, we assume that we are dealing with a new 
# sample and reset the FSM state 
if input_ids[: len(self._prefix)] != self._prefix or len(input_ids) == len(self._prefix): 
   self._fsm_state = defaultdict(int) 
   self._prefix = input_ids 
   seq_id = hash(tuple([])) 

else: 
   # Remove the prefix token IDs from the input token IDs, as the FSM should 
   # only be applied to the generated tokens 
   input_ids = input_ids[len(self._prefix) :] 

   last_token = input_ids[-1] 
   last_seq_id = hash(tuple(input_ids[:-1])) 
   seq_id = hash(tuple(input_ids)) 
   self._fsm_state[seq_id] = self.fsm.get_next_state( 
       state=self._fsm_state[last_seq_id], token_id=last_token 
   )

to recognize when a new but equal prompt has come in and reset the FSM state in that case. This appears to solve the problem, but I am not 100% sure on its correctness (relating to possible edge cases in particular).

Thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
3 participants