Skip to content

Commit

Permalink
cherry pick #42255 (fuse conv + bn in QAT) and #42378 (support skip_o…
Browse files Browse the repository at this point in the history
…p_list in PTQ) (#43301)

* support fuse conv and bn in QAT (#42255)

* support skip_op_list in PostTrainingQuantization (#42378)

* fix unittest
  • Loading branch information
yghstill committed Jun 9, 2022
1 parent f4e0939 commit 0a00fc4
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ def forward(self, input):
return input


def fuse_conv_bn(model):
is_train = False
if model.training:
model.eval()
is_train = True
fuse_list = []
tmp_pair = [None, None]
for name, layer in model.named_sublayers():
if isinstance(layer, nn.Conv2D):
tmp_pair[0] = name
if isinstance(layer, nn.BatchNorm2D):
tmp_pair[1] = name

if tmp_pair[0] and tmp_pair[1] and len(tmp_pair) == 2:
fuse_list.append(tmp_pair)
tmp_pair = [None, None]
model = fuse_layers(model, fuse_list)
if is_train:
model.train()


def fuse_layers(model, layers_to_fuse, inplace=False):
'''
fuse layers in layers_to_fuse
Expand Down
10 changes: 10 additions & 0 deletions python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings

import paddle
import paddle.nn as nn
import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.framework import IrGraph
Expand All @@ -32,6 +33,7 @@
from paddle.fluid.log_helper import get_logger
from .. import quantization_pass
from . import utils
from . import fuse_utils

__all__ = ['ImperativeQuantAware']

Expand All @@ -52,6 +54,7 @@ def __init__(
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
fuse_conv_bn=False,
weight_preprocess_layer=None,
act_preprocess_layer=None,
weight_quantize_layer=None,
Expand All @@ -76,6 +79,7 @@ def __init__(
activation_bits(int): quantization bit number for activations.
moving_rate(float): the parameter for 'moving_average_abs_max'
quantization.
fuse_conv_bn(bool): Whether to fuse conv and bn, default is False.
weight_preprocess_layer(paddle.nn.Layer, optional): A paddle
Layer that defines how to preprocess weight before quantization.
Using this can quickly test if user's preprocess method works
Expand Down Expand Up @@ -188,6 +192,7 @@ def forward(self, inputs):
model_path="./imperative_model_qat")
"""
super(ImperativeQuantAware, self).__init__()
self.fuse_conv_bn = fuse_conv_bn

kwargs = {
"quantizable_layer_type": quantizable_layer_type,
Expand Down Expand Up @@ -256,8 +261,13 @@ def forward(self, inputs):
"""
assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer."

if self.fuse_conv_bn:
fuse_utils.fuse_conv_bn(model)

self._quantize_inputs.apply(model)
self._quantize_outputs.apply(model)
return model

def save_quantized_model(self, layer, path, input_spec=None, **config):
self._quantize_outputs.save_quantized_model(layer, path, input_spec,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(self,
onnx_format=False,
optimize_model=False,
is_use_cache_file=False,
skip_tensor_list=None,
cache_dir=None):
'''
Constructor.
Expand Down Expand Up @@ -198,6 +199,7 @@ def __init__(self,
the model accuracy is usually higher when it is 'channel_wise_abs_max'.
onnx_format(bool): Whether to export the quantized model with format of ONNX.
Default is False.
skip_tensor_list(list): List of skip quant tensor name.
optimize_model(bool, optional): If set optimize_model as True, it applies
some passes to the model before quantization, and it supports
`conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
Expand Down Expand Up @@ -301,6 +303,7 @@ def __init__(self,
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._onnx_format = onnx_format
self._skip_tensor_list = skip_tensor_list
self._is_full_quantize = is_full_quantize
if is_full_quantize:
self._quantizable_op_type = self._support_quantize_op_type
Expand Down Expand Up @@ -547,6 +550,12 @@ def collect_var_name(var_name_list, persistable_var_names, op_type):
persistable_var_names = _all_persistable_var_names(self._program)
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
# skip quant form self._skip_tensor_list
if self._skip_tensor_list is not None:
for inp_name in utils._get_op_input_var_names(op):
if inp_name in self._skip_tensor_list:
op._set_attr("op_namescope", "skip_quant")

op_type = op.type
if self._is_full_quantize and \
op_type not in self._quantizable_op_type:
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ set_tests_properties(test_quantization_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_qat_channelwise PROPERTIES TIMEOUT 200)
set_tests_properties(test_user_defined_quantization PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_fuse PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ def set_vars(self):
self.onnx_format = False
self.check_export_model_accuracy = True
self.diff_threshold = 0.01
self.fuse_conv_bn = False

def func_qat(self):
self.set_vars()

imperative_qat = ImperativeQuantAware(
weight_quantize_type=self.weight_quantize_type,
activation_quantize_type=self.activation_quantize_type)
activation_quantize_type=self.activation_quantize_type,
fuse_conv_bn=self.fuse_conv_bn)

with fluid.dygraph.guard():
# For CI coverage
Expand Down Expand Up @@ -214,6 +216,7 @@ def set_vars(self):
self.activation_quantize_type = 'moving_average_abs_max'
self.onnx_format = True
self.diff_threshold = 0.025
self.fuse_conv_bn = False


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def set_vars(self):
self.activation_quantize_type = 'moving_average_abs_max'
self.diff_threshold = 0.01
self.onnx_format = False
self.fuse_conv_bn = False
print('weight_quantize_type', self.weight_quantize_type)


Expand All @@ -52,6 +53,7 @@ def set_vars(self):
self.activation_quantize_type = 'moving_average_abs_max'
self.onnx_format = True
self.diff_threshold = 0.025
self.fuse_conv_bn = False
print('weight_quantize_type', self.weight_quantize_type)


Expand Down
50 changes: 50 additions & 0 deletions python/paddle/fluid/contrib/slim/tests/test_imperative_qat_fuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

from __future__ import print_function

import os
import numpy as np
import random
import unittest
import logging

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger

from test_imperative_qat import TestImperativeQat

paddle.enable_static()

os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')


class TestImperativeQatfuseBN(TestImperativeQat):
def set_vars(self):
self.weight_quantize_type = 'abs_max'
self.activation_quantize_type = 'moving_average_abs_max'
self.diff_threshold = 0.01
self.onnx_format = False
self.fuse_conv_bn = True


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def generate_quantized_model(self,
is_optimize_model=False,
batch_size=10,
batch_nums=10,
onnx_format=False):
onnx_format=False,
skip_tensor_list=None):

place = fluid.CPUPlace()
exe = fluid.Executor(place)
Expand All @@ -132,6 +133,7 @@ def generate_quantized_model(self,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model_path)
Expand All @@ -150,7 +152,8 @@ def run_test(self,
batch_size=10,
infer_iterations=10,
quant_iterations=5,
onnx_format=False):
onnx_format=False,
skip_tensor_list=None):

origin_model_path = self.download_model(data_url, data_md5, model_name)
origin_model_path = os.path.join(origin_model_path, model_name)
Expand All @@ -162,10 +165,10 @@ def run_test(self,

print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size))
self.generate_quantized_model(origin_model_path, algo, round_type,
quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model,
batch_size, quant_iterations, onnx_format)
self.generate_quantized_model(
origin_model_path, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model, batch_size,
quant_iterations, onnx_format, skip_tensor_list)

print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size))
Expand Down Expand Up @@ -422,5 +425,38 @@ def test_post_training_mse_onnx_format_full_quant(self):
onnx_format=onnx_format)


class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
def test_post_training_avg_skip_op(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "avg"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
skip_tensor_list = ["fc_0.w_0"]
self.run_test(
model_name,
data_url,
data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size,
infer_iterations,
quant_iterations,
skip_tensor_list=skip_tensor_list)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ def generate_quantized_model(self,
is_full_quantize=False,
is_use_cache_file=False,
is_optimize_model=False,
onnx_format=False):
onnx_format=False,
skip_tensor_list=None):
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
Expand All @@ -264,6 +265,7 @@ def generate_quantized_model(self,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
Expand All @@ -279,7 +281,8 @@ def run_test(self,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False):
onnx_format=False,
skip_tensor_list=None):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
Expand All @@ -293,10 +296,10 @@ def run_test(self,

print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model",
quantizable_op_type, algo, round_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, onnx_format)
self.generate_quantized_model(
model_cache_folder + "/model", quantizable_op_type, algo,
round_type, is_full_quantize, is_use_cache_file, is_optimize_model,
onnx_format, skip_tensor_list)

print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
Expand Down Expand Up @@ -444,5 +447,38 @@ def test_post_training_onnx_format_mobilenetv1(self):
onnx_format=onnx_format)


class TestPostTrainingForMobilenetv1SkipOP(TestPostTrainingQuantization):
def test_post_training_mobilenetv1_skip(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
skip_tensor_list = ["fc_0.w_0"]
self.run_test(
model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
skip_tensor_list=skip_tensor_list)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0a00fc4

Please sign in to comment.