Skip to content

Commit

Permalink
Fix aimet_torch cle acceptance test (#3161)
Browse files Browse the repository at this point in the history
Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>
  • Loading branch information
quic-hitameht committed Jul 11, 2024
1 parent 51780c8 commit c257278
Showing 1 changed file with 61 additions and 46 deletions.
107 changes: 61 additions & 46 deletions NightlyTests/torch/test_cross_layer_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2017-2021, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2017-2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -35,74 +35,92 @@
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Cross Layer Equalization acceptance tests for ResNet model. """

import os
import unittest
import tempfile
import pytest
import copy
import torch
import numpy as np
from contextlib import contextmanager
from torchvision import models

from aimet_torch import cross_layer_equalization as cle
from aimet_torch import batch_norm_fold
from aimet_torch.cross_layer_equalization import CrossLayerScaling, HighBiasFold, equalize_model
from aimet_torch import visualize_model
from models.mobilenet import MobileNetV2


class TestCrossLayerEqualization(unittest.TestCase):
@contextmanager
def _use_python_impl(flag: bool):
orig_flag = cle.USE_PYTHON_IMPL
try:
cle.USE_PYTHON_IMPL = flag
yield
finally:
cle.USE_PYTHON_IMPL = orig_flag


@pytest.fixture(params=[True, False])
def use_python_impl(request):
param: bool = request.param

with _use_python_impl(param):
yield


class TestCrossLayerEqualization:
""" Acceptance tests related to winnowing ResNet models. """

def test_cross_layer_equalization_resnet(self):
def test_cross_layer_equalization_resnet(self, use_python_impl):

torch.manual_seed(10)
model = models.resnet18(pretrained=True)

model = model.eval()
model = models.resnet18().eval()

folded_pairs = batch_norm_fold.fold_all_batch_norms(model, (1, 3, 224, 224))
bn_dict = {}
for conv_bn in folded_pairs:
bn_dict[conv_bn[0]] = conv_bn[1]

self.assertFalse(isinstance(model.layer2[0].bn1, torch.nn.BatchNorm2d))
assert not isinstance(model.layer2[0].bn1, torch.nn.BatchNorm2d)

w1 = model.layer1[0].conv1.weight.detach().numpy()
w2 = model.layer1[0].conv2.weight.detach().numpy()
w3 = model.layer1[1].conv1.weight.detach().numpy()
w1 = model.layer1[0].conv1.weight.clone()
w2 = model.layer1[0].conv2.weight.clone()
w3 = model.layer1[1].conv1.weight.clone()

cls_set_info_list = CrossLayerScaling.scale_model(model, (1, 3, 224, 224))

# check if weights are updating
assert not np.allclose(model.layer1[0].conv1.weight.detach().numpy(), w1)
assert not np.allclose(model.layer1[0].conv2.weight.detach().numpy(), w2)
assert not np.allclose(model.layer1[1].conv1.weight.detach().numpy(), w3)
assert not torch.allclose(model.layer1[0].conv1.weight, w1)
assert not torch.allclose(model.layer1[0].conv2.weight, w2)
assert not torch.allclose(model.layer1[1].conv1.weight, w3)

b1 = model.layer1[0].conv1.bias.data
b2 = model.layer1[1].conv2.bias.data
b1 = model.layer1[0].conv1.bias.clone()
b2 = model.layer1[1].conv2.bias.clone()

HighBiasFold.bias_fold(cls_set_info_list, bn_dict)

for i in range(len(model.layer1[0].conv1.bias.data)):
self.assertTrue(model.layer1[0].conv1.bias.data[i] <= b1[i])
for i in range(len(model.layer1[0].conv1.bias)):
assert model.layer1[0].conv1.bias[i] <= b1[i]

for i in range(len(model.layer1[1].conv2.bias.data)):
self.assertTrue(model.layer1[1].conv2.bias.data[i] <= b2[i])
for i in range(len(model.layer1[1].conv2.bias)):
assert model.layer1[1].conv2.bias[i] <= b2[i]

def test_cross_layer_equalization_mobilenet_v2(self):
def test_cross_layer_equalization_mobilenet_v2(self, use_python_impl):
torch.manual_seed(10)

model = MobileNetV2().to(torch.device('cpu'))
print(model)

model = model.eval()
equalize_model(model, (1, 3, 224, 224))

def test_cross_layer_equalization_vgg(self):
def test_cross_layer_equalization_vgg(self, use_python_impl):
torch.manual_seed(10)
model = models.vgg16().to(torch.device('cpu'))
model = model.eval()
equalize_model(model, (1, 3, 224, 224))

@unittest.skip("Takes 1 min 42 secs to run")
def test_cross_layer_equalization_mobilenet_v2_visualize_after_optimization(self):
@pytest.mark.skip("Takes 1 min 42 secs to run")
def test_cross_layer_equalization_mobilenet_v2_visualize_after_optimization(self, use_python_impl):
torch.manual_seed(10)
model = MobileNetV2().to(torch.device('cpu'))
model = model.eval()
Expand All @@ -116,22 +134,19 @@ def test_cross_layer_equalization_mobilenet_v2_visualize_after_optimization(self
equalize_model(model, (1, 3, 224, 224))
visualize_model.visualize_changes_after_optimization(model_copy, model, results_dir)

def test_cross_layer_equalization_resnet18_visualize_to_identify_problem_layers(self):
def test_cross_layer_equalization_resnet18_visualize_to_identify_problem_layers(self, use_python_impl):
torch.manual_seed(10)
model = models.resnet18()
model = model.eval()
model = models.resnet18().eval()

results_dir = 'artifacts'
if not os.path.exists('artifacts'):
os.makedirs('artifacts')
file = os.path.join(results_dir, 'visualize_relative_weight_ranges_to_identify_problematic_layers.html')
with tempfile.TemporaryDirectory() as tmp_dir:
file = os.path.join(tmp_dir, 'visualize_relative_weight_ranges_to_identify_problematic_layers.html')

batch_norm_fold.fold_all_batch_norms(model, (1, 3, 224, 224))
batch_norm_fold.fold_all_batch_norms(model, (1, 3, 224, 224))

visualize_model.visualize_relative_weight_ranges_to_identify_problematic_layers(model, results_dir)
self.assertTrue(os.path.isfile(file))
visualize_model.visualize_relative_weight_ranges_to_identify_problematic_layers(model, tmp_dir)
assert os.path.isfile(file)

def test_cle_transposed_conv2D(self):
def test_cle_transposed_conv2D(self, use_python_impl):
class TransposedConvModel(torch.nn.Module):
def __init__(self):
super(TransposedConvModel, self).__init__()
Expand Down Expand Up @@ -169,13 +184,13 @@ def forward(self, x):
cls_set_info_list = CrossLayerScaling.scale_model(model, input_shapes)
HighBiasFold.bias_fold(cls_set_info_list, bn_dict)

self.assertEqual(w_shape_1, model.conv1.weight.shape)
self.assertEqual(w_shape_2, model.conv2.weight.shape)
assert w_shape_1 == model.conv1.weight.shape
assert w_shape_2 == model.conv2.weight.shape

output_after_cle = model(random_input).detach().numpy()
self.assertTrue(np.allclose(output_before_cle, output_after_cle, rtol=1.e-2))
assert np.allclose(output_before_cle, output_after_cle, rtol=1.e-2)

def test_cle_depthwise_transposed_conv2D(self):
def test_cle_depthwise_transposed_conv2D(self, use_python_impl):

class TransposedConvModel(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -220,13 +235,13 @@ def forward(self, x):
cls_set_info_list = CrossLayerScaling.scale_model(model, input_shapes)
HighBiasFold.bias_fold(cls_set_info_list, bn_dict)

self.assertEqual(w_shape_1, model.conv1.weight.shape)
self.assertEqual(w_shape_2, model.conv2.weight.shape)
assert w_shape_1 == model.conv1.weight.shape
assert w_shape_2 == model.conv2.weight.shape

output_after_cle = model(random_input).detach().numpy()
self.assertTrue(np.allclose(output_before_cle, output_after_cle, rtol=1.e-2))
assert np.allclose(output_before_cle, output_after_cle, rtol=1.e-2)

def test_cle_for_maskrcnn(self):
def test_cle_for_maskrcnn(self, use_python_impl):
class JITTraceableWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
Expand Down

0 comments on commit c257278

Please sign in to comment.