Skip to content

Commit

Permalink
fix post process with topk topp of python api (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenwei-intel committed Oct 17, 2023
1 parent 5feac76 commit 7b4730d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def quant_model(self, model_name, model_path, out_path, **kwargs):
self.module.Model.quant_model(model_path = model_path,
out_path = out_path, **kwargs)

def generate(self, input_ids, streamer=None, interactive=False, **kwargs):
def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, **kwargs):
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, **kwargs)
self.generate_round = 0
Expand All @@ -104,7 +104,7 @@ def generate(self, input_ids, streamer=None, interactive=False, **kwargs):
self.generate_round = 0

ret = [[]]
if self.generate_round == 0:
if self.generate_round == 0 and not ignore_prompt:
ret = input_ids.tolist()

# TODO support multi batch
Expand All @@ -118,12 +118,13 @@ def generate(self, input_ids, streamer=None, interactive=False, **kwargs):
print("ERROR, can not use streamer when use beam search for generation!")
import sys
sys.exit(1)
if self.generate_round == 0:
if self.generate_round == 0 and not ignore_prompt:
streamer.put(input_ids)
while not self.is_token_end():
out = self.model.generate(input_ids = input_ids.tolist()[0])
streamer.put(torch.tensor([out]))
ret[0].extend(out)
streamer.end()
else:
ret[0].extend(self.model.generate_tokens(input_ids = input_ids.tolist()[0]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Model {
if (ctx) model_free(ctx);
}
void init_model(const std::string& model_path, int n_predict, int batch_size, int ctx_size, int seed, int threads,
float repeat_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature,
float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature,
int min_new_tokens, float length_penalty, bool early_stopping);
void reinit();
std::vector<model_token> generate(const std::vector<model_token>& input_ids);
Expand Down Expand Up @@ -84,7 +84,7 @@ class Model {
};

void Model::init_model(const std::string& model_path, int max_new_tokens, int batch_size, int ctx_size, int seed,
int threads, float repeat_penalty, int num_beams, bool do_sample, int top_k, float top_p,
int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p,
float temperature, int min_new_tokens, float length_penalty, bool early_stopping) {
#ifdef MODEL_NAME
params.model_name = MODEL_NAME;
Expand All @@ -96,7 +96,7 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int ba
params.n_ctx = ctx_size;
params.seed = seed;
params.n_threads = threads;
params.repeat_penalty = repeat_penalty;
params.repeat_penalty = repetition_penalty;
params.beam_size = num_beams;
params.do_sample = do_sample;
params.beam_search = (num_beams > 1 && !do_sample) ? true : false;
Expand Down Expand Up @@ -161,7 +161,7 @@ std::vector<model_token> Model::generate(const std::vector<model_token>& input_i
model_token next_token_id = post_process(logits);
curr_input_ids = {next_token_id};

if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() == params.n_predict) {
if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() >= params.n_predict) {
token_eos = true;
}

Expand Down Expand Up @@ -206,7 +206,7 @@ std::vector<model_token> Model::generate_tokens(const std::vector<model_token>&
model_token next_token_id = post_process(logits);
curr_input_ids = {next_token_id};
output_ids.push_back(next_token_id);
if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() == params.n_predict) {
if (next_token_id == ctx->vocab.eos_token_id || n_past - input_ids.size() >= params.n_predict) {
token_eos = true;
break;
}
Expand Down Expand Up @@ -234,100 +234,38 @@ std::vector<model_token> Model::post_beam_search(model_context* lctx, const int&
}

model_token Model::post_sample_top_k_top_p_repeat(float* logits) {
int n_logits = n_vocab;
std::random_device rd;
std::mt19937 rng{rd()};

const auto* plogits = logits;
int alpha_frequency = 0;
int alpha_presence = 0;
int repeat_last_n = 64;
float repeat_penalty = 1.02;
if (params.temp <= 0) {
// select the token with the highest logit directly
float max_logit = plogits[0];
gpt_vocab::id max_id = 0;

for (int i = 1; i < n_logits; ++i) {
if (plogits[i] > max_logit) {
max_logit = plogits[i];
max_id = i;
}
}
return max_id;
}

std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);

{
const float scale = 1.0f / params.temp;
for (int i = 0; i < n_logits; ++i) {
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (repeat_last_n > 0 &&
std::find(last_n_tokens.end() - repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i] * scale * repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i] * scale / repeat_penalty, i));
}
} else {
logits_id.push_back(std::make_pair(plogits[i] * scale, i));
}
}
}

// find the top K tokens
std::partial_sort(logits_id.begin(), logits_id.begin() + params.top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id>& a, const std::pair<double, gpt_vocab::id>& b) {
return a.first > b.first;
});

logits_id.resize(params.top_k);

double maxl = -INFINITY;
for (const auto& kv : logits_id) {
maxl = std::max(maxl, kv.first);
}

// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());

double sum = 0.0;
for (const auto& kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}

// normalize the probs
for (auto& p : probs) {
p /= sum;
int top_k = params.top_k;
float tfs_z = 1.00f;
float typical_p = 1.00f;
float top_p = params.top_p;
float temp = params.temp;
std::vector<model_token_data> candidates;
candidates.reserve(n_vocab);
for (model_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(model_token_data{token_id, logits[token_id], 0.0f});
}

if (params.top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < params.top_k; i++) {
cumsum += probs[i];
if (cumsum >= params.top_p) {
params.top_k = i + 1;
probs.resize(params.top_k);
logits_id.resize(params.top_k);
break;
}
}

cumsum = 1.0 / cumsum;
for (int i = 0; i < (int)probs.size(); i++) {
probs[i] *= cumsum;
}
}

std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);

return logits_id[idx].second;
model_token_data_array candidates_p = {candidates.data(), candidates.size(), false};

// Apply penalties
float nl_logit = logits[model_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
model_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, params.repeat_penalty);
model_sample_frequency_and_presence_penalties(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
// int id = model_sample_token_greedy(ctx, &candidates_p);
// Temperature sampling
model_sample_top_k(ctx, &candidates_p, top_k, 1);
model_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
model_sample_typical(ctx, &candidates_p, typical_p, 1);
model_sample_top_p(ctx, &candidates_p, top_p, 1);
model_sample_temperature(ctx, &candidates_p, temp);
int id = model_sample_token(ctx, &candidates_p);
return id;
}

model_token Model::post_process(float* logits) {
Expand Down Expand Up @@ -434,7 +372,7 @@ PYBIND11_MODULE(polyglot_cpp, m)
.def(py::init())
.def("init_model", &Model::init_model, "initial model with model path and parameters", py::arg("model_path"),
py::arg("max_new_tokens") = -1, py::arg("batch_size") = 512, py::arg("ctx_size") = 512, py::arg("seed") = -1,
py::arg("threads") = 8, py::arg("repeat_penalty") = 1.1f, py::arg("num_beams") = 1,
py::arg("threads") = 8, py::arg("repetition_penalty") = 1.1f, py::arg("num_beams") = 1,
py::arg("do_sample") = false, py::arg("top_k") = 40, py::arg("top_p") = 0.95, py::arg("temperature") = 0.8,
py::arg("min_new_tokens") = 0, py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false)
.def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def main(args_in: Optional[List[str]] = None) -> None:
hparams = model.config.to_dict()

print("Model loaded: ", dir_model)
os.makedirs(os.path.dirname(fname_out), exist_ok=True)
fout = open(fname_out, "wb")

print(hparams)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def check_submodules():
check_submodules()
ext_modules.extend([
CMakeExtension("intel_extension_for_transformers.neural_engine_py", "intel_extension_for_transformers/llm/runtime/deprecated/"),
CMakeExtension("intel_extension_for_transformers.llm.runtime.graph.Model", "intel_extension_for_transformers/llm/runtime/graph/"),
CMakeExtension("intel_extension_for_transformers.llm.runtime.graph.mpt_cpp", "intel_extension_for_transformers/llm/runtime/graph/"),
])
cmdclass={'build_ext': CMakeBuild}

Expand Down

0 comments on commit 7b4730d

Please sign in to comment.