Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the files for woq of codegen25 using ipex #3024

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions examples/large_models/codegen25_ipex/ipex_woq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
This example demonstrates how to run code generation model e.g., Salesforce/codegen25-7b-multi. We are using IPEX Weight Only Quantization to convert the model to INT8.

For setting the conda environment for IPEX WoQ check out the documentation here:

https://github.com/intel/intel-extension-for-pytorch/tree/main/examples/cpu/inference/python/llm


1. Zip everything using model archiver
```
torch-model-archiver --model-name codegen25 --version 1.0 --handler codegen_handler.py --config-file model-config.yaml
```

2. Move archive to model_store
```
mkdir model_store
mv codegen25.mar ./model_store
```
3. Start the torch server
```
torchserve --ncs --start --model-store model_store
```

4. From the client, set up batching parameters. I couldn't make it work by putting the max_batch_size and max_batch_delay in config.properties.
```
curl -X POST "localhost:8081/models?url=codegen25.mar&batch_size=4&max_batch_delay=500"
```

5. Test the model status
```
curl http://localhost:8081/models/codegen25
```

6. Send the request
```
curl http://localhost:8080/predictions/codegen25 -T ./sample_text_0.txt
```

7. Batching the requests
```
bash benchmark.sh _batch_size
e.g., bash benchmark.sh 4
```



13 changes: 13 additions & 0 deletions examples/large_models/codegen25_ipex/ipex_woq/benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
batch_size=$1

start_time=$(date +%s%N)
for _ in `seq 1 10`; do
for i in `seq 1 $batch_size`; do
curl http://localhost:8080/predictions/codegen25 -T ./sample_text_0.txt -o output_${i}.txt &
done
wait;
done
end_time=$(date +%s%N)
elapsed=$(($(($end_time - $start_time))/10000000))
echo "Average e2e runtime per batch for batch size = ${batch_size} is ${elapsed} ms"
314 changes: 314 additions & 0 deletions examples/large_models/codegen25_ipex/ipex_woq/codegen_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
import os
import logging
from abc import ABC
from pathlib import Path
import subprocess

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

from ts.context import Context
from ts.torch_handler.base_handler import BaseHandler
import intel_extension_for_pytorch as ipex

EXAMPLE_INPUTS_MODE = {
"MASK_KV": 1,
"KV_MASK": 2,
"MASK_POS_KV": 3,
"MASK_KV_POS": 4,
"MASK_KV_ENC": 5,
}


logger = logging.getLogger(__name__)
logger.info("PyTorch version %s", torch.__version__)
logger.info("IPEX version %s", ipex.__version__)
logger.info("Transformers version %s", transformers.__version__)

class CodeGenHandler(BaseHandler, ABC):

def __init__(self):
super(CodeGenHandler, self).__init__()

def initialize(self, ctx: Context):
model_name = ctx.model_yaml_config["handler"]["model_name"]
# path to quantized model, if we are quantizing on the fly, we'll use this path to save the model
self.quantized_model_path = ctx.model_yaml_config["handler"]["quantized_model_path"]
self.example_inputs_mode = ctx.model_yaml_config["handler"]["example_inputs_mode"]
self.to_channels_last = ctx.model_yaml_config["handler"]["to_channels_last"]

# generation params
self.batch_size = int(ctx.model_yaml_config["handler"]["batch_size"])
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"])

# when benchmarking, we'll limit both the min and max token to this number for exact measurement
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])

# optimization params: right now we're only using WOQ, for SQ and other approach need to add support
self.ipex_weight_only_quantization = ctx.model_yaml_config["handler"]["ipex_weight_only_quantization"]
self.woq_dtype = ctx.model_yaml_config["handler"]["woq_dtype"]
self.lowp_mode = ctx.model_yaml_config["handler"]["lowp_mode"]
self.amp_enabled = ctx.model_yaml_config["handler"]["amp_enabled"]
self.act_quant_mode = ctx.model_yaml_config["handler"]["act_quant_mode"] # This is only relevant for INT4x2 quantization
self.group_size = ctx.model_yaml_config["handler"]["group_size"]

# decoding parameters
self.greedy = ctx.model_yaml_config["handler"]["greedy"]
logger.info(f"Max length of the sequence context is {self.max_length}")

try:
ipex._C.disable_jit_linear_repack()
torch._C._jit_set_texpr_fuser_enabled(False)
except Exception:
pass

# amp datatype
if self.amp_enabled:
self.amp_dtype = torch.bfloat16
else:
self.amp_dtype = torch.float32

# generate args: using greedy for now
self.num_beams = 1 if self.greedy else 4
# donot use min number of tokens on demo mode, only use it on benchmark mode
self.generate_kwargs = dict(
do_sample=False,
temperature=0.9,
num_beams=self.num_beams,
max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
)

# device
device = torch.device("cpu")

# model config
config = AutoConfig.from_pretrained(model_name, torchscript=True, trust_remote_code=True)

# set up max context
if not hasattr(config, "text_max_length"):
config.text_max_length = int(self.max_length)

# load model and tokenizer
self.user_model = AutoModelForCausalLM.from_pretrained(model_name, config=config, low_cpu_mem_usage=True, torch_dtype=torch.float)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code = True)
logger.info("Data type of the model: %s", self.user_model.dtype)

if self.to_channels_last:
self.user_model = self.user_model.to(memory_format=torch.channels_last)
self.user_model.eval()


# dummy past key value
beam_idx_tmp = torch.zeros((2048, int(self.batch_size * self.num_beams)), dtype=torch.long).contiguous()
def _get_target_nums(names):
for n in names:
if hasattr(self.user_model.config, n):
return getattr(self.user_model.config, n)
logger.error(f"Not found target {names[0]}")
exit(0)

num_heads_names = ["num_attention_heads", "n_head", "num_heads", "n_heads"]
num_layers_names = ["num_hidden_layers", "n_layer", "num_layers", "n_layers"]
hidden_size_names = ["hidden_size", "n_embd"]
n_heads = _get_target_nums(num_heads_names)
n_layers = _get_target_nums(num_layers_names)
hidden_size = _get_target_nums(hidden_size_names)
head_dim = int(hidden_size / n_heads)
self.global_past_key_value = [
(
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
torch.zeros([1, n_heads, 1, head_dim]).contiguous(),
torch.zeros([1, n_heads, 1, head_dim]).contiguous(),
beam_idx_tmp,
)
for i in range(n_layers)
]

logger.info(f"num_attention_heads: {n_heads}, num_hidden_layers: {n_layers}, hidden size: {hidden_size}, head_dim: {head_dim}")

# lets implement the WOQ
if self.ipex_weight_only_quantization:
weight_dtype = torch.quint4x2 if self.woq_dtype == "INT4" else torch.qint8

if self.lowp_mode == "INT8":
lowp_mode = ipex.quantization.WoqLowpMode.INT8
elif self.lowp_mode == "FP32":
lowp_mode = ipex.quantization.WoqLowpMode.NONE
elif self.lowp_mode == "FP16":
lowp_mode = ipex.quantization.WoqLowpMode.FP16
elif self.lowp_mode == "BF16":
lowp_mode = ipex.quantization.WoqLowpMode.BF16
else:
lowp_mode = ipex.quantization.WoqLowpMode.BF16

act_quant_mode_dict = {
"PER_TENSOR": ipex.quantization.WoqActQuantMode.PER_TENSOR,
"PER_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_IC_BLOCK,
"PER_BATCH": ipex.quantization.WoqActQuantMode.PER_BATCH,
"PER_BATCH_IC_BLOCK": ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK,
}

qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype,
lowp_mode=lowp_mode,
act_quant_mode=act_quant_mode_dict[self.act_quant_mode],
group_size=self.group_size,
)

# low precision checkpoint can be loaded, but we're considering there isn't any
low_precision_checkpoint = None
self.user_model = ipex.llm.optimize(
self.user_model.eval(),
dtype=self.amp_dtype,
quantization_config=qconfig,
inplace=True,
low_precision_checkpoint=low_precision_checkpoint,
deployment_mode=False,
)
logger.info("The model conversion completed, now tracing the quantized model")

example_inputs = self.get_example_inputs()

with torch.no_grad(), torch.cpu.amp.autocast(
enabled=self.amp_enabled,
dtype=self.amp_dtype
):
self_jit = torch.jit.trace(self.user_model.eval(), example_inputs, strict=False, check_trace=False)
self_jit = torch.jit.freeze(self_jit.eval())

self_jit.save(self.quantized_model_path)

logger.info("The IPEX Weight only quantization has been completed successfully")

# set PAD token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token=self.tokenizer.eos_token

logger.info("Loading the IPEX quantized model")
try:
self_jit = torch.jit.load(self.quantized_model_path)
self_jit = torch.jit.freeze(self_jit.eval())
except Exception as e:
logger.error("Error: loading the quantized model failed.", e)
exit(0)

setattr(self.user_model, "trace_graph", self_jit)
logger.info("Successfully loaded the Model %s with Intel® Extension for PyTorch*", ctx.model_name)

# Different model need to have their inputs supplied in different order unless we pass dict
# For torchserve sending dict is not always possible
# This function reorders the input ids, masks, and kv cache based on models
def get_example_inputs(self):
example_inputs = None
input_ids = torch.ones(32).to(torch.long)
attention_mask = torch.ones(len(input_ids))
if self.example_inputs_mode == "MASK_POS_KV":
position_ids = torch.arange(len(input_ids))
example_inputs = (
input_ids.unsqueeze(0),
attention_mask.unsqueeze(0),
position_ids.unsqueeze(0),
tuple(self.global_past_key_value),
)
elif self.example_inputs_mode == "MASK_KV_POS":
position_ids = torch.arange(len(input_ids))
example_inputs = (
input_ids.unsqueeze(0),
attention_mask.unsqueeze(0),
tuple(self.global_past_key_value),
position_ids.unsqueeze(0),
)
elif self.example_inputs_mode == "KV_MASK":
example_inputs = (
input_ids.unsqueeze(0),
tuple(self.global_past_key_value),
attention_mask.unsqueeze(0),
)
elif self.example_inputs_mode == "MASK_KV":
example_inputs = (
input_ids.unsqueeze(0),
attention_mask.unsqueeze(0),
tuple(self.global_past_key_value),
)
elif self.example_inputs_mode == "MASK_KV_ENC":
last_hidden_state = torch.rand([1, 32, 2048])
global_past_key_value = [
(
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
torch.zeros([1, n_heads, 1, head_dim]).contiguous(),
torch.zeros([1, n_heads, 1, head_dim]).contiguous(),
beam_idx_tmp,
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
torch.zeros([32, 1, n_heads, head_dim]).contiguous(),
torch.zeros([32, 1, n_heads, head_dim]).contiguous(),
beam_idx_tmp,
)
for i in range(n_layers)
]
example_inputs = (
torch.ones(1).to(torch.long).unsqueeze(0),
attention_mask.unsqueeze(0),
tuple(global_past_key_value),
(last_hidden_state,),
)
else:
raise RuntimeError("Your model does not match existing example inputs used in ipex quantization, exiting...")
#if hasattr(model, "extra_inputs"):
# example_inputs = example_inputs + model.extra_inputs
return example_inputs

def preprocess(self, requests):
input_ids_batch = None
attention_mask_batch = None
for idx, data in enumerate(requests):
input_text = data.get("data")
if input_text is None:
input_text = data.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")

with torch.inference_mode(), torch.no_grad(), torch.autocast(
device_type="cpu",
enabled=self.amp_enabled,
dtype=self.amp_dtype
):
inputs = self.tokenizer(
input_text,
pad_to_max_length=True,
add_special_tokens=True,
return_tensors="pt",
#max_length=int(self.max_length),
)

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# making a batch out of the recieved requests
if input_ids.shape is not None:
if input_ids_batch is None:
input_ids_batch = input_ids
attention_mask_batch = attention_mask
else:
input_ids_batch = torch.cat((input_ids_batch, input_ids), 0)
attention_mask_batch = torch.cat((attention_mask_batch, attention_mask), 0)
return (input_ids_batch, attention_mask_batch)

def inference(self, input_batch):
input_ids_batch, attention_mask_batch = input_batch
inferences = []
# total_list = []

with torch.inference_mode(), torch.no_grad(), torch.autocast(
device_type="cpu",
enabled=self.amp_enabled,
dtype=self.amp_dtype
):
outputs = self.user_model.generate(input_ids_batch, attention_mask=attention_mask_batch, **self.generate_kwargs)
for i, x in enumerate(outputs):
inferences.append(self.tokenizer.decode(outputs[i], skip_special_tokens=True))

return inferences

def postprocess(self, inference_output):
return inference_output
20 changes: 20 additions & 0 deletions examples/large_models/codegen25_ipex/ipex_woq/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
minWorkers: 1
maxWorkers: 1
responseTimeout: 1500

handler:
model_name: "Salesforce/codegen25-7b-multi"
quantized_model_path: "best_model.pt"
example_inputs_mode: "MASK_KV_POS"
to_channels_last: false
batch_size: 1
max_length: 2048
max_new_tokens: 128
ipex_weight_only_quantization: true
woq_dtype: "INT8"
lowp_mode: "BF16"
amp_enabled: true
act_quant_mode: "PER_IC_BLOCK"
group_size: -1
greedy: true

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Write a python function to compute the factorial of an integer.