Skip to content

Commit

Permalink
Integrate llama.cpp via a logits processor
Browse files Browse the repository at this point in the history
  • Loading branch information
dtiarks authored and rlouf committed Feb 16, 2024
1 parent bc71b23 commit e99d92d
Show file tree
Hide file tree
Showing 18 changed files with 649 additions and 349 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ __pycache__
docs/build
.coverage
.idea/
*.gguf
2 changes: 1 addition & 1 deletion docs/reference/models/llamacpp.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the
```python
from outlines import models, generate

model = models.llamacpp("./phi-2.Q4_K_M.gguf", device="cpu")
model = models.llamacpp("./phi-2.Q4_K_M.gguf")
```
4 changes: 2 additions & 2 deletions examples/llamacpp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class Character(BaseModel):


if __name__ == "__main__":
# Download model from https://huggingface.co/TheBloke/phi-2-GGUF
model = outlines.models.llamacpp("./phi-2.Q3_K_M.gguf", device="cpu")
# curl -L -o mistral-7b-instruct-v0.2.Q5_K_M.gguf https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q5_K_M.gguf
model = outlines.models.llamacpp("./mistral-7b-instruct-v0.2.Q5_K_M.gguf")

# Construct structured sequence generator
generator = outlines.generate.json(model, Character)
Expand Down
50 changes: 50 additions & 0 deletions examples/llamacpp_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from enum import Enum

from llama_cpp import Llama, LogitsProcessorList
from pydantic import BaseModel, constr

from outlines.generate.processors import JSONLogitsProcessor
from outlines.models.llamacpp import LlamaCppTokenizer


class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"


class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"


class Character(BaseModel):
name: constr(max_length=10)
age: int
armor: Armor
weapon: Weapon
strength: int


if __name__ == "__main__":
llama = Llama("./phi-2.Q4_K_M.gguf")
tokenizer = LlamaCppTokenizer(llama)

prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:"

logits_processor = JSONLogitsProcessor(Character, tokenizer)

json_str = llama.create_completion(
prompt,
top_k=40,
top_p=0.95,
temperature=0.7,
max_tokens=100,
logits_processor=LogitsProcessorList([logits_processor]),
)["choices"][0]["text"]

print(json_str)
4 changes: 2 additions & 2 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def copy(self) -> "StopAtEosFSM":
class RegexFSM(FSM):
"""FSM to generate text that is in the language of a regular expression."""

def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]]
Expand Down Expand Up @@ -190,7 +190,7 @@ def copy(self) -> "RegexFSM":
class CFGFSM(FSM):
"""FSM to generate text that is in the language of a context-free grammar."""

def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
def __init__(self, cfg_string: str, tokenizer):
self.cfg_string = cfg_string
self.tokenizer = tokenizer

Expand Down
16 changes: 4 additions & 12 deletions outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import inspect
import json
import re
from typing import Callable, Optional, Union
from typing import Callable, Optional

from jsonschema.protocols import Validator
from pydantic import BaseModel, create_model
from pydantic import create_model
from referencing import Registry, Resource
from referencing._core import Resolver
from referencing.jsonschema import DRAFT202012
Expand Down Expand Up @@ -38,9 +38,7 @@
}


def build_regex_from_object(
object: Union[str, Callable, BaseModel], whitespace_pattern: Optional[str] = None
):
def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None):

This comment has been minimized.

Copy link
@joennlae

joennlae Mar 19, 2024

rename of this function :-)

"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.
Expand Down Expand Up @@ -72,13 +70,7 @@ def build_regex_from_object(
"""

if isinstance(object, type(BaseModel)):
schema = object.model_json_schema()
elif callable(object):
schema = get_schema_from_signature(object)
else:
schema = json.loads(object)

schema = json.loads(schema)
Validator.check_schema(schema)

# Build reference resolver
Expand Down
19 changes: 19 additions & 0 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from outlines.fsm.fsm import CFGFSM
from outlines.generate.api import SequenceGenerator
from outlines.models import OpenAI
from outlines.models.llamacpp import CFGLogitsProcessor, LlamaCpp
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -31,6 +32,24 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
return generator


@cfg.register(LlamaCpp)
def cfg_llamacpp(
model: LlamaCpp,
cfg_str: str,
sampler: Sampler = multinomial(),
):
if not isinstance(sampler, multinomial):
raise NotImplementedError(
r"The llama.cpp integration does not currently support any other sampling algorithm "
+ "than the multinomial sampler."
)

logits_processor = CFGLogitsProcessor(cfg_str, model.tokenizer)
model.logits_processor = logits_processor

return model


@cfg.register(OpenAI)
def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()):
raise NotImplementedError(
Expand Down
6 changes: 5 additions & 1 deletion outlines/generate/choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def choice(
model, choices: List[str], sampler: Sampler = multinomial()
) -> SequenceGenerator:
regex_str = r"(" + r"|".join(choices) + r")"
return regex(model, regex_str, sampler)

generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: x

return generator


@choice.register(OpenAI)
Expand Down
8 changes: 4 additions & 4 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from outlines.fsm.json_schema import build_regex_from_object, get_schema_from_signature
from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from outlines.generate.api import SequenceGenerator
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial
Expand Down Expand Up @@ -45,17 +45,17 @@ def json(
"""
if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
regex_str = build_regex_from_object(schema, whitespace_pattern)
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: schema_object.parse_raw(x)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_object(schema, whitespace_pattern)
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
elif isinstance(schema_object, str):
schema = schema_object
regex_str = build_regex_from_object(schema, whitespace_pattern)
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
else:
Expand Down
25 changes: 24 additions & 1 deletion outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from outlines.fsm.fsm import RegexFSM
from outlines.generate.api import SequenceGenerator
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp, RegexLogitsProcessor
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -35,8 +36,30 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
return generator


@regex.register(LlamaCpp)
def regex_llamacpp(
model: LlamaCpp,
regex_str: str,
sampler: Sampler = multinomial(),
):
if not isinstance(sampler, multinomial):
raise NotImplementedError(
r"The llama.cpp integration does not currently support any other sampling algorithm "
+ "than the multinomial sampler."
)

logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer)
model.logits_processor = logits_processor

return model


@regex.register(OpenAI)
def regex_openai(model, regex_str: str, sampler: Sampler = multinomial()):
def regex_openai(
model: OpenAI,
regex_str: str,
sampler: Sampler = multinomial(),
):
raise NotImplementedError(
"Cannot use regex-structured generation with an OpenAI model"
+ "due to the limitations of the OpenAI API."
Expand Down
15 changes: 13 additions & 2 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from outlines.fsm.fsm import StopAtEosFSM
from outlines.generate import SequenceGenerator
from outlines.models import OpenAI
from outlines.models import LlamaCpp, OpenAI
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -36,12 +36,23 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
return generator


@text.register(LlamaCpp)
def text_llamacpp(model: LlamaCpp, sampler: Sampler = multinomial()):
if not isinstance(sampler, multinomial):
raise NotImplementedError(
r"The llama.cpp API does not support any other sampling algorithm "
+ "than the multinomial sampler."
)

return model


@text.register(OpenAI)
def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI:
if not isinstance(sampler, multinomial):
raise NotImplementedError(
r"The OpenAI API does not support any other sampling algorithm "
+ "that the multinomial sampler."
+ "than the multinomial sampler."
)

return model
Loading

1 comment on commit e99d92d

@remixer-dec
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this broke sglang

Please sign in to comment.