-
Notifications
You must be signed in to change notification settings - Fork 204
/
run_generation_gpu_woq.py
340 lines (321 loc) · 14.8 KB
/
run_generation_gpu_woq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import argparse
import re
import time
import json
import torch
from transformers import AutoConfig, AutoTokenizer
from transformers.generation import GenerationConfig
import intel_extension_for_pytorch as ipex
from intel_extension_for_transformers.transformers.llm.utils.generation import _beam_search, _greedy_search
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, AutoRoundConfig, RtnConfig, GPTQConfig
from intel_extension_for_transformers.transformers.llm.quantization.utils import convert_dtype_str2torch
from transformers.utils import check_min_version
import contextlib
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", nargs="?", default="Qwen/Qwen-7B-Chat", const="Qwen/Qwen-7B-Chat"
)
parser.add_argument("--revision", default=None, type=str)
parser.add_argument("--trust_remote_code", action="store_true")
parser.add_argument(
"--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k"
)
parser.add_argument(
"--max-new-tokens", default=32, type=int, help="output max new tokens"
)
parser.add_argument(
"--num_beams", default=1, type=int, help="number of beams"
)
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
# ============Benchmark configs==============
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--benchmark_batch_size", default=1, type=int,
help="batch size num.")
parser.add_argument("--do_profiling", action="store_true")
parser.add_argument("--profile_token_latency", action="store_true")
parser.add_argument("--benchmark_iters", default=10, type=int, help="num iter")
parser.add_argument("--num_warmup", default=3, type=int, help="num warmup")
# ============Accuracy configs==============
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--eval_batch_size", default=56, type=int,
help="batch size num.")
parser.add_argument("--save_accuracy_path", default=None,
help="Save accuracy results path.")
parser.add_argument("--tasks", default="lambada_openai", type=str, \
help="tasks list for accuracy validation")
# ============WeightOnlyQuant configs===============
parser.add_argument("--bits", type=int, default=4, choices=[4])
parser.add_argument("--woq", action="store_true")
parser.add_argument("--woq_algo", default="Rtn", choices=['Rtn', 'GPTQ', 'AutoRound'],
help="Weight-only parameter.")
parser.add_argument("--weight_dtype", type=str, default="int4",
choices=[
"int4", # int4 == int4_fullrange
"int4_fullrange",
]
)
parser.add_argument("--batch_size", default=8, type=int,
help="calibration batch size num.")
parser.add_argument("--group_size", type=int, default=128)
parser.add_argument("--scheme", default="sym")
parser.add_argument("--device", default="xpu")
parser.add_argument("--compute_dtype", default="fp16")
parser.add_argument("--load_in_4bit", type=bool, default=False)
parser.add_argument("--load_in_8bit", type=bool, default=False)
# ============GPTQ configs==============
parser.add_argument(
"--desc_act",
action="store_true",
help="Whether to apply the activation order GPTQ heuristic.",
)
parser.add_argument(
"--damp_percent",
type=float,
default=0.01,
help="Percent of the average Hessian diagonal to use for dampening.",
)
parser.add_argument(
"--blocksize",
type=int,
default=128,
help="Block size. sub weight matrix size to run GPTQ.",
)
parser.add_argument(
"--n_samples", type=int, default=512, help="Number of calibration data samples."
)
parser.add_argument(
"--seq_len",
type=int,
default=2048,
help="Calibration dataset sequence max length, this should align with your model config",
)
parser.add_argument(
"--static_groups",
action="store_true",
help="Use determined group to do quantization",
)
# ============AutoRound==================
parser.add_argument(
"--lr",
type=float,
default=None,
help="learning rate, if None, it will be set to 1.0/iters automatically",
)
parser.add_argument(
"--minmax_lr",
type=float,
default=None,
help="minmax learning rate, if None,it will beset to be the same with lr",
)
parser.add_argument("--autoround_iters", default=200, type=int, help="num iters for autoround calibration.")
parser.add_argument(
"--disable_quanted_input",
action="store_true",
help="whether to use the output of quantized block to tune the next block",
)
parser.add_argument(
"--quant_lm_head",
action="store_true",
help="whether to quant the lm head layer",
)
# =======================================
args = parser.parse_args()
torch_dtype = convert_dtype_str2torch(args.compute_dtype)
# transformers version >= 4.32.0 contained the mpt modeling definition.
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
check_min_version("4.31.0")
# get model config
config = AutoConfig.from_pretrained(
args.model,
use_cache=True, # to use kv cache.
trust_remote_code=args.trust_remote_code,
revision=args.revision,
)
user_model = None
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
quantization_config = None
if args.woq:
if args.woq_algo.lower() == "gptq":
quantization_config = GPTQConfig(
tokenizer=tokenizer,
dataset=args.dataset,
bits=args.bits,
desc_act=args.desc_act,
damp_percent=args.damp_percent,
sym=True if args.scheme == "sym" else False,
blocksize=args.blocksize,
n_samples=args.n_samples,
static_groups=args.static_groups,
group_size=args.group_size,
seq_len=args.seq_len,
compute_dtype=args.compute_dtype,
scale_dtype=args.compute_dtype,
weight_dtype=args.weight_dtype,
batch_size=args.batch_size,
)
elif args.woq_algo.lower() == "autoround":
quantization_config = AutoRoundConfig(
tokenizer=tokenizer,
dataset=args.dataset,
bits=args.bits,
sym=True if args.scheme == "sym" else False,
group_size=args.group_size,
compute_dtype=args.compute_dtype,
scale_dtype=args.compute_dtype,
weight_dtype=args.weight_dtype,
iters=args.autoround_iters,
seq_len=args.seq_len,
n_samples=args.n_samples,
lr=args.lr,
minmax_lr=args.minmax_lr,
disable_quanted_input=args.disable_quanted_input,
quant_lm_head = args.quant_lm_head,
)
elif args.woq_algo.lower() == "rtn":
quantization_config = RtnConfig(
compute_dtype=args.compute_dtype, weight_dtype=args.weight_dtype,
group_size=args.group_size, scale_dtype=args.compute_dtype
) #default is A16W4G16
# get model
if quantization_config is not None:
user_model = AutoModelForCausalLM.from_pretrained(args.model,
device_map=args.device,
quantization_config=quantization_config,
trust_remote_code=args.trust_remote_code,
torch_dtype=torch.float16,
use_neural_speed=False
)
elif args.load_in_4bit or args.load_in_8bit:
# CPU device usage is provided by intel-extension-for-transformers.
user_model = AutoModelForCausalLM.from_pretrained(args.model,
device_map=args.device,
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
use_neural_speed=False
)
if user_model is not None:
user_model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
enable_optimize_transformers = False
opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen"]
if config.model_type in opt_gpu_model_type_list:
enable_optimize_transformers = True
if args.benchmark:
if config.model_type == "qwen":
prompt = "它完成了,并提交了。你可以在Android和网络上玩美味生存。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子."
else:
prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun."
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)
user_model = AutoModelForCausalLM.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
if user_model is None else user_model
user_model = user_model.to(memory_format=torch.channels_last)
if quantization_config is None:
quantization_config = user_model.quantization_config if hasattr(user_model, "quantization_config") else None
if enable_optimize_transformers:
print("Optimize with IPEX...")
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype)
else:
print("Disabled optimization with IPEX...")
# start
num_iter = args.benchmark_iters
num_warmup = args.num_warmup
prompt = [prompt] * args.benchmark_batch_size
amp_enabled = True
amp_dtype = torch_dtype
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=args.num_beams)
if args.profile_token_latency:
ipex.transformers.optimize.convert_function(user_model, "greedy_search", _greedy_search)
ipex.transformers.optimize.convert_function(user_model, "_greedy_search", _greedy_search)
if not enable_optimize_transformers:
ipex.transformers.optimize.convert_function(user_model, "beam_search", _beam_search)
ipex.transformers.optimize.convert_function(user_model, "_beam_search", _beam_search)
user_model.config.token_latency = True
total_time = 0.0
total_list = []
with torch.inference_mode(), torch.no_grad(), torch.autocast(
device_type=args.device,
enabled=amp_enabled,
dtype=amp_dtype if amp_enabled else None,
):
for i in range(num_iter + num_warmup):
if args.do_profiling:
context = torch.autograd.profiler_legacy.profile(enabled=args.do_profiling, use_xpu=True, record_shapes=True)
else:
context = contextlib.nullcontext()
with context as prof:
input_ids = tokenizer(
prompt, return_tensors="pt").input_ids.to(args.device)
tic = time.time()
output = user_model.generate(
input_ids, max_new_tokens=int(args.max_new_tokens), **generate_kwargs
)
if args.device == "xpu":
torch.xpu.synchronize()
toc = time.time()
gen_ids = output[0] if args.profile_token_latency else output
gen_text = tokenizer.batch_decode(
gen_ids, skip_special_tokens=True)
if args.do_profiling and i >= num_warmup and (i == num_warmup or i == num_iter + num_warmup - 1):
print(f"Save pt for iter {i}")
torch.save(prof.key_averages().table(
sort_by="self_xpu_time_total"), f"./profile_{i}.pt")
# torch.save(prof.table(sort_by="id", row_limit=-1),
# './profile_id.pt')
# torch.save(prof.key_averages(
# group_by_input_shape=True).table(), "./profile_detail.pt")
prof.export_chrome_trace(f"./trace_{i}.json")
input_tokens_lengths = [x.shape[0] for x in input_ids]
output_tokens_lengths = [x.shape[0] for x in gen_ids]
total_new_tokens = [
o - i if user_model.config.model_type != "t5" else o
for i, o in zip(input_tokens_lengths, output_tokens_lengths)
]
print(gen_text, total_new_tokens, flush=True)
print("Iteration: %d, Time: %.6f sec" % (i, toc - tic), flush=True)
if i >= num_warmup:
total_time += toc - tic
if args.profile_token_latency:
total_list.append(output[1])
print("\n", "-" * 10, "Summary:", "-" * 10)
latency = total_time / (num_iter - num_warmup)
print("Inference latency: %.5f sec." % latency)
throughput = (args.max_new_tokens + input_size) / latency
print("Average throughput: {} samples/sec".format(throughput))
if args.profile_token_latency:
import numpy as np
from itertools import chain
first_latency = np.mean([x[0] for x in total_list])
average_2n = list(chain(*[x[1:] for x in total_list]))
average_2n.sort()
average_2n_latency = np.mean(average_2n)
print("First token average latency: %.5f sec." % first_latency)
print("Average 2... latency: %.5f sec." % average_2n_latency)
print(total_list)
if args.accuracy:
user_model = AutoModelForCausalLM.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
if user_model is None else user_model
if quantization_config is None:
quantization_config = user_model.quantization_config if hasattr(user_model, "quantization_config") else None
if enable_optimize_transformers:
print("Optimize with IPEX...")
user_model = ipex.optimize_transformers(
user_model.eval(), device=args.device, inplace=True, quantization_config=quantization_config, dtype=torch_dtype)
else:
print("Disabled optimization with IPEX...")
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
args = LMEvalParser(model = "hf",
tokenizer = tokenizer,
user_model = user_model,
tasks = args.tasks,
device = args.device,
batch_size = args.eval_batch_size)
results = evaluate(args)
for task_name in args.tasks.split(","):
if task_name == "wikitext":
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity,none"]))
else:
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc,none"]))