Skip to content

Commit

Permalink
Support neural-chat-7b model for chatbot (#112)
Browse files Browse the repository at this point in the history
* Support neural-chat-7b model

Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>

* support jit trace

Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>

* support neural-chat-7b for finetuning

Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>

* support neural-chat-7b-v2

Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>

* update finetune

Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>

* update code for latency data show

Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>

---------

Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>
  • Loading branch information
lvliang-intel authored Aug 11, 2023
1 parent 8fa0dc7 commit 126d07b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ def main():
# Load model
if model_args.model_name_or_path:
model_dtype = torch.bfloat16 if training_args.bf16 else None
if re.search("mpt", model_args.model_name_or_path, re.IGNORECASE):
if (re.search("mpt", model_args.model_name_or_path, re.IGNORECASE) or
re.search("neural-chat-7b-v1", model_args.model_name_or_path, re.IGNORECASE)):
from models.mpt.modeling_mpt import MPTForCausalLM

model = MPTForCausalLM.from_pretrained(
Expand Down
32 changes: 18 additions & 14 deletions workflows/chatbot/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,14 +338,16 @@ def load_model(
MODELS[model_name] = {}
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
use_fast=False if re.search("llama", model_name, re.IGNORECASE) else True,
use_fast=False if (re.search("llama", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)) else True,
)
if re.search("flan-t5", model_name, re.IGNORECASE):
with smart_context_manager(use_deepspeed=use_deepspeed):
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, low_cpu_mem_usage=True
)
elif re.search("mpt", model_name, re.IGNORECASE):
elif (re.search("mpt", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)):
from models.mpt.modeling_mpt import MPTForCausalLM

with smart_context_manager(use_deepspeed=use_deepspeed):
Expand All @@ -361,14 +363,15 @@ def load_model(
or re.search("bloom", model_name, re.IGNORECASE)
or re.search("llama", model_name, re.IGNORECASE)
or re.search("opt", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)
):
with smart_context_manager(use_deepspeed=use_deepspeed):
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
)
else:
raise ValueError(
f"Unsupported model {model_name}, only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT now."
f"Unsupported model {model_name}, only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/NEURAL-CHAT now."
)

if re.search("llama", model.config.architectures[0], re.IGNORECASE):
Expand Down Expand Up @@ -437,7 +440,8 @@ def load_model(
level="O1",
auto_kernel_selection=True,
)
if cpu_jit and re.search("mpt-7b", model_name, re.IGNORECASE):
if cpu_jit and (re.search("mpt-7b", model_name, re.IGNORECASE)
or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)):
from models.mpt.mpt_trace import jit_trace_mpt_7b, MPTTSModelForCausalLM

model = jit_trace_mpt_7b(model)
Expand Down Expand Up @@ -640,32 +644,32 @@ def generate_output():
raise ValueError(
f"Unsupported device type {device}, only supports cpu and hpu now."
)
output_token_len = 0
output_word_len = 0

for new_text in streamer:
if len(new_text) == 0:
continue
if output_token_len == 0:
if output_word_len == 0:
first_token_output_time = datetime.now()
output_token_len += 1
output_word_len += 1
yield new_text

end_time = datetime.now()
duration = int((end_time - start_time).total_seconds() * 1000)
first_token_latency = int(
first_word_latency = int(
(first_token_output_time - start_time).total_seconds() * 1000
)
token_per_second = (
(duration - first_token_latency) / (output_token_len - 1)
if output_token_len != 1
msecond_per_word = (
(duration - first_word_latency) / (output_word_len - 1)
if output_word_len != 1
else 0
)
stats = {
"input_token_len": input_token_len,
"output_token_len": output_token_len,
"output_word_len": output_word_len,
"duration": duration,
"first_token_latency": first_token_latency,
"token_per_second": token_per_second,
"first_word_latency": first_word_latency,
"msecond_per_word": msecond_per_word,
}
yield "END_OF_STREAM_STATS={}".format(stats)

Expand Down

0 comments on commit 126d07b

Please sign in to comment.