From 8dcd24eeb4dbfaa102e4c812920c7b1dba2a4362 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 21 Jun 2024 03:43:39 -0500 Subject: [PATCH] fix models.llamacpp vocabulary normalization function --- outlines/models/llamacpp.py | 12 ++++++++++- tests/generate/test_integration_llamacpp.py | 22 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 840e1364f..a982b080c 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -34,8 +34,10 @@ def __init__(self, model: "Llama"): self.tokenizer = model.tokenizer() # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + self._hf_tokenizer = None try: self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() + self._hf_tokenizer = model.tokenizer_.hf_tokenizer except AttributeError: # ### for t in range(model.n_vocab()): @@ -71,7 +73,15 @@ def encode( return token_ids, attention_mask def convert_token_to_string(self, token: str) -> str: - return token + if self._hf_tokenizer is not None: + from transformers.file_utils import SPIECE_UNDERLINE + + token_str = self._hf_tokenizer.convert_tokens_to_string([token]) + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + token_str = " " + token_str + return token_str + else: + return token def __eq__(self, other): if not isinstance(other, LlamaCppTokenizer): diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index b7eb8b3cb..fcd2bfda9 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -334,3 +334,25 @@ def test_RegexGuide_caching(model, temp_cache_dir): assert re.fullmatch(regex, structured) assert re.fullmatch(regex, structured_2) assert structured != structured_2 + + +def test_tokenizer_vocabulary_decode_sanity(): + """Assert the decoded newline token (198) is the same as the normalized vocab token""" + import llama_cpp + + model = models.llamacpp( + "bartowski/Meta-Llama-3-8B-Instruct-GGUF", + "Meta-Llama-3-8B-Instruct-IQ1_M.gguf", + tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( + "NousResearch/Hermes-2-Pro-Llama-3-8B" + ), + ) + tokenizer = generate.regex(model, "a").logits_processor.tokenizer + + decoded_nl_token = tokenizer.decode([198])[0] + vocab_nl_token = tokenizer.convert_token_to_string( + [token for token, token_id in tokenizer.vocabulary.items() if token_id == 198][ + 0 + ] + ) + assert decoded_nl_token == vocab_nl_token