Skip to content

Commit

Permalink
llama : add option to render special/control tokens (#6807)
Browse files Browse the repository at this point in the history
* make : fix common dep on llama.h

* llama : add option to render special tokens

* readme : add API change notice

ggml-ci

* swift : fix build
  • Loading branch information
ggerganov committed Apr 21, 2024
1 parent b9cc76d commit 40f74e4
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o
llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@

COMMON_H_DEPS = common/common.h common/sampling.h common/log.h
COMMON_H_DEPS = common/common.h common/sampling.h common/log.h llama.h
COMMON_DEPS = common.o sampling.o grammar-parser.o build-info.o json-schema-to-grammar.o

common.o: common/common.cpp $(COMMON_H_DEPS)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)

### Recent API changes

- [2024 Apr 21] `llama_token_to_piece` can now optionally render special tokens https://github.com/ggerganov/llama.cpp/pull/6807
- [2024 Apr 4] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341
- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
Expand Down
4 changes: 2 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2328,10 +2328,10 @@ std::vector<llama_token> llama_tokenize(

std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), true);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), true);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
Expand Down
5 changes: 3 additions & 2 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,16 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
var result = [CChar](repeating: 0, count: 8)
let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count))
let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), false)
if nTokens < 0 {
let actualTokensCount = -Int(nTokens)
result = .init(repeating: 0, count: actualTokensCount)
let check = llama_token_to_piece(
model,
token,
&result,
Int32(result.count)
Int32(result.count),
false
)
assert(check == actualTokensCount)
} else {
Expand Down
4 changes: 2 additions & 2 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,15 @@ actor LlamaContext {
defer {
result.deallocate()
}
let nTokens = llama_token_to_piece(model, token, result, 8)
let nTokens = llama_token_to_piece(model, token, result, 8, false)

if nTokens < 0 {
let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
newResult.initialize(repeating: Int8(0), count: Int(-nTokens))
defer {
newResult.deallocate()
}
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens)
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, false)
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
return Array(bufferPointer)
} else {
Expand Down
25 changes: 13 additions & 12 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1600,12 +1600,12 @@ struct llama_mlock {
};
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;

static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
GGML_ASSERT(check == -n_tokens);
}
else {
Expand Down Expand Up @@ -13312,7 +13312,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id);
const std::string piece = llama_token_to_piece(ctx, id, false);

if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
Expand Down Expand Up @@ -13512,7 +13513,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}

const std::string piece = llama_token_to_piece(ctx, token);
const std::string piece = llama_token_to_piece(ctx, token, false);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
Expand Down Expand Up @@ -16991,7 +16992,7 @@ static std::string llama_decode_text(const std::string & text) {
}

// does not write null-terminator to buf
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) {
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
if (0 <= token && token < llama_n_vocab(model)) {
switch (llama_vocab_get_type(model->vocab)) {
case LLAMA_VOCAB_TYPE_WPM:
Expand All @@ -17006,7 +17007,9 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_user_defined_token(model->vocab, token)) {
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
Expand All @@ -17019,8 +17022,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
memcpy(buf, "\xe2\x96\x85", 3);
return 3;
} else if (llama_is_control_token(model->vocab, token)) {
;
} else if (llama_is_byte_token(model->vocab, token)) {
if (length < 1) {
return -1;
Expand All @@ -17041,15 +17042,15 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_user_defined_token(model->vocab, token)) {
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_control_token(model->vocab, token)) {
;
}
break;
}
Expand Down
4 changes: 3 additions & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -828,11 +828,13 @@ extern "C" {
// Uses the vocabulary in the provided context.
// Does not write null terminator to the buffer.
// User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
// @param special If true, special tokens are rendered in the output.
LLAMA_API int32_t llama_token_to_piece(
const struct llama_model * model,
llama_token token,
char * buf,
int32_t length);
int32_t length,
bool special);

/// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
Expand Down

0 comments on commit 40f74e4

Please sign in to comment.