Skip to content

Commit

Permalink
Remove stale code (#3334)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Sep 10, 2024
1 parent 20e3287 commit ab2f42b
Showing 1 changed file with 2 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import itertools
from typing import List, Optional, Tuple
import torch
from torch import nn

from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_ROUND_MODE_TO_PYMO
from aimet_common.utils import AimetLogger, log_with_error_and_assert_if_false
Expand Down Expand Up @@ -205,41 +204,18 @@ def realize_v2_wrapper(self):
else:
quantized_module = _legacy_impl.FakeQuantizationMixin.from_module(self._module_to_wrap)

def set_recursive(module_list, i, quantizer):
"""
Set quantizer recursively.
AIMET V1 handles nested input/output tensors with single input quantizers,
whereas V2 quantized module allows having nested input/output quantizers.
(For reference, see the class definition of FakeQuantizedLSTM in fake_quant.py)
To implement V1 behavior, we set the nested input/output quantizers to
share the same single quantizer, for example as below:
- self.input_quantizers = [q1, q2, q3]
- quant_module.input_quantizers = [None, [None, None], None]
(before set_recursive)
- quant_module.input_quantizers = [q1, [q2, q2], q3]
(after set_recursive)
"""
if module_list[i] is None:
module_list[i] = quantizer
elif isinstance(module_list[i], nn.ModuleList):
for j in range(len(module_list[i])):
set_recursive(module_list[i], j, quantizer)
else:
raise RuntimeError

# For unused modules, quantsim assumes # inputs = # outputs = 1
# If this is incorrect, propagate the configuration of the last input/output quantizers to the remaining
# quantizer positions
for i, _ in list(enumerate(quantized_module.input_quantizers)):
q_idx = min(i, len(self.input_quantizers) - 1)
quantizer = self.input_quantizers[q_idx].realize()
set_recursive(quantized_module.input_quantizers, i, quantizer)
quantized_module.input_quantizers[i] = quantizer

for i, _ in list(enumerate(quantized_module.output_quantizers)):
q_idx = min(i, len(self.input_quantizers) - 1)
quantizer = self.output_quantizers[q_idx].realize()
set_recursive(quantized_module.output_quantizers, i, quantizer)
quantized_module.output_quantizers[i] = quantizer

for param_name, quant_builder in self.param_quantizers.items():
quantized_module.param_quantizers[param_name] = quant_builder.realize()
Expand Down

0 comments on commit ab2f42b

Please sign in to comment.