Skip to content

Commit

Permalink
- Download from hub is now available through mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
bikash119 committed Oct 2, 2024
1 parent 778532f commit 55c3a0d
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 88 deletions.
102 changes: 23 additions & 79 deletions src/distilabel/embeddings/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
from typing import TYPE_CHECKING, Any, Dict, List, Union

from pydantic import Field, PrivateAttr

from distilabel.embeddings.base import Embeddings
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.mixins.hub_downloader import HuggingFaceModelLoaderMixin
from distilabel.mixins.runtime_parameters import RuntimeParameter

if TYPE_CHECKING:
from llama_cpp import Llama as _LlamaCpp


class LlamaCppEmbeddings(Embeddings, CudaDevicePlacementMixin):
class LlamaCppEmbeddings(
Embeddings, CudaDevicePlacementMixin, HuggingFaceModelLoaderMixin
):
"""`LlamaCpp` library implementation for embedding generation.
Attributes:
Expand Down Expand Up @@ -71,16 +72,8 @@ class LlamaCppEmbeddings(Embeddings, CudaDevicePlacementMixin):
```
"""

model_path: str
repo_id: RuntimeParameter[Union[None, str]] = Field(
default=None,
description="The Hugging Face Hub repository id.",
)
hf_token: RuntimeParameter[Union[None, str]] = Field(
default=None,
description="Hugging Face token for accessing gated models.",
)
n_gpu_layers: int = 0
model_file: str
n_gpu_layers: RuntimeParameter[int] = Field(default=0, description="Numbe of gpu")
disable_cuda_device_placement: RuntimeParameter[bool] = Field(
default=True,
description="Whether to disable CUDA device placement.",
Expand Down Expand Up @@ -120,72 +113,23 @@ def load(self) -> None:
" `pip install llama-cpp-python`."
) from ie

if self.repo_id is not None:
try:
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_repo_id
except ImportError as ie:
raise ImportError(
"Llama.from_pretrained requires the huggingface-hub package. "
"You can install it with `pip install huggingface-hub`."
) from ie

validate_repo_id(self.repo_id)

# Determine the download directory
download_dir = os.environ.get("DISTILABEL_MODEL_DIR")
if download_dir is None:
download_dir = tempfile.gettempdir()

self._logger.info(
f"Attempting to download model from Hugging Face Hub: {self.repo_id}"
model_path = self.download_model()
try:
self._logger.info(f"Attempting to load model from: {self.model_file}")
self._model = _LlamaCpp(
model_path=model_path,
seed=self.seed,
n_gpu_layers=self.n_gpu_layers,
n_ctx=self.n_ctx,
n_batch=self.n_batch,
verbose=self.verbose,
embedding=True,
**self.extra_kwargs,
)
try:
model_path = hf_hub_download(
repo_id=self.repo_id,
filename=self.model_path,
token=self.hf_token,
local_dir=download_dir,
)
self._logger.info(f"Model downloaded successfully to: {model_path}")
except Exception as e:
self._logger.error(
f"Failed to download model from Hugging Face Hub: {str(e)}"
)
raise

try:
self._model = _LlamaCpp(
model_path=model_path,
n_gpu_layers=self.n_gpu_layers,
seed=self.seed,
n_ctx=self.n_ctx,
n_batch=self.n_batch,
verbose=self.verbose,
embedding=True,
**self.extra_kwargs,
)
self._logger.info("Model loaded successfully")
except Exception as e:
self._logger.error(f"Failed to load model: {str(e)}")
raise
else:
try:
self._logger.info(f"Attempting to load model from: {self.model_path}")
self._model = _LlamaCpp(
model_path=self.model_path,
seed=self.seed,
n_gpu_layers=self.n_gpu_layers,
n_ctx=self.n_ctx,
n_batch=self.n_batch,
verbose=self.verbose,
embedding=True,
**self.extra_kwargs,
)
self._logger.info("Model loaded successfully")
except Exception as e:
self._logger.error(f"Failed to load model: {str(e)}")
raise
self._logger.info("Model loaded successfully")
except Exception as e:
self._logger.error(f"Failed to load model: {str(e)}")
raise

def unload(self) -> None:
"""Unloads the `gguf` model."""
Expand All @@ -195,7 +139,7 @@ def unload(self) -> None:
@property
def model_name(self) -> str:
"""Returns the name of the model."""
return self.model_path
return self.model_file

def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Expand Down
89 changes: 89 additions & 0 deletions src/distilabel/mixins/hub_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
from typing import Optional

from pydantic import BaseModel, Field


class HuggingFaceModelLoaderMixin(BaseModel):
"""
A mixin for downloading models from the Hugging Face Hub.
Attributes:
repo_id (Optional[str]): The Hugging Face Hub repository id.
model_file (str): The name of the model file to download.
hf_token (Optional[str]): Hugging Face token for accessing gated models.
"""

repo_id: Optional[str] = Field(
default=None,
description="The Hugging Face Hub repository id.",
)
model_file: str = Field(
description="The name of the model file to download.",
)
hf_token: Optional[str] = Field(
default=None,
description="Hugging Face token for accessing gated models.",
)

def download_model(self) -> str:
"""
Downloads the model from Hugging Face Hub if repo_id is provided.
Returns:
str: The path to the downloaded or local model file.
Raises:
ImportError: If huggingface_hub is not installed.
ValueError: If repo_id is not provided or invalid.
Exception: If there's an error downloading or loading the model.
"""
if self.repo_id is None:
return self.model_file

try:
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_repo_id
except ImportError as ie:
raise ImportError(
"huggingface_hub package is not installed. "
"You can install it with `pip install huggingface_hub`."
) from ie

try:
validate_repo_id(self.repo_id)
except ValueError as ve:
raise ValueError(f"Invalid repo_id: {self.repo_id}") from ve

# Determine the download directory
download_dir = os.environ.get("DISTILABEL_MODEL_DIR")
if download_dir is None:
download_dir = tempfile.gettempdir()

try:
model_path = hf_hub_download(
repo_id=self.repo_id,
filename=self.model_file,
token=self.hf_token,
local_dir=download_dir,
)
return model_path
except Exception as e:
raise Exception(
f"Failed to download model from Hugging Face Hub: {str(e)}"
) from e
18 changes: 9 additions & 9 deletions tests/unit/embeddings/test_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def test_model_name(self) -> None:
"""
Test if the model name is correctly set.
"""
embeddings = LlamaCppEmbeddings(model_path=self.model_name)
assert embeddings.model_path == self.model_name
embeddings = LlamaCppEmbeddings(model_file=self.model_name)
assert embeddings.model_file == self.model_name

def test_encode(self, local_llamacpp_model_path) -> None:
"""
Expand All @@ -35,7 +35,7 @@ def test_encode(self, local_llamacpp_model_path) -> None:
Args:
local_llamacpp_model_path (str): Fixture providing the local model path.
"""
embeddings = LlamaCppEmbeddings(model_path=local_llamacpp_model_path)
embeddings = LlamaCppEmbeddings(model_file=local_llamacpp_model_path)
inputs = [
"Hello, how are you?",
"What a nice day!",
Expand All @@ -54,7 +54,7 @@ def test_load_model_from_local(self, local_llamacpp_model_path):
Args:
local_llamacpp_model_path (str): Fixture providing the local model path.
"""
embeddings = LlamaCppEmbeddings(model_path=local_llamacpp_model_path)
embeddings = LlamaCppEmbeddings(model_file=local_llamacpp_model_path)
inputs = [
"Hello, how are you?",
"What a nice day!",
Expand All @@ -73,7 +73,7 @@ def test_load_model_from_repo(self):
"""
embeddings = LlamaCppEmbeddings(
repo_id=self.repo_id,
model_path=self.model_name,
model_file=self.model_name,
normalize_embeddings=True,
)
inputs = [
Expand All @@ -94,7 +94,7 @@ def test_normalize_embeddings_true(self, local_llamacpp_model_path):
Test if embeddings are normalized when normalize_embeddings is True.
"""
embeddings = LlamaCppEmbeddings(
model_path=local_llamacpp_model_path, normalize_embeddings=True
model_file=local_llamacpp_model_path, normalize_embeddings=True
)
embeddings.load()

Expand All @@ -118,7 +118,7 @@ def test_normalize_embeddings_false(self, local_llamacpp_model_path):
Test if embeddings are not normalized when normalize_embeddings is False.
"""
embeddings = LlamaCppEmbeddings(
model_path=local_llamacpp_model_path, normalize_embeddings=False
model_file=local_llamacpp_model_path, normalize_embeddings=False
)
embeddings.load()

Expand Down Expand Up @@ -150,7 +150,7 @@ def test_encode_batch(self, local_llamacpp_model_path) -> None:
Args:
local_llamacpp_model_path (str): Fixture providing the local model path.
"""
embeddings = LlamaCppEmbeddings(model_path=local_llamacpp_model_path)
embeddings = LlamaCppEmbeddings(model_file=local_llamacpp_model_path)
embeddings.load()

# Test with different batch sizes
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_encode_batch_consistency(self, local_llamacpp_model_path) -> None:
Args:
local_llamacpp_model_path (str): Fixture providing the local model path.
"""
embeddings = LlamaCppEmbeddings(model_path=local_llamacpp_model_path)
embeddings = LlamaCppEmbeddings(model_file=local_llamacpp_model_path)
embeddings.load()

input_text = "This is a test sentence for consistency"
Expand Down

0 comments on commit 55c3a0d

Please sign in to comment.