Skip to content

Commit

Permalink
Delay importing torch until needed from transformers pathway, raise a…
Browse files Browse the repository at this point in the history
…ppropriate error when not installed. (#1187)

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
  • Loading branch information
rahul-tuli and dbogunowicz committed Aug 22, 2023
1 parent 545348b commit 8dbcb0c
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/deepsparse/transformers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import numpy
from tqdm import tqdm

import torch
from deepsparse import Pipeline
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline
from deepsparse.transformers.utils.helpers import pad_to_fixed_length
Expand All @@ -49,6 +48,7 @@ def __init__(self, pipeline: Pipeline, batch_size: int = 16):
:param batch_size: The batch size to split the input text into
non-overlapping batches
"""
torch = _import_torch()
if not isinstance(pipeline, TextGenerationPipeline):
raise ValueError(
"Perplexity can only be computed for text generation pipelines"
Expand All @@ -67,6 +67,7 @@ def add_batch(self, predictions: List[str]):
:param predictions: The predictions to compute perplexity on
"""
torch = _import_torch()
# tokenize the input text
encodings = self._pipeline.tokenizer(
predictions,
Expand Down Expand Up @@ -225,3 +226,21 @@ def compute(self) -> Dict[str, float]:
results["f1_std"] = f1.std()

return results


def _import_torch():
"""
Import and return the required torch module. Raises an ImportError if torch is not
installed.
:raises ImportError: if torch is not installed
:return: torch module
"""
try:
import torch

return torch
except ImportError as import_error:
raise ImportError(
"Please install `deepsparse[torch]` to use this pathway"
) from import_error

0 comments on commit 8dbcb0c

Please sign in to comment.