Skip to content

Commit

Permalink
Support for SentenceTransformers with `deepsparse.sentence_transforme…
Browse files Browse the repository at this point in the history
…rs.SentenceTransformer` (#1301)

* Support for SentenceTransformer with `deepsparse.sentence_transformers.SentenceTransformer`

* Format

* Update install

* Update

* Address comments

* Add README

* Fix docs

* Update setup.py

* Update README

* Add batching
  • Loading branch information
mgoin committed Oct 19, 2023
1 parent 9d0d897 commit 869af57
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 2 deletions.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _parse_requirements_file(file_path):
_onnxruntime_deps = [
"onnxruntime>=1.7.0",
]
_torch_deps = ["torch>=1.7.0,<=2.0"]
_image_classification_deps = [
"torchvision>=0.3.0,<0.14",
"opencv-python<=4.6.0.66",
Expand All @@ -150,6 +151,7 @@ def _parse_requirements_file(file_path):
"scikit-learn",
"seqeval",
]
_sentence_transformers_integration_deps = ["optimum-deepsparse"] + _torch_deps

# haystack dependencies are installed from a requirements file to avoid
# conflicting versions with NM's deepsparse/transformers
Expand All @@ -168,8 +170,6 @@ def _parse_requirements_file(file_path):
"transformers<4.35",
]

_torch_deps = ["torch>=1.7.0,<=2.0"]


def _check_supported_system():
if sys.platform.startswith("linux"):
Expand Down Expand Up @@ -275,6 +275,7 @@ def _setup_extras() -> Dict:
"yolov8": _yolov8_integration_deps,
"transformers": _transformers_integration_deps,
"llm": _transformers_integration_deps,
"sentence_transformers": _sentence_transformers_integration_deps,
"torch": _torch_deps,
"clip": _clip_deps,
}
Expand Down
86 changes: 86 additions & 0 deletions src/deepsparse/sentence_transformers/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

# DeepSparse SentenceTransformers

```python
from deepsparse.sentence_transformers import SentenceTransformer
```

[DeepSparse](https://github.com/neuralmagic/deepsparse) enhances [SentenceTransformers](https://www.sbert.net/), enabling more efficient computation of embeddings for text and images across numerous languages. This improvement hinges on advanced sparse inference methods from DeepSparse and provides performance improvements on CPUs as a result. The system, originally built on PyTorch and Transformers, gains additional muscle from DeepSparse, expanding its repertoire of pre-trained models. It's especially adept at tasks like identifying similar meanings in text, supporting applications in semantic search, paraphrase detection, and more.

## Installation

You can install the DeepSparse SentenceTransformers extension using pip:

```bash
pip install -U deepsparse-nightly[sentence_transformers]
```

## Usage

Using DeepSparse SentenceTransformers is straightforward and similar to the original:

```python
from deepsparse.sentence_transformers import SentenceTransformer
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', export=True)

# Our sentences we like to encode
sentences = ['This framework generates embeddings for each input sentence',
'Sentences are passed as a list of string.',
'The quick brown fox jumps over the lazy dog.']

# Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

# Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
print("Sentence:", sentence)
print("Embedding:", embedding.shape)
print("")
```

## Accuracy Validation with MTEB

DeepSparse's efficiency doesn't compromise its accuracy, thanks to testing with the Multilingual Text Embedding Benchmark (MTEB). This process validates the model's performance against standard tasks, ensuring its reliability.

To initiate this, you'll need to install MTEB, along with the necessary DeepSparse and SentenceTransformers libraries. Use the following command:

```
pip install mteb deepsparse-nightly[sentence_transformers] sentence-transformers
```

Once installed, you can leverage MTEB for an evaluation as shown in the Python script below:

```python
from mteb import MTEB

# Specify the model to use
model_name = "TaylorAI/bge-micro-v2"

# DeepSparse Model Evaluation
from deepsparse.sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name, export=True)
evaluation = MTEB(tasks=["Banking77Classification"])
results_ds = evaluation.run(model, output_folder=f"results/ds-{model_name}")
print(results_ds)

# Original SentenceTransformers Model Evaluation
import sentence_transformers
model = sentence_transformers.SentenceTransformer(model_name)
evaluation = MTEB(tasks=["Banking77Classification"])
results_st = evaluation.run(model, output_folder=f"results/st-{model_name}")
print(results_st)
```

Output:
```
{'Banking77Classification': {'mteb_version': '1.1.1', 'dataset_revision': '0fd18e25b25c072e09e0d92ab615fda904d66300', 'mteb_dataset_name': 'Banking77Classification', 'test': {'accuracy': 0.8117207792207791, 'f1': 0.8109893836310513, 'accuracy_stderr': 0.007164150669501205, 'f1_stderr': 0.007346045502756079, 'main_score': 0.8117207792207791, 'evaluation_time': 8.05}}}
{'Banking77Classification': {'mteb_version': '1.1.1', 'dataset_revision': '0fd18e25b25c072e09e0d92ab615fda904d66300', 'mteb_dataset_name': 'Banking77Classification', 'test': {'accuracy': 0.8117207792207791, 'f1': 0.8109893836310513, 'accuracy_stderr': 0.007164150669501205, 'f1_stderr': 0.007346045502756079, 'main_score': 0.8117207792207791, 'evaluation_time': 12.21}}}
```

This script performs a comparative analysis between the DeepSparse-optimized model and the original SentenceTransformers model, using MTEB's "Banking77Classification" task as a benchmark. The results are then saved in separate directories for a clear, side-by-side comparison. This thorough evaluation ensures that the enhancements provided by DeepSparse maintain the high standards of accuracy expected from state-of-the-art NLP models.

---

This documentation is based on the original README from [SentenceTransformers](https://www.sbert.net/). It extends the original functionalities with the optimizations provided by [DeepSparse](https://github.com/neuralmagic/deepsparse).

**Note**: The example usage is designed for the DeepSparse-enhanced version of SentenceTransformers. Make sure to follow the specific installation instructions for full compatibility. Performance optimizations with batching and other advanced features will be part of future updates.
36 changes: 36 additions & 0 deletions src/deepsparse/sentence_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Helpers for running SentenceTransformer based models with DeepSparse and integrating with
huggingface/transformers
"""

# flake8: noqa

from deepsparse.analytics import deepsparse_analytics as _analytics


_analytics.send_event("python__sentence_transformers__init")


try:
import optimum.deepsparse
except ImportError:
raise ImportError(
"Please install deepsparse[sentence_transformers] to use this pathway"
)


from .sentence_transformer import SentenceTransformer
216 changes: 216 additions & 0 deletions src/deepsparse/sentence_transformers/sentence_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, List, Tuple, Union

import numpy as np
from tqdm.autonotebook import trange
from transformers.onnx.utils import get_preprocessor

import torch
from optimum.deepsparse import DeepSparseModelForFeatureExtraction


logger = logging.getLogger(__name__)

DEFAULT_MODEL_NAME = "zeroshot/bge-small-en-v1.5-quant"


class SentenceTransformer:
"""
Loads or creates a SentenceTransformer-compatible model that can be used to map
text to embeddings.
:param model_name_or_path: If it is a filepath on disc, it loads the model from
that path. If it is not a path, it first tries to download and export a model
from a HuggingFace models repository with that name.
:param export: To load a PyTorch checkpoint and convert it to the DeepSparse
format on-the-fly, you can set `export=True` when loading your model.
:param max_seq_length: Sets a limit on the maxmimum sequence length allowed,
this should be set to 512 for most models. Any text that exceeds this
token length will be truncated.
:param use_auth_token: HuggingFace authentication token to download private models.
"""

def __init__(
self,
model_name_or_path: str = DEFAULT_MODEL_NAME,
export: bool = False,
max_seq_length: int = 512,
use_auth_token: Union[bool, str, None] = None,
):

self.model_name_or_path = model_name_or_path
self.model = DeepSparseModelForFeatureExtraction.from_pretrained(
model_name_or_path, export=export, use_auth_token=use_auth_token
)
self.model.compile(batch_size=0)
self.tokenizer = get_preprocessor(model_name_or_path)

self._max_seq_length = max_seq_length

def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 1,
show_progress_bar: bool = None,
output_value: str = "sentence_embedding",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
normalize_embeddings: bool = False,
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
"""
Computes sentence embeddings
:param sentences: the sentences to embed
:param batch_size: the batch size used for the computation
:param show_progress_bar: Output a progress bar when encode sentences
:param output_value: Default sentence_embedding, to get sentence embeddings.
Can be set to token_embeddings to get wordpiece token embeddings. Set to
None, to get all output values
:param convert_to_numpy: If true, the output is a list of numpy vectors.
Else, it is a list of PyTorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return.
Overwrites any setting from convert_to_numpy
:param normalize_embeddings: If set to true, returned vectors will have
length 1. In that case, the faster dot-product (util.dot_score)
instead of cosine similarity can be used.
:return:
By default, a list of tensors is returned. If convert_to_tensor,
a stacked tensor is returned. If convert_to_numpy, a numpy matrix
is returned.
"""

if show_progress_bar is None:
show_progress_bar = logger.getEffectiveLevel() in (
logging.INFO,
logging.DEBUG,
)

if convert_to_tensor:
convert_to_numpy = False

if output_value != "sentence_embedding":
convert_to_tensor = False
convert_to_numpy = False

input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
): # Cast an individual sentence to a list with length 1
sentences = [sentences]
input_was_string = True

all_embeddings = []
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

for start_index in trange(
0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar
):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]

model_inputs = self.tokenize(sentences_batch)
model_output = self.model(**model_inputs)

out_features = {}
out_features["sentence_embedding"] = self.mean_pooling(
model_output, model_inputs["attention_mask"]
)

embeddings = []
if output_value == "token_embeddings":
for token_emb, attention in zip(
out_features[output_value], out_features["attention_mask"]
):
# Apply the attention mask to remove embeddings for padding tokens
# Count non-zero values in the attention mask
actual_tokens_count = attention.sum().item()
# Slice the embeddings using this count
embeddings.append(token_emb[:actual_tokens_count])
elif output_value is None:
# Return all outputs
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {name: out_features[name][sent_idx] for name in out_features}
embeddings.append(row)
else:
# Sentence embeddings
embeddings = out_features[output_value]
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

all_embeddings.extend(embeddings)

all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

if convert_to_tensor:
all_embeddings = torch.stack(all_embeddings)
elif convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

if input_was_string:
all_embeddings = all_embeddings[0]

return all_embeddings

def get_max_seq_length(self) -> int:
"""
Returns the maximal sequence length for input the model accepts.
Longer inputs will be truncated
"""
return self._max_seq_length

def _text_length(self, text: Union[List[int], List[List[int]]]) -> int:
"""
Help function to get the length for the input text. Text can be either
a list of ints (which means a single text as input), or a tuple of list of ints
(representing several text inputs to the model).
"""

if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, "__len__"): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
return len(text)
else:
return sum([len(t) for t in text]) # Sum of length of individual strings

def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
"""
Tokenizes the texts
"""
return self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

def mean_pooling(
self, model_output: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Compute mean pooling of token embeddings weighted by attention mask.
Args:
model_output (torch.Tensor): The model's output tensor.
attention_mask (torch.Tensor): The attention mask tensor.
Returns:
torch.Tensor: Mean-pooled embeddings.
"""
# First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)

0 comments on commit 869af57

Please sign in to comment.