Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Delay torch import until needed for deepsparse.transformers.eval_downstream #1187

Merged
merged 2 commits into from
Aug 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/deepsparse/transformers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import numpy
from tqdm import tqdm

import torch
from deepsparse import Pipeline
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline
from deepsparse.transformers.utils.helpers import pad_to_fixed_length
Expand All @@ -49,6 +48,7 @@ def __init__(self, pipeline: Pipeline, batch_size: int = 16):
:param batch_size: The batch size to split the input text into
non-overlapping batches
"""
torch = _import_torch()
if not isinstance(pipeline, TextGenerationPipeline):
raise ValueError(
"Perplexity can only be computed for text generation pipelines"
Expand All @@ -67,6 +67,7 @@ def add_batch(self, predictions: List[str]):

:param predictions: The predictions to compute perplexity on
"""
torch = _import_torch()
# tokenize the input text
encodings = self._pipeline.tokenizer(
predictions,
Expand Down Expand Up @@ -225,3 +226,21 @@ def compute(self) -> Dict[str, float]:
results["f1_std"] = f1.std()

return results


def _import_torch():
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
"""
Import and return the required torch module. Raises an ImportError if torch is not
installed.

:raises ImportError: if torch is not installed
:return: torch module
"""
try:
import torch

return torch
except ImportError as import_error:
raise ImportError(
"Please install `deepsparse[torch]` to use this pathway"
) from import_error
Loading