Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : fix embeddings #5796

Merged
merged 9 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)

### Recent API changes

- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849

### Hot topics
Expand Down
2 changes: 1 addition & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;
Expand Down
28 changes: 21 additions & 7 deletions examples/embedding/embedding.cpp
Copy link
Collaborator

@cebtenzzre cebtenzzre Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the llama_kv_cache_clear still do anything useful?

edit: I remembered that this example is used for models with causal attention as well. I won't need the equivalent if I'm just working with embedding models.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, llama_kv_cache_clear is not necessary for embedding models, but makes sense for causal attention models

Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ static std::vector<std::string> split_lines(const std::string & s) {

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
}
}

static void normalize(float * vec, float * out, int n) {
static void normalize(const float * vec, float * out, int n) {
float norm = 0;
for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i];
Expand All @@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

// normalize on copy
for (int k = 0; k < n_seq; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
float * out = output + k * n_embd;
normalize(emb, out, n_embd);
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}

// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}

float * out = output + batch.seq_id[i][0] * n_embd;
normalize(embd, out, n_embd);
}
}

Expand Down Expand Up @@ -132,7 +145,7 @@ int main(int argc, char ** argv) {

// initialize batch
const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);

// allocate output
const int n_embd = llama_n_embd(model);
Expand All @@ -145,6 +158,7 @@ int main(int argc, char ** argv) {
for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
auto & inp = inputs[k];

const uint64_t n_toks = inp.size();

// encode if at capacity
Expand Down
34 changes: 34 additions & 0 deletions examples/server-embd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import asyncio
import requests
import numpy as np

n = 8

result = []

async def requests_post_async(*args, **kwargs):
return await asyncio.to_thread(requests.post, *args, **kwargs)

async def main():
model_url = "http://127.0.0.1:6900"
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
url= f"{model_url}/embedding",
json= {"content": str(i)*1024}
) for i in range(n)])

for response in responses:
embedding = response.json()["embedding"]
print(embedding[-8:])
result.append(embedding)

asyncio.run(main())

# compute cosine similarity

for i in range(n-1):
for j in range(i+1, n):
embedding1 = np.array(result[i])
embedding2 = np.array(result[j])
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
print(f"Similarity between {i} and {j}: {similarity:.2f}")

53 changes: 42 additions & 11 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ struct llama_server_context
queue_results.send(res);
}

void send_embedding(server_slot &slot)
void send_embedding(server_slot & slot, const llama_batch & batch)
{
task_result res;
res.id = slot.task_id;
Expand All @@ -1219,6 +1219,7 @@ struct llama_server_context
res.stop = true;

const int n_embd = llama_n_embd(model);

if (!params.embedding)
{
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
Expand All @@ -1229,12 +1230,29 @@ struct llama_server_context
}
else
{
const float *data = llama_get_embeddings(ctx);
std::vector<float> embedding(data, data + n_embd);
res.result_json = json
{
{"embedding", embedding},
};
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}

const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
res.result_json = json
{
{"embedding", std::vector<float>(n_embd, 0.0f)},
};
continue;
}
}

res.result_json = json
{
{"embedding", std::vector<float>(embd, embd + n_embd)},
};
}
}
queue_results.send(res);
}
Expand Down Expand Up @@ -1845,7 +1863,7 @@ struct llama_server_context
ga_i += ga_w/ga_n;
}
}
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
slot_npast++;
}

Expand Down Expand Up @@ -1881,7 +1899,7 @@ struct llama_server_context

for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
{
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);

for (auto & slot : slots)
{
Expand Down Expand Up @@ -1954,7 +1972,7 @@ struct llama_server_context
// prompt evaluated for embedding
if (slot.embedding)
{
send_embedding(slot);
send_embedding(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue;
Expand Down Expand Up @@ -2036,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --pooling {none,mean,cls}\n");
printf(" pooling type for embeddings, use model default if unspecified\n");
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
Expand Down Expand Up @@ -2276,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.yarn_beta_slow = std::stof(argv[i]);
}
else if (arg == "--pooling")
{
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else { invalid_param = true; break; }
}
else if (arg == "--threads" || arg == "-t")
{
if (++i >= argc)
Expand Down Expand Up @@ -2330,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_batch = std::stoi(argv[i]);
params.n_batch = std::min(512, params.n_batch);
}
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
{
Expand Down
Loading
Loading