Skip to content

Commit

Permalink
make LlamaCppTokenizer an outlines Tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed May 30, 2024
1 parent cb16b16 commit 280611f
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 64 deletions.
39 changes: 2 additions & 37 deletions outlines/integrations/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

import math
from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union
from typing import TYPE_CHECKING, Optional, Type, Union

import numpy as np
import torch
Expand All @@ -36,47 +36,12 @@
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import convert_json_schema_to_str
from outlines.models.llamacpp import LlamaCppTokenizer

if TYPE_CHECKING:
from llama_cpp import Llama


class LlamaCppTokenizer:
def __init__(self, model: "Llama"):
self.eos_token_id = model.token_eos()
self.eos_token = model.tokenizer().decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens: Set[int] = set()

self.vocabulary: Dict[str, int] = dict()

tokenizer = model.tokenizer()

self.decode = tokenizer.decode

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
try:
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
except AttributeError:
# ###
for t in range(model.n_vocab()):
token_piece = model.tokenizer().decode([t])
self.vocabulary[token_piece] = t

def convert_token_to_string(self, token: str) -> str:
return token

def __getstate__(self):
"""Allow tokenizer to be used as hash key by excluding self.decode"""
return (
self.vocabulary.items(),
self.eos_token_id,
self.eos_token,
self.pad_token_id,
sorted(self.special_tokens),
)


class LogitsProcessor:
"""Bias LlamaCpp generation using a finite state machine.
Expand Down
84 changes: 83 additions & 1 deletion outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,97 @@
import dataclasses
import pickle
import warnings
from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union
from typing import (
TYPE_CHECKING,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
TypedDict,
Union,
)

from typing_extensions import Unpack

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
from llama_cpp import Llama, LogitsProcessorList


class LlamaCppTokenizer(Tokenizer):
def __init__(self, model: "Llama"):
self.eos_token_id = model.token_eos()
self.eos_token = model.tokenizer().decode([self.eos_token_id])
self.pad_token_id = self.eos_token_id
self.special_tokens: Set[int] = set()

self.vocabulary: Dict[str, int] = dict()

self.tokenizer = model.tokenizer()

# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
try:
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
except AttributeError:
# ###
for t in range(model.n_vocab()):
token_piece = model.tokenizer().decode([t])
self.vocabulary[token_piece] = t

self._hash = None

def decode(self, token_ids: List[int]) -> List[str]:
decoded_bytes = self.tokenizer.detokenize(token_ids)
return [decoded_bytes.decode("utf-8", errors="ignore")]

def encode(
self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True
) -> Tuple[List[int], List[int]]:
if isinstance(prompt, list):
raise NotImplementedError(
"llama-cpp-python tokenizer doesn't support batch tokenization"
)
token_ids = self.tokenizer.tokenize(
prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
)
# generate attention mask, missing from llama-cpp-python
attention_mask = [
1 if token_id != self.pad_token_id else 0 for token_id in token_ids
]
return token_ids, attention_mask

def convert_token_to_string(self, token: str) -> str:
return token

def __eq__(self, other):
if not isinstance(other, LlamaCppTokenizer):
return False
return self.__getstate__() == other.__getstate__()

def __hash__(self):
# cache object hash
if self._hash is None:
self._hash = hash(pickle.dumps(self))
return self._hash

def __getstate__(self):
"""Create a stable representation for outlines.caching"""
return (
self.vocabulary,
self.eos_token_id,
self.eos_token,
self.pad_token_id,
self.special_tokens,
)

def __setstate__(self, state):
raise NotImplementedError("Cannot load a pickled llamacpp tokenizer")


class LlamaCppParams(TypedDict, total=False):
suffix: Optional[str]
temperature: float
Expand Down
24 changes: 24 additions & 0 deletions tests/generate/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from importlib import reload

import pytest


@pytest.fixture
def temp_cache_dir():
import os
import tempfile

import outlines.caching
import outlines.fsm.guide

with tempfile.TemporaryDirectory() as tempdir:
os.environ["OUTLINES_CACHE_DIR"] = tempdir
outlines.caching.get_cache.cache_clear()
reload(outlines)
reload(outlines.fsm.guide)
cache_status = outlines.caching._caching_enabled
try:
outlines.caching._caching_enabled = True
yield
finally:
outlines.caching._caching_enabled = cache_status
55 changes: 51 additions & 4 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,56 @@ def test_llama_cpp_pre_tokenizer_remains_broken():
generate.choice(model, ["skirt", "dress", "pen", "jacket"])


def test_create_states_mapping_llamacpp_tokenizer_regression(model):
"""Minimal reproducer for #922, error passing llamacpp tokenizer to create_states_mapping"""
def test_RegexGuide_caching(model, temp_cache_dir):
import llama_cpp

import outlines.caching
from outlines.fsm.guide import create_states_mapping
from outlines.integrations.llamacpp import LlamaCppTokenizer

create_states_mapping("a", LlamaCppTokenizer(model.model))
assert outlines.caching._caching_enabled

regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
prompt = "What is the IP address of the Google DNS servers? "

cache = outlines.caching.get_cache()

# Returns (hits, misses)
_ = cache.stats(enable=True)
assert cache.statistics

assert create_states_mapping.__memory__ is cache

generator = generate.regex(model, regex, sampler=samplers.greedy())
assert cache.stats() == (0, 1)

model_2 = models.llamacpp(
"Qwen/Qwen1.5-0.5B-Chat-GGUF",
"*q2*.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(
"Qwen/Qwen1.5-0.5B-Chat"
),
)
generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy())
assert cache.stats() == (0, 2)

# These two different models and tokenizers should not have the same state
# mapping results
assert (
generator.logits_processor.fsm.states_to_token_maps
!= generator_2.logits_processor.fsm.states_to_token_maps
)

generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy())
assert cache.stats() == (1, 2)
assert (
generator_2.logits_processor.fsm.states_to_token_maps
== generator_3.logits_processor.fsm.states_to_token_maps
)

# Just for fun...
structured = generator(prompt, max_tokens=30)
structured_2 = generator_2(prompt, max_tokens=30)

assert re.fullmatch(regex, structured)
assert re.fullmatch(regex, structured_2)
assert structured != structured_2
22 changes: 0 additions & 22 deletions tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import re
from enum import Enum
from importlib import reload
from typing import List, Union

import pytest
Expand All @@ -15,27 +14,6 @@
from outlines.samplers import beam_search, greedy, multinomial


@pytest.fixture
def temp_cache_dir():
import os
import tempfile

import outlines.caching
import outlines.fsm.guide

with tempfile.TemporaryDirectory() as tempdir:
os.environ["OUTLINES_CACHE_DIR"] = tempdir
outlines.caching.get_cache.cache_clear()
reload(outlines)
reload(outlines.fsm.guide)
cache_status = outlines.caching._caching_enabled
try:
outlines.caching._caching_enabled = True
yield
finally:
outlines.caching._caching_enabled = cache_status


def test_transformers_integration_text():
rng = torch.Generator()
rng.manual_seed(10000) # Choosen so <EOS> is generated
Expand Down

0 comments on commit 280611f

Please sign in to comment.