Skip to content

Commit

Permalink
Add weight_only support for PyTorch framework (#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghuiCheng authored Sep 15, 2023
1 parent 93ca550 commit 3a064fa
Show file tree
Hide file tree
Showing 23 changed files with 1,011 additions and 154 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/cpp-graph-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ jobs:
cd ${{ github.workspace }}
conda activate cpp-graph-test || source activate cpp-graph-test
pip install build --upgrade
python -m build -s -w
pip install -r requirements.txt
python setup.py sdist bdist_wheel
pip install dist/intel_extension_for_transformers*.whl
pip list
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/llm-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ jobs:
cd ${{ github.workspace }}
conda activate llm-test || source activate llm-test
pip install build --upgrade
python -m build -s -w
pip install -r requirements.txt
python setup.py sdist bdist_wheel
pip install dist/intel_extension_for_transformers*.whl
pip list
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/script/formatScan/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ else
echo "Not found requirements.txt file."
fi
# install packages
pip install accelerate nlpaug nltk optimum-intel
pip install accelerate nlpaug nltk schema optimum-intel
pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@83dbfbf6070324f3e5872f63e49d49ff7ef4c9b3

echo "[DEBUG] list pipdeptree..."
Expand All @@ -39,7 +39,7 @@ python -m pylint -f json --disable=R,C,W,E1129 \
--max-line-length=120 \
--extension-pkg-whitelist=numpy,nltk \
--ignored-classes=TensorProto,NodeProto \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,mxnet,onnx,onnxruntime,neural_compressor,neural_compressor.benchmark,intel_extension_for_transformers.transformers.modeling.modeling_causal,intel_extension_for_transformers.neural_engine_py \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,mxnet,onnx,onnxruntime,neural_compressor,neural_compressor.benchmark,intel_extension_for_transformers.neural_engine_py \
/intel-extension-for-transformers/intel_extension_for_transformers >${log_dir}/pylint.json
exit_code=$?

Expand Down
7 changes: 4 additions & 3 deletions .github/workflows/script/install_binary.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ $BOLD_YELLOW && echo "---------------- git submodule update --init --recursive -
git config --global --add safe.directory "*"
git submodule update --init --recursive

$BOLD_YELLOW && echo "---------------- run python setup.py bdist_wheel -------------" && $RESET
pip install build --upgrade
python3 -m build -s -w

$BOLD_YELLOW && echo "---------------- run python setup.py sdist bdist_wheel -------------" && $RESET
python setup.py bdist_wheel


$BOLD_YELLOW && echo "---------------- pip install binary -------------" && $RESET
pip install dist/intel_extension_for_transformers*.whl
Expand Down
21 changes: 17 additions & 4 deletions intel_extension_for_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

try:
from ._version import __version__ # load _version file generated by setuptools_scm
except ModuleNotFoundError:
__version__ = "1.1"
def _get_version(default='x.x.x.dev'):
try:
from pkg_resources import DistributionNotFound, get_distribution
except ImportError:
return default
else:
try:
return get_distribution(__package__).version
except DistributionNotFound: # Run without install
return default
except ValueError: # Python 3 setup
return default
except TypeError: # Python 2 setup
return default


__version__ = _get_version()
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ JBLAS_CODE dequant_kblock_s8_f32_fwd(int8_t* srcptr, float* dstptr, int row, int
auto sptr = scales + kpos * NPad;
int j = 0;
for (; j < simd_process_num; j += Vlen) {
auto s8_ymm_v = _mm_loadu_si64(srcptr + i * ld_src + j);
auto s8_ymm_v = _mm_loadl_epi64(reinterpret_cast<__m128i_u*>( srcptr + i * ld_src + j));
auto s32_ymm_v = _mm256_cvtepi8_epi32(s8_ymm_v);
if constexpr (WITH_ZP) {
s32_ymm_v = _mm256_sub_epi32(s32_ymm_v, _mm256_cvtepi8_epi32(_mm_loadu_si64(zero_points + kpos * NPad + j)));
s32_ymm_v = _mm256_sub_epi32(s32_ymm_v, _mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i_u*>( zero_points + kpos * NPad + j))));
}
auto f32_ymm_v = _mm256_cvtepi32_ps(s32_ymm_v);
f32_ymm_v = _mm256_mul_ps(f32_ymm_v, _mm256_loadu_ps(sptr + j));
Expand Down Expand Up @@ -347,7 +347,7 @@ static inline JBLAS_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T*
vdsrc = _mm256_min_epi32(vdsrc, vff);
vdsrc = _mm256_max_epi32(vdsrc, v0);
auto vbsrc = avx2_cvtepi32_epu8(vdsrc);
_mm_storeu_si64((__m128i*)&dstptr[(j + ij) + i * ld_dst], vbsrc);
_mm_storel_epi64((__m128i*)&dstptr[(j + ij) + i * ld_dst], vbsrc);
}
if (ij < blocksize) {
for (; ij < blocksize; ij++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
## See the License for the specific language governing permissions and
## limitations under the License.
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(QBits LANGUAGES C CXX)
project(qbits LANGUAGES C CXX)


set(QBITS_TORCH_PATH "" CACHE STRING "Torch install path")
Expand All @@ -37,9 +37,9 @@ add_subdirectory(dispatcher)

file(GLOB qbits_src ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
# Define our library target
add_library(QBits SHARED ${qbits_src})
add_library(qbits SHARED ${qbits_src})
# Enable C++14
target_compile_features(QBits PRIVATE cxx_std_14)
target_compile_features(qbits PRIVATE cxx_std_14)

# Link against LibTorch
target_link_libraries(QBits jblas_dispatcher)
target_link_libraries(qbits jblas_dispatcher)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ inline JBLAS_CODE alphabeta_dt_cvt_process(float* tmp_dst, const int cachestep,
auto cptr = reinterpret_cast<DST_T*>(_param.C) + COffset;
if constexpr (std::is_same_v<DST_T, float>) {
return jblas::kernel::wrapper::Memcpy2D::template forward<ISA_T, float, DST_T>(
(void*)tmp_dst, (void*)cptr, M, N * sizeof(DST_T), cachestep * sizeof(float), _param.ldc * sizeof(DST_T), NULL);
tmp_dst, cptr, M, N, cachestep, _param.ldc, NULL);
}
if constexpr (std::is_same_v<DST_T, jblas::utils::bf16>) {
return jblas::kernel::wrapper::Memcpy2DFp32CvtBf16::template forward<ISA_T>(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .functions import matmul_4bit
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import operator
import torch
from functools import reduce
from torch import Tensor
from typing import Tuple, Optional, List

def prod(iterable):
return reduce(operator.mul, iterable, 1)

class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
B_shape = state[1]
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)


# 1. Dequantize
# 2. MatmulnN
# torch.ops.weight_only_jblasop.jblas_symqdq_weight(B, False, 4, 32) # TODO: replace with dequantize
output = torch.nn.functional.linear(A, B.to(A.dtype), bias)

# 3. Save state
ctx.state = state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, B)
else:
ctx.tensors = (None, None)

return output

@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
A, B = ctx.tensors
state = ctx.state

grad_A, grad_B, grad_bias = None, None, None

if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)

# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
# torch.ops.weight_only_jblasop.jblas_symqdq_weight(B, False, 4, 32) # TODO: replace with dequantize
if req_gradA: grad_A = torch.matmul(grad_output, B.to(grad_output.dtype))

return grad_A, grad_B, None, grad_bias, None

def matmul_4bit(A: Tensor, B: Tensor, quant_state: List = None, out: Tensor = None, bias=None, do_dequant=True):
# assert quant_state is not None
if do_dequant:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state) # TODO: replace with 4bit matmul
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .quantization_config import WeightOnlyConfig
Loading

0 comments on commit 3a064fa

Please sign in to comment.