Skip to content

Commit

Permalink
Add llama.cpp backend (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
baptistecolle committed Jul 30, 2024
1 parent c45aecd commit 0aac010
Show file tree
Hide file tree
Showing 21 changed files with 356 additions and 15 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/test_cli_llama_cpp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: CLI Llama.cpp Tests

on:
workflow_dispatch:
push:
branches:
- main
paths:
- .github/workflows/test_cli_llama_cpp.yaml
- "optimum_benchmark/**"
- "docker/**"
- "tests/**"
- "setup.py"
pull_request:
branches:
- main
paths:
- .github/workflows/test_cli_llama_cpp.yaml
- "optimum_benchmark/**"
- "docker/**"
- "tests/**"
- "setup.py"

concurrency:
cancel-in-progress: true
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}

jobs:
run_cli_llama_cpp_tests:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: "3.10"

- name: Install requirements
run: |
pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -e .[testing,lamma-cpp]
- name: Run tests
run: pytest -s -k "llama_cpp"
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,7 @@ work-in-progress/
experiments/
amdsmi/
amd-*

# Mac specific
.DS_Store
outputs/
8 changes: 4 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ If you would like to work on any of the open Issues:
6. Depending on the feature you're working on and your development environment, you can run tests locally in an isolated docker container using the [makefile](Makefile). For example, to test the CLI with CPU device and PyTorch backend, you can run the following commands:

```bash
make install_cli_cpu_pytorch_extras
make install_cli_cpu_pytorch
make test_cli_cpu_pytorch
```

For a better development experience, we recommend using isolated docker containers to run tests:

```bash
make build_docker_cpu
make run_docker_cpu
make install_cli_cpu_pytorch_extras
make build_cpu_image
make run_cpu_container
make install_cli_cpu_pytorch
make test_cli_cpu_pytorch
```

Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ test_cli_rocm_pytorch_single_gpu:
pytest -s -k "cli and rocm and pytorch and not (dp or ddp or device_map or deepspeed) and not (bnb or awq)"

# llm-perf
test_cli_llama_cpp:
pytest -s -k "llama_cpp"

install_llm_perf_cuda_pytorch:
pip install packaging && pip install flash-attn einops scipy auto-gptq optimum bitsandbytes autoawq codecarbon
Expand Down
26 changes: 26 additions & 0 deletions examples/llama_cpp_embedding.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
defaults:
- benchmark
- scenario: inference
- launcher: inline
- backend: llama_cpp
- _base_
- _self_

name: llama_cpp_llama

backend:
device: mps
model: nomic-ai/nomic-embed-text-v1.5-GGUF
task: feature-extraction
filename: nomic-embed-text-v1.5.Q4_0.gguf

scenario:
input_shapes:
batch_size: 1
sequence_length: 256
vocab_size: 30000
type_vocab_size: 1
max_position_embeddings: 512
generate_kwargs:
max_new_tokens: 100
min_new_tokens: 100
25 changes: 25 additions & 0 deletions examples/llama_cpp_text_generation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults:
- benchmark
- scenario: inference
- launcher: inline
- backend: llama_cpp
- _base_
- _self_

name: llama_cpp_llama

backend:
device: mps
model: TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF
task: text-generation
filename: tinyllama-1.1b-chat-v1.0.Q4_0.gguf


scenario:
input_shapes:
batch_size: 1
sequence_length: 256
vocab_size: 32000
generate_kwargs:
max_new_tokens: 100
min_new_tokens: 100
26 changes: 26 additions & 0 deletions examples/pytorch_bert_mps.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
defaults:
- benchmark
- scenario: inference
- launcher: process # launcher: inline works,
- backend: pytorch
- _base_
- _self_

name: pytorch_bert

# launcher:
# start_method: spawn

scenario:
latency: true
memory: true
input_shapes:
batch_size: 1
sequence_length: 128

backend:
device: cpu
no_weights: true
model: bert-base-uncased


2 changes: 2 additions & 0 deletions optimum_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .backends import (
BackendConfig,
INCConfig,
LlamaCppConfig,
LLMSwarmConfig,
ORTConfig,
OVConfig,
Expand Down Expand Up @@ -38,4 +39,5 @@
"TrainingConfig",
"TRTLLMConfig",
"VLLMConfig",
"LlamaCppConfig",
]
2 changes: 2 additions & 0 deletions optimum_benchmark/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .config import BackendConfig
from .llama_cpp.config import LlamaCppConfig
from .llm_swarm.config import LLMSwarmConfig
from .neural_compressor.config import INCConfig
from .onnxruntime.config import ORTConfig
Expand All @@ -20,4 +21,5 @@
"LLMSwarmConfig",
"BackendConfig",
"VLLMConfig",
"LlamaCppConfig",
]
4 changes: 3 additions & 1 deletion optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def __init__(self, config: BackendConfigT):
self.automodel_loader = get_timm_automodel_loader()
self.pretrained_processor = None
self.generation_config = None

elif self.config.library == "llama_cpp":
self.logger.info("\t+ Benchmarking a Llama.cpp model")
self.model_shapes = {}
else:
self.logger.info("\t+ Benchmarking a Transformers model")
self.generation_config = get_transformers_generation_config(self.config.model, **self.config.model_kwargs)
Expand Down
12 changes: 8 additions & 4 deletions optimum_benchmark/backends/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ def __post_init__(self):

# TODO: add cache_dir, token, etc. to these methods
if self.task is None:
self.task = infer_task_from_model_name_or_path(self.model, self.model_kwargs.get("revision", None))
self.task = infer_task_from_model_name_or_path(
self.model, self.model_kwargs.get("revision", None), self.library
)

if self.library is None:
self.library = infer_library_from_model_name_or_path(self.model, self.model_kwargs.get("revision", None))

if self.model_type is None:
self.model_type = infer_model_type_from_model_name_or_path(
self.model, self.model_kwargs.get("revision", None)
self.model, self.model_kwargs.get("revision", None), self.library
)

if self.device is None:
Expand Down Expand Up @@ -90,8 +92,10 @@ def __post_init__(self):
else:
raise RuntimeError("CUDA device is only supported on systems with NVIDIA or ROCm drivers.")

if self.library not in ["transformers", "diffusers", "timm"]:
raise ValueError(f"`library` must be either `transformers`, `diffusers` or `timm`, but got {self.library}")
if self.library not in ["transformers", "diffusers", "timm", "llama_cpp"]:
raise ValueError(
f"`library` must be either `transformers`, `diffusers`, `timm` or `llama_cpp`, but got {self.library}"
)

if self.inter_op_num_threads is not None:
if self.inter_op_num_threads == -1:
Expand Down
Empty file.
92 changes: 92 additions & 0 deletions optimum_benchmark/backends/llama_cpp/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from tempfile import TemporaryDirectory
from typing import Any, Dict, Tuple

from llama_cpp import Llama

from ..base import Backend
from .config import LlamaCppConfig


class LlamaCppBackend(Backend[LlamaCppConfig]):
NAME: str = "llama_cpp"

def __init__(self, config: LlamaCppConfig) -> None:
super().__init__(config)

if self.config.no_weights:
self.logger.info("\t+ Loading no weights model")
raise NotImplementedError("No weights model is not yet implemented")

def load(self) -> None:
self.logger.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()
self.logger.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()
self.tmpdir.cleanup()

def load_model_from_pretrained(self) -> None:
"""
Load the pretrained model from the given model name (normally GGUF, GGML)
"""
embedding = True if self.config.task == "feature-extraction" else False

self.pretrained_model = Llama.from_pretrained(
repo_id=self.config.model, # type: ignore
filename=self.config.filename,
verbose=False,
echo=False,
embedding=embedding,
) # type: ignore

def validate_task(self) -> None:
if self.config.task not in ["text-generation"]:
raise ValueError(f"Task {self.config.task} not supported by {self.NAME}")

def prepare_inputs(self, inputs: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if self.config.task == "text-generation":
if inputs["input_ids"].shape[0] != 1:
raise ValueError("Batch size must be 1 for Llama.cpp text generation")

inputs = super().prepare_inputs(inputs)
inputs["tokens"] = inputs["input_ids"].squeeze()

return inputs
elif self.config.task == "feature-extraction":
detokenized_batch = list(map(self.pretrained_model.detokenize, inputs["input_ids"]))
decoded_batch = [x.decode("utf-8") for x in detokenized_batch]

inputs["input_str"] = decoded_batch
return inputs

raise ValueError(f"Task {self.config.task} not supported by {self.NAME}")

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
"""
Forward pass of the model\
Get the embeddings of the input tokens
"""

return self.pretrained_model.embed(inputs["input_str"])

def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> list[int]:
"""
Prefill the model with the input tokens
We consider prefill as the time to first token, thus we evaluate the time it takes for the model to generate the first token
"""

next(self.pretrained_model.generate(tokens=inputs["tokens"]))
return inputs

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> list[int]:
"""
Generate new tokens from the pretrained model
"""

output = []

for token in self.pretrained_model.generate(tokens=inputs["tokens"]):
output.append(token)
if len(output) >= kwargs["max_new_tokens"]:
break

return output
34 changes: 34 additions & 0 deletions optimum_benchmark/backends/llama_cpp/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from logging import getLogger
from typing import Optional

from ...import_utils import llama_cpp_version
from ..config import BackendConfig

LOGGER = getLogger("backend")


def llama_cpp_model_kwargs():
return {"verbose": True}


@dataclass
class LlamaCppConfig(BackendConfig):
name: str = "llama_cpp"
version: Optional[str] = llama_cpp_version()
_target_: str = "optimum_benchmark.backends.llama_cpp.backend.LlamaCppBackend"

no_weights: bool = False
library: str = "llama_cpp"
filename: Optional[str] = None

def __post_init__(self):
super().__post_init__()

self.device = self.device.lower() # type: ignore
self.library = "llama_cpp"

if self.device not in ["cuda", "mps", "cpu"]:
raise ValueError(f"Llama.cpp Backend only supports 'cpu', 'mps' and 'cuda' devices, got {self.device}")

LOGGER.warning("Llama.cpp automatically selects the device, ignoring the device parameter in the config.")
2 changes: 2 additions & 0 deletions optimum_benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
INCConfig,
InferenceConfig,
InlineConfig,
LlamaCppConfig,
LLMSwarmConfig,
ORTConfig,
OVConfig,
Expand Down Expand Up @@ -44,6 +45,7 @@
cs.store(group="backend", name=PyTXIConfig.name, node=PyTXIConfig)
cs.store(group="backend", name=LLMSwarmConfig.name, node=LLMSwarmConfig)
cs.store(group="backend", name=VLLMConfig.name, node=VLLMConfig)
cs.store(group="backend", name=LlamaCppConfig.name, node=LlamaCppConfig)
# scenarios configurations
cs.store(group="scenario", name=TrainingConfig.name, node=TrainingConfig)
cs.store(group="scenario", name=InferenceConfig.name, node=InferenceConfig)
Expand Down
7 changes: 6 additions & 1 deletion optimum_benchmark/generators/task_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,12 @@ class FeatureExtractionGenerator(TextGenerator, ImageGenerator):
def __call__(self):
dummy = {}

if self.shapes["num_channels"] is not None and self.shapes["height"] is not None:
if (
"num_channels" in self.shapes
and self.shapes["num_channels"] is not None
and "height" in self.shapes
and self.shapes["height"] is not None
):
dummy["pixel_values"] = self.pixel_values()
else:
dummy["input_ids"] = self.input_ids()
Expand Down
Loading

0 comments on commit 0aac010

Please sign in to comment.