Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(llm): check multi modal input with provider and fix cost calculation #13445

Merged
merged 6 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/fix-multi-modal.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: >
"**AI Plugins**: Fixed an issue for multi-modal inputs are not properly validated and calculated.
type: bugfix
scope: Plugin
17 changes: 12 additions & 5 deletions kong/llm/drivers/bedrock.lua
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,22 @@ local function to_bedrock_chat_openai(request_table, model_info, route_type)
system_prompts[#system_prompts+1] = { text = v.content }

else
-- for any other role, just construct the chat history as 'parts.text' type
new_r.messages = new_r.messages or {}
table_insert(new_r.messages, {
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user'
local content
if type(v.content) == "table" then
content = v.content
else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker. Are we going to keep compatible to the old format?

Copy link
Contributor Author

@fffonion fffonion Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both formats are not deprecated, but right not only openai and bedrock supports the multi-modal format as well as the non multi-modal formats ("old format")

content = {
{
text = v.content or ""
},
},
}
end

-- for any other role, just construct the chat history as 'parts.text' type
new_r.messages = new_r.messages or {}
table_insert(new_r.messages, {
role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user'
content = content,
})
end
end
Expand Down
39 changes: 35 additions & 4 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,21 @@ function _M.resolve_plugin_conf(kong_request, conf)
return conf_m
end


local function check_multi_modal(conf, request_table)
if not request_table.messages or (conf.model.provider == "openai" or conf.model.provider == "bedrock" ) then
return true
end

for _, m in ipairs(request_table.messages) do
if type(m.content) == "table" then
return false
end
end

return true
end

function _M.pre_request(conf, request_table)
-- process form/json body auth information
local auth_param_name = conf.auth and conf.auth.param_name
Expand All @@ -539,6 +554,11 @@ function _M.pre_request(conf, request_table)
return nil, "no plugin name is being passed by the plugin"
end

local ok = check_multi_modal(conf, request_table)
if not ok then
return kong.response.exit("multi-modal input is not supported by current provider")
end

-- if enabled AND request type is compatible, capture the input for analytics
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(fmt("ai.%s.%s.%s", plugin_name, log_entry_keys.PAYLOAD_CONTAINER, log_entry_keys.REQUEST_BODY), kong.request.get_raw_body())
Expand Down Expand Up @@ -744,11 +764,19 @@ function _M.http_request(url, body, method, headers, http_opts, buffered)
end

-- Function to count the number of words in a string
local function count_words(str)
local function count_words(any)
local count = 0
if type(str) == "string" then
for word in str:gmatch("%S+") do
count = count + 1
if type(any) == "string" then
for _ in any:gmatch("%S+") do
count = count + 1
end
elseif type(any) == "table" then -- is multi-modal input
for _, item in ipairs(any) do
if item.type == "text" and item.text then
for _ in (item.text):gmatch("%S+") do
count = count + 1
end
end
end
end
return count
Expand Down Expand Up @@ -820,4 +848,7 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor)
return query_cost, nil
end

-- for unit tests
_M._count_words = count_words

return _M
43 changes: 43 additions & 0 deletions spec/03-plugins/38-ai-proxy/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -732,4 +732,47 @@ describe(PLUGIN_NAME .. ": (unit)", function()

end)

describe("count_words", function()
local c = ai_shared._count_words

it("normal prompts", function()
assert.same(10, c(string.rep("apple ", 10)))
end)

it("multi-modal prompts", function()
assert.same(10, c({
{
type = "text",
text = string.rep("apple ", 10),
},
}))

assert.same(20, c({
{
type = "text",
text = string.rep("apple ", 10),
},
{
type = "text",
text = string.rep("banana ", 10),
},
}))

assert.same(10, c({
{
type = "not_text",
text = string.rep("apple ", 10),
},
{
type = "text",
text = string.rep("banana ", 10),
},
{
type = "text",
-- somehow malformed
},
}))
end)
end)

end)
111 changes: 105 additions & 6 deletions spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ local FILE_LOG_PATH_STATS_ONLY = os.tmpname()
local FILE_LOG_PATH_NO_LOGS = os.tmpname()
local FILE_LOG_PATH_WITH_PAYLOADS = os.tmpname()

local truncate_file = function(path)
local file = io.open(path, "w")
file:close()
end


local function wait_for_json_log_entry(FILE_LOG_PATH)
local json
Expand Down Expand Up @@ -764,20 +769,21 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then

lazy_teardown(function()
helpers.stop_kong()
os.remove(FILE_LOG_PATH_STATS_ONLY)
os.remove(FILE_LOG_PATH_NO_LOGS)
os.remove(FILE_LOG_PATH_WITH_PAYLOADS)
end)

before_each(function()
client = helpers.proxy_client()
os.remove(FILE_LOG_PATH_STATS_ONLY)
os.remove(FILE_LOG_PATH_NO_LOGS)
os.remove(FILE_LOG_PATH_WITH_PAYLOADS)
-- Note: if file is removed instead of trunacted, file-log ends writing to a unlinked file handle
truncate_file(FILE_LOG_PATH_STATS_ONLY)
truncate_file(FILE_LOG_PATH_NO_LOGS)
truncate_file(FILE_LOG_PATH_WITH_PAYLOADS)
end)

after_each(function()
if client then client:close() end
os.remove(FILE_LOG_PATH_STATS_ONLY)
os.remove(FILE_LOG_PATH_NO_LOGS)
os.remove(FILE_LOG_PATH_WITH_PAYLOADS)
end)

describe("openai general", function()
Expand Down Expand Up @@ -1391,6 +1397,99 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
end)
end)

describe("#aaa openai multi-modal requests", function()
it("logs statistics", function()
local r = client:get("/openai/llm/v1/chat/good", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_multi_modal.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "gpt-3.5-turbo-0613")
assert.equals(json.object, "chat.completion")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)

local log_message = wait_for_json_log_entry(FILE_LOG_PATH_STATS_ONLY)
assert.same("127.0.0.1", log_message.client_ip)
assert.is_number(log_message.request.size)
assert.is_number(log_message.response.size)

-- test ai-proxy stats
-- TODO: as we are reusing this test for ai-proxy and ai-proxy-advanced
-- we are currently stripping the top level key and comparing values directly
local _, first_expected = next(_EXPECTED_CHAT_STATS)
local _, first_got = next(log_message.ai)
local actual_llm_latency = first_got.meta.llm_latency
local actual_time_per_token = first_got.usage.time_per_token
local time_per_token = math.floor(actual_llm_latency / first_got.usage.completion_tokens)

first_got.meta.llm_latency = 1
first_got.usage.time_per_token = 1

assert.same(first_expected, first_got)
assert.is_true(actual_llm_latency > 0)
assert.same(actual_time_per_token, time_per_token)
end)

it("logs payloads", function()
local r = client:get("/openai/llm/v1/chat/good-with-payloads", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_multi_modal.json"),
})

-- validate that the request succeeded, response status 200
local body = assert.res_status(200 , r)
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "gpt-3.5-turbo-0613")
assert.equals(json.object, "chat.completion")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
assert.same({
content = "The sum of 1 + 1 is 2.",
role = "assistant",
}, json.choices[1].message)

local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS)
assert.same("127.0.0.1", log_message.client_ip)
assert.is_number(log_message.request.size)
assert.is_number(log_message.response.size)

-- TODO: as we are reusing this test for ai-proxy and ai-proxy-advanced
-- we are currently stripping the top level key and comparing values directly
local _, message = next(log_message.ai)

-- test request bodies
assert.matches('"text": "What\'s in this image?"', message.payload.request, nil, true)
assert.matches('"role": "user"', message.payload.request, nil, true)

-- test response bodies
assert.matches('"content": "The sum of 1 + 1 is 2.",', message.payload.response, nil, true)
assert.matches('"role": "assistant"', message.payload.response, nil, true)
assert.matches('"id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2"', message.payload.response, nil, true)
end)
end)

describe("one-shot request", function()
it("success", function()
local ai_driver = require("kong.llm.drivers.openai")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
"stream": false
}
Loading