Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] MMRazor Quantization Design #347

Open
humu789 opened this issue Nov 11, 2022 · 5 comments
Open

[RFC] MMRazor Quantization Design #347

humu789 opened this issue Nov 11, 2022 · 5 comments
Assignees
Labels
RFC Request for Comments

Comments

@humu789
Copy link
Collaborator

humu789 commented Nov 11, 2022

Motivation

  1. To design and implement the better quantization part of MMRazor with community.

  2. Collect more requirements and suggestions before releasing quantization by RFC (Request for Comments).

Overview

MMRazor quantization will be an algorithm platform not just provide basic quantization function api. We hope it will help us in the following ways:

  1. Compress and deploy your model faster.

  2. Producing better models with our quantization algorithms.

  3. Implement some novel quantization algorithms easier.

Goals

  1. Support implementing mainstream QAT and PTQ algorithms, such as LSQ, Adaround and so on.

  2. Support complete working pipeline from quantization to deployment. You can deploy quantized models on multiple backends with mmdploy.

  3. Adaptive OpenMMLab 2.0. Thus it can unified support OpenMMLab upstream repositories without extra code.

  4. Easier to use. You can quantize your model just by modifying config and running script, rather than modify your source model.

Algorithms

We plan to support some quantization algorithms in future as follows. Welcome to propose your requirements.

QAT

  1. LSQ

  2. LSQ+

  3. IAO

......

PTQ

  1. Adaround

  2. BRECQ

  3. QDrop

......

Main features

We list some main features to be supported in future. Welcome to comment.

  1. Quantization type: QAT and PTQ(static/dynamic)

  2. Quantization bits: 1 ~ 32 Note: 1 bit is not binaryzation, just common quantization.

  3. Quantization methods (uniform quantization):

    1. per_tensor / per_channel
    2. symmetry / asymmetry
    3. FP_scale / Pot_scale (power of two)
  4. Multiple backends:

    1. TensorRT
    2. SNPE
    3. ncnn
    4. .....

Some algorithms and features to be supported will be implemented in the next several versions due to lack of manpower, welcome to create PRs to speed up development.

Most features will be released in the first release, except dynamic quantization and more backends supporting. According to quantization algorithms, we will release them by ranks in the next two versions.

Release plan

We will release our first version in December 2022 if everything goes well.

Design and Implement

We will extend and develop to implement our design based on PyTorch basic quantization function api and torch.fx. So some modules in PyTorch will be inherited and also some new modules will be created.

User-friendly config

We will use Qscheme to convert user-friendly config to API oriented parameters. Demo config is as follows.

_base_ = [
    'mmcls::resnet/resnet18_8xb32_in1k.py'
]

model = dict(
    _delete_=True,
    type='mmrazor.GeneralQuant',
    architecture=_base_.model,
    quantizer=dict(
        type='mmrazor.CustomQuantizer',
        is_qat=False,
        # `skipped_methods` is to trace model automatically by skipping 
        # these untraced method.
        skipped_methods=[
            'mmcls.models.heads.ClsHead._get_loss', 
            'mmcls.models.heads.ClsHead._get_predictions'],
        qconfig=dict(
            qtype='affine',
            w_observer=dict(type='mmrazor.MSEObserver'),
            a_observer=dict(type='mmrazor.EMAMSEObserver'),
            w_fake_quant=dict(type='mmrazor.AdaRoundFakeQuantize'),
            a_fake_quant=dict(type='mmrazor.FakeQuantize'),
            w_qscheme=dict(
                bit=2,
                is_symmetry=False,
                is_per_channel=True,
                is_pot_scale=False,
            ),
            a_qscheme=dict(
                bit=4,
                is_symmetry=False,
                is_per_channel=False,
                is_pot_scale=False),
        )
    )
)

Usage

Quantization algorithms' entrance is like other model compression algorithms as follows.

QAT: tools/train.py

PTQ: tools/test.py

Deploy quantized model's entrance is mmdeploy/tools/deploy.py. So you can just run the following commands to implement the pipeline from quantization to deploy

python mmrazor/tools/train.py (test.py)
python mmdeploy/tools/deploy.py

For more details about the above commands, please refer to the quantization document to be released.

Core modules

  1. Observers

In forward, they will update the statistics of the observed Tensor. And they should provide a calculate_qparams function that computes the quantization parameters given the collected statistics.

from torch.ao.quantization.observer import UniformQuantizationObserverBase

class BaseObserver(UniformQuantizationObserverBase):

    min_val: torch.Tensor
    max_val: torch.Tensor

    def __init__(
        self,
        dtype=torch.quint8,
        qscheme=torch.per_tensor_affine,
        reduce_range=False,
        quant_min=None,
        quant_max=None,
        ch_axis=-1,
        is_pot_scale=False,
        factory_kwargs=None,
        eps=torch.finfo(torch.float32).eps) -> None:
        super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, 
            factory_kwargs, eps)
        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
        self.ch_axis = ch_axis
        self.is_pot_scale = is_pot_scale

    @torch.jit.export
    def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Calculates the quantization parameters."""
        pass

    @torch.jit.export
    def extra_repr(self):
        pass

    @torch.jit.export
    def reset_min_max_vals(self):
        """Resets the min/max values."""
        pass
  1. FakeQuantizes

In forward, they will update the statistics of the observed Tensor and fake quantize the input. They should also provide a calculate_qparams function that computes the quantization parameters given the collected statistics.

In fake quantize, you can implement some algorithms' special operations.

from torch.ao.quantization import FakeQuantizeBase
from mmrazor.registry import MODELS

@MODELS.register_module()
class FakeQuantize(FakeQuantizeBase):

    scale: torch.Tensor
    zero_point: torch.Tensor

    def __init__(self, observer, **observer_kwargs):
        super().__init__()
        self.activation_post_process = observer(**observer_kwargs)
        self.quant_min = self.activation_post_process.quant_min
        self.quant_max = self.activation_post_process.quant_max
        if _is_float_qparams(self.activation_post_process.qscheme):
            zero_point_dtype = torch.float
        else:
            zero_point_dtype = torch.int
        self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
        self.register_buffer('zero_point', torch.tensor([0], dtype=zero_point_dtype))
        self.dtype = self.activation_post_process.dtype
        self.qscheme = self.activation_post_process.qscheme
        self.ch_axis = self.activation_post_process.ch_axis \
            if hasattr(self.activation_post_process, 'ch_axis') else -1
        assert _is_per_channel(self.qscheme) or \
            _is_per_tensor(self.qscheme), \
            'Only per channel and per tensor quantization are supported in fake quantize' + \
            ' got qscheme: ' + str(self.qscheme)
        self.is_per_channel = _is_per_channel(self.qscheme)
        
        bitrange = torch.tensor(self.quant_max - self.quant_min + 1).double()
        self.bitwidth = int(torch.log2(bitrange).item())
        self.is_pot_scale = self.activation_post_process.is_pot_scale
        self.is_symmetric_quant = _is_symmetric_quant(self.qscheme)

    @torch.jit.export
    def calculate_qparams(self):
        return self.activation_post_process.calculate_qparams()

    def forward(self, X):
        if self.observer_enabled[0] == 1:
            pass

        if self.fake_quant_enabled[0] == 1:
            pass
            
        return X

    @torch.jit.export
    def extra_repr(self):
        pass

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        pass
        
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        pass
  1. Quantizers

They implement some core quantization function APIs for algorithm, such as qconfig_convert, prepare,convert_model, fuse_model, and so on. What is more, different quantizers can deal with different backends to be deployed, thus we can configure it in the config for different backends.

@MODELS.register_module()
class CustomQuantizer(BaseModule):

    def __init__(self,
                 qconfig=DefalutQconfigs['default'],
                 is_qat=True,
                 skipped_methods=None,
                 prepare_custom_config_dict=None,
                 convert_custom_config_dict=None,
                 equalization_qconfig_dict=None,
                 _remove_qconfig=True,
                 init_cfg=None):
        super().__init__(init_cfg)
        if self.check_qconfig(qconfig):
            qconfig = self.qconfig_convert(qconfig)
            self.qconfig_dict = {"": qconfig}
        else:
            raise ValueError('qconfig is incorrect!')
        
        if prepare_custom_config_dict is None:
            self.prepare_custom_config_dict = {}
        else:
            self.prepare_custom_config_dict = prepare_custom_config_dict
        if convert_custom_config_dict is None:
            self.convert_custom_config_dict = {}
        else:
            self.convert_custom_config_dict = convert_custom_config_dict
        if equalization_qconfig_dict is None:
            self.equalization_qconfig_dict = {}
        else:
            self.equalization_qconfig_dict = equalization_qconfig_dict

        check_is_valid_qconfig_dict(self.qconfig_dict)
        check_is_valid_prepare_custom_config_dict(self.prepare_custom_config_dict)
        check_is_valid_convert_custom_config_dict(self.convert_custom_config_dict)
        check_is_valid_qconfig_dict(self.equalization_qconfig_dict)
        
        self.is_qat = is_qat
        self.skipped_methods = skipped_methods
        self._remove_qconfig = _remove_qconfig
        self.tracer = self.build_tracer()

    def prepare(self, model, graph_module):
        pass

    def convert(self, graph_module):
        pass
        
    def qconfig_convert(self, qconfig):
        pass

    def build_tracer(self):
        pass
    
    def fuse_model(self, graph_module):
        pass
    
    .....
    
  1. Algorithms

They will provide some core APIs for Quantization Loops to implement quantization pipelines. Such as calib_step,prepare,convert and so on. And Algorithms also maintain traced graphs and forward with graphs.

  1. Quantization Loops

They inherited mmengine's TrainLoop and TestLoop, adding some core quantization steps, such as calibrate,preprare,convert. There are also some special steps for some quantization algorithms, such as subgraph reconstruction.

How to trace the model automatically

Because torch.fx has its own limitations, some models' forward can not be traced when there are some special cases in forward, such as dynamic judgment.

For tracing the model automatically, we custom a CustomTracer and UntracedMethodRegistry. UntracedMethodRegistry can be used as a decorator to make decorated methods skipped by CustomTracer. What is more, methods to be skipped can be configured in our configs. Please refer to the chapter User-friendly config to learn about its usage.

So the solution is as follows.

  1. Collect these untraceable codes to a function or a method and make the rest of the pipeline traceable. In OpenMMLab 2.0, we refactored some model interfaces to adapt torch.fx preliminary.

  2. Specified these methods to be skipped in our configs.

WIP code

For more details about the implementation, please refer to the branch: https://github.com/open-mmlab/mmrazor/tree/quantize

Note:

The quantize branch is in development, modifying code will happen at any time.

@humu789 humu789 added the RFC Request for Comments label Nov 11, 2022
@humu789 humu789 pinned this issue Nov 14, 2022
@ZhangZhiPku
Copy link

ZhangZhiPku commented Nov 16, 2022

训练算法并不是网络量化的核心

  1. 量化的规则是神经网络量化的核心。
  2. 先有正确的模拟,才有高精度的量化。
  3. 网络调度、误差分析、以及计算图的变换与融合,是量化系统必须的能力。
  4. 开放的接口,灵活可靠的量化控制是解决复杂量化问题的基础。
  5. 基于训练的算法只是量化算法的一小部分。

1. 符合规则的量化

说起量化的规则,我们常常会提到量化位宽、对称-非对称量化、逐层-逐通道量化、以及power-of-2的量化等等概念,他们是量化中真正影响量化系统精度的重要因素,但并不是全部。我们设计的 Qscheme 对象可以对上述重要的因素有描述的能力,但我将向你介绍有关量化规则的一些其他内容。

有那些图模式可以融合

我们熟知的融合是激活函数融合,这是简单的;pytorch.fx也提供了相应的接口帮助我们完成这些功能。但实际的图融合可以更加复杂,甚至可以不遵循特定的图模式。在 TensorRT 中,一种常见的图融合模式是 Conv - Add 融合,Add 算子将被融合进入 Conv,因此 Add 节点将不需要量化信息,或仅需要单侧分支的量化信息。以目前的接口设计,这样的融合几乎不可能被正确的模拟。一种更复杂的融合是 Add - Mul - Div 等串联的元素级算子的融合(PWN),对于 TensorRT 而言,其可以融合最多64个串联的,执行元素级操作的算子,不论被融合的算子类型是什么。对于被融合的串联结构块,其中所有节点不需要量化信息,只有头尾节点需要量化。目前的设计中我们很难处理这样的融合。

特殊算子的量化规则

对于一个向 int8 量化的网络而言,我们不仅仅需要讨论量化是否是对称的;是否是 per-channel 的;是否是 power-of-2 的;我们还需要讨论:

  • Average pool, max pool 算子,输入输出必须共享 scale
  • Concat 算子,输入输出必须共享 scale
  • Add, Sub, 输入必须共享 scale
  • Conv, Gemm 算子,bias scale 必须是 input scale * weight scale
  • Log, Exp, Swish, Sigmoid, Tanh 算子,Output scale 必须是 1/128
  • Reshape, Resize, SpaceToDepth, Gather, Squeeze, Unsqueeze... 输入输出必须共享 scale
  • Rounding policy,包括 Resize, average pooling 算子的内部计算逻辑,在他们执行 int8 除法时如何 round
  • Clip 算子,min, max 必须与 Input 共享 scale
  • Pad 算子,padding value 必须与 Input 共享 scale

我们在系统的设计之初就要考虑到如何支持这些特殊算子的量化规则,如果不按照上述规则完成量化,则要么无法完成量化网络的部署,要么无法得到正确的量化模拟,要么得到性能极差的量化网络。事实上目前设计极其单薄的 Qscheme 不足以描述这些特殊的量化规则——一个网络的量化并不是由一个个独立的 Qscheme 就能完成描述的,相反地他们彼此之间一定会存在 相互依赖 的关系,会出现一些主要量化控制信息,和一些从属的量化控制信息。以下图为例,考虑到 Concat, average pooling 的量化约束,全图中几乎所有激活值的量化都从属于average pool的输出量化,Qscheme 作为唯一的量化控制结构体,必须能够正确表达上述逻辑。这是网络成功量化的根本
Time series

如果我们罔顾上述约束,只考虑卷积和矩阵乘的量化,会得出一系列独立的量化信息,对于 Openvino, TensorRT 而言,这都将导致极低的量化网络执行效率。考虑上图中的例子,如果我们令4个卷积的输出分别具有不同的 scale,则 TensorRT 会在他们与 concat 层中间插入 reformat 节点转换数据流的 scale,这些 reformat 节点的执行时间与卷积本身差不多长,将导致网络整体的运行效率降低 100%,这完全耗尽了量化所带来的性能收益。另外,上述约束条件并非一成不变。对于一些框架而言,他们的bias可以独立于input和weight单独量化,从 Qscheme 的设计上我们需要提供这样的兼容性。

2. 模拟误差决定了量化系统的上限

我们所讨论的模拟误差,是指由我们编写的软件量化模拟器,与硬件实际执行结果之间的误差。承接前文所述,为了进行正确的模拟,我们必须熟知硬件的量化规则,并严格地遵守每一项需要被遵守的约束,但量化规则并不是所有模拟误差的来源。模拟误差的存在,导致那些在软件模拟器中表现良好的网络,部署到硬件上时精度却无法保障,这是所有量化软件都必须严肃对待的问题。

对于一些平台而言,其量化算子的计算逻辑可能与浮点算子有本质上的区别,这使得我们无法正确地模拟量化算子的行为。并且软件量化系统的浮点误差也会影响我们对硬件的模拟。需要认识到,一个网络完成量化后,它的精度能达到多高,其上限是由模拟误差决定的。对于一个分类精度72%的浮点网络,若我们软件量化系统关于硬件的模拟误差是10%(relative error),我们就几乎不可能通过该系统得到一个分类精度超过70%的量化网络,无论我们使用怎样的算法。

为了解决模拟误差的问题,除了重新设计 Qscheme 外,还需要认识到量化算子与浮点算子之间可能存在的行为差异。量化模拟器必须拥有针对不同硬件,随时替换算子执行逻辑的能力,这也是前期软件设计中就应当考虑到的问题——FakeQuantize 对象只能表示对算子输入/输出的量化逻辑,没有能力描述算子执行逻辑的变化。

3. 网络调度、误差分析、复杂图变换

网络调度与混合精度执行

高可用的工具势必需要一种 "备选方案",对于量化而言,备选方案就是 fall back 回到浮点——如果一个网络层无法用 int8 完成表示,那我们只能将其回退到浮点。混合精度调度是工业部署迫切需要的一项能力,我建议能够在一开始就考虑到这样的需求,能够以一种最为稳定的方式提高模型部署时的精度——量化不是一锤子买卖,只是一种精度和速度的权衡。

我们的工具必须能够为用户提供调度的接口,允许用户显示地说明某一层的计算精度,软件中的其他逻辑也必须以混合精度为前提做出相应的改变。Qscheme 中提供了 num_of_bits 接口,但这远远不够。考虑下面的图结构:

Time series (1)

图左侧描述了 4 bit 卷积串联 8 bit 卷积的情况,我们想问的是 4bit 卷积的输出需要被量化几次,使用谁的量化参数完成量化?图右侧描述了混合精度 concat 的问题,我们想问的是 concat 要求所有输入输出共享 scale 的约束此时如何实现,这个concat 的逻辑如何完成执行?

上文中我们先后提到了量化规则与量化模拟的重要性,在混合精度量化中我们将更难满足量化规则;没有任何框架能够正确模拟混合精度推理,事实上推理库自己都不知道怎么处理这样的问题。量化工具是混合精度量化的先行者,我们在设计中需要考虑到如此复杂的问题,并提出合理的设计、留下余地为混合推理提供支持。目前的 Qscheme, FakeQuantize 是否具备支持混合精度的能力?

误差分析是网络量化的有力工具

"我的网络为什么量化效果这么差",这是我们经常要回答的问题。我们对此会有一些浅显的认识,譬如一个具有很多离群值的网络是不好量化的,depthwise 卷积是不好量化的,我们希望能够更好地解答类似的问题。

对于误差分析,我们有两类视角:一类从误差的统计学性质出发,可以证明量化的误差服从高斯分布,且量化噪声的能量大小取决于网络权重的方差。对于 depthwise 卷积而言,由于 batchnorm 的影响,他的权重方差显著高于普通卷积。基于这样的理论我们可以理解为什么 depthwise 卷积量化通常很差;另一类视角从数值分析角度出发,一个卷积是否能够完成量化,取决于它对于输入摄动的敏感性,这一敏感性由卷积条件数决定,卷积的条件数越大,权重及其逆卷积权重就越奇异,也就越不可能完成量化,这解释了为什么存在离群点的层不能量化。

我们需要提供这样的工具,展示网络中权重的基本统计信息,以及网络中误差的扩散情况。这将告诉我们如何解决量化中存在的问题,以及网络量化的极限——如果网络对输入摄动非常敏感,则它不可能完成量化,它的量化将是一个病态问题,难以求解。

复杂结构量化的基础是图变换

LSTM, Transformer 的量化基础是图变换,下图展示了如何 GRU 算子被分解成一系列更原子的算子:
微信图片_20221116164548

执行上述分解的原因是为了得到正确的量化信息,我们将与硬件团队详细讨论 GRU 的每一步计算使用何种精度,遵循怎样的量化约束。对于上图而言,我们可能会给出 gemm, add, sub, mul 4处量化信息,其中部分节点使用16比特位宽完成量化,并且要求 gemm 间共享 scale。

针对这样复杂的结构,直接使用 pytorch fx 完成的量化不具有部署能力,量化工具必须与硬件团队达成一致,并对计算图进行定制化的优化。这是目前的软件设计未曾考虑到的内容。量化工具应当具有强而有力的图模式匹配和图修改的能力,这是处理复杂结构时难以绕开的问题。

4. 开放的接口,灵活的量化

Pytorch Fx 的接口设计与目前大部分硬件厂商量化工具的接口设计目的是一致的,都寄希望于通过一种简单的接口完成高度自动化的量化。接口函数 prepare(self, model, graph_module) 将完成插入量化节点、图变换、模型追踪等一系列功能。这样的接口确实解决了硬件厂商所面临的一些问题,但是对于定制化的需求而言经常要做出更改。

进入2022年以来,AI 软件的工业化部署越来越受到重视,很多算法工程师会尝试去了解模型部署,并参与到模型优化的流程当中来。为此我们需要提供一组更为基础的,更为灵活的接口供算法工程师使用,从而完成高度定制化的模型量化需求。上文所述的 LSTM, Transformer 量化,其规则之复杂,模型结构之多变,几乎不可能通过统一的管线完成流水线式的量化——几乎每一个 transformer 模型都有着不太一样的 self-attention 结构,都意味着不同的量化规则,这样的复杂性难以全部依赖 prepare 接口函数提供实现。

为此我们可以做出的努力包括:

  • 将模型调度、图融合、模型追踪、量化优化过程、量化器、模型导出、observer全部解耦并在 API 中提供接口
  • 允许用户自定义量化器、优化过程、调度器、observer,模型导出器,提供注册接口
  • 提供一套基础量化接口,允许用户以手动的方式调用优化过程;允许用户手动控制 prepare 流程;
  • 提供误差分析接口,帮助用户定位量化问题
  • 多写注释,多写文档

5. 训练只是完成模型量化的一小类方法

看到目前的软件规划中包含的算法是 LSQ, LSQ+, Adaround, Brecq, Qdrop,我对这样的设计感到遗憾。我始终认为低精度计算的数学基础是去噪,而非重新训练一个新的网络。这些基于训练的方法,哪怕挂着PTQ的标签,我们都可以轻而易举的证明他们只不过是以原来的网络为初值,训练了一个新的适合量化的网络从而提高精度。感谢梯度下降算法的有效性,这类方法在目前能够获得不错的效果,但这方向难以出现有足够理论支撑的论文。

以 Adaround 的单层优化问题为例,该算法限制权重 x 在优化过程中只能取 ceil(x) 与 floor(x) 两个值。这导致它的优化问题是可以定性为0-1约束下的二次规划问题,通常不可能使用梯度下降法求到最值,这类问题大部分情况下是 np-hard 的问题。由于0-1约束是一道不可逾越的难关,因此我不认为基于梯度训练的方法能够在网络量化中有长远的发展。

我们应当思考如何消去网络中的量化噪声,从这一角度出发可以衍生出许多更加有理有据的方法,也可以获得更好的网络量化效果。包括但不限于:

  • 基于滤波的去噪方法
  • 基于 FFT 变换与小波变换的去噪
  • 基于低秩方法的去噪
  • 图的等价变换,包括层间权重变换,层的分裂与合并

信号去噪和数值计算方法有着深厚而广泛的理论基础,但现在我们提到量化算法,思路依然仅局限于基于训练的方法,这是十分令人遗憾的。但是我们要如何设计接口从而支持这更为广泛的方法?这是我们需要考虑的问题。

结语

Pytorch Fx 的接口只是单纯为训练服务的,从我的陈述中你可以发现它的局限性。事实上基于这套接口去实现 LSQ 等算法所需的代码不过寥寥几千行,但 AIMET, VitisAI, PPQ 等量化框架均有数万行甚至数十万行的体量,你可以看到这些训练算法绝不是这些量化系统的主要内容。

我们建议重新审视软件的整体设计方案:保证量化模拟的正确性;保证所有量化规则得到遵守;为用户提供高可用、高度灵活的API接口;与硬件团队一起协商复杂结构的量化方案;并尝试探索和支持其他量化方法。

@humu789
Copy link
Collaborator Author

humu789 commented Nov 18, 2022

@ZhangZhiPku 非常感谢这么用心的回复,从回复内容可以看出对量化有着专业的洞察与见地。关于mmrazor的量化也谈下我们背后的一些思考。
其实我们的定位是做量化算法平台,且更倾向于QAT这类训练算法,并不是全面且复杂的量化系统,原因有以下几点:

  1. PTQ量化工具已经有很多,主流的推理引擎很多都支持PTQ量化,并且一些硬件平台的专用推理引擎对硬件过程更了解,例如SNPE,这类工具在PTQ算法上有着天然的优势,如果效果掉点可接受用户其实会更倾向于选择自带的量化,毕竟更方便可靠。
  2. 从效果上限来看,QAT对于PTQ有着明显的优势
  3. MMRazor是OpenMMLab生态下的模型压缩库,会优先支持OpenMMLab生态的能力建设,QAT算法更容易和蒸馏等其他算法结合,可以为上游repo提供综合的模型压缩解决方案。
  4. 如果是做一个包含基础功能接口的全面的量化系统,按照repo职责划分,在mmrazor里并不好独立实现,很多实现是需要在mmengine和mmdeploy里去添加的。
  5. pytorch量化相关的功能接口在1.10以后每版都有较大的迭代改进,我们不全部重复造轮子也是看好pytorch量化的快速迭代进步,且对pytorch的量化用户更加友好。

关于回复中的5个重点,整体的回复如下:

  1. 符合规则的量化
    算法落地中,符合规则的量化确实很重要,这块我们是在quantizer中提供了抽象的上层接口prepare,对于不同平台的量化规则适配,可以通过register不同的qauntizer去实现的。
  2. 模拟误差
    关于模拟硬件误差方面,由于缺少硬件团队的支撑,对硬件实际执行过程缺乏了解,后续希望能通过内部和社区交流完善这方面的能力建设。
  3. 网络调度、误差分析、复杂图变换
    混合精度和误差分析工具确实都很重要,目前由于人力和量化的定位会优先保证对常规功能的支持,后续会根据优先级逐步支持一些进阶的功能,也希望能和更多像你一样资深的用户来共建。
    关于复杂图变换,在mmrazor中pruning的tracer是可以获取网络节点的相关连接关系的,根据连接关系和变换规则应该是可以在prepare阶段实现复杂图变换的功能的
  4. 开放的接口,灵活的量化
  • 我们是共用mmengine的注册机制,目前已有的模块全都是可以注册,甚至包括自定义的tracer,并且在OpenMMLab内的repo之间可以跨库调用的
  • 在quantizer中提供了高层抽象的prepare,convert等接口,是可以满足用户手动控制流程的需求的
  • 由于人力原因,目前优先保证一些基础功能,一些进阶的功能是计划在后续逐步添加完善的
  • 目前开放的代码分支是wip,在完成开发后我们会按照OpenMMLab的要求统一补充文档和注释的,也会出一些关于量化的系列教程。
  1. 训练只是完成模型量化的一小类方法
    关于低精度计算基础是去噪问题我们是基本认同的,但是我们不是特别认同以追求理论支撑为主要目标,更希望是以实际效果为目标。虽然目前基于梯度下降的深度学习算法可解释性依旧存在争议,但在众多领域是会比解释性更好的传统算法效果做到更好的。
    关于去噪问题,目前基于深度学习训练的方法已经能做到效果优于大多数传统的信号处理方法,而且考虑到去噪问题本身并没有一个能衡量真实效果的客观指标,因此哪种方法更好我们认为是存在争议的。相较于信号处理的方法,我们认为QAT量化算法也是一个值得探索的方向。最后也希望能借此吸引大家一起来共建,去支持更多方向的探索。

@choong-park
Copy link

Hi, I have a question about this quantization scheme for mmdet and mmseg.

There are only config files for mmcls, but no references for mmdet and mmseg.
How can I use this quantization scheme for them?

For example, there is "skipped_methods" field in config file,
and If I want to quantize mmdetection models, it should be modified.
But I don't know what methods in mmdet should be "skipped"...

Please give me any insights.
Thanks

@taofuyu
Copy link

taofuyu commented Dec 13, 2022

Thanks for your work

@FabianSchuetze
Copy link

Thanks for the wonderful work!

From the proposal I have not understood if the is pytorch-cuda support for the quantization. The backend lists the three examples, and I wonder if native pytorch-cuda backend is supported of if one would go through TensorRT?

humu789 pushed a commit to humu789/mmrazor that referenced this issue Feb 13, 2023
* move to lib

* optional import pytorch rewriter

* reduce torch dependancy of tensorrt export

* remove more mmcv support

* fix pytest

* remove mmcv logge

* Add `mmdeploy.utils.logging`

* Improve the common of the `get_logger`

* Fix lint

* onnxruntim add try catch to  import wrapper if pytorch is available

* Using `mmcv.utils.logging` in all files under `mmdeploy/codebase`

* add __init__

* add prebuild tools

* support windows

* for comment

* exit if failed

* add exist

* decouple

* add tags

* remove .mmdeploy_python

* read python version from system

* update windows config

* update linux config

* remote many

* better build name

* rename python tag

* fix pyhon-tag

* update window config

* add env search

* update tag

* fix build without CUDA_TOOLKIT_ROOT_DIR

Co-authored-by: HinGwenWoong <peterhuang0323@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC Request for Comments
Projects
None yet
Development

No branches or pull requests

6 participants