Skip to content

Commit

Permalink
mamba : multiple sequences, but one at a time
Browse files Browse the repository at this point in the history
This is a step towards making this Mamba implementation usable
with the server example (the way the system prompt is kept when clearing
the client slots will need to be changed before this can work, though).

The KV cache size for this kind of model is tied to the maximum number
of sequences kept at any single time.
For now, this number is obtained from n_parallel (plus one,
to have an extra sequence to dedicate to the system prompt),
but there might be a better way to do this which won't also
make the main example use 2 cells even if only 1 is really used.
(for this specific case, --parallel 0 helps)

Simultaneous sequence processing will probably require changes to
ggml_ssm_scan, and possibly a new operator for the conv step.

* mamba : support llama_kv_cache_seq_cp

This (mis)uses the logic around K shifts, because tokens in a state
can't be shifted anyway, and because inp_K_shift has the right shape and type.
Using ggml_get_rows is a nice way to do copies, but copy chains can't work.
Fortunately, copy chains don't really seem to be used in the examples.

Each KV cell is dedicated to the sequence ID corresponding to its own index.

* mamba : use a state mask

It's cleaner than the previous heuristic of
checking for the pos of the first token in the batch.

inp_KQ_mask could not be re-used for this, because it has the wrong shape
and because it seems more suited to the next step of
simultaneous sequence processing (helping with the problem of
remembering which token belongs to which sequence(s)/state(s)).

* llama : replace the usage of n_ctx with kv_self.size in many places

* mamba : use n_tokens directly instead of n_tok
  • Loading branch information
compilade committed Feb 14, 2024
1 parent 98e6328 commit 9c4c257
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 88 deletions.
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param

cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_parallel = params.n_parallel;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
Expand Down
26 changes: 13 additions & 13 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5905,15 +5905,15 @@ struct ggml_tensor * ggml_ssm_scan(
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D

{
const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1];
const int64_t n_tok = x->ne[1];
const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1];
const int64_t n_tokens = x->ne[1];

GGML_ASSERT(x->ne[0] == d_inner);
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == d_inner);
GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_tok);
GGML_ASSERT(B->ne[1] == n_tokens);
}

bool is_node = false;
Expand Down Expand Up @@ -14178,12 +14178,12 @@ static void ggml_compute_forward_ssm_scan_f32(

// first batch
{
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data); // {d_state, n_tok}
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
float dt_soft_plus = log1pf(expf(dt[i1]));
Expand All @@ -14199,12 +14199,12 @@ static void ggml_compute_forward_ssm_scan_f32(

// compute state for rest of tokens, previous state comes from dest
for (int i2 = 1; i2 < n_t; ++i2) {
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
float dt_soft_plus = log1pf(expf(dt[i1]));
Expand Down
Loading

0 comments on commit 9c4c257

Please sign in to comment.