Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add post training observer_quantizer #3915

Merged
merged 13 commits into from
Jul 28, 2021

Conversation

chenbohua3
Copy link
Contributor

@chenbohua3 chenbohua3 commented Jul 8, 2021

This PR adds a post training quantizer ObserverQuantizer. It uses PyTorch official HistogramObserver to calculate quantization information of activation and uses MinMaxObserver for weight. This quantizer has two advantages:

  1. Users can directly get quantization information of a pre-trained model with a little calibration data. Training is not needed.
  2. The quantization information generated by this quantizer can be used to initialize the scale and zero point of QAT quantizer, which will lead to a faster convergence and a better accuracy.

Also, users can easily customize their own observer, as long as the usage of the observer is consistent with that of PyTorch. I will upload our implementation of a classic post-training quantizer (e.g. Kullback-Leibler observer, used by TensorRT) in the next pr.

from nni.algorithms.compression.pytorch.quantization import ObserverQuantizer


class Mnist(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please import model directly from mnist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -119,6 +122,217 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma
return grad_output


class ObserverQuantizer(Quantizer):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add description of ObserverQuantizer here such as emphasizing some key points that it is an post-training quantization quantizer and it only works in evaluation mode currently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

scale, zero_point = observer.calculate_qparams()
return scale, zero_point

def quantize(self, x, scale, zero_point, qmin, qmax):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For class internal function, recommend using '_' as prefix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

calibration_config = quantizer.export_model(model_path, calibration_path)
print("calibration_config: ", calibration_config)

# For now the quantization settings of ObserverQuantizer does not match the TensorRT,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we support TensorRT currently? Is it for reasons that some runtime errors may be raised or the result will not be aligned between simulated quantization and TensorRT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the dtype and qscheme of PyTorch default observer are different with that in TensorRT. For example, TensorRT uses per_tensor_symmetric with uint8 for activation, PyTorch uses per_tensor_affine with quint8 for activation.
When customization of dtype and qscheme is ready, we can support TensorRT.

model(data)


def test_trt(engine, test_loader):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little strange that we keep this function without using it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have deleted it

def validate_config(self, model, config_list):
schema = CompressorSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]),
Optional('quant_bits'): Or(And(int, lambda n: n == 8), Schema({
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For post-training quantization, we support int8 right now. If we want to support all bit type or mixed precision, is there any obstacle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that there would not be major obstacles. The reason why we only support int8 right now is that PyTorch quantization observers only support 8 bit quantization (see here ). To support them, we should extend/customize the observers to support all bit type.

# NOTE: this quantizer is experimental for now. The dtype and qscheme of quantization
# is hard-coded.
# TODO:
# 1. support dtype and qscheme customization through config_list. Current settings:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if dtype and qscheme can be applied on each layer separately, then it is better to support them in config_list. otherwise, it is better to support them as quantizer's initialization argument

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, and I think it should be supported in config_list since weight and activation may use different settings and even different layers can use different settings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype and qscheme can be applied on each layer separately. Also we can customize validate_config function to control the rules for each quantizer

def calibration(model, device, test_loader):
model.eval()
with torch.no_grad():
for data, target in test_loader:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend using _ to substitute target if target is not used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# NOTE: this quantizer is experimental for now. The dtype and qscheme of quantization
# is hard-coded.
# TODO:
# 1. support dtype and qscheme customization through config_list. Current settings:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, and I think it should be supported in config_list since weight and activation may use different settings and even different layers can use different settings.

@linbinskn
Copy link
Contributor

linbinskn commented Jul 19, 2021

Post-training quantization is a critical feature in model compression and I think this pr is a good start for nni to support it. I have some thoughts related to this pr and current NNI quantization design:

  • Both bit type choice and observers should be supported in config_list in layer level. The reason is:

    1. For post-training quantization, in addition to 8bit, the other bit types like 4-bit and 1bit are also supported in hardware such as NVIDIA GPU(>= turing architecture). And more bit type would be supported in future hardwares which enable more quantization data types.
    2. For deployment, different backend frameworks may use different observers. We have to offer general support for them by setting config_list. For research, researchers may want to support different observers intro layer(activation/weight) and inter layer to get best performance.
  • Current design of quantization in NNI is not good enough and it is important for us to define post-training quantization and quantization aware training(not only QAT) clearly in model compression and design in NNI. For instance, I think one potential abstraction is post-training is a one-shot quantization algorithm, and quantization aware training algorithms(not only QAT) are different finetuning methods above it. Should we separate them as different quantizers? Or they are different stages in the whole quantization pipeline(post-training quantization is necessary and quantization aware training is optional) and they can be applied sequentially.

I think we don't need to modify a lot in this pr. We can discuss them and make an appropriate design and customize corresponded API gradually. I am very glad to attend contributions of this part to make it better.

@QuanluZhang
Copy link
Contributor

Post-training quantization is a critical feature in model compression and I think this pr is a good start for nni to support it. I have some thoughts related to this pr and current NNI quantization design:

  • Both bit type choice and observers should be supported in config_list in layer level. The reason is:

    1. For post-training quantization, in addition to 8bit, the other bit types like 4-bit and 1bit are also supported in hardware such as NVIDIA GPU(>= turing architecture). And more bit type would be supported in future hardwares which enable more quantization data types.
    2. For deployment, different backend frameworks may use different observers. We have to offer general support for them by setting config_list. For research, researchers may want to support different observers intro layer(activation/weight) and inter layer to get best performance.
  • Current design of quantization in NNI is not good enough and it is important for us to define post-training quantization and quantization aware training(not only QAT) clearly in model compression and design in NNI. For instance, I think one potential abstraction is post-training is a one-shot quantization algorithm, and quantization aware training algorithms(not only QAT) are different finetuning methods above it. Should we separate them as different quantizers? Or they are different stages in the whole quantization pipeline(post-training quantization is necessary and quantization aware training is optional) and they can be applied sequentially.

I think we don't need to modify a lot in this pr. We can discuss them and make an appropriate design and customize corresponded API gradually. I am very glad to attend contributions of this part to make it better.

thanks for the discussion. in the current stage, it is not good to put bit types in config_list, because users can easily specify different bit types for different layers. If the supported quantization algorithm does not support different bit types on different layers, it would be error-prone, not user friendly.

about the refactor of quantization in NNI. we can think about how to adapt quantizer into similar modularized framework as NNI pruners. observer in quantization is very similar to "metric calculator" in our refined NNI pruning framework. on the other hand, we should survey all the quantization aware training to figure out whether it can be seen as a type of fine-tuning.

@J-shang
Copy link
Contributor

J-shang commented Jul 19, 2021

Post-training quantization is a critical feature in model compression and I think this pr is a good start for nni to support it. I have some thoughts related to this pr and current NNI quantization design:

  • Both bit type choice and observers should be supported in config_list in layer level. The reason is:

    1. For post-training quantization, in addition to 8bit, the other bit types like 4-bit and 1bit are also supported in hardware such as NVIDIA GPU(>= turing architecture). And more bit type would be supported in future hardwares which enable more quantization data types.
    2. For deployment, different backend frameworks may use different observers. We have to offer general support for them by setting config_list. For research, researchers may want to support different observers intro layer(activation/weight) and inter layer to get best performance.
  • Current design of quantization in NNI is not good enough and it is important for us to define post-training quantization and quantization aware training(not only QAT) clearly in model compression and design in NNI. For instance, I think one potential abstraction is post-training is a one-shot quantization algorithm, and quantization aware training algorithms(not only QAT) are different finetuning methods above it. Should we separate them as different quantizers? Or they are different stages in the whole quantization pipeline(post-training quantization is necessary and quantization aware training is optional) and they can be applied sequentially.

I think we don't need to modify a lot in this pr. We can discuss them and make an appropriate design and customize corresponded API gradually. I am very glad to attend contributions of this part to make it better.

thanks for the discussion. in the current stage, it is not good to put bit types in config_list, because users can easily specify different bit types for different layers. If the supported quantization algorithm does not support different bit types on different layers, it would be error-prone, not user friendly.

about the refactor of quantization in NNI. we can think about how to adapt quantizer into similar modularized framework as NNI pruners. observer in quantization is very similar to "metric calculator" in our refined NNI pruning framework. on the other hand, we should survey all the quantization aware training to figure out whether it can be seen as a type of fine-tuning.

agree, maybe we need a meeting to discuss these, and rethinking what should in compressor and what in config_list. Are all post-training quantization algorithms based on observers?

@@ -120,6 +123,222 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma
return grad_output


class ObserverQuantizer(Quantizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a ut for this quantizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

scale, zero_point = self.calculate_qparams(layer.name, 'weight')
module.register_buffer('weight_scale', scale.to(self.device))
module.register_buffer('weight_zero_point', zero_point.to(self.device))
# todo: recover old_weight to weight, because the compressed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain more about the case that should recover old_weight to weight?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because after being wrapped by QuantizerModuleWrapper, two things may happen to the weight of each layer:

  1. The type may changed from torch.nn.Parameter to torch.Tensor
  2. The weight of BN may have been folded.

Basically, we need to ensure that the structure/parameter types of the model are consistent before and after PTQ. In theory, we can use the original model to perform downstream tasks (like qat or deployment), since the current ptq will not change the weight of the model. But users may also use the model exported by the ptq for the downstream tasks, so I think it is better to recover them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added some code logics in export_model for recovering weight

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. suppose quantizer's export should export the model after quantization. it is not proper to export original model.

  2. I notice the inconsistency among the supported quantizers. some quantizers deal with fold bn, some others not. some quantizers recover weight, some others not. could you explain the current status a little bit more?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, I replace the origin weight with the pseudo-quantized weight.

@chenbohua3
Copy link
Contributor Author

Is it necessary to add documentation?I think it’s better to leave this until the dtype/quant scheme customization is ready. Otherwise there will be many NOTEs or TODOs in the doc:)

@chenbohua3 chenbohua3 closed this Jul 26, 2021
@chenbohua3 chenbohua3 reopened this Jul 26, 2021
@QuanluZhang
Copy link
Contributor

Is it necessary to add documentation?I think it’s better to leave this until the dtype/quant scheme customization is ready. Otherwise there will be many NOTEs or TODOs in the doc:)

agree

else:
self.record(wrapper, 'weight', old_weight)
new_weight = old_weight
return new_weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we never use this new_weight? because there isn't something like module.weight = new_weight

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, new_weight will not be used. I just return it like what QAT quantizer does. It also returns unused new_weight. Do I need to delete it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the quantization simulation after compress(), we need to add module.weight = new_weight.

        if self.compressed:
            new_weight = self._quantize(old_weight,
                                       module.weight_scale,
                                       module.weight_zero_point,
                                       module.weight_qmin,
                                       module.weight_qmax)
            module.weight = new_weight

Will we get the correct inference result in this way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, have corrected it

calibration_config[name]['weight_bit'] = 8
val = float(module.weight_scale * module.weight_qmax)
calibration_config[name]['tracked_min_weight'] = val
calibration_config[name]['tracked_max_weight'] = -val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is calibration_config[name]['tracked_min_weight'] = -val and calibration_config[name]['tracked_max_weight'] = val?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, have corrected it

@chenbohua3
Copy link
Contributor Author

I have removed the codes for replacing weight with quantized one in compress function. @QuanluZhang @J-shang

new_parameters = dict(model.named_parameters())
self.assertTrue(all(torch.equal(v, new_parameters[k]) for k, v in origin_parameters.items()))
self.assertTrue(calibration_config is not None)
self.assertTrue(len(calibration_config) == 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update test accordingly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

modules_to_compress = self.get_modules_to_compress()
all_observers = defaultdict(dict)
weight_q_min, weight_q_max = -127, 127
activation_q_min, activation_q_max = 0, 127 # reduce_range is set to True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we set quantized activation range to (0, 127) instead of (0,255)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, activation observer's reduce_range is set to True. This means that the range of the quantized data type is reduced by 1 bit. This is sometimes required to avoid instruction overflow.
However, there does exist a mismatch between here and that in export_model, I have corrected it.

@QuanluZhang QuanluZhang merged commit 370e88d into microsoft:master Jul 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants