Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vLLM integration bug] Generated output is stopped for all samples in batch #757

Closed
saattrupdan opened this issue Mar 18, 2024 · 5 comments · Fixed by #760
Closed

[vLLM integration bug] Generated output is stopped for all samples in batch #757

saattrupdan opened this issue Mar 18, 2024 · 5 comments · Fixed by #760
Labels

Comments

@saattrupdan
Copy link
Contributor

saattrupdan commented Mar 18, 2024

Describe the issue as clearly as possible:

When using vLLM with the JSONLogitsProcessor, it stops the generation prematurely for the entire batch when the batch is sufficiently large, making the majority of the outputs non-JSON.

Steps/code to reproduce the bug:

import vllm
from pydantic import BaseModel, conlist
from outlines.integrations.vllm import JSONLogitsProcessor

class Person(BaseModel):
    first_names: conlist(str, max_length=2)
    surnames: conlist(str, max_length=2)

llm = vllm.LLM(model="mhenrichsen/danskgpt-tiny", max_model_len=512)
logits_processor = JSONLogitsProcessor(schema=Person, llm=llm, whitespace_pattern=r" ?")

def generate(num: int, sampling_params: vllm.SamplingParams) -> None:
    result = llm.generate(
        ["He is Tom Jones", "This is funny"] * num,
        sampling_params=sampling_params,
        use_tqdm=False,
    )
    first_output = result[0].outputs[0].text
    second_output = result[1].outputs[0].text
    return f"First output: {first_output!r}\nSecond output: {second_output!r}"

sampling_params_with_logits_processors = vllm.SamplingParams(
    temperature=0.0,
    max_tokens=128,
    logits_processors=[logits_processor],
)

sampling_params_without_logits_processors = vllm.SamplingParams(
    temperature=0.0,
    max_tokens=128,
)

all_sampling_params = [
    sampling_params_with_logits_processors, 
    sampling_params_without_logits_processors,
]

for sampling_params in all_sampling_params:
    for num in [1, 128, 129, 130]:
        with_logits_processors = sampling_params.logits_processors is not None
        print(f"With logits processors: {with_logits_processors} - {num} samples")
        print(generate(num=num, sampling_params=sampling_params))
        print()

The above script generates correct JSON output when num is less than 129, but the first output is truncated whenever num is 130 or above. When we do not include logits processors in the sampling params then it is always correct JSON.

Here is the full output:

With logits processors: True - 1 samples
First output: '{ "first_names" : ["Tom Jones" ] ,"surnames" :[ "Tom Jones" ] }'
Second output: '{ "first_names" :[  ] ,"surnames" :[  ] }'

With logits processors: True - 128 samples
First output: '{ "first_names" : ["Tom Jones" ] ,"surnames" :[ "Tom Jones" ] }'
Second output: '{ "first_names" :[  ] ,"surnames" :[  ] }'

With logits processors: True - 129 samples
First output: '{ "first_names" : ["Tom Jones" ] ,"surnames" '
Second output: '{ "first_names" :[  ] ,"surnames" :[  ] }'

With logits processors: True - 130 samples
First output: '{ "first_names" : ["Tom Jones" ] ,"surnames" '
Second output: '{ "first_names" :[  ] ,"surnames" :[  ] }'

With logits processors: False - 1 samples
First output: '.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der g
alt med dig?\n-Jeg er Tom Jones.\n-Hvad er der'
Second output: '. Jeg er selv fra 90erne og har aldrig hørt om det.'

With logits processors: False - 128 samples
First output: '.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der g
alt med dig?\n-Jeg er Tom Jones.\n-Hvad er der'
Second output: '. Jeg har lige været i USA og der er det ikke unormalt at se folk med en 1000$ tegnebog i lommen.'

With logits processors: False - 129 samples
First output: '.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der g
alt med dig?\n-Jeg er Tom Jones.\n-Hvad er der'
Second output: '. Jeg har lige været i USA og der er det ikke unormalt at se folk med en 1000$ tegnebog i lommen.'

With logits processors: False - 130 samples
First output: '.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der galt med dig?\n-Jeg er Tom Jones.\n-Hvad er der g
alt med dig?\n-Jeg er Tom Jones.\n-Hvad er der'
Second output: '. Jeg har lige været i USA og der er det ikke unormalt at se folk med en 1000$ tegnebog i lommen.'

Expected result:

The output should remain the same for any batch size.

Outlines/Python version information:

Version information

0.0.36
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
absl-py==2.0.0
accelerate==0.26.1
ai2-olmo==0.2.4
aiohttp==3.9.1
aioprometheus==23.12.0
aiosignal==1.3.1
annotated-types==0.6.0
antlr4-python3-runtime==4.9.3
anyio==4.2.0
async-timeout==4.0.3
attrs==23.2.0
bert-score==0.3.13
bitsandbytes==0.42.0
boto3==1.34.37
botocore==1.34.37
cached-path==1.5.1
cachetools==5.3.2
certifi==2023.11.17
charset-normalizer==3.3.2
chex==0.1.85
click==8.1.7
cloudpickle==3.0.0
contourpy==1.2.0
cupy-cuda12x==12.1.0
cycler==0.12.1
datasets==2.16.1
demjson3==3.0.6
dill==0.3.7
diskcache==5.6.3
distro==1.9.0
einops==0.7.0
etils==1.6.0
evaluate==0.4.1
exceptiongroup==1.2.0
fastapi==0.108.0
fastrlock==0.8.2
filelock==3.12.4
flash-attn==2.4.2
flax==0.8.1
fonttools==4.47.0
frozenlist==1.4.1
fsspec==2023.10.0
google-api-core==2.17.0
google-auth==2.27.0
google-cloud-core==2.4.1
google-cloud-storage==2.14.0
google-crc32c==1.5.0
google-resumable-media==2.7.0
googleapis-common-protos==1.62.0
h11==0.14.0
httpcore==1.0.2
httptools==0.6.1
httpx==0.26.0
huggingface-hub==0.19.4
idna==3.6
importlib-resources==6.1.1
interegular==0.3.3
jax==0.4.24
jaxlib==0.4.24
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
jsonschema==4.20.0
jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
lark==1.1.9
Levenshtein==0.24.0
llvmlite==0.42.0
lm-format-enforcer==0.8.2
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.2
mdurl==0.1.2
ml-dtypes==0.3.2
mpmath==1.3.0
msgpack==1.0.7
multidict==6.0.4
multiprocess==0.70.15
nest-asyncio==1.5.8
networkx==3.2.1
ninja==1.11.1.1
nltk==3.8.1
numba==0.59.0
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
omegaconf==2.3.0
pydantic==2.6.0
pydantic_core==2.16.1
Pygments==2.17.2
pyinfer==0.0.3
pynvml==11.5.0
pyparsing==3.1.1
python-dateutil==2.8.2
python-dotenv==1.0.1
pytz==2023.3.post1
PyYAML==6.0.1
quantile-python==1.1
rapidfuzz==3.6.1
ray==2.9.0
referencing==0.32.1
regex==2023.12.25
requests==2.31.0
responses==0.18.0
rich==13.7.0
rouge_score==0.1.2
rpds-py==0.16.2
rsa==4.9
s3transfer==0.10.0
sacremoses==0.1.1
safetensors==0.4.1
ScandEval @ git+https://github.com/ScandEval/ScandEval@ad34ea8812dfbc2b8ca700d33ff8ddb57e5edee8
scikit-learn==1.3.2
scipy==1.11.4
sentencepiece==0.1.99
seqeval==1.2.2
six==1.16.0
sniffio==1.3.0
starlette==0.32.0.post1
sympy==1.12
tabulate==0.9.0
tensorstore==0.1.52
termcolor==2.4.0
threadpoolctl==3.2.0
tiktoken==0.5.2
tokenizers==0.15.0
toolz==0.12.0
torch==2.1.2
tqdm==4.66.1
transformers==4.38.1
triton==2.1.0
typing_extensions==4.9.0
tzdata==2023.4
urllib3==2.0.7
uvicorn==0.25.0
uvloop==0.19.0
vllm==0.3.3
watchfiles==0.21.0
websockets==12.0
xformers==0.0.23.post1
xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0

Context for the issue:

This causes all models to get random performance on the NER task in the ScandEval benchmark, which uses outlines under the hood for structured generation.

@saattrupdan
Copy link
Contributor Author

saattrupdan commented Mar 18, 2024

This error can also be reproduced with outlines==0.0.34, so it doesn't seem to be due to any new changes. Also tried downgrading vllm to 0.3.1, which didn't change anything either. I tried with the vLLM logits processor from lm-format-enforcer, which didn't have the same issue, so the issue seems to be with outlines, as far as I can see.

@saattrupdan
Copy link
Contributor Author

saattrupdan commented Mar 19, 2024

@rlouf After extensive debugging, I've identified that this happens here, where we reset the self._fsm_state variable. As the inference calls in vLLM happens in parallel, it seems like the resetting of this causes all other processes to stop as this self._fsm_state lookup results in 0 as the key doesn't exist and 0 is the default value in a defaultdict(int).

If we thus simply replace the following code block:

if len(input_ids) == 0:
    self._fsm_state = defaultdict(int)
else:
    last_token = input_ids[-1]
    last_seq_id = hash(tuple(input_ids[:-1]))
    self._fsm_state[seq_id] = self.fsm.get_next_state(
        state=self._fsm_state[last_seq_id], token_id=last_token
    )

with the following:

if len(input_ids) > 0:
    last_token = input_ids[-1]
    last_seq_id = hash(tuple(input_ids[:-1]))
    self._fsm_state[seq_id] = self.fsm.get_next_state(
        state=self._fsm_state[last_seq_id], token_id=last_token
    )

then it works just fine in my tests.

Is there a reason for the resetting of self._fsm_state? Because if not then I can create a PR with the above-mentioned fix.

@saattrupdan
Copy link
Contributor Author

PR open now (#760), feel free to close if you see an issue with not resetting the FSM state.

@rlouf
Copy link
Member

rlouf commented Mar 20, 2024

Good find, I think this condition was left from a previous version. There is no good reason to reset _fsm_state here.

@saattrupdan
Copy link
Contributor Author

Good find, I think this condition was left from a previous version. There is no good reason to reset _fsm_state here.

Perfect. The PR is ready when you've got time 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants