Skip to content

Commit

Permalink
[SQ bug] unifiy weight_amax for modules same input (#1139)
Browse files Browse the repository at this point in the history
unifiy weight_amax for modules same input, or qkv will get different scale and got wrong accuracy.
  • Loading branch information
xin3he authored Aug 7, 2023
1 parent d9d1fcc commit 8f36452
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 49 deletions.
40 changes: 22 additions & 18 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,19 +1388,21 @@ def qdq_quantize(self, model, tune_cfg):
assert not q_model._smoothquant_optimized, \
"The model is already optimized by smoothquant, cannot apply new alpha."
alpha = tune_cfg['recipe_cfgs']['smooth_quant_args']['alpha']
for op_name, info in sq_max_info.items():
for _, info in sq_max_info.items():
if alpha == 'auto':
alpha = info['alpha']
absorbed_layer = info['absorbed_layer']
input_minmax = info['input_minmax']
weight_max = info['weight_max']
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
input_power = torch.pow(abs_input_max, alpha)
weight_power = torch.pow(weight_max, 1 - alpha)
scale = torch.clip(input_power / weight_power, min=1e-5)
module = fetch_module(q_model, op_name)
new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha)
set_module(q_model, op_name, new_module)
logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}")
for op_name in absorbed_layer:
module = fetch_module(q_model, op_name)
new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha)
set_module(q_model, op_name, new_module)
logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}")

smoothquant_op_info = {'sq_linear': {}, 'qdq_linear': []}
stats_result['SQLinearWrapper'] = {'INT8(QDQ)': 0, 'BF16': 0, 'FP32': 0}
Expand Down Expand Up @@ -3117,27 +3119,29 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
from .torch_utils.model_wrapper import SQLinearWrapper
from .torch_utils.util import fetch_module
alpha = tune_cfg['recipe_cfgs']['smooth_quant_args']['alpha']
for op_name, info in sq_max_info.items():
for _, info in sq_max_info.items():
if alpha == 'auto':
alpha = info['alpha']
absorbed_layer = info['absorbed_layer']
input_minmax = info['input_minmax']
weight_max = info['weight_max']
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
input_power = torch.pow(abs_input_max, alpha)
weight_power = torch.pow(weight_max, 1 - alpha)
scale = torch.clip(input_power / weight_power, min=1e-5)
module = fetch_module(q_model._model, op_name)
new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha)
weight_scale = new_module._get_weight_scale()
smoothquant_scale_info[op_name] = {
'alpha': new_module.alpha,
'input_scale_for_mul': new_module.input_scale,
'input_scale_after_mul': new_module.scale,
'input_zero_point_after_mul': new_module.zero_point,
'input_dtype': new_module.dtype,
'weight_scale_after_mul': weight_scale,
}
logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}")
for op_name in absorbed_layer:
module = fetch_module(q_model._model, op_name)
new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha)
weight_scale = new_module._get_weight_scale()
smoothquant_scale_info[op_name] = {
'alpha': new_module.alpha,
'input_scale_for_mul': new_module.input_scale,
'input_scale_after_mul': new_module.scale,
'input_zero_point_after_mul': new_module.zero_point,
'input_dtype': new_module.dtype,
'weight_scale_after_mul': weight_scale,
}
logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}")

# Check save_qconf_summary part is a workaroud for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
Expand Down
23 changes: 19 additions & 4 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,15 @@ def transform(self, alpha=0.5, folding=False, percentile=99.999, op_types=['Line
if need_calibration: ##avoid multiple calibaration during tuning if the only difference is alpha
if self.insert_mul:
self.self_absorb_layers = self._get_all_layer_names() # TODO: only support linear now.
# fetch modules with the same input
group_modules = self._trace(op_types, skip_unsupported_layers=False)
for k, v in group_modules.items():
# use one input for qkv
for i in v:
if i in self.self_absorb_layers:
self.self_absorb_layers.pop(i)
self.self_absorb_layers[v[0]] = v
logger.debug(f"self_absorb_layers:{self.self_absorb_layers}")
if self.allow_absorb:
self.absorb_to_layer, no_absorb_layers = self._trace(
op_types) ##TODO we need to insert mul layer for no_absorb_layers later
Expand Down Expand Up @@ -836,7 +845,7 @@ def _get_example_input(self):

return self.example_inputs

def _trace(self, op_types):
def _trace(self, op_types, skip_unsupported_layers=True):
"""
Try the model to find the layers which can be smooth quantized.
:param op_types: The op types to be smooth quantized
Expand All @@ -846,7 +855,12 @@ def _trace(self, op_types):
"""
tg = GraphTrace()
self._get_example_input()
absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer(self.traced_model, self.example_inputs, op_types)
absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer(
self.traced_model, self.example_inputs, op_types,
skip_unsupported_layers=skip_unsupported_layers
)
if not skip_unsupported_layers:
return absorb_to_layer
if absorb_to_layer == None and no_absorb_layers == None:
logger.warning("sorry, could not trace the model, smooth quant is skipped")
logger.warning("if you are using huggingface model,"
Expand Down Expand Up @@ -994,7 +1008,7 @@ def mapping_torch_module_to_aten(self, op_types):
res = list(set(res))
return res

def get_absorb_to_layer(self, model, example_input, op_types):
def get_absorb_to_layer(self, model, example_input, op_types, skip_unsupported_layers=True):
traced_model = self.trace(model, example_input)
if traced_model == None:
return None, None
Expand All @@ -1019,7 +1033,8 @@ def get_absorb_to_layer(self, model, example_input, op_types):
absorb_to_layer[absorb_name].append(layer_name)
else:
absorb_to_layer[absorb_name] = [layer_name]
absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers)
if skip_unsupported_layers:
absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers)
return absorb_to_layer, no_absorb_layers

def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers):
Expand Down
65 changes: 38 additions & 27 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from neural_compressor.data import Datasets, DATALOADERS
from neural_compressor.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
from neural_compressor.adaptor.torch_utils.smooth_quant import TorchSmoothQuant
from neural_compressor.adaptor.torch_utils.model_wrapper import SQLinearWrapper
import logging
logger = logging.getLogger("neural_compressor")

Expand All @@ -22,6 +23,31 @@
TEST_IPEX = False


class DemoModel(torch.nn.Module):
def __init__(self):
super(DemoModel, self).__init__()
self.fc1 = torch.nn.Linear(3, 4)
self.fc2 = torch.nn.Linear(4, 3)

def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
return out

class DemoCalibDataloader:
def __init__(self):
self.batch_size = 1
def __iter__(self):
yield torch.randn([1, 3])


class LLMCalibDataloader:
def __init__(self):
self.batch_size = 1
def __iter__(self):
yield torch.ones([1, 3], dtype=torch.long)


class TestSqDepthwiseConv(unittest.TestCase):
@classmethod
def setUpClass(self):
Expand Down Expand Up @@ -579,7 +605,6 @@ def forward(self, x):

sq = TorchSmoothQuant(model, self.linear_dl)
sq.transform(alpha=0.5, calib_iter=1) # By default, folding=False
from neural_compressor.adaptor.torch_utils.model_wrapper import SQLinearWrapper
assert isinstance(sq.model.fc1, SQLinearWrapper)

def test_sq_quant(self):
Expand Down Expand Up @@ -617,7 +642,6 @@ def calib_func(model):
calib_dataloader=CalibDataloader(),
eval_func=lambda x: 0.1,
)
from neural_compressor.adaptor.torch_utils.model_wrapper import SQLinearWrapper
assert isinstance(q_model.model.fc1, SQLinearWrapper)

q_model.save('saved_result')
Expand All @@ -642,6 +666,7 @@ def calib_func(model):

# with calib_func
conf = PostTrainingQuantConfig(
example_inputs=input_ids,
recipes={"smooth_quant": True,
"smooth_quant_args": {'alpha': 'auto', 'folding': False}}
)
Expand Down Expand Up @@ -748,7 +773,17 @@ def forward(self, x):
sq = TorchSmoothQuant(model, self.linear_dl)
sq.transform(alpha='auto', calib_iter=1, folding=True)
#the layernorm could not used for sq-absorb because it outputs to an add op.
assert len(sq.absorb_to_layer) == 0
assert len(sq.absorb_to_layer) == 0

def test_sq_no_skip_op_auto(self):
model = transformers.AutoModelForCausalLM.from_pretrained(
'facebook/opt-125m', torchscript=True,
)
sq = TorchSmoothQuant(model, LLMCalibDataloader())
sq.transform(alpha='auto', calib_iter=0, folding=False)
# folding=False will absorb all Linears with mul, kqv will use same input.
assert len(sq.absorb_to_layer['model.decoder.layers.2.self_attn.q_proj']) == 3


class TestSqSkipOp_attn(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -801,30 +836,6 @@ def forward(self, x):
assert len(sq.absorb_to_layer) == 0


class DemoModel(torch.nn.Module):
def __init__(self):
super(DemoModel, self).__init__()
self.fc1 = torch.nn.Linear(3, 4)
self.fc2 = torch.nn.Linear(4, 3)

def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
return out

class DemoCalibDataloader:
def __init__(self):
self.batch_size = 1
def __iter__(self):
yield torch.randn([1, 3])


class LLMCalibDataloader:
def __init__(self):
self.batch_size = 1
def __iter__(self):
yield torch.ones([1, 3], dtype=torch.long)

class TestTuneSqAlpha(unittest.TestCase):
@classmethod
def setUpClass(self):
Expand Down

0 comments on commit 8f36452

Please sign in to comment.