Skip to content

Commit

Permalink
fix the perf of llama. llama support static_shape in optimum habana (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sywangyi committed Jul 27, 2023
1 parent 71b5ac9 commit 481f389
Showing 1 changed file with 61 additions and 36 deletions.
97 changes: 61 additions & 36 deletions workflows/chatbot/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def max_input_len(model, outlen=0):
# need to adjust due to perf and real usage
return 128


def create_prompts(examples):
prompts = []
for example in examples:
Expand All @@ -200,6 +201,14 @@ def get_optimized_model_name(config):
return None


def model_is_optimized(config):
"""
Checks if the given config belongs to a model in optimum/habana/transformers/models, which has a
new input token_idx.
"""
return get_optimized_model_name(config) is not None or config.model_type == "mpt"


def get_ds_injection_policy(config):
model_type = get_optimized_model_name(config)
policy = {}
Expand Down Expand Up @@ -328,11 +337,6 @@ def load_model(
if model.generation_config.eos_token_id is None:
model.generation_config.eos_token_id = tokenizer.eos_token_id

if peft_path:
from peft import PeftModel

model = PeftModel.from_pretrained(model, peft_path)
model = model.to(torch.bfloat16)

if device == "hpu":
model = model.eval().to("hpu")
Expand All @@ -341,7 +345,18 @@ def load_model(
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model)

if peft_path:
from peft import PeftModel
model = PeftModel.from_pretrained(model, peft_path)
model = model.to(torch.bfloat16)
else:

if peft_path:
from peft import PeftModel
model = PeftModel.from_pretrained(model, peft_path)
model = model.to(torch.bfloat16)

import intel_extension_for_pytorch as intel_ipex

model = intel_ipex.optimize(
Expand All @@ -363,7 +378,7 @@ def load_model(
if not model.config.is_encoder_decoder:
tokenizer.padding_side = "left"

if tokenizer.pad_token is None:
if tokenizer.pad_token is None and tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

Expand Down Expand Up @@ -441,8 +456,14 @@ def predict_stream(**params):
[prompt], return_tensors="pt", padding=True
)
input_token_len = input_tokens.input_ids.shape[-1]
stop_token_ids = [model.generation_config.eos_token_id]
stop_token_ids = stop_token_ids + list(torch.flatten(tokenizer(".", return_tensors="pt").input_ids))
if isinstance(model.generation_config.eos_token_id, list):
stop_token_ids = copy.deepcopy(model.generation_config.eos_token_id)
else:
stop_token_ids = [model.generation_config.eos_token_id]
end_token_id = torch.flatten(tokenizer("go.", return_tensors="pt").input_ids)[
-1
]
stop_token_ids.append(end_token_id)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
Expand Down Expand Up @@ -486,8 +507,14 @@ def generate_output():
max_length=max_input_len(model, max_new_tokens),
)
input_token_len = input_tokens.input_ids.shape[-1]
stop_token_ids = [model.generation_config.eos_token_id]
stop_token_ids = stop_token_ids + list(torch.flatten(tokenizer(".", return_tensors="pt").input_ids))
if isinstance(model.generation_config.eos_token_id, list):
stop_token_ids = copy.deepcopy(model.generation_config.eos_token_id)
else:
stop_token_ids = [model.generation_config.eos_token_id]
end_token_id = torch.flatten(tokenizer("go.", return_tensors="pt").input_ids)[
-1
]
stop_token_ids.append(end_token_id)
generate_kwargs = {
"stopping_criteria": StoppingCriteriaList(
[
Expand All @@ -499,14 +526,6 @@ def generate_output():
]
)
}
is_graph_optimized = False
if (
re.search("gpt", model_name, re.IGNORECASE)
or re.search("bloom", model_name, re.IGNORECASE)
or re.search("mpt", model_name, re.IGNORECASE)
or re.search("opt", model_name, re.IGNORECASE)
):
is_graph_optimized = True
# Move inputs to target device(s)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
Expand All @@ -521,7 +540,7 @@ def generate_output():
generation_config.bad_words_ids = bad_words_ids
generation_config.force_words_ids = force_words_ids
generation_config.num_return_sequences = num_return_sequences
generation_config.static_shapes = is_graph_optimized
generation_config.static_shapes = model_is_optimized(model.config)
generation_config.top_k = top_k
# TODO there is an issue when top_p is used in Habana
# generation_config.top_p = top_p
Expand All @@ -540,7 +559,7 @@ def generate_output():
max_new_tokens=max_new_tokens,
lazy_mode=True,
hpu_graphs=use_hpu_graphs,
ignore_eos = False,
ignore_eos=False,
)

generation_thread = Thread(target=generate_output)
Expand Down Expand Up @@ -620,8 +639,14 @@ def predict(**params):
[prompt], return_tensors="pt", padding=True
)
input_token_len = input_tokens.input_ids.shape[-1]
stop_token_ids = [model.generation_config.eos_token_id]
stop_token_ids = stop_token_ids + list(torch.flatten(tokenizer(".", return_tensors="pt").input_ids))
if isinstance(model.generation_config.eos_token_id, list):
stop_token_ids = copy.deepcopy(model.generation_config.eos_token_id)
else:
stop_token_ids = [model.generation_config.eos_token_id]
end_token_id = torch.flatten(tokenizer("go.", return_tensors="pt").input_ids)[
-1
]
stop_token_ids.append(end_token_id)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
Expand Down Expand Up @@ -659,8 +684,14 @@ def predict(**params):
max_length=max_input_len(model, max_new_tokens),
)
input_token_len = input_tokens.input_ids.shape[-1]
stop_token_ids = [model.generation_config.eos_token_id]
stop_token_ids = stop_token_ids + list(torch.flatten(tokenizer(".", return_tensors="pt").input_ids))
if isinstance(model.generation_config.eos_token_id, list):
stop_token_ids = copy.deepcopy(model.generation_config.eos_token_id)
else:
stop_token_ids = [model.generation_config.eos_token_id]
end_token_id = torch.flatten(tokenizer("go.", return_tensors="pt").input_ids)[
-1
]
stop_token_ids.append(end_token_id)
generate_kwargs = {
"stopping_criteria": StoppingCriteriaList(
[
Expand All @@ -672,14 +703,6 @@ def predict(**params):
]
)
}
is_graph_optimized = False
if (
re.search("gpt", model_name, re.IGNORECASE)
or re.search("bloom", model_name, re.IGNORECASE)
or re.search("mpt", model_name, re.IGNORECASE)
or re.search("opt", model_name, re.IGNORECASE)
):
is_graph_optimized = True
# Move inputs to target device(s)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
Expand All @@ -694,7 +717,7 @@ def predict(**params):
generation_config.bad_words_ids = bad_words_ids
generation_config.force_words_ids = force_words_ids
generation_config.num_return_sequences = num_return_sequences
generation_config.static_shapes = is_graph_optimized
generation_config.static_shapes = model_is_optimized(model.config)
generation_config.top_k = top_k
# TODO there is an issue when top_p is used in Habana
# generation_config.top_p = top_p
Expand All @@ -711,7 +734,7 @@ def predict(**params):
max_new_tokens=max_new_tokens,
lazy_mode=True,
use_hpu_graphs=use_hpu_graphs,
ignore_eos = False,
ignore_eos=False,
)
output = tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
if "### Response:" in output:
Expand Down Expand Up @@ -829,7 +852,7 @@ def main():
use_hpu_graphs=args.use_hpu_graphs,
cpu_jit=args.jit,
use_cache=args.use_kv_cache,
peft_path=args.peft_model_path
peft_path=args.peft_model_path,
)

if args.habana and rank in [-1, 0]:
Expand Down Expand Up @@ -895,7 +918,9 @@ def main():
token_len = token_len + 1
if args.local_rank in [-1, 0]:
duration = time.time() - first_time_stamp
logger.info(f"duration: {time.time() - start_time}, msecond_per_token = {duration*1000/(token_len-1)}")
logger.info(
f"duration: {time.time() - start_time}, msecond_per_token = {duration*1000/(token_len-1)}"
)
logger.info("=" * (60 + len(idxs)))

for idx, tp in enumerate(zip(prompts, args.instructions)):
Expand Down

0 comments on commit 481f389

Please sign in to comment.