Skip to content

Commit

Permalink
Llama wa quant support down_proj mix bits (#451)
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Jul 1, 2024
1 parent c91d734 commit b8501a9
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 52 deletions.
54 changes: 41 additions & 13 deletions lightllm/common/basemodel/cuda_kernel/ppl_awquant.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,45 @@
import torch


def CONTIGUOUS_TENSOR(tensor: torch.Tensor):
""" Helper function """
if tensor.is_contiguous(): return tensor
else: return tensor.contiguous()
"""Helper function"""
if tensor.is_contiguous():
return tensor
else:
return tensor.contiguous()


def skiprmsnorm_ppl(x, weight, skip=None, eps=1e-6):
from lightllm_ppl_int8_kernel import SkipRmsNormForward_fp16_i8
if skip is None: skip = torch.zeros_like(x)
return SkipRmsNormForward_fp16_i8(
CONTIGUOUS_TENSOR(x), CONTIGUOUS_TENSOR(weight),
CONTIGUOUS_TENSOR(skip), eps)

if skip is None:
skip = torch.zeros_like(x)
return SkipRmsNormForward_fp16_i8(CONTIGUOUS_TENSOR(x), CONTIGUOUS_TENSOR(weight), CONTIGUOUS_TENSOR(skip), eps)


def gatesilu_i32_i8_ppl(x, y, x_scale, y_scale, token_scale):
from lightllm_ppl_int8_kernel import GateSilu_i32_i8

return GateSilu_i32_i8(
CONTIGUOUS_TENSOR(x), CONTIGUOUS_TENSOR(y),
CONTIGUOUS_TENSOR(token_scale), CONTIGUOUS_TENSOR(x_scale),
CONTIGUOUS_TENSOR(y_scale))
CONTIGUOUS_TENSOR(x),
CONTIGUOUS_TENSOR(y),
CONTIGUOUS_TENSOR(token_scale),
CONTIGUOUS_TENSOR(x_scale),
CONTIGUOUS_TENSOR(y_scale),
)


def gatesilu_i32_fp16_ppl(x, y, x_scale, y_scale, token_scale):
from lightllm_ppl_int8_kernel import GateSilu_i32_fp16

return GateSilu_i32_fp16(
CONTIGUOUS_TENSOR(x),
CONTIGUOUS_TENSOR(y),
CONTIGUOUS_TENSOR(token_scale),
CONTIGUOUS_TENSOR(x_scale),
CONTIGUOUS_TENSOR(y_scale),
)


def matmul_i8_i32_ppl(
A: torch.Tensor,
Expand All @@ -26,12 +48,14 @@ def matmul_i8_i32_ppl(
split_k_slices: int = 1,
) -> torch.Tensor:
from lightllm_ppl_int8_kernel import GemmForward_i8_i32
return GemmForward_i8_i32(
CONTIGUOUS_TENSOR(A), CONTIGUOUS_TENSOR(B), selected_algo, split_k_slices)

return GemmForward_i8_i32(CONTIGUOUS_TENSOR(A), CONTIGUOUS_TENSOR(B), selected_algo, split_k_slices)


def dynamic_channelwise_quant_fp16_i8_ppl(x: torch.Tensor, channel_idx=0, tp_rank=8):
x = x.transpose(0, 1).to(dtype=torch.float16).cuda(tp_rank)
from lightllm_ppl_int8_kernel import QuantizeTensor_LG

assert channel_idx < x.ndim, "channel index out of range"
# reorder channel to first dimension, then invoke group quantize impl.
num_of_channel = x.shape[channel_idx]
Expand All @@ -41,6 +65,10 @@ def dynamic_channelwise_quant_fp16_i8_ppl(x: torch.Tensor, channel_idx=0, tp_ran
qt = qt.view_as(_).transpose(0, channel_idx)
return qt, scale


def channel_token_dequant_i32_fp16_ppl(x: torch.Tensor, scale_tokenwise: torch.Tensor, scale_channelwise: torch.Tensor):
from lightllm_ppl_int8_kernel import PerTokenPerChannelDequant_i32_fp16
return PerTokenPerChannelDequant_i32_fp16(CONTIGUOUS_TENSOR(x), CONTIGUOUS_TENSOR(scale_tokenwise), CONTIGUOUS_TENSOR(scale_channelwise))

return PerTokenPerChannelDequant_i32_fp16(
CONTIGUOUS_TENSOR(x), CONTIGUOUS_TENSOR(scale_tokenwise), CONTIGUOUS_TENSOR(scale_channelwise)
)
150 changes: 113 additions & 37 deletions lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
skiprmsnorm_ppl,
channel_token_dequant_i32_fp16_ppl,
)
from lightllm.common.basemodel.cuda_kernel.ppl_awquant import dynamic_channelwise_quant_fp16_i8_ppl, gatesilu_i32_i8_ppl
from lightllm.common.basemodel.cuda_kernel.ppl_awquant import (
dynamic_channelwise_quant_fp16_i8_ppl,
gatesilu_i32_i8_ppl,
gatesilu_i32_fp16_ppl,
)
from lightllm.common.basemodel.triton_kernel.quantize_gemm_int8 import matmul_quantize_int8
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
from lightllm.utils.infer_utils import mark_cost_time
Expand Down Expand Up @@ -52,7 +56,7 @@ def _bind_func(self):
return

def _bind_norm(self):
if "ppl_w8a8" in self.mode:
if "ppl_w8a8" in self.mode or "ppl_w8a8_mixdown" in self.mode:
self._awquant_att_norm = partial(
LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_att_norm_ppl_int8, self
)
Expand All @@ -64,7 +68,7 @@ def _bind_norm(self):
return

def _bind_matmul(self):
if "ppl_w8a8" in self.mode:
if "ppl_w8a8" in self.mode or "ppl_w8a8_mixdown" in self.mode:
self._awquant_matmul_for_qkv = partial(
LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_matmul_ppl_int8_quant_dequant, self
)
Expand Down Expand Up @@ -283,6 +287,112 @@ def _awquant_silu_ppl_int8(self, x, y, x_scale, y_scale, token_scale):
return gatesilu_i32_i8_ppl(x, y, x_scale, y_scale, token_scale)


class LlamaTransformerLayerInferActivationWeightQuantPplMixdown(LlamaTransformerLayerInferActivationWeightQuantPpl):
""" """

def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
super(LlamaTransformerLayerInferActivationWeightQuantPpl, self).__init__(
layer_num, tp_rank, world_size, network_config, mode
)
self.eps_ = network_config["rms_norm_eps"]
self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_
self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_
self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_
self.tp_o_head_num_ = self.tp_q_head_num_
self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"]
self.embed_dim_ = network_config["hidden_size"]
self.inter_dim_ = network_config["intermediate_size"]
self._init_mixdown()
self._bind_func()
self._bind_ffn()
return

def _init_mixdown(self):
self.mixdown = self.network_config_.get("mixdown", list(range(self.network_config_["num_hidden_layers"])))
assert isinstance(self.mixdown, list), "mixdown must be all or a list."

def _bind_silu(self):
if "ppl_w8a8_mixdown" in self.mode:
if self.layer_num_ in self.mixdown:
func = partial(LlamaTransformerLayerInferActivationWeightQuantPplMixdown._awquant_silu_ppl_fp16, self)
else:
func = partial(LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_silu_ppl_int8, self)
self._awquant_silu = func
else:
raise Exception(f"error mode {self.mode}")
return

def _bind_ffn(self):
if "ppl_w8a8_mixdown" in self.mode:
if self.layer_num_ in self.mixdown:
func = partial(LlamaTransformerLayerInferActivationWeightQuantPplMixdown._ffn_down_fp16, self)
else:
func = partial(LlamaTransformerLayerInferActivationWeightQuantPplMixdown._ffn_down_int8, self)
self._ffn = func
else:
raise Exception(f"error mode {self.mode}")
return

def _ffn_down_int8(
self,
input,
token_scale,
infer_state: LlamaInferStateInfo,
layer_weight: LlamaTransformerLayerActivationWeightQuantPpl,
) -> torch.Tensor:
gate_out = self._awquant_matmul_for_ffn_up(
input.view(-1, self.embed_dim_),
layer_weight.gate_proj,
is_prefill=infer_state.is_prefill,
)
up_out = self._awquant_matmul_for_ffn_up(
input.view(-1, self.embed_dim_),
layer_weight.up_proj,
is_prefill=infer_state.is_prefill,
)
input = None
_, gate_proj_scale = layer_weight.gate_proj
_, up_proj_scale = layer_weight.up_proj
ffn1_out, ffn1_out_scale = self._awquant_silu(gate_out, up_out, gate_proj_scale, up_proj_scale, token_scale)
gate_out, up_out = None, None
ffn2_out = self._awquant_matmul_for_ffn_down(
ffn1_out, layer_weight.down_proj, is_prefill=infer_state.is_prefill, token_scale=ffn1_out_scale
)
ffn1_out = None

return ffn2_out

def _ffn_down_fp16(
self,
input,
token_scale,
infer_state: LlamaInferStateInfo,
layer_weight: LlamaTransformerLayerActivationWeightQuantPpl,
) -> torch.Tensor:
gate_out = self._awquant_matmul_for_ffn_up(
input.view(-1, self.embed_dim_),
layer_weight.gate_proj,
is_prefill=infer_state.is_prefill,
)
up_out = self._awquant_matmul_for_ffn_up(
input.view(-1, self.embed_dim_),
layer_weight.up_proj,
is_prefill=infer_state.is_prefill,
)
input = None
_, gate_proj_scale = layer_weight.gate_proj
_, up_proj_scale = layer_weight.up_proj
ffn1_out = self._awquant_silu(gate_out, up_out, gate_proj_scale, up_proj_scale, token_scale)
gate_out, up_out = None, None
ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj)
ffn1_out = None

return ffn2_out

def _awquant_silu_ppl_fp16(self, x, y, x_scale, y_scale, token_scale):
return gatesilu_i32_fp16_ppl(x, y, x_scale, y_scale, token_scale)


class LlamaTransformerLayerInferActivationWeightQuantTriton(TransformerLayerInferActivationWeightQuantTpl):
""" """

Expand Down Expand Up @@ -470,37 +580,3 @@ def _awquant_matmul_triton_w8a8(
if bias is not None:
out.add_(bias)
return out

def _awquant_matmul_ppl_int8_quant_dequant(
self, input, quant_weight_params, is_prefill, token_scale=None, out=None, bias=None, has_act=False
):
if input.dtype == torch.float16:
input, token_scale = dynamic_channelwise_quant_fp16_i8_ppl(input.transpose(0, 1))
assert has_act is False
qweight, qscale = quant_weight_params
out = matmul_i8_i32_ppl(input, qweight)
out = channel_token_dequant_i32_fp16_ppl(out, token_scale, qscale)
if bias is not None:
out.add_(bias)
return out

def _awquant_matmul_ppl_int8_quant(
self, input, quant_weight_params, is_prefill, out=None, bias=None, has_act=False
):
assert has_act is False
qweight, qscale = quant_weight_params
out = matmul_i8_i32_ppl(input, qweight)
if bias is not None:
out.add_(bias)
return out

def _awquant_att_norm_ppl_int8(self, input, infer_state: LlamaInferStateInfo, layer_weight):
if getattr(infer_state, "skip", None) is None:
infer_state.skip = torch.zeros_like(input)
return skiprmsnorm_ppl(input, layer_weight.att_norm_weight_, skip=infer_state.skip)

def _awquant_ffn_norm_ppl_int8(self, input, infer_state: LlamaInferStateInfo, layer_weight):
return skiprmsnorm_ppl(input, layer_weight.ffn_norm_weight_, skip=infer_state.skip)

def _awquant_silu_ppl_int8(self, x, y, x_scale, y_scale, token_scale):
return gatesilu_i32_i8_ppl(x, y, x_scale, y_scale, token_scale)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from lightllm.common.basemodel import TransformerLayerWeight
from lightllm.common.basemodel.cuda_kernel.ppl_awquant import dynamic_channelwise_quant_fp16_i8_ppl
from lightllm.common.basemodel.triton_kernel.quantize_gemm_int8 import quantize_int8
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class LlamaTransformerLayerActivationWeightQuantPpl(TransformerLayerWeight):
Expand Down Expand Up @@ -108,6 +111,54 @@ def _load_ffn_weights(self, weights):
return


class LlamaTransformerLayerActivationWeightQuantPplMixdown(LlamaTransformerLayerActivationWeightQuantPpl):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)
self._init_mixdown()

def _init_mixdown(self):
self.mixdown = self.network_config_.get("mixdown", list(range(self.network_config_["num_hidden_layers"])))
assert isinstance(self.mixdown, list), "mixdown must be all or a list."

def init_quant_mode(self):
if "ppl_w8a8_mixdown" in self.mode:
self.quantize_weight = partial(dynamic_channelwise_quant_fp16_i8_ppl, tp_rank=self.tp_rank_)
else:
raise Exception(f"error mode {self.mode}")

def _load_ffn_weights(self, weights):
if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights:
self.ffn_norm_weight_ = self._cuda(
weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]
)

inter_size = self.network_config_["intermediate_size"]
split_inter_size = inter_size // self.world_size_

if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights:
up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
]
self.up_proj = self.quantize_weight(up_proj.transpose(0, 1).to(self.data_type_))

if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights:
gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
]
self.gate_proj = self.quantize_weight(gate_proj.transpose(0, 1).to(self.data_type_))

if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights:
down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)
]
if self.layer_num_ in self.mixdown:
self.down_proj = self._cuda(down_proj.transpose(0, 1))
logger.info(f"layer {self.layer_num_} down_proj set to fp16")
else:
self.down_proj = self.quantize_weight(down_proj.transpose(0, 1))
return


class LlamaTransformerLayerActivationWeightQuantTriton(TransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)
Expand Down
5 changes: 5 additions & 0 deletions lightllm/models/llama_awquant/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from lightllm.models.llama_awquant.layer_weights.transformer_layer_weight import (
LlamaTransformerLayerActivationWeightQuantPpl,
LlamaTransformerLayerActivationWeightQuantPplMixdown,
LlamaTransformerLayerActivationWeightQuantTriton,
)
from lightllm.models.llama_awquant.layer_infer.transformer_layer_infer import (
LlamaTransformerLayerInferActivationWeightQuantPpl,
LlamaTransformerLayerInferActivationWeightQuantPplMixdown,
LlamaTransformerLayerInferActivationWeightQuantTriton,
)
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
Expand All @@ -28,6 +30,9 @@ def __init__(self, kvargs):
if "ppl_w8a8" in kvargs["mode"]:
self.transformer_weight_class = LlamaTransformerLayerActivationWeightQuantPpl
self.transformer_layer_infer_class = LlamaTransformerLayerInferActivationWeightQuantPpl
elif "ppl_w8a8_mixdown" in kvargs["mode"]:
self.transformer_weight_class = LlamaTransformerLayerActivationWeightQuantPplMixdown
self.transformer_layer_infer_class = LlamaTransformerLayerInferActivationWeightQuantPplMixdown
elif "triton_w8a8" in kvargs["mode"]:
self.transformer_weight_class = LlamaTransformerLayerActivationWeightQuantTriton
self.transformer_layer_infer_class = LlamaTransformerLayerInferActivationWeightQuantTriton
Expand Down
5 changes: 3 additions & 2 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,9 @@ def main():
default=[],
nargs="+",
help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding
| triton_gqa_attention | triton_gqa_flashdecoding]
[triton_w4a16 | triton_w8a16 | triton_w8a8 | lmdeploy_w4a16 | ppl_w4a16 | ppl_w8a8],
| triton_gqa_attention | triton_gqa_flashdecoding
| triton_w4a16 | triton_w8a16 | triton_w8a8 | lmdeploy_w4a16
| ppl_w4a16 | ppl_w8a8 | ppl_w8a8_mixdown],
triton_flashdecoding mode is for long context, current support llama llama2 qwen;
triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA;
triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
Expand Down

0 comments on commit b8501a9

Please sign in to comment.