Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: accept list as prompt and use first string #1702

Merged
merged 16 commits into from
Apr 17, 2024

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Apr 3, 2024

This PR allows the CompletionRequest.prompt to be sent as a string or array of strings. When an array is sent the first value will be used if it's a string; otherwise the according error will be thrown

Fixes: #1690
Similar to: https://github.com/vllm-project/vllm/pull/323/files

match value {
Value::String(s) => Ok(s),
Value::Array(arr) => arr
.first()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're not treating the array properly (as multiple queries) I suggest we just don't do support it.

I don't think we want to support arrays (it was done this way for along time in pipelines and create so much headaches it's not worth it.).

If we still want that exact functionality we need to YELL if the array contains more than 1 element (instead of silently ignoring)

@OlivierDehaene
Copy link
Member

It would be pretty easy to support arrays like we do in TEI. Just push all requests in the internal queue and wait.
But I feel that the client would timeout very often waiting on the slowest request from the batch and that could lead to a lot of wasted compute.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@drbh drbh force-pushed the extract-first-prompt-if-list branch from 02414b5 to b08038c Compare April 9, 2024 00:43
@drbh
Copy link
Collaborator Author

drbh commented Apr 9, 2024

notes

@drbh
Copy link
Collaborator Author

drbh commented Apr 10, 2024

example requests:

streaming with openai

from openai import OpenAI

YOUR_TOKEN = "YOUR_API_KEY"

# Initialize the client, pointing it to one of the available models
client = OpenAI(
    base_url="http://localhost:3000/v1",
    api_key=YOUR_TOKEN,
)

completion = client.completions.create(
    model="gpt-3.5-turbo-instruct",
    prompt=["Say", "this", "is", "a", "test"],
    echo=True,
    n=1,
    stream=True,
    max_tokens=10,
)

for chunk in completion:
    print(chunk)

# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text=' =')], created=1712722135, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text=' ')], created=1712722135, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text='1')], created=1712722136, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# Completion(id='', choices=[CompletionChoice(finish_reason='', index=4, logprobs=None, text='0')], created=1712722136, model='google/gemma-7b', object='text_completion', system_fingerprint='1.4.5-native', usage=None)
# ...

with aiohttp (streaming)

from aiohttp import ClientSession
import json
import asyncio

base_url = "http://localhost:3000"


request = {
    "model": "tgi",
    "prompt": [
        "What color is the sky?",
        "Is water wet?",
        "What is the capital of France?",
        "def mai",
    ],
    "max_tokens": 10,
    "seed": 0,
    "stream": True,
}

url = f"{base_url}/v1/completions"


async def main():

    async with ClientSession() as session:
        async with session.post(url, json=request) as response:
            async for chunk in response.content.iter_any():
                chunk = chunk.decode().split("\n\n")
                chunk = [c.replace("data:", "") for c in chunk]
                chunk = [c for c in chunk if c]
                chunk = [json.loads(c) for c in chunk]

                for c in chunk:
                    print(c)

asyncio.run(main())

# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 1, 'text': ' a', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 2, 'text': ' Paris', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 3, 'text': 'nic', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 0, 'text': ' blue', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}
# {'id': '', 'object': 'text_completion', 'created': 1712863765, 'choices': [{'index': 1, 'text': ' liquid', 'logprobs': None, 'finish_reason': ''}], 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native'}

sync with requests (non streaming)

import requests

base_url = "http://localhost:3000"

response = requests.post(
    f"{base_url}/v1/completions",
    json={
        "model": "tgi",
        "prompt": ["Say", "this", "is", "a", "test"],
        "max_tokens": 2,
        "seed": 0,
    },
    stream=False,
)
response = response.json()

print(response)
# {'id': '', 'object': 'text_completion', 'created': 1712722405, 'model': 'google/gemma-7b', 'system_fingerprint': '1.4.5-native', 'choices': [{'index': 0, 'text': " you'", 'logprobs': None, 'finish_reason': 'length'}, {'index': 1, 'text': ' the sequence', 'logprobs': None, 'finish_reason': 'length'}, {'index': 2, 'text': '_cases', 'logprobs': None, 'finish_reason': 'length'}, {'index': 3, 'text': '.\n\n', 'logprobs': None, 'finish_reason': 'length'}, {'index': 4, 'text': '. ', 'logprobs': None, 'finish_reason': 'length'}], 'usage': {'prompt_tokens': 10, 'completion_tokens': 10, 'total_tokens': 20}}

@drbh
Copy link
Collaborator Author

drbh commented Apr 10, 2024

**note the client library intentionally does not include a completions method because this is a legacy API. The changes in this PR are to align with the API and address integrations with existing tools (langchain retrieval chain)

@drbh drbh requested a review from Narsil April 10, 2024 16:45
json={
"model": "tgi",
"prompt": ["Say", "this", "is", "a", "test"],
"max_tokens": 5,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use different numbers than 5 in both dimensions ?Make it hard to understand what is what.
What happens if the length of the completions vary ?
Can you make the prompt of various sizes too ?

Does that mean that both queries have to wait on each other to send back chunks to the client ?

Copy link
Collaborator Author

@drbh drbh Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**updates

Tests are updated with "max_tokens": 10, and four prompts of varying lengths.

Theres is not a way to specify different max_tokens in the openai api and the same value is applied to each prompt.

In the recent change responses do not need to wait on each other and are interleaved, responses can complete at different times (chunks with that index stop being emitted)

integration-tests/models/test_completion_prompts.py Outdated Show resolved Hide resolved
launcher/src/main.rs Outdated Show resolved Hide resolved
router/src/server.rs Outdated Show resolved Hide resolved
Comment on lines 641 to 643
let mut x_compute_type = "unknown".to_string();
let mut x_compute_characters = 0u32;
let mut x_accel_buffering = "no".to_string();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big fan of mutables here. Not sure I have an easy better way atm.

router/src/server.rs Outdated Show resolved Hide resolved
router/src/server.rs Outdated Show resolved Hide resolved
router/src/server.rs Outdated Show resolved Hide resolved
router/src/server.rs Outdated Show resolved Hide resolved
router/src/server.rs Outdated Show resolved Hide resolved
@drbh drbh requested a review from Narsil April 12, 2024 02:19
router/src/server.rs Outdated Show resolved Hide resolved
router/src/server.rs Outdated Show resolved Hide resolved
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user kills the connection, make sure the inference is not running in the background
The logs are rather poor compared to the regular endpoints.

2024-04-16T10:42:49.931556Z  INFO text_generation_router::server: router/src/server.rs:500: Success

vs

2024-04-16T10:42:56.302342Z  INFO generate_stream{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(10), return_full_text: None, stop: [], truncate: None, watermark: false, details: false, decoder_input_details: false, seed: None, top_n_tokens: None, grammar: None } total_time="429.831681ms" validation_time="217.73µs" queue_time="64.823µs" inference_time="429.549248ms" time_per_token="42.954924ms" seed="None"}: text_generation_router::server: router/src/server.rs:500: Success

@drbh drbh self-assigned this Apr 16, 2024
@Narsil
Copy link
Collaborator

Narsil commented Apr 16, 2024

Should be good after rebase.

@drbh drbh force-pushed the extract-first-prompt-if-list branch from 46d97d8 to 52d234f Compare April 16, 2024 15:55
@drbh
Copy link
Collaborator Author

drbh commented Apr 16, 2024

**failing client tests do not seem related to these changes and are resolved here: #1751

@drbh
Copy link
Collaborator Author

drbh commented Apr 16, 2024

... The logs are rather poor compared to the regular endpoints.

2024-04-16T10:42:49.931556Z  INFO text_generation_router::server: router/src/server.rs:500: Success

vs

2024-04-16T10:42:56.302342Z  INFO generate_stream{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(10), return_full_text: None, stop: [], truncate: None, watermark: false, details: false, decoder_input_details: false, seed: None, top_n_tokens: None, grammar: None } total_time="429.831681ms" validation_time="217.73µs" queue_time="64.823µs" inference_time="429.549248ms" time_per_token="42.954924ms" seed="None"}: text_generation_router::server: router/src/server.rs:500: Success

yea its a bit strange that the same logging line produces more output in one case. Any ideas on how to have it emit the same output?

Narsil
Narsil previously approved these changes Apr 16, 2024
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Narsil
Copy link
Collaborator

Narsil commented Apr 16, 2024

... The logs are rather poor compared to the regular endpoints.

2024-04-16T10:42:49.931556Z  INFO text_generation_router::server: router/src/server.rs:500: Success

vs

2024-04-16T10:42:56.302342Z  INFO generate_stream{parameters=GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, frequency_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: Some(10), return_full_text: None, stop: [], truncate: None, watermark: false, details: false, decoder_input_details: false, seed: None, top_n_tokens: None, grammar: None } total_time="429.831681ms" validation_time="217.73µs" queue_time="64.823µs" inference_time="429.549248ms" time_per_token="42.954924ms" seed="None"}: text_generation_router::server: router/src/server.rs:500: Success

yea its a bit strange that the same logging line produces more output in one case. Any ideas on how to have it emit the same output?

Should be about the span capture

@drbh drbh requested a review from Narsil April 17, 2024 01:42
@drbh
Copy link
Collaborator Author

drbh commented Apr 17, 2024

logs are now bubbled up to the calling function and output the same information as generate and generate_stream

change: generate_internal and generate_stream_internal now take a span as an argument and is passed to tracing::info as a parent span.

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM very nice PR in the end.

@Narsil Narsil merged commit 06c3d4b into main Apr 17, 2024
8 of 9 checks passed
@Narsil Narsil deleted the extract-first-prompt-if-list branch April 17, 2024 08:41
Nilabhra pushed a commit to TII-AI-Research-Center/text-generation-inference that referenced this pull request May 14, 2024
This PR allows the `CompletionRequest.prompt` to be sent as a string or
array of strings. When an array is sent the first value will be used if
it's a string; otherwise the according error will be thrown

Fixes:
huggingface#1690
Similar to: https://github.com/vllm-project/vllm/pull/323/files
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request May 27, 2024
This PR allows the `CompletionRequest.prompt` to be sent as a string or
array of strings. When an array is sent the first value will be used if
it's a string; otherwise the according error will be thrown

Fixes:
huggingface#1690
Similar to: https://github.com/vllm-project/vllm/pull/323/files
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request Jun 3, 2024
This PR allows the `CompletionRequest.prompt` to be sent as a string or
array of strings. When an array is sent the first value will be used if
it's a string; otherwise the according error will be thrown

Fixes:
huggingface#1690
Similar to: https://github.com/vllm-project/vllm/pull/323/files
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

OpenAI API support - Langchain passes prompt as a list instead of str
4 participants