Skip to content

Commit

Permalink
Fix intermittent qsim2 test failures
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle committed Jul 12, 2024
1 parent df67bc7 commit 00dd8f9
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
import os
import json
import pytest
import random
import numpy as np
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
from aimet_common.defs import QuantizationDataType
from aimet_torch.quantsim import load_encodings_to_sim
from aimet_torch.quantsim import load_encodings_to_sim, QuantScheme
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.quantization.encoding_analyzer import PercentileEncodingAnalyzer
from aimet_torch.v2.quantization.base import QuantizerBase
Expand All @@ -59,6 +61,12 @@ def encodings_are_close(quantizer_1: AffineQuantizerBase, quantizer_2: AffineQua
and quantizer_1.bitwidth == quantizer_2.bitwidth \
and quantizer_1.symmetric == quantizer_2.symmetric

@pytest.fixture(autouse=True)
def set_seed():
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)


class ConcatModel(torch.nn.Module):

Expand Down Expand Up @@ -103,7 +111,7 @@ def forward_pass(model, args):
def test_set_and_freeze_param_encodings(self, config_file):
model = test_models.BasicConv2d(kernel_size=3)
dummy_input = torch.rand(1, 64, 16, 16)
sim = QuantizationSimModel(model, dummy_input, config_file=config_file)
sim = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf, config_file=config_file)
sim.compute_encodings(lambda model, _: model(dummy_input), None)

with tempfile.TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -149,7 +157,7 @@ def test_set_and_freeze_param_encodings(self, config_file):
def test_load_and_freeze_encodings(self, config_file):
model = test_models.TinyModel()
dummy_input = torch.rand(1, 3, 32, 32)
sim = QuantizationSimModel(model, dummy_input, config_file=config_file)
sim = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf, config_file=config_file)
sim.compute_encodings(lambda model, _: model(dummy_input), None)

with tempfile.TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -208,7 +216,7 @@ def test_load_and_freeze_with_partial_encodings(self):
"param_encodings": {"conv1.weight": [sample_encoding]}
}

sim = QuantizationSimModel(model, dummy_input)
sim = QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf)
all_quantizers = [q for q in sim.model.modules() if isinstance(q, QuantizerBase)]
sim.load_and_freeze_encodings(partial_encodings)

Expand Down

0 comments on commit 00dd8f9

Please sign in to comment.