Skip to content

Commit

Permalink
Fix gptq desc_act and static_group (#1395)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss authored Mar 20, 2024
1 parent edede40 commit 528d7de
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,27 @@ def set_weights_bias(
q_config,
bias=None,
):
if q_config.quant_method.value == "gptq" and (
q_config.desc_act and not q_config.static_groups
):
int_weight2 = int_weight.clone()
group_size = q_config.group_size
group_dict = {}
for i in range(len(g_idx)):
group_idx = g_idx[i].item()
if group_idx not in group_dict:
target_idx = group_idx * group_size
group_dict[group_idx] = 0

if q_config.quant_method.value == "gptq":
if q_config.desc_act:
if not q_config.static_groups:
int_weight2 = int_weight.clone()
group_size = q_config.group_size
group_dict = {}
for i in range(len(g_idx)):
group_idx = g_idx[i].item()
if group_idx not in group_dict:
target_idx = group_idx * group_size
group_dict[group_idx] = 0
else:
group_dict[group_idx] = group_dict[group_idx] + 1
target_idx = group_idx * group_size + group_dict[group_idx]
int_weight2[target_idx] = int_weight[i]
int_weight = int_weight2
else:
group_dict[group_idx] = group_dict[group_idx] + 1
target_idx = group_idx * group_size + group_dict[group_idx]
int_weight2[target_idx] = int_weight[i]
int_weight = int_weight2
g_idx = torch.empty(0, dtype=torch.int32)
else:
g_idx = torch.empty(0, dtype=torch.int32)

if q_config.bits == 4:
int_weight = (int_weight - 8) * 16
Expand All @@ -194,11 +199,7 @@ def set_weights_bias(
if q_config.sym:
gptq_zeros = torch.empty(0, dtype=torch.int8)

if (
q_config.quant_method.value != "gptq"
or q_config.static_groups
or (not q_config.desc_act)
):
if q_config.quant_method.value != "gptq":
g_idx = torch.empty(0, dtype=torch.int32)

packw = torch.ops.bestlaop.woq_packq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
convert_to_quantized_model,
replace_linear,
)
from ..llm.quantization.nn.modules import QuantizedLinearQBits
from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig
Expand All @@ -83,6 +82,7 @@ def recover_export_model(model, current_key_name=None):
Return optimum format model.
"""
from ..llm.quantization.nn.modules import QuantizedLinearQBits
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
Expand Down

0 comments on commit 528d7de

Please sign in to comment.