Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
liukai committed Feb 6, 2023
1 parent 92b40cf commit 8846a0e
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 0 deletions.
63 changes: 63 additions & 0 deletions tests/test_models/test_algorithms/test_prune_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from mmrazor.models.algorithms.pruning.ite_prune_algorithm import (
ItePruneAlgorithm, ItePruneConfigManager)
from mmrazor.registry import MODELS
from projects.group_fisher.modules.group_fisher_algorthm import \
GroupFisherAlgorithm
from projects.group_fisher.modules.group_fisher_ops import GroupFisherConv2d
from ...utils.set_dist_env import SetDistEnv


Expand Down Expand Up @@ -262,3 +265,63 @@ def test_resume(self):
print(algorithm2.mutator.current_choices)
self.assertDictEqual(algorithm.mutator.current_choices,
algorithm2.mutator.current_choices)


class TestGroupFisherPruneAlgorithm(TestItePruneAlgorithm):

def test_group_fisher_prune(self):
data = self.fake_cifar_data()

MUTATOR_CONFIG = dict(
type='GroupFisherChannelMutator',
parse_cfg=dict(type='ChannelAnalyzer', tracer_type='FxTracer'),
channel_unit_cfg=dict(type='GroupFisherChannelUnit'))

epoch = 2
interval = 1
delta = 'flops'

algorithm = GroupFisherAlgorithm(
MODEL_CFG,
pruning=True,
mutator=MUTATOR_CONFIG,
delta=delta,
interval=interval,
save_ckpt_delta_thr=[1.1]).to(DEVICE)
mutator = algorithm.mutator

ckpt_path = os.path.dirname(__file__) + f'/{delta}_0.99.pth'

fake_cfg_path = os.path.dirname(__file__) + '/cfg.py'
self.gen_fake_cfg(fake_cfg_path)
self.assertTrue(os.path.exists(fake_cfg_path))

message_hub = MessageHub.get_current_instance()
cfg_str = open(fake_cfg_path).read()
message_hub.update_info('cfg', cfg_str)

for e in range(epoch):
for ite in range(10):
self._set_epoch_ite(e, ite, epoch)
algorithm.forward(
data['inputs'], data['data_samples'], mode='loss')
self.gen_fake_grad(mutator)
self.assertEqual(delta, algorithm.delta)
self.assertEqual(interval, algorithm.interval)
self.assertTrue(os.path.exists(ckpt_path))
os.remove(ckpt_path)
os.remove(fake_cfg_path)
self.assertTrue(not os.path.exists(ckpt_path))
self.assertTrue(not os.path.exists(fake_cfg_path))

def gen_fake_grad(self, mutator):
for unit in mutator.mutable_units:
for channel in unit.input_related:
module = channel.module
if isinstance(module, GroupFisherConv2d):
module.recorded_grad = module.recorded_input

def gen_fake_cfg(self, fake_cfg_path):
with open(fake_cfg_path, 'a', encoding='utf-8') as cfg:
cfg.write(f'work_dir = \'{os.path.dirname(__file__)}\'')
cfg.write('\n')
1 change: 1 addition & 0 deletions tests/test_projects/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
1 change: 1 addition & 0 deletions tests/test_projects/test_expand/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
57 changes: 57 additions & 0 deletions tests/test_projects/test_expand/test_expand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import torch

from mmrazor.models.mutables import SimpleMutableChannel
from mmrazor.models.mutators import ChannelMutator
from projects.cores.expandable_ops.ops import ExpandLinear
from projects.cores.expandable_ops.unit import (ExpandableUnit, expand_model,
expand_static_model)
from ...data.models import MultiConcatModel, SingleLineModel


class TestExpand(unittest.TestCase):

def test_expand(self):
x = torch.rand([1, 3, 224, 224])
model = MultiConcatModel()
print(model)
mutator = ChannelMutator[ExpandableUnit](
channel_unit_cfg=ExpandableUnit)
mutator.prepare_from_supernet(model)
print(mutator.choice_template)
print(model)
y1 = model(x)

for unit in mutator.mutable_units:
unit.expand(10)
print(unit.mutable_channel.mask.shape)
expand_model(model, zero=True)
print(model)
y2 = model(x)
self.assertTrue((y1 - y2).abs().max() < 1e-3)

def test_expand_static_model(self):
x = torch.rand([1, 3, 224, 224])
model = SingleLineModel()
y1 = model(x)
expand_static_model(model, divisor=4)
y2 = model(x)
print(y1.reshape([-1])[:5])
print(y2.reshape([-1])[:5])
self.assertTrue((y1 - y2).abs().max() < 1e-3)

def test_ExpandConv2d(self):
linear = ExpandLinear(3, 3)
mutable_in = SimpleMutableChannel(3)
mutable_out = SimpleMutableChannel(3)
linear.register_mutable_attr('in_channels', mutable_in)
linear.register_mutable_attr('out_channels', mutable_out)

print(linear.weight)

mutable_in.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0])
mutable_out.mask = torch.tensor([1.0, 1.0, 0.0, 1.0, 0.0])
linear_ex = linear.expand(zero=True)
print(linear_ex.weight)

0 comments on commit 8846a0e

Please sign in to comment.