Skip to content

Commit

Permalink
fix(llm): check multi modal input with provider and fix cost calculat…
Browse files Browse the repository at this point in the history
…ion (#13445)

AG-61
  • Loading branch information
fffonion committed Aug 16, 2024
1 parent 5b587f2 commit acbbdf3
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 15 deletions.
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/fix-multi-modal.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: >
"**AI Plugins**: Fixed an issue for multi-modal inputs are not properly validated and calculated.
type: bugfix
scope: Plugin
17 changes: 12 additions & 5 deletions kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,22 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type)
system_prompts[#system_prompts+1] = { text = v.content }

else
-- for any other role, just construct the chat history as 'parts.text' type
new_r.messages = new_r.messages or {}
table_insert(new_r.messages, {
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user'
local content
if type(v.content) == "table" then
content = v.content
else
content = {
{
text = v.content or ""
},
},
}
end

-- for any other role, just construct the chat history as 'parts.text' type
new_r.messages = new_r.messages or {}
table_insert(new_r.messages, {
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user'
content = content,
})
end
end
Expand Down
39 changes: 35 additions & 4 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,21 @@ function _M.resolve_plugin_conf(kong_request, conf)
return conf_m
end


local function check_multi_modal(conf, request_table)
if not request_table.messages or (conf.model.provider == "openai" or conf.model.provider == "bedrock" ) then
return true
end

for _, m in ipairs(request_table.messages) do
if type(m.content) == "table" then
return false
end
end

return true
end

function _M.pre_request(conf, request_table)
-- process form/json body auth information
local auth_param_name = conf.auth and conf.auth.param_name
Expand All @@ -621,6 +636,11 @@ function _M.pre_request(conf, request_table)
return nil, "no plugin name is being passed by the plugin"
end

local ok = check_multi_modal(conf, request_table)
if not ok then
return kong.response.exit("multi-modal input is not supported by current provider")
end

-- if enabled AND request type is compatible, capture the input for analytics
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.REQUEST_BODY), kong.request.get_raw_body())
Expand Down Expand Up @@ -826,11 +846,19 @@ function _M.http_request(url, body, method, headers, http_opts, buffered)
end

-- Function to count the number of words in a string
local function count_words(str)
local function count_words(any)
local count = 0
if type(str) == "string" then
for word in str:gmatch("%S+") do
count = count + 1
if type(any) == "string" then
for _ in any:gmatch("%S+") do
count = count + 1
end
elseif type(any) == "table" then -- is multi-modal input
for _, item in ipairs(any) do
if item.type == "text" and item.text then
for _ in (item.text):gmatch("%S+") do
count = count + 1
end
end
end
end
return count
Expand Down Expand Up @@ -902,4 +930,7 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor)
return query_cost, nil
end

-- for unit tests
_M._count_words = count_words

return _M
43 changes: 43 additions & 0 deletions spec/03-plugins/38-ai-proxy/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -732,4 +732,47 @@ describe(PLUGIN_NAME .. ": (unit)", function()

end)

describe("count_words", function()
local c = ai_shared._count_words

it("normal prompts", function()
assert.same(10, c(string.rep("apple ", 10)))
end)

it("multi-modal prompts", function()
assert.same(10, c({
{
type = "text",
text = string.rep("apple ", 10),
},
}))

assert.same(20, c({
{
type = "text",
text = string.rep("apple ", 10),
},
{
type = "text",
text = string.rep("banana ", 10),
},
}))

assert.same(10, c({
{
type = "not_text",
text = string.rep("apple ", 10),
},
{
type = "text",
text = string.rep("banana ", 10),
},
{
type = "text",
-- somehow malformed
},
}))
end)
end)

end)
111 changes: 105 additions & 6 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ local FILE_LOG_PATH_STATS_ONLY = os.tmpname()
local FILE_LOG_PATH_NO_LOGS = os.tmpname()
local FILE_LOG_PATH_WITH_PAYLOADS = os.tmpname()

local truncate_file = function(path)
local file = io.open(path, "w")
file:close()
end


local function wait_for_json_log_entry(FILE_LOG_PATH)
local json
Expand Down Expand Up @@ -764,20 +769,21 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then

lazy_teardown(function()
helpers.stop_kong()
os.remove(FILE_LOG_PATH_STATS_ONLY)
os.remove(FILE_LOG_PATH_NO_LOGS)
os.remove(FILE_LOG_PATH_WITH_PAYLOADS)
end)

before_each(function()
client = helpers.proxy_client()
os.remove(FILE_LOG_PATH_STATS_ONLY)
os.remove(FILE_LOG_PATH_NO_LOGS)
os.remove(FILE_LOG_PATH_WITH_PAYLOADS)
-- Note: if file is removed instead of trunacted, file-log ends writing to a unlinked file handle
truncate_file(FILE_LOG_PATH_STATS_ONLY)
truncate_file(FILE_LOG_PATH_NO_LOGS)
truncate_file(FILE_LOG_PATH_WITH_PAYLOADS)
end)

after_each(function()
if client then client:close() end
os.remove(FILE_LOG_PATH_STATS_ONLY)
os.remove(FILE_LOG_PATH_NO_LOGS)
os.remove(FILE_LOG_PATH_WITH_PAYLOADS)
end)

describe("openai general", function()
Expand Down Expand Up @@ -1391,6 +1397,99 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
end)
end)

describe("#aaa openai multi-modal requests", function()
it("logs statistics", function()
local r = client:get("/openai/llm/v1/chat/good", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_multi_modal.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "gpt-3.5-turbo-0613")
assert.equals(json.object, "chat.completion")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)

local log_message = wait_for_json_log_entry(FILE_LOG_PATH_STATS_ONLY)
assert.same("127.0.0.1", log_message.client_ip)
assert.is_number(log_message.request.size)
assert.is_number(log_message.response.size)

-- test ai-proxy stats
-- TODO: as we are reusing this test for ai-proxy and ai-proxy-advanced
-- we are currently stripping the top level key and comparing values directly
local _, first_expected = next(_EXPECTED_CHAT_STATS)
local _, first_got = next(log_message.ai)
local actual_llm_latency = first_got.meta.llm_latency
local actual_time_per_token = first_got.usage.time_per_token
local time_per_token = math.floor(actual_llm_latency / first_got.usage.completion_tokens)

first_got.meta.llm_latency = 1
first_got.usage.time_per_token = 1

assert.same(first_expected, first_got)
assert.is_true(actual_llm_latency > 0)
assert.same(actual_time_per_token, time_per_token)
end)

it("logs payloads", function()
local r = client:get("/openai/llm/v1/chat/good-with-payloads", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_multi_modal.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "gpt-3.5-turbo-0613")
assert.equals(json.object, "chat.completion")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)

local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS)
assert.same("127.0.0.1", log_message.client_ip)
assert.is_number(log_message.request.size)
assert.is_number(log_message.response.size)

-- TODO: as we are reusing this test for ai-proxy and ai-proxy-advanced
-- we are currently stripping the top level key and comparing values directly
local _, message = next(log_message.ai)

-- test request bodies
assert.matches('"text": "What\'s in this image?"', message.payload.request, nil, true)
assert.matches('"role": "user"', message.payload.request, nil, true)

-- test response bodies
assert.matches('"content": "The sum of 1 + 1 is 2.",', message.payload.response, nil, true)
assert.matches('"role": "assistant"', message.payload.response, nil, true)
assert.matches('"id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2"', message.payload.response, nil, true)
end)
end)

describe("one-shot request", function()
it("success", function()
local ai_driver = require("kong.llm.drivers.openai")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
"stream": false
}

1 comment on commit acbbdf3

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bazel Build

Docker image available kong/kong:acbbdf39d578095858d32e433327d2800fcf16fa
Artifacts available https://github.com/Kong/kong/actions/runs/10420532525

Please sign in to comment.