Skip to content

Commit

Permalink
[LLM Runtime] Enable GPTQ models (#611)
Browse files Browse the repository at this point in the history
* Enable GPTQ for bloom model

Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com>
  • Loading branch information
zhenwei-intel committed Nov 24, 2023
1 parent dfcfc09 commit 8145e63
Show file tree
Hide file tree
Showing 11 changed files with 715 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Argument description of WeightOnlyQuantConfig:
| group_size | Int | Group size: Int (default: 32) |
| scale_dtype | String | Data type of scales: fp32/bf16 (dafault fp32) |
| use_ggml | Bool | Enable ggml for quantization and inference (default: False) |
| not_quant | Bool | Determine whether or not the model will be quantized. (default: False) |
| use_quant | Bool | Determine whether or not the model will be quantized. (default: True) |
| use_cache | Bool | Use local quantized model if file exists (default: False) |

Argument description of generate function:
Expand Down
13 changes: 10 additions & 3 deletions intel_extension_for_transformers/llm/runtime/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_model_type(model_config):
model_type = "chatglm2"
return model_type

def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs):
def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, **quant_kwargs):
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model_type = Model.get_model_type(self.config)
Expand All @@ -94,20 +94,27 @@ def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs):
quant_desc += "_pc"
else:
quant_desc += "_g{}".format(quant_kwargs['group_size'])
if use_gptq:
quant_desc = "gptq"
quant_bin = "{}/ne_{}_q_{}.bin".format(output_path, model_type, quant_desc)

if not_quant:
if not use_quant:
self.bin_file = fp32_bin
else:
self.bin_file = quant_bin
if use_cache and os.path.exists(self.bin_file):
return

if use_gptq:
convert_model(model_name, quant_bin, "f32")
return


if not use_cache or not os.path.exists(fp32_bin):
convert_model(model_name, fp32_bin, "f32")
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"

if not_quant:
if not use_quant:
print("FP32 model will be used.")
return
self.module.Model.quant_model(model_path=fp32_bin, out_path=quant_bin, **quant_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@
#include <unordered_map>
#include <utility>
#include <vector>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "common.h"
#include "core/layers/jblas_common.hpp"
#include "models/model_utils/model_types.h"
#include "models/model_utils/model_config.h"
#include "models/model_utils/model_types.h"
#include "models/model_utils/model_utils.h"
Expand Down Expand Up @@ -87,6 +91,32 @@ class Model {
generate_count = 0;
}

static size_t jblas_qpack(const int8_t* src_w, const float* src_scales, const int8_t* src_zps, void* dstpr,
const quant_params_internal params, int nthread, int n, int k);
static size_t jblas_quantize(const float* src_w, void* dstpr, const quant_params_internal params, int nthread, int n,
int k);
static size_t np_jblas_qpack(py::array_t<int8_t> src_w, py::array_t<float> src_scales, py::array_t<int8_t> dst) {
int8_t* w_ptr = src_w.mutable_data();
float* scales_ptr = src_scales.mutable_data();
int8_t* dst_ptr = dst.mutable_data();

quant_params_internal q_params;
q_params.bits = quant_bits::q4;
q_params.scale_dtype = quant_sdtype::fp32;
q_params.compute_dtype = quant_comp::int8;
q_params.group_size = 128;
return Model::jblas_qpack(w_ptr, scales_ptr, nullptr, dst_ptr, q_params, 1, src_w.shape(0), src_w.shape(1));
}

static size_t np_jblas_quantize(py::array_t<float> src_w, py::array_t<int8_t> dst) {
quant_params_internal q_params;
q_params.bits = quant_bits::q4;
q_params.scale_dtype = quant_sdtype::fp32;
q_params.compute_dtype = quant_comp::int8;
q_params.group_size = 32;
return Model::jblas_quantize(src_w.mutable_data(), dst.mutable_data(), q_params, 8, src_w.shape(0), src_w.shape(1));
}

private:
model_context* ctx = nullptr;
gpt_params params;
Expand Down Expand Up @@ -480,6 +510,57 @@ int Model::quant_model(const std::string& model_path, const std::string& out_pat
return 0;
}

size_t Model::jblas_qpack(const int8_t* src_w, const float* src_scales, const int8_t* src_zps, void* dstpr,
const quant_params_internal params, int nthread, int n, int k) {
using CompType = jblas::prologue::weight_comp::gemm_kblcok::PrologueBIDs;
using namespace ne_jblas;
auto cd = jblas::utils::parallel::CpuDevice::getInstance();
auto dstbptr = (int8_t*)dstpr;
cd->setThreads(nthread);
// int8: using Kernel = WeiS8Fp32<GcCompInt8KBlock, JblasAVX512F>;
using Kernel = WeiS4ClipFp32<GcCompInt8KBlock, JblasAVX512F>;
static Kernel kernel;
auto packedw = kernel.createStorage(n, k, params.group_size);

// jblas::utils::aligned_vector<int8_t> buffer(packedw.mSize);
packedw.assign(dstbptr);

jblas::utils::aligned_vector<int8_t> tmpq(n * k);
std::copy(src_w, src_w + n * k, tmpq.data());

int nk_scale = jblas::utils::updiv(k, packedw.mBlockSize);
auto ssize = (size_t)n * nk_scale;
jblas::utils::avector<float> Tscales(ssize);
std::copy(src_scales, src_scales + ssize, Tscales.data());

jblas::utils::avector<int8_t> Tzps(packedw.mIsAsym ? ssize : 0);

kernel.packQWeight(n, k, tmpq.data(), n, Tscales.data(), Tzps.data(), &packedw);

// kernel.unpackWeight(n, k, &packedw, dstbptr, n);
return packedw.mSize;
}

size_t Model::jblas_quantize(const float* src_w, void* dstpr, const quant_params_internal params, int nthread, int n,
int k) {
using CompType = jblas::prologue::weight_comp::gemm_kblcok::PrologueBIDs;
using namespace ne_jblas;
auto cd = jblas::utils::parallel::CpuDevice::getInstance();
auto dstbptr = (int8_t*)dstpr;
cd->setThreads(nthread);
// using Kernel = WeiS8Fp32<GcCompInt8KBlock, JblasAVX512F>;
using Kernel = WeiS4ClipFp32<GcCompInt8KBlock, JblasAVX512F>;
static Kernel kernel;
auto packedw = kernel.createStorage(n, k, params.group_size);

// jblas::utils::aligned_vector<int8_t> buffer(packedw.mSize);
packedw.assign(dstbptr);

kernel.packTransposeWeight(n, k, src_w, k, &packedw);
// kernel.unpackTransposeWeight(n, k, &packedw, dstbptr, n);
return packedw.mSize;
}

#if MODEL_NAME_ID == 1

PYBIND11_MODULE(gptj_cpp, m)
Expand Down Expand Up @@ -561,5 +642,7 @@ PYBIND11_MODULE(qwen_cpp, m)
py::arg("threads") = 8)
.def("is_token_end", &Model::is_token_end)
.def("reset_token_end", &Model::reset_token_end)
.def_static("np_jblas_qpack", &Model::np_jblas_qpack)
.def_static("np_jblas_quantize", &Model::np_jblas_quantize)
.def("reinit", &Model::reinit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def cmpData(numa, numb):
args = parser.parse_args()

woq_configs = {
"fp32": WeightOnlyQuantConfig(use_cache=True, not_quant=True),
"fp32": WeightOnlyQuantConfig(use_cache=True, use_quant=False),
"ggml_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True, use_ggml=True),
"jblas_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True),
"jblas_int8": WeightOnlyQuantConfig(compute_dtype="bf16", weight_dtype="int8", use_cache=True),
Expand Down
144 changes: 144 additions & 0 deletions intel_extension_for_transformers/llm/runtime/graph/scripts/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from pathlib import Path
import numpy as np
import struct
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Literal, Optional, Sequence, Tuple, TypeVar, Union)
from sentencepiece import SentencePieceProcessor # type: ignore

GGML_QK8_0 = 32
GGML_QK4_0 = 32
GGML_QK4_1 = 32
GGML_QK5_0 = 32
GGML_QK5_1 = 32

def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_0 in ggml.c
assert tensor.shape[1] % GGML_QK4_0 == 0
tensor = tensor.view(-1, GGML_QK4_0)
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
scale = max_values / -8
tensor = (tensor / scale + 8).round().clamp(min=0, max=15).char()
# compress two int4 weights into an int8
tensor = tensor[:, :16] | (tensor[:, 16:] << 4)
# add scale into each block
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
return tensor

class SentencePieceVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens))
else:
added_tokens = {}
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
actual_ids = sorted(added_tokens.values())
if expected_ids != actual_ids:
raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}")
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
self.added_tokens_list = [text for (text, idx) in items]
self.vocab_size_base: int = vocab_size
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens

def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
tokenizer = self.sentencepiece_tokenizer
for i in range(tokenizer.vocab_size()):
text: bytes
if tokenizer.is_unknown(i):
text = " \u2047 ".encode("utf-8")
elif tokenizer.is_control(i):
text = b""
elif tokenizer.is_byte(i):
piece = tokenizer.id_to_piece(i)
if len(piece) != 6:
raise Exception(f"Invalid token: {piece}")
byte_value = int(piece[3:-1], 16)
text = struct.pack("B", byte_value)
else:
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
score: float = tokenizer.get_score(i)
yield text, score

def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
for text in self.added_tokens_list:
score = -1000.0
yield text.encode("utf-8"), score

def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
yield from self.sentencepiece_tokens()
yield from self.added_tokens()

def __repr__(self) -> str:
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"

def load_vocab(path: Path) -> SentencePieceVocab:
# Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might
# be in the parent of that.
if path.is_dir():
path2 = path / "tokenizer.model"
# Use `.parent` instead of /.. to handle the symlink case better.
path3 = path.parent / "tokenizer.model"
if path2.exists():
path = path2
elif path3.exists():
path = path3
else:
raise FileNotFoundError(f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json"
print(f"Loading vocab file {path}")
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)

def expandToInt4(qweight):
eweight = qweight.repeat(8, axis=2)
eweight = eweight.astype(np.uint32)
for i in range(0, eweight.shape[2]):
offset = i % (32 // 4) * 4
eweight[:, :, i] = eweight[:, :, i] >> offset & (2 ** 4 - 1)
return eweight


def to_ggml_int16(eweight):
qweight = np.zeros((eweight.shape[0], eweight.shape[1], eweight.shape[2] // 4), dtype=np.uint16)
eweight = np.asarray(eweight, dtype=np.uint16)
for i in range(0, qweight.shape[2]):
qweight[:, :, i] = eweight[:, :, i * 2 + 0]
qweight[:, :, i] |= eweight[:, :, i * 2 + 32] << 1 * 4
qweight[:, :, i] |= eweight[:, :, i * 2 + 1] << 2 * 4
qweight[:, :, i] |= eweight[:, :, i * 2 + 33] << 3 * 4
return qweight.astype(np.int16)


def qzeros_to_zeros(qzeros, bits=4):
zeros = np.zeros((qzeros.shape[0], qzeros.shape[1] * (32 // bits)), dtype=np.float32)
i = 0
col = 0
while col < qzeros.shape[1]:
for j in range(i, i + (32 // bits)):
zeros[:, j] = (qzeros[:, col] >> (bits * (j - i)) & (2 ** bits - 1)) + 1
i += 32 // bits
col += 1
return zeros
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def convert_model(model, outfile, outtype):
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
model_type = model_maps.get(config.model_type, config.model_type)

path = Path(Path(__file__).parent.absolute(), "convert_{}.py".format(model_type))
gpt_model = 'gptq' in model.lower()
if gpt_model:
path = Path(Path(__file__).parent.absolute(), "convert_gptq_{}.py".format(model_type))
else:
path = Path(Path(__file__).parent.absolute(), "convert_{}.py".format(model_type))
cmd = []
cmd.extend(["python", path])
cmd.extend(["--outfile", outfile])
Expand Down
Loading

0 comments on commit 8145e63

Please sign in to comment.