Skip to content

Commit

Permalink
Fix starcoder issues for IPEX int8 and Weight Only int4 (#508)
Browse files Browse the repository at this point in the history
* Fix starcoder issues for IPEX int8 and Weight Only int4

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel committed Oct 23, 2023
1 parent 8f41d49 commit e88c7b6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self):
self.cache = None
self.device = None
self.conv_template = None
self.ipex_int8 = None

def match(self, model_path: str):
"""
Expand Down Expand Up @@ -106,6 +107,7 @@ def load_model(self, kwargs: dict):
self.use_hpu_graphs = kwargs["use_hpu_graphs"]
self.cpu_jit = kwargs["cpu_jit"]
self.use_cache = kwargs["use_cache"]
self.ipex_int8 = kwargs["ipex_int8"]
load_model(model_name=kwargs["model_name"],
tokenizer_name=kwargs["tokenizer_name"],
device=kwargs["device"],
Expand Down Expand Up @@ -133,14 +135,16 @@ def predict_stream(self, query, config=None):
config.use_hpu_graphs = self.use_hpu_graphs
config.cpu_jit = self.cpu_jit
config.use_cache = self.use_cache
config.ipex_int8 = self.ipex_int8

if is_audio_file(query):
if not os.path.exists(query):
raise ValueError(f"The audio file path {query} is invalid.")

query_include_prompt = False
self.get_conv_template(self.model_name, config.task)
if self.conv_template.roles[0] in query and self.conv_template.roles[1] in query:
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name:
query_include_prompt = True

# plugin pre actions
Expand Down Expand Up @@ -205,14 +209,16 @@ def predict(self, query, config=None):
config.use_hpu_graphs = self.use_hpu_graphs
config.cpu_jit = self.cpu_jit
config.use_cache = self.use_cache
config.ipex_int8 = self.ipex_int8

if is_audio_file(query):
if not os.path.exists(query):
raise ValueError(f"The audio file path {query} is invalid.")

query_include_prompt = False
self.get_conv_template(self.model_name, config.task)
if self.conv_template.roles[0] in query and self.conv_template.roles[1] in query:
if (self.conv_template.roles[0] in query and self.conv_template.roles[1] in query) or \
"starcoder" in self.model_name:
query_include_prompt = True

# plugin pre actions
Expand Down
1 change: 1 addition & 0 deletions intel_extension_for_transformers/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@
AutoModel,
AutoModelForSeq2SeqLM,
OptimizedModel,
GPTBigCodeForCausalLM
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@

from .model import OptimizedModel
from .modeling_auto import (AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM)
AutoModelForSeq2SeqLM, GPTBigCodeForCausalLM)
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,6 @@ class AutoModel(_BaseQBitsAutoModelClass):

class AutoModelForSeq2SeqLM(_BaseQBitsAutoModelClass):
ORIG_MODEL = transformers.AutoModelForSeq2SeqLM

class GPTBigCodeForCausalLM(_BaseQBitsAutoModelClass):
ORIG_MODEL = transformers.GPTBigCodeForCausalLM

0 comments on commit e88c7b6

Please sign in to comment.