Skip to content

Commit

Permalink
[Cpp Graph] Beam Search Pybind (model archs: gptj and gptneox) (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhentaoyu authored Oct 17, 2023
1 parent 87b00d8 commit 958d048
Show file tree
Hide file tree
Showing 11 changed files with 308 additions and 101 deletions.
10 changes: 10 additions & 0 deletions intel_extension_for_transformers/llm/runtime/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __import_package(self, model_name):
import intel_extension_for_transformers.llm.runtime.graph.chatglm2_cpp as cpp_model
elif model_name == "baichuan":
import intel_extension_for_transformers.llm.runtime.graph.baichuan_cpp as cpp_model
elif model_name == "polyglot":
import intel_extension_for_transformers.llm.runtime.graph.polyglot_cpp as cpp_model
else:
raise TypeError("Unspported model type {}!".format(model_name))
self.module = cpp_model
Expand Down Expand Up @@ -107,7 +109,15 @@ def generate(self, input_ids, streamer=None, interactive=False, **kwargs):

# TODO support multi batch
assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids."
beam_search = False
if ("num_beams" in kwargs and kwargs["num_beams"] > 1) and not \
kwargs.get("do_sample", False):
beam_search = True
if streamer:
if beam_search:
print("ERROR, can not use streamer when use beam search for generation!")
import sys
sys.exit(1)
if self.generate_round == 0:
streamer.put(input_ids)
while not self.is_token_end():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ compile_quant(quant_gptj quant_model.cpp gptj gptj)
compile_quant(quant_falcon quant_model.cpp falcon falcon)
compile_quant(quant_gptneox quant_model.cpp gptneox gptneox)
compile_quant(quant_dolly quant_model.cpp dolly gptneox)
compile_quant(quant_polyglot quant_model.cpp polyglot gptneox)
compile_quant(quant_llama quant_model.cpp llama llama)
compile_quant(quant_mpt quant_model.cpp mpt mpt)
compile_quant(quant_starcoder quant_model.cpp starcoder starcoder)
Expand Down Expand Up @@ -84,6 +85,7 @@ set(mymap_bloom 9)
set(mymap_chatglm2 10)
set(mymap_chatglm 11)
set(mymap_baichuan 12)
set(mymap_polyglot 13)

function(compile_run TARGET SRC MODEL_NAME MODEL_LIB)
add_executable_w_warning(${TARGET} ${SRC})
Expand All @@ -106,6 +108,7 @@ compile_run(run_gptj main_run.cpp gptj gptj)
compile_run(run_falcon main_run.cpp falcon falcon)
compile_run(run_gptneox main_run.cpp gptneox gptneox)
compile_run(run_dolly main_run.cpp dolly gptneox)
compile_run(run_polyglot main_run.cpp polyglot gptneox)
compile_run(run_llama main_run.cpp llama llama)
compile_run(run_mpt main_run.cpp mpt mpt)
compile_run(run_starcoder main_run.cpp starcoder starcoder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ class Model {
if (ctx) model_free(ctx);
}
void init_model(const std::string& model_path, int n_predict, int batch_size, int ctx_size, int seed, int threads,
float repeat_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature);
float repeat_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature,
int min_new_tokens, float length_penalty, bool early_stopping);
void reinit();
std::vector<int> generate(const std::vector<int>& input_ids);
std::vector<int> generate_tokens(const std::vector<int>& input_ids);
std::vector<model_token> generate(const std::vector<model_token>& input_ids);
std::vector<model_token> generate_tokens(const std::vector<model_token>& input_ids);
bool is_token_end() { return token_eos; }
static int quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype,
const std::string& alg, int group_size, const std::string& scale_dtype,
Expand All @@ -68,22 +69,23 @@ class Model {
private:
model_context* ctx = nullptr;
gpt_params params;
std::vector<int> curr_input_ids;
std::vector<model_token> curr_input_ids;
int n_past = 0;
int n_vocab = 0;
int n_ctx = 0;
std::vector<model_token> last_n_tokens;
bool token_eos = false;

int post_process(float* logits);
int post_greedy_search(float* logits);
int post_beam_search(float* logits);
int post_sample_top_k_top_p_repeat(float* logits);
model_token post_process(float* logits);
model_token post_greedy_search(float* logits);
std::vector<model_token> post_beam_search(model_context* lctx, const int& n_predict, const model_token* tokens_inp,
const int& n_tokens, const int& n_threads);
model_token post_sample_top_k_top_p_repeat(float* logits);
};

void Model::init_model(const std::string& model_path, int max_new_tokens, int batch_size, int ctx_size, int seed,
int threads, float repeat_penalty, int num_beams, bool do_sample, int top_k, float top_p,
float temperature) {
float temperature, int min_new_tokens, float length_penalty, bool early_stopping) {
#ifdef MODEL_NAME
params.model_name = MODEL_NAME;
#endif
Expand All @@ -97,6 +99,10 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int ba
params.repeat_penalty = repeat_penalty;
params.beam_size = num_beams;
params.do_sample = do_sample;
params.beam_search = (num_beams > 1 && !do_sample) ? true : false;
if (params.beam_search) {
params.memory_type = KV_MEM_TYPE_F16; // TODO NO MHA IN BEAM SEARCH
}
params.top_k = top_k;
params.top_p = top_p;
params.temp = temperature;
Expand All @@ -111,6 +117,9 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int ba
n_vocab = model_n_vocab(ctx);
n_ctx = model_n_ctx(ctx);
last_n_tokens.resize(n_ctx, 0);
ctx->generation_conf.min_new_tokens = min_new_tokens;
ctx->generation_conf.length_penalty = length_penalty;
ctx->generation_conf.do_early_stopping = early_stopping;
}

void Model::reinit() {
Expand All @@ -123,7 +132,7 @@ void Model::reinit() {
ctx->t_sample_us = 0;
}

std::vector<int> Model::generate(const std::vector<int>& input_ids) {
std::vector<model_token> Model::generate(const std::vector<model_token>& input_ids) {
if (curr_input_ids.empty()) {
curr_input_ids = input_ids;
}
Expand All @@ -149,7 +158,7 @@ std::vector<int> Model::generate(const std::vector<int>& input_ids) {
n_past += curr_input_ids.size();

float* logits = model_get_logits(ctx);
int next_token_id = post_process(logits);
model_token next_token_id = post_process(logits);
curr_input_ids = {next_token_id};

if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() == params.n_predict) {
Expand All @@ -159,9 +168,9 @@ std::vector<int> Model::generate(const std::vector<int>& input_ids) {
return {next_token_id};
}

std::vector<int> Model::generate_tokens(const std::vector<int>& input_ids) {
std::vector<model_token> Model::generate_tokens(const std::vector<model_token>& input_ids) {
int n_remain = params.n_predict;
std::vector<int> output_ids;
std::vector<model_token> output_ids;

if (curr_input_ids.empty()) {
curr_input_ids = input_ids;
Expand All @@ -186,11 +195,15 @@ std::vector<int> Model::generate_tokens(const std::vector<int>& input_ids) {
curr_input_ids.insert(curr_input_ids.begin(), last_n_tokens.begin() + n_ctx - n_left / 2 - curr_input_ids.size(),
last_n_tokens.end() - curr_input_ids.size());
}
if (ctx->beam_search) {
output_ids = post_beam_search(ctx, n_remain, curr_input_ids.data(), curr_input_ids.size(), params.n_threads);
break;
}
model_eval(ctx, &curr_input_ids[0], curr_input_ids.size(), n_past, params.n_threads);
n_past += curr_input_ids.size();

float* logits = model_get_logits(ctx);
int next_token_id = post_process(logits);
model_token next_token_id = post_process(logits);
curr_input_ids = {next_token_id};
output_ids.push_back(next_token_id);
if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() == params.n_predict) {
Expand All @@ -202,18 +215,25 @@ std::vector<int> Model::generate_tokens(const std::vector<int>& input_ids) {
return output_ids;
}

int Model::post_greedy_search(float* logits) {
int id = std::max_element(logits, logits + n_vocab) - logits;
model_token Model::post_greedy_search(float* logits) {
model_token id = std::max_element(logits, logits + n_vocab) - logits;
return id;
}

int Model::post_beam_search(float* logits) {
std::vector<model_token> Model::post_beam_search(model_context* lctx, const int& n_predict,
const model_token* tokens_inp, const int& n_tokens,
const int& n_threads) {
// TODO: to implement
fprintf(stderr, "\nERROR: beam search is not supported!\n");
return -1;
static std::set<model_archs> supported_archs = {MODEL_GPTJ, MODEL_GPTNEOX};
if (supported_archs.count(params.model_arch) != 0) {
return beam_search(lctx, n_predict, tokens_inp, n_tokens, n_threads);
} else {
fprintf(stderr, "\nERROR: this model does not support beam search generation!\n");
return std::vector<model_token>();
}
}

int Model::post_sample_top_k_top_p_repeat(float* logits) {
model_token Model::post_sample_top_k_top_p_repeat(float* logits) {
int n_logits = n_vocab;
std::random_device rd;
std::mt19937 rng{rd()};
Expand Down Expand Up @@ -310,21 +330,14 @@ int Model::post_sample_top_k_top_p_repeat(float* logits) {
return logits_id[idx].second;
}

int Model::post_process(float* logits) {
model_token Model::post_process(float* logits) {
if (params.beam_size == 1) {
if (params.do_sample == false) {
return post_greedy_search(logits);
} else {
return post_sample_top_k_top_p_repeat(logits);
}
} else {
if (params.do_sample == false) {
return post_beam_search(logits);
}
}
fprintf(stderr, "\nERROR: post process (beam_size=%d, do_sample=%d) is not supported!\n", params.beam_size,
params.do_sample);
return -1;
}

int Model::quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype,
Expand Down Expand Up @@ -410,6 +423,10 @@ PYBIND11_MODULE(chatglm_cpp, m)

PYBIND11_MODULE(baichuan_cpp, m)

#elif MODEL_NAME_ID == 13

PYBIND11_MODULE(polyglot_cpp, m)

#endif
{
m.doc() = "cpp model python binding";
Expand All @@ -418,7 +435,8 @@ PYBIND11_MODULE(baichuan_cpp, m)
.def("init_model", &Model::init_model, "initial model with model path and parameters", py::arg("model_path"),
py::arg("max_new_tokens") = -1, py::arg("batch_size") = 512, py::arg("ctx_size") = 512, py::arg("seed") = -1,
py::arg("threads") = 8, py::arg("repeat_penalty") = 1.1f, py::arg("num_beams") = 1,
py::arg("do_sample") = false, py::arg("top_k") = 40, py::arg("top_p") = 0.95, py::arg("temperature") = 0.8)
py::arg("do_sample") = false, py::arg("top_k") = 40, py::arg("top_p") = 0.95, py::arg("temperature") = 0.8,
py::arg("min_new_tokens") = 0, py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false)
.def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids"))
.def("generate_tokens", &Model::generate_tokens, "Generate tokens with input ids", py::arg("input_ids"))
.def_static("quant_model", &Model::quant_model, "Quantize model", py::arg("model_path"), py::arg("out_path"),
Expand Down
Loading

0 comments on commit 958d048

Please sign in to comment.