Skip to content

Commit

Permalink
Enhance embedding to support jit model (#1335)
Browse files Browse the repository at this point in the history
* Enhance embedding to support  jit model 

Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
  • Loading branch information
yuwenzho committed Mar 8, 2024
1 parent c5e294a commit 588c608
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 46 deletions.
41 changes: 17 additions & 24 deletions intel_extension_for_transformers/langchain/embeddings/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.

import logging
import importlib.util
from typing import Any, Dict, List, Optional
from .optimized_instructor_embedding import OptimizedInstructor
from .optimized_sentence_transformers import OptimizedSentenceTransformer
Expand Down Expand Up @@ -74,14 +75,13 @@ class HuggingFaceEmbeddings(langchain_core.pydantic_v1.BaseModel, langchain_core
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers

except ImportError as exc:
# check sentence_transformers python package
if importlib.util.find_spec("sentence_transformers") is None: # pragma: no cover
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
) from exc
"Please install it with `pip install -U sentence-transformers`."
)

self.client = OptimizedSentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
Expand All @@ -104,7 +104,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
import sentence_transformers

texts = list(map(lambda x: x.replace("\n", " "), texts))
if self.multi_process:
if self.multi_process: # pragma: no cover
pool = self.client.start_multi_process_pool()
embeddings = self.client.encode_multi_process(texts, pool)
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
Expand Down Expand Up @@ -161,19 +161,18 @@ class HuggingFaceBgeEmbeddings(langchain_core.pydantic_v1.BaseModel, langchain_c
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers

except ImportError as exc:
# check sentence_transformers python package
if importlib.util.find_spec("sentence_transformers") is None: # pragma: no cover
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence_transformers`."
) from exc
"Please install it with `pip install -U sentence-transformers`."
)

self.client = OptimizedSentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
if "-zh" in self.model_name:
if "-zh" in self.model_name: # pragma: no cover
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH

class Config:
Expand Down Expand Up @@ -250,24 +249,18 @@ def __init__(self, **kwargs: Any):
super().__init__(**kwargs)

# check sentence_transformers python package
try:
import sentence_transformers

except ImportError as exc:
if importlib.util.find_spec("sentence_transformers") is None: # pragma: no cover
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence_transformers`."
) from exc
"Please install it with `pip install -U sentence-transformers`."
)

# check InstructorEmbedding python package
try:
import InstructorEmbedding

except ImportError as exc:
if importlib.util.find_spec("InstructorEmbedding") is None: # pragma: no cover
raise ImportError(
"Could not import InstructorEmbedding python package. "
"Please install it with `pip install InstructorEmbedding`."
) from exc
"Please install it with `pip install -U InstructorEmbedding`."
)

self.client = OptimizedInstructor(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import os
import json
import torch
import logging
from collections import OrderedDict
from intel_extension_for_transformers.transformers import OptimizedModel
Expand All @@ -37,19 +38,57 @@ def __init__(self, *args, **kwargs):

def _load_model(self, model_name_or_path, config, cache_dir, **model_args):
"""Loads the transformer model"""
if isinstance(config, T5Config):
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
elif isinstance(config, MT5Config):
self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args)
self.auto_model = OptimizedModel.from_pretrained(model_name_or_path,
config=config,
cache_dir=cache_dir,
**model_args)
if isinstance(self.auto_model, torch.jit.ScriptModule):
setattr(self.auto_model, "config", config)

def forward(self, features):
"""Returns token_embeddings, cls_token"""
trans_features = {'input_ids': features['input_ids'], 'attention_mask': features['attention_mask']}
if 'token_type_ids' in features: # pragma: no cover
trans_features['token_type_ids'] = features['token_type_ids']

context_masks = None
if 'context_masks' in features: # pragma: no cover
context_masks = features['context_masks']

if isinstance(self.auto_model, torch.jit.ScriptModule):
output_states = self.auto_model(**trans_features)
if isinstance(output_states, dict):
output_states = tuple(output_states.values())
output_tokens = output_states[0]
else:
self.auto_model = OptimizedModel.from_pretrained(model_name_or_path,
config=config,
cache_dir=cache_dir,
**model_args)
output_states = self.auto_model(**trans_features, return_dict=False)
output_tokens = output_states[0]
attention_mask = features['attention_mask']
if context_masks is not None:
assert len(context_masks) == len(attention_mask)
n = len(attention_mask)
for local_idx in range(n):
assert torch.sum(attention_mask[local_idx]).item() >= context_masks[local_idx].item(),\
f'{attention_mask[local_idx]}, {context_masks[local_idx]}, ' \
f'{torch.sum(attention_mask[local_idx]).item()}, {context_masks[local_idx].item()}'
attention_mask[local_idx][:context_masks[local_idx]] = 0

features.update({'token_embeddings': output_tokens, 'attention_mask': attention_mask})

if self.auto_model.config.output_hidden_states:
all_layer_idx = 2
if len(output_states) < 3: #Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1

hidden_states = output_states[all_layer_idx]
features.update({'all_layer_embeddings': hidden_states})

return features

class OptimizedInstructor(InstructorEmbedding.INSTRUCTOR):
def __init__(self, *args, **kwargs):
"""Initialize the OptimizedInstructor."""
self._jit_model = False
super().__init__(*args, **kwargs)

def _load_auto_model(self,
Expand All @@ -67,6 +106,8 @@ def _load_auto_model(self,
model_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision},
tokenizer_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision},
)
if isinstance(transformer_model.auto_model, torch.jit.ScriptModule):
self._jit_model = True
pooling_model = sentence_transformers.models.Pooling(
transformer_model.get_word_embedding_dimension(), 'mean')
return [transformer_model, pooling_model]
Expand Down Expand Up @@ -149,6 +190,8 @@ def _load_sbert_model(self,
else:
kwargs["tokenizer_args"] = hub_kwargs
module = OptimizedInstructorTransformer(model_name_or_path, cache_dir=cache_folder, **kwargs)
if isinstance(module.auto_model, torch.jit.ScriptModule):
self._jit_model = True
elif module_class == sentence_transformers.models.Pooling:
module_class = InstructorEmbedding.INSTRUCTOR_Pooling
module_path = sentence_transformers.util.load_dir_path(
Expand All @@ -175,3 +218,10 @@ def _load_sbert_model(self,
modules[module_config['name']] = module

return modules

def encode(self, sentences, device=None, *args, **kwargs):
if self._jit_model and device is None:
# set default device to 'cpu' for jit model, otherwise may fail when getting device
return super().encode(sentences, device='cpu', *args, **kwargs)
else:
return super().encode(sentences, device=device, *args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,44 @@ def __init__(self, *args, **kwargs):

def _load_model(self, model_name_or_path, config, cache_dir, **model_args):
"""Loads the transformer model"""
if isinstance(config, T5Config): # pragma: no cover
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
elif isinstance(config, MT5Config): # pragma: no cover
self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args)
self.auto_model = OptimizedModel.from_pretrained(model_name_or_path,
config=config,
cache_dir=cache_dir,
**model_args)
if isinstance(self.auto_model, torch.jit.ScriptModule):
setattr(self.auto_model, "config", config)

def forward(self, features):
"""Returns token_embeddings, cls_token"""
trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
if "token_type_ids" in features:
trans_features["token_type_ids"] = features["token_type_ids"]

if isinstance(self.auto_model, torch.jit.ScriptModule):
output_states = self.auto_model(**trans_features)
if isinstance(output_states, dict):
output_states = tuple(output_states.values())
output_tokens = output_states[0]
else:
self.auto_model = OptimizedModel.from_pretrained(model_name_or_path,
config=config,
cache_dir=cache_dir,
**model_args)
output_states = self.auto_model(**trans_features, return_dict=False)
output_tokens = output_states[0]

features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})

if self.auto_model.config.output_hidden_states:
all_layer_idx = 2
if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1

hidden_states = output_states[all_layer_idx]
features.update({"all_layer_embeddings": hidden_states})

return features

class OptimizedSentenceTransformer(sentence_transformers.SentenceTransformer):
def __init__(self, *args, **kwargs):
"""Initialize the OptimizedSentenceTransformer."""
self._jit_model = False
super().__init__(*args, **kwargs)

def _load_auto_model(
Expand All @@ -71,6 +96,8 @@ def _load_auto_model(
model_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision},
tokenizer_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision},
)
if isinstance(transformer_model.auto_model, torch.jit.ScriptModule):
self._jit_model = True
pooling_model = sentence_transformers.models.Pooling(
transformer_model.get_word_embedding_dimension(), 'mean')
return [transformer_model, pooling_model]
Expand Down Expand Up @@ -161,8 +188,10 @@ def _load_sbert_model(
kwargs["tokenizer_args"].update(hub_kwargs)
else:
kwargs["tokenizer_args"] = hub_kwargs
module = sentence_transformers.models.Transformer(
module = OptimzedTransformer(
model_name_or_path, cache_dir=cache_folder, **kwargs)
if isinstance(module.auto_model, torch.jit.ScriptModule):
self._jit_model = True
else:
# Normalize does not require any files to be loaded
if module_class == sentence_transformers.models.Normalize:
Expand All @@ -179,3 +208,10 @@ def _load_sbert_model(
modules[module_config['name']] = module

return modules

def encode(self, sentences, device=None, *args, **kwargs):
if self._jit_model and device is None:
# set default device to 'cpu' for jit model, otherwise may fail when getting device
return super().encode(sentences, device='cpu', *args, **kwargs)
else:
return super().encode(sentences, device=device, *args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_build_chatbot_with_retrieval_plugin_bge_int8(self):
plugins.retrieval.args["input_path"] = "../../../README.md"
# Intel/bge-base-en-v1.5-sts-int8-static is private now, so we need to load it from local.
plugins.retrieval.args["embedding_model"] = \
"/tf_dataset2/inc-ut/bge-base-en-v1.5-sts-int8-static"
"/tf_dataset2/inc-ut/embedding_models/itrex-int8/bge-base-en-v1.5-int8-static"
pipeline_config = PipelineConfig(model_name_or_path="facebook/opt-125m",
plugins=plugins)
chatbot = build_chatbot(pipeline_config)
Expand All @@ -126,10 +126,26 @@ def _run_retrieval(local_dir):
self.assertIsNotNone(response)
plugins.retrieval.enable = False

# test local file
_run_retrieval(local_dir="/tf_dataset2/inc-ut/gte-base")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/instructor-large")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/bge-base-en-v1.5")
# test fp32 model
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/fp32/paraphrase-multilingual-mpnet-base-v2")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/fp32/bge-base-en-v1.5")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/fp32/instructor-base")

# test itrex optimized model
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/itrex-int8/paraphrase-multilingual-mpnet-base-v2-int8-static")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/itrex-int8/bge-base-en-v1.5-int8-static")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/itrex-int8/instructor-base-int8-static")

# test itrex optimized model in sentence-transformers format
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/sentence-transformers-int8/paraphrase-multilingual-mpnet-base-v2-int8-static-st")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/sentence-transformers-int8/bge-base-en-v1.5-int8-static-st")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/sentence-transformers-int8/instructor-base-int8-static-st")

# test ipex optimized model
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/ipex-int8/paraphrase-multilingual-mpnet-base-v2-int8-static-ipex")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/ipex-int8/bge-base-en-v1.5-int8-static-ipex")
_run_retrieval(local_dir="/tf_dataset2/inc-ut/embedding_models/ipex-int8/instructor-base-int8-static-ipex")


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs):
import intel_extension_for_pytorch # pylint: disable=E0401
logger.info("the INC IPEX quantization optimized model is loading.")
weight_file = os.path.join(model_name_or_path, WEIGHTS_NAME)
if not os.path.exists(weight_file):
from huggingface_hub import hf_hub_download
weight_file = hf_hub_download(model_name_or_path, filename=WEIGHTS_NAME)
q_model = torch.jit.load(weight_file)
q_model = torch.jit.freeze(q_model.eval())
return q_model
Expand Down

0 comments on commit 588c608

Please sign in to comment.