Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIMET and YOLOv5 #1067

Open
cascosula opened this issue Mar 1, 2022 · 21 comments
Open

AIMET and YOLOv5 #1067

cascosula opened this issue Mar 1, 2022 · 21 comments

Comments

@cascosula
Copy link

Hi, currently I was trying to quantized YOLOv5n(v6.0) with AIMET(v1.19.1) and aimed to futher applied the quantized model to device through Snapdragon Neural Processing SDK(1.59).

I have followed suggestions in #168 but failed at applying AIMET to YOLOv5.
Error occured as I applied batchnorm folding to YOLOv5

model = torch.hub.load('ultralytics/yolov5', 'yolov5n', path=path)
sim = QuantizationSimModel(model, dummy_input,
                           quant_scheme=quant_params.quant_scheme,
                           rounding_mode=quant_params.round_mode,
                           default_output_bw=quant_params.act_bw, 
                           default_param_bw=quant_params.weight_bw,
                           config_file=args.quantization_config,
                           in_place=True)
# add attribute 'f' and 'i' defined in yolo to sim warpping layer
add_yolo_attribute_to_quantized_layer(model)

folded_pairs = fold_all_batch_norms(model, input_shapes=input_shapes

The error message is

2022-03-01 16:03:35,919 - ConnectedGraph - ERROR - Input of 3 defined in (%3 : (Tensor) = ^SteGatingFuncForParameters(StaticGridQuantWrapper(
  (_module_to_wrap): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
))(%1), scope: __module.model/__module.model.model.9/__module.model.model.9.m # /home/gvsai/anaconda3/envs/aimet-yolov5/lib/python3.7/site-packages/aimet_torch/qc_quantize_op.py:358:0
) with name None already exists. Ensure that no modules are being reused in the model.
Traceback (most recent call last):
  File "quantize_yolov5.py", line 587, in <module>
    main(args)
  File "quantize_yolov5.py", line 499, in main
    folded_pairs = fold_all_batch_norms(model, input_shapes=input_shapes)
  File "/home/gvsai/anaconda3/envs/aimet-yolov5/lib/python3.7/site-packages/aimet_torch/batch_norm_fold.py", line 240, in fold_all_batch_norms
    bn_conv_linear_pairs = find_all_batch_norms_to_fold(model, input_shapes)
  File "/home/gvsai/anaconda3/envs/aimet-yolov5/lib/python3.7/site-packages/aimet_torch/batch_norm_fold.py", line 199, in find_all_batch_norms_to_fold
    conv_linear_bn_activation_info_dict = find_all_conv_bn_with_activation(model, input_shapes)
  File "/home/gvsai/anaconda3/envs/aimet-yolov5/lib/python3.7/site-packages/aimet_torch/batch_norm_fold.py", line 278, in find_all_conv_bn_with_activation
    connected_graph = ConnectedGraph(model, inp_tensor_list)
  File "/home/gvsai/anaconda3/envs/aimet-yolov5/lib/python3.7/site-packages/aimet_torch/meta/connectedgraph.py", line 132, in __init__
    self._construct_graph(model, model_input)
  File "/home/gvsai/anaconda3/envs/aimet-yolov5/lib/python3.7/site-packages/aimet_torch/meta/connectedgraph.py", line 261, in _construct_graph
    output_map)
  File "/home/gvsai/anaconda3/envs/aimet-yolov5/lib/python3.7/site-packages/aimet_torch/meta/connectedgraph.py", line 542, in _construct_ops_and_products
    self._handle_ir_nodes_of_interest(ir_nodes_list, passthrough_types, input_types_to_ignore)
  File "/home/gvsai/anaconda3/envs/aimet-yolov5/lib/python3.7/site-packages/aimet_torch/meta/connectedgraph.py", line 915, in _handle_ir_nodes_of_interest
    connections_to_ir_nodes_dict = self._create_connections_to_ir_nodes_dict(ir_nodes_list)
  File "/home/gvsai/anaconda3/envs/aimet-yolov5/lib/python3.7/site-packages/aimet_torch/meta/connectedgraph.py", line 900, in _create_connections_to_ir_nodes_dict
    raise AssertionError
AssertionError

It seems that it is invalid to apply AIMET to models with reused layers. However, I could do batchnorm folding to yolov5 with SNPE tools and obtain quantized model in *.dlc.

I am wondering the difference between AIMET and SNPE.
Is it possible to build the quantization pipeline AIMET+SNPE to YOLOv5?

@zhiqwang
Copy link

zhiqwang commented Mar 2, 2022

Hi @cascosula , Seems that this issue is caused by the SPPF of yolov5, maybe you could replace it with the equivalent SPP. Check the following for additional context about the SPPF module

@cascosula
Copy link
Author

cascosula commented Mar 2, 2022

Hi @cascosula , Seems that this issue is caused by the SPPF of yolov5, maybe you could replace it with the equivalent SPP. Check the following for additional context about the SPPF module

Thank for your reply! This work for me!
AIMET can work after replacing SPPF with SPP.

I have another question that if I need SPPF for speeding up mentioned in ultralytics/yolov5#4420,
could the following modification, which simply copy the maxpooling layer in the SPPF instead of reused it,
perform identically as the origin SPPF does?

class SPPF_modified(torch.nn.Module):
    def __init__(self, sppf): 
        super().__init__()
        self.cv1 = sppf.cv1
        self.cv2 = sppf.cv2
        self.m1 = sppf.m
        self.m2 = copy.deepcopy(sppf.m)
        self.m3 = copy.deepcopy(sppf.m)
        self.f = sppf.f
        self.i = sppf.i
        self.type = sppf.type
        self.np = sppf.np

    def forward(self, x):
        x = self.cv1(x)
        y1 = self.m1(x)
        y2 = self.m2(y1)
        return self.cv2(torch.cat([x, y1, y2, self.m3(y2)], 1))

@zhiqwang
Copy link

zhiqwang commented Mar 2, 2022

I guess that we can use other technique to speed up the inference of SPP.

@quic-ssiddego
Copy link
Contributor

@cascosula Thanks for reporting this and @zhiqwang Thanks for pitching in. @cascosula just curious about modified block - why was this needed?

@zhiqwang
Copy link

zhiqwang commented Mar 3, 2022

Hi @quic-ssiddego and @cascosula ,

This problem occurs because SPPF will reuse the node nn.MaxPool2d. And in fact, many downstream inference frameworks do not have good support for this way of calling functions, such as nni and ncnn. (ncnn use an assign unique pass to resolve this issue later.)

The main purpose of YOLOv5's newly proposed SPPF is to speed up inference (vs SPP), but this speedup has only been tested on PyTorch platform, and its internal operators may change further if it is converted to other platforms via ONNX or torchscript. Maybe we need further data to judge whether this module could be accelerated on the end platform.

I think at this stage we can directly use the traditional SPP and optimize the inference of this module by means of something like passing on the inference framework due to the fact that there are downstream incompatibilities.

@cascosula
Copy link
Author

cascosula commented Mar 3, 2022

@cascosula Thanks for reporting this and @zhiqwang Thanks for pitching in. @cascosula just curious about modified block - why was this needed?

Hi @quic-ssiddego,
I need faster inference speed when applying models to devices. My final goal is to do inference task on streaming device. So, intuitively, I considered that maintaining the structure SPPF is possible to speed up inference on device since SPPF has faster speed than SPP on python testing.

@cascosula
Copy link
Author

Hi @quic-ssiddego and @cascosula ,

This problem occurs because SPPF will reuse the node nn.MaxPool2d. And in fact, many downstream inference frameworks do not have good support for this way of calling functions, such as nni and ncnn. (ncnn use a assign unique pass to resolve this issue later.)

The main purpose of YOLOv5's newly proposed SPPF is to speed up inference (vs SPP), but this speedup has only been tested on PyTorch platform, and its internal operators may change further if it is converted to other platforms via ONNX or torchscript. Maybe we need further data to finally judge whether this new module can finally be accelerated on the end platform.

I think at this stage we can directly use the traditional SPP and optimize the inference of this module by means of something like passing on the inference framework due to the fact that there are downstream incompatibilities.

Hi @zhiqwang,
I agree with you. I think my top priority is to make my whole pipeline works. Applying AIMET to yolov5 caused several issues beside the reused layer problem. Thanks for your suggestion!

@tucachmo2202
Copy link

Hi, did you quantize yolov5 with SPP successfully? I try to quantize yolov5(v5) with QuantizationSimModel but when call function sim.compute_encodings() it throw error: "AttributeError: 'StaticGridQuantWrapper' object has no attribute 'f'". Could you help to fix it? Thanks.

@cascosula
Copy link
Author

cascosula commented May 11, 2022

Hi, did you quantize yolov5 with SPP successfully? I try to quantize yolov5(v5) with QuantizationSimModel but when call function sim.compute_encodings() it throw error: "AttributeError: 'StaticGridQuantWrapper' object has no attribute 'f'". Could you help to fix it? Thanks.

Hi @tucachmo2202

In yolov5 forward method, it assign "f", "i" to modules to help connecting its dataflow.
quantsim add wrapper modules to yolov5. However, these two attributes are not added to wrapper modules.
When yolov5 calling forward, it will try to access "f" and "i" in wrapper modules, thus causing error.

After quantsim adding wrapper modules, you can add these attributes back externally to wrapper modules before calling sim.compute_encodings.

    for name, module in model.named_modules():
        if isinstance(module, (QcQuantizeWrapper, QcQuantizeRecurrent)):
            warp_module = module._module_to_wrap
            if hasattr(warp_module, 'i'):
                module.i = warp_module.i
            if hasattr(warp_module, 'f'):
                module.f = warp_module.f

Or you can modify the initialization of wrapper module in qc_quantize_op.py. You can add these attributes in module_to_wrap to wrapper module.
And, remember to add "i" to wrapper module, which is also used in forward method in yolov5.

The same error will happen when exporting onnx model.
However, in this case, you can only modify initialization of wrapper module to avoid this problem.
Modify CustomMarker in onnx_utils.py

@tucachmo2202
Copy link

Hi, did you quantize yolov5 with SPP successfully? I try to quantize yolov5(v5) with QuantizationSimModel but when call function sim.compute_encodings() it throw error: "AttributeError: 'StaticGridQuantWrapper' object has no attribute 'f'". Could you help to fix it? Thanks.

Hi @tucachmo2202

In yolov5 forward method, it assign "f", "i" to modules to help connecting its dataflow. quantsim add wrapper modules to yolov5. However, these two attributes are not added to wrapper modules. When yolov5 calling forward, it will try to access "f" and "i" in wrapper modules, thus causing error.

After quantsim adding wrapper modules, you can add these attributes back externally to wrapper modules before calling sim.compute_encodings.

    for name, module in model.named_modules():
        if isinstance(module, (QcQuantizeWrapper, QcQuantizeRecurrent)):
            warp_module = module._module_to_wrap
            if hasattr(warp_module, 'i'):
                module.i = warp_module.i
            if hasattr(warp_module, 'f'):
                module.f = warp_module.f

Or you can modify the initialization of wrapper module in qc_quantize_op.py. You can add these attributes in module_to_wrap to wrapper module. And, remember to add "i" to wrapper module, which is also used in forward method in yolov5.

The same error will happen when exporting onnx model. However, in this case, you can only modify initialization of wrapper module to avoid this problem. Modify CustomMarker in onnx_utils.py

Hi, I appreciate that you reply quickly. However, after adding add 'i' and 'f' to wrapper modules, I met other error

2022-05-11 21:36:46,119 - Quant - ERROR - Expecting quantize activation input of type torch.Tensor but got <class 'list'>
Traceback (most recent call last):
  File "yolov5_quantization.py", line 305, in <module>
    sim.compute_encodings(forward_pass_callback=pass_calibration_data, forward_pass_callback_args=500)
  File "/usr/local/lib/python3.6/dist-packages/aimet_torch/quantsim.py", line 255, in compute_encodings
    _ = forward_pass_callback(self.model, forward_pass_callback_args)
  File "yolov5_quantization.py", line 265, in pass_calibration_data
    sim_model(inputs_batch)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/aimet/traffic_pytorch/models/yolo.py", line 124, in forward
    return self.forward_once(x, profile, visualize)  # single-scale inference, train
  File "/home/aimet/traffic_pytorch/models/yolo.py", line 156, in forward_once
    x = m(x)  # run
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/aimet_torch/qc_quantize_op.py", line 347, in forward
    quantized_inputs = self._quantize_activation(self.input_quantizers, inputs)
  File "/usr/local/lib/python3.6/dist-packages/aimet_torch/qc_quantize_op.py", line 447, in _quantize_activation
    raise AssertionError
AssertionError

Do you know how to fix it? Thanks for helping me!

@cascosula
Copy link
Author

Hi, I appreciate that you reply quickly. However, after adding add 'i' and 'f' to wrapper modules, I met other error

2022-05-11 21:36:46,119 - Quant - ERROR - Expecting quantize activation input of type torch.Tensor but got <class 'list'>
Traceback (most recent call last):
  File "yolov5_quantization.py", line 305, in <module>
    sim.compute_encodings(forward_pass_callback=pass_calibration_data, forward_pass_callback_args=500)
  File "/usr/local/lib/python3.6/dist-packages/aimet_torch/quantsim.py", line 255, in compute_encodings
    _ = forward_pass_callback(self.model, forward_pass_callback_args)
  File "yolov5_quantization.py", line 265, in pass_calibration_data
    sim_model(inputs_batch)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/aimet/traffic_pytorch/models/yolo.py", line 124, in forward
    return self.forward_once(x, profile, visualize)  # single-scale inference, train
  File "/home/aimet/traffic_pytorch/models/yolo.py", line 156, in forward_once
    x = m(x)  # run
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/aimet_torch/qc_quantize_op.py", line 347, in forward
    quantized_inputs = self._quantize_activation(self.input_quantizers, inputs)
  File "/usr/local/lib/python3.6/dist-packages/aimet_torch/qc_quantize_op.py", line 447, in _quantize_activation
    raise AssertionError
AssertionError

Do you know how to fix it? Thanks for helping me!

Hi @tucachmo2202

This problem is caused by the Concat module in yolov5.
The Concat module takes a list of tensors as input but valid input for a pytorch module in AIMET is a single tensor.
In fact, I do not have effective solution for this problem.
I modified source codes in qc_quantize_op.py to skip quantization of Concat layer.
Hope this suggestion can help you.

@tucachmo2202
Copy link

tucachmo2202 commented May 11, 2022

Hi, Thank for your guide. But could you provide me the code how to skip quantization of Concat layer, I really don't know how to do. And could you share your accuray you get when trying quatization technique with aimet (cle, adaround or QAT) for yolov5?
Best regards

@cascosula
Copy link
Author

cascosula commented May 12, 2022

Hi, Thank for your guide. But could you provide me the code how to skip quantization of Concat layer, I really don't know how to do. And could you share your accuray you get when trying quatization technique with aimet (cle, adaround or QAT) for yolov5? Best regards

Hi @tucachmo2202

You could modify forward method of StaticGridQuantWrapper in qc_quantize_op.py

    def forward(self, *inputs):
        """
        Forward-pass routine. This quantizes the weights before delegating to the wrapped module and
        then quantizes the output before returning the same
        :param inputs: Inputs passed to the module in the forward pass
        :return: Quantized output from the wrapped module
        """
        if 'Concat' in self._module_to_wrap._get_name():
            output = self._module_to_wrap(*inputs)
            return output

@tucachmo2202
Copy link

In yolov5 forward method, it assign "f", "i" to modules to help connecting its dataflow. quantsim add wrapper modules to yolov5. However, these two attributes are not added to wrapper modules. When yolov5 calling forward, it will try to access "f" and "i" in wrapper modules, thus causing error.

Hi, I followed your guide and successfully quantized with Quansim technique. However, when exporting to onnx, I don't know how to add "i" and "f" to wrapper module in onnx_utils.py. Could you help me to pass over this problem? Thanks very much!

@cascosula
Copy link
Author

cascosula commented May 12, 2022

you can only modify initialization of wrapper module to avoid this problem.
Modify CustomMarker in onnx_utils.py

Hi @tucachmo2202,

This problem is similar to that quantsim adds wrapper modules but wrapper modules do not have these attributes.
In this case, it is impossible to add these attributes before exporting onnx,
since the wrapper modules are added and removed inside the exporting method.
you can only modify initialization of wrapper module to solve this problem.
Please find CustomMarker in onnx_utils.py AIMET source code and add these attributes to wrapper modules.

@tucachmo2202
Copy link

you can only modify initialization of wrapper module to avoid this problem.
Modify CustomMarker in onnx_utils.py

Hi @tucachmo2202,

This problem is similar to that quantsim adds wrapper modules but wrapper modules do not have these attributes. In this case, it is impossible to add these attributes before exporting onnx, since the wrapper modules are added and removed inside the exporting method. you can only modify initialization of wrapper module to solve this problem. Please find CustomMarker in onnx_utils.py AIMET source code and add these attributes to wrapper modules.

Hi, it's so embarrassing to annoy you again. But I don't know how to modify initialization of wrapper module. It's very nice of you that you provide me the code to do that. Thank you very much!

@manhmox
Copy link

manhmox commented May 15, 2022

Hi @cascosula, I'm stuck to convert onnx too. Would you mind helping me with the code to modify CustomMarker in onnx_utils.py. Thank you very much!

@cascosula
Copy link
Author

Hi,

You could find a classmethod of OnnxSaver class.
The CustomMarker is defined under this classmethod

    @classmethod
    def _add_markers(cls, starting_module, module_name_map):
        """Recursively add marker layers
        """

        class CustomMarkerFunc(torch.autograd.Function):
            """
            This function helps add a custom layer when exporting to ONNX
            Note the input tensor has a trivial operation performed on it (clamp). This is needed to force
            pytorch trace to not ignore the function.
            """

            @staticmethod
            def symbolic(g, inp, identifier, start):
                """
                Magic method that helps with exporting a custom ONNX node
                """
                return g.op('CustomMarker', inp, id_s=identifier, start_s=start)

            @staticmethod
            def forward(ctx, inp, _identifier, _start):     # pylint: disable=arguments-differ
                return inp.clamp(0)

            @staticmethod
            def backward(ctx, _grad):                       # pylint: disable=arguments-differ
                raise NotImplementedError()

        class CustomMarker(torch.nn.Module):
            """
            This is a temporary layer that in inserted next to a real layer to distinguish the real layer in the
            exported ONNX format
            """

            def __init__(self, module, identifier):
                super(CustomMarker, self).__init__()
                self.marked_module = module
                self.identifier = identifier
                for attr in dir(module):
                    if hasattr(module, attr) and not hasattr(self, attr):
                        setattr(self, attr, getattr(module, attr))

@tucachmo2202
Copy link

Hi,

You could find a classmethod of OnnxSaver class. The CustomMarker is defined under this classmethod

    @classmethod
    def _add_markers(cls, starting_module, module_name_map):
        """Recursively add marker layers
        """

        class CustomMarkerFunc(torch.autograd.Function):
            """
            This function helps add a custom layer when exporting to ONNX
            Note the input tensor has a trivial operation performed on it (clamp). This is needed to force
            pytorch trace to not ignore the function.
            """

            @staticmethod
            def symbolic(g, inp, identifier, start):
                """
                Magic method that helps with exporting a custom ONNX node
                """
                return g.op('CustomMarker', inp, id_s=identifier, start_s=start)

            @staticmethod
            def forward(ctx, inp, _identifier, _start):     # pylint: disable=arguments-differ
                return inp.clamp(0)

            @staticmethod
            def backward(ctx, _grad):                       # pylint: disable=arguments-differ
                raise NotImplementedError()

        class CustomMarker(torch.nn.Module):
            """
            This is a temporary layer that in inserted next to a real layer to distinguish the real layer in the
            exported ONNX format
            """

            def __init__(self, module, identifier):
                super(CustomMarker, self).__init__()
                self.marked_module = module
                self.identifier = identifier
                for attr in dir(module):
                    if hasattr(module, attr) and not hasattr(self, attr):
                        setattr(self, attr, getattr(module, attr))

I exported model to onnx successfully. Thank you very much!

@Sanidhya27
Copy link

@cascosula Is the support for nodes with list input available? Like I have a custom node which takes in list of tensors. Also I cant just avoid quantizing such nodes as I have an issue also with a Concat node with input tensors with different range of inputs hence need channel wise quantization in such case

@Dong-tranction
Copy link

在 yolov5 forward 方法中,它将“f”、“i”分配给模块以帮助连接其数据流。quantsim 将包装器模块添加到 yolov5。但是,这两个属性不会添加到包装器模块中。yolov5调用forward时,会尝试访问wrapper modules中的“f”和“i”,从而导致错误。

您好,我按照您的指导并成功地使用 Quansim 技术进行了量化。但是,当导出到 onnx 时,我不知道如何将“i”和“f”添加到 onnx_utils.py 中的包装器模块。你能帮我解决这个问题吗?非常感谢!

您好,您已经成功进行了yolov5的QAT吗?我在进行感知量化训练时精度一直上不去,可以让我看一下你的量化代码吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants