Skip to content

Commit

Permalink
feat: implement a templated endpoint for visibility into chat requests (
Browse files Browse the repository at this point in the history
#2333)

* feat: implement a templated endpoint for visibility into chat requests

* feat: improve to tokenize too

* fix: adjust return type

* feat: simplify prepare_chat_input logic and adjust start stop chars
  • Loading branch information
drbh committed Aug 6, 2024
1 parent 29b8d19 commit e11f5f1
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 60 deletions.
6 changes: 6 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,12 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>,
}

#[derive(Serialize, ToSchema)]
pub(crate) struct ChatTokenizeResponse {
pub(crate) tokenize_response: TokenizeResponse,
pub(crate) templated_text: String,
}

#[derive(Serialize, ToSchema)]
#[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
Expand Down
210 changes: 150 additions & 60 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError;
use crate::ChatTokenizeResponse;
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
Expand All @@ -22,7 +23,7 @@ use crate::{
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse,
};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
Expand Down Expand Up @@ -115,6 +116,107 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
Json(info.0)
}

#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/chat_tokenize",
request_body = ChatRequest,
responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse))
)]
async fn get_chat_tokenize(
Extension(infer): Extension<Infer>,
Json(req): Json<ChatRequest>,
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
metrics::counter!("tgi_request_count").increment(1);

let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
..
} = req;

let tool_prompt = tool_prompt.unwrap_or_default();
let (inputs, _grammar, _tool_grammar) = prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
messages,
)?;

let generate_request = GenerateRequest {
inputs,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty: None,
frequency_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: true,
max_new_tokens: max_tokens,
return_full_text: None,
stop: stop.unwrap_or_default(),
truncate: None,
watermark: false,
details: false,
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: _grammar,
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
},
};

let input = generate_request.inputs.clone();
let encoding = infer.tokenize(generate_request).await?;
if let Some(encoding) = encoding {
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text = input
.chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();

let resp = ChatTokenizeResponse {
tokenize_response: TokenizeResponse(tokens),
templated_text: input,
};
Ok((HeaderMap::new(), Json(resp)))
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
}

#[utoipa::path(
get,
tag = "Text Generation Inference",
Expand Down Expand Up @@ -1034,63 +1136,14 @@ async fn chat_completions(
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};

// response_format and tools are mutually exclusive
if response_format.is_some() && tools.as_ref().is_some() {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Grammar and tools are mutually exclusive".to_string(),
error_type: "grammar and tools".to_string(),
}),
));
}

// extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};

// determine the appropriate arguments for apply_chat_template
let tools_grammar_prompt = tool_grammar
.as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));

let (tools_grammar_prompt, grammar) = match response_format {
Some(response_format) => (None, Some(response_format)),
None => (
tools_grammar_prompt.clone(),
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
),
};

// apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
let (inputs, grammar, tool_grammar) = prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
messages,
)?;

// build the request passing some parameters
let generate_request = GenerateRequest {
Expand Down Expand Up @@ -1360,8 +1413,11 @@ async fn tokenize(
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text: String =
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string();
let text = input
.chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken {
id,
text,
Expand Down Expand Up @@ -2036,6 +2092,7 @@ async fn start(
}
let info_routes = Router::new()
.route("/", get(health))
.route("/chat_tokenize", post(get_chat_tokenize))
.route("/info", get(get_model_info))
.route("/health", get(health))
.route("/ping", get(health))
Expand Down Expand Up @@ -2332,3 +2389,36 @@ fn create_post_processor(

Ok(post_processor)
}

type PreparedInput = (String, Option<GrammarType>, Option<Tools>);

fn prepare_chat_input(
infer: &Infer,
response_format: Option<GrammarType>,
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
tool_prompt: &str,
messages: Vec<Message>,
) -> Result<PreparedInput, InferError> {
if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError(
"Grammar and tools are mutually exclusive".into(),
));
}

if let Some(format) = response_format {
let inputs = infer.apply_chat_template(messages, None)?;
return Ok((inputs, Some(format), None));
}

// if tools are set, apply the tool grammar and then the chat template
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
let grammar = tool_grammar
.as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t)));
let tools_grammar_prompt = tool_grammar
.as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?;
Ok((inputs, grammar, tool_grammar))
}

0 comments on commit e11f5f1

Please sign in to comment.