Skip to content

Commit

Permalink
- Use HF_TOKEN to download model from hub to generate embeddings.
Browse files Browse the repository at this point in the history
  • Loading branch information
bikash119 committed Sep 30, 2024
1 parent 2d0aa76 commit 778532f
Showing 1 changed file with 42 additions and 14 deletions.
56 changes: 42 additions & 14 deletions src/distilabel/embeddings/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
from typing import TYPE_CHECKING, Any, Dict, List, Union

from pydantic import Field, PrivateAttr
Expand All @@ -31,6 +33,7 @@ class LlamaCppEmbeddings(Embeddings, CudaDevicePlacementMixin):
model_path: contains the path to the GGUF quantized model, compatible with the
installed version of the `llama.cpp` Python bindings.
repo_id: the Hugging Face Hub repository id.
hf_token: Hugging Face token for accessing gated models.
verbose: whether to print verbose output. Defaults to `False`.
n_gpu_layers: number of layers to run on the GPU. Defaults to `-1` (use the GPU if available).
disable_cuda_device_placement: whether to disable CUDA device placement. Defaults to `True`.
Expand Down Expand Up @@ -73,6 +76,10 @@ class LlamaCppEmbeddings(Embeddings, CudaDevicePlacementMixin):
default=None,
description="The Hugging Face Hub repository id.",
)
hf_token: RuntimeParameter[Union[None, str]] = Field(
default=None,
description="Hugging Face token for accessing gated models.",
)
n_gpu_layers: int = 0
disable_cuda_device_placement: RuntimeParameter[bool] = Field(
default=True,
Expand All @@ -96,7 +103,11 @@ class LlamaCppEmbeddings(Embeddings, CudaDevicePlacementMixin):
_model: Union["_LlamaCpp", None] = PrivateAttr(None)

def load(self) -> None:
"""Loads the `gguf` model using either the path or the Hugging Face Hub repository id."""
"""
Loads the `gguf` model using either the path or the Hugging Face Hub repository id.
If using Hugging Face Hub, the model will be downloaded to a local directory
specified by the DISTILABEL_MODEL_DIR environment variable or to a temporary directory.
"""
super().load()

CudaDevicePlacementMixin.load(self)
Expand All @@ -111,34 +122,52 @@ def load(self) -> None:

if self.repo_id is not None:
try:
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_repo_id

validate_repo_id(self.repo_id)
except ImportError as ie:
raise ImportError(
"Llama.from_pretrained requires the huggingface-hub package. "
"You can install it with `pip install huggingface-hub`."
) from ie

validate_repo_id(self.repo_id)

# Determine the download directory
download_dir = os.environ.get("DISTILABEL_MODEL_DIR")
if download_dir is None:
download_dir = tempfile.gettempdir()

self._logger.info(
f"Attempting to download model from Hugging Face Hub: {self.repo_id}"
)
try:
self._logger.info(
f"Attempting to load model from Hugging Face Hub: {self.repo_id}"
)
self._model = _LlamaCpp.from_pretrained(
model_path = hf_hub_download(
repo_id=self.repo_id,
filename=self.model_path,
token=self.hf_token,
local_dir=download_dir,
)
self._logger.info(f"Model downloaded successfully to: {model_path}")
except Exception as e:
self._logger.error(
f"Failed to download model from Hugging Face Hub: {str(e)}"
)
raise

try:
self._model = _LlamaCpp(
model_path=model_path,
n_gpu_layers=self.n_gpu_layers,
seed=self.seed,
n_ctx=self.n_ctx,
n_batch=self.n_batch,
verbose=self.verbose,
embedding=True,
kwargs=self.extra_kwargs,
**self.extra_kwargs,
)
self._logger.info("Model loaded successfully from Hugging Face Hub")
self._logger.info("Model loaded successfully")
except Exception as e:
self._logger.error(
f"Failed to load model from Hugging Face Hub: {str(e)}"
)
self._logger.error(f"Failed to load model: {str(e)}")
raise
else:
try:
Expand All @@ -151,9 +180,8 @@ def load(self) -> None:
n_batch=self.n_batch,
verbose=self.verbose,
embedding=True,
kwargs=self.extra_kwargs,
**self.extra_kwargs,
)
self._logger.info(f"self._model: {self._model}")
self._logger.info("Model loaded successfully")
except Exception as e:
self._logger.error(f"Failed to load model: {str(e)}")
Expand Down

0 comments on commit 778532f

Please sign in to comment.