Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add group_fisher pruning algorithm to prune a detection model #410

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,5 @@ venv.bak/
# Srun
*.out
batchscript-*
work_dir
mmdeploy
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ repos:
^test
| ^docs
| ^configs
| ^.*/configs*
)
1 change: 1 addition & 0 deletions mmrazor/models/mutators/channel_mutator/channel_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self,
dict,
Type[MutableChannelUnit]] = SequentialMutableChannelUnit,
parse_cfg: Dict = dict(
_scope_='mmrazor',
type='ChannelAnalyzer',
demo_input=(1, 3, 224, 224),
tracer_type='BackwardTracer'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from mmrazor.registry import TASK_UTILS
from mmrazor.utils import get_placeholder
from ...algorithms.base import BaseAlgorithm
from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput,
DefaultMMDemoInput, DefaultMMDetDemoInput,
DefaultMMPoseDemoInput, DefaultMMRotateDemoInput,
Expand Down Expand Up @@ -70,8 +71,12 @@ def get_default_demo_input_class(model, scope):

def defaul_demo_inputs(model, input_shape, training=False, scope=None):
"""Get demo input according to a model and scope."""
demo_input = get_default_demo_input_class(model, scope)
return demo_input().get_data(model, input_shape, training)
if isinstance(model, BaseAlgorithm):
return defaul_demo_inputs(model.architecture, input_shape, training,
scope)
else:
demo_input = get_default_demo_input_class(model, scope)
return demo_input().get_data(model, input_shape, training)


@TASK_UTILS.register_module()
Expand Down
6 changes: 4 additions & 2 deletions mmrazor/models/task_modules/demo_inputs/demo_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def _get_data(self, model, input_shape=None, training=None):
return data

def _get_mm_data(self, model, input_shape, training=False):
return {'inputs': torch.rand(input_shape), 'data_samples': None}
data = {'inputs': torch.rand(input_shape), 'data_samples': None}
data = model.data_preprocessor(data, training)
return data


@TASK_UTILS.register_module()
Expand Down Expand Up @@ -132,7 +134,7 @@ def _get_mm_data(self, model, input_shape, training=False):
from mmpose.models import TopdownPoseEstimator

from .mmpose_demo_input import demo_mmpose_inputs
assert isinstance(model, TopdownPoseEstimator)
assert isinstance(model, TopdownPoseEstimator), f'{type(model)}'

data = demo_mmpose_inputs(model, input_shape)
return data
3 changes: 3 additions & 0 deletions projects/cores/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .counters import * # noqa
from .hooks import * # noqa
from .models import * # noqa
60 changes: 60 additions & 0 deletions projects/cores/counters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn

from mmrazor.models.task_modules.estimators.counters import (Conv2dCounter,
LinearCounter)
from mmrazor.registry import TASK_UTILS


@TASK_UTILS.register_module()
class DynamicConv2dCounter(Conv2dCounter):
"""Flop counter for DynamicCon2d."""

@staticmethod
def add_count_hook(module: nn.Conv2d, input: Tuple[torch.Tensor],
output: torch.Tensor) -> None:
"""Count the flops and params of a DynamicConv2d.

Args:
module (nn.Conv2d): A Conv2d module.
input (Tuple[torch.Tensor]): Input of this module.
output (torch.Tensor): Output of this module.
"""
batch_size = input[0].shape[0]
output_dims = list(output.shape[2:])

kernel_dims = list(module.kernel_size)

out_channels = module.mutable_attrs['out_channels'].activated_channels
in_channels = module.mutable_attrs['in_channels'].activated_channels

groups = module.groups

filters_per_channel = out_channels / groups
conv_per_position_flops = int(
np.prod(kernel_dims)) * in_channels * filters_per_channel

active_elements_count = batch_size * int(np.prod(output_dims))

overall_conv_flops = conv_per_position_flops * active_elements_count
overall_params = conv_per_position_flops

bias_flops = 0
overall_params = conv_per_position_flops
if module.bias is not None:
bias_flops = out_channels * active_elements_count
overall_params += out_channels

overall_flops = overall_conv_flops + bias_flops

module.__flops__ += overall_flops
module.__params__ += int(overall_params)


@TASK_UTILS.register_module()
class DynamicLinearCounter(LinearCounter):
pass
1 change: 1 addition & 0 deletions projects/cores/expandable_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""This module is used to expand the channels of a supernet."""
209 changes: 209 additions & 0 deletions projects/cores/expandable_ops/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import torch
import torch.nn as nn

from mmrazor.models.architectures import dynamic_ops
from mmrazor.models.mutables import MutableChannelContainer


class ExpandableMixin:
"""This minin coroperates with dynamic ops.

It defines interfaces to expand the channels of ops. We can get a wider
network than original supernet with it.
"""

def expand(self, zero=False):
"""Expand the op.

Args:
zero (bool, optional): whether to set new weights to zero. Defaults
to False.
"""
return self.get_expand_op(
self.expanded_in_channel,
self.expanded_out_channel,
zero=zero,
)

def get_expand_op(self, in_c, out_c, zero=False):
"""Get an expanded op.

Args:
in_c (int): New input channels
out_c (int): New output channels
zero (bool, optional): Whether to zero new weights. Defaults to
False.
"""
pass

@property
def _original_in_channel(self):
"""Return original in channel."""
raise NotImplementedError()

@property
def _original_out_channel(self):
"""Return original out channel."""

@property
def expanded_in_channel(self):
"""Return expanded in channel number."""
if self.in_mutable is not None:
return self.in_mutable.current_mask.numel()
else:
return self._original_in_channel

@property
def expanded_out_channel(self):
"""Return expanded out channel number."""
if self.out_mutable is not None:
return self.out_mutable.current_mask.numel()
else:
return self._original_out_channel

@property
def mutable_in_mask(self):
"""Return the mutable in mask."""
if self.in_mutable is not None:
return self.in_mutable.current_mask
else:
if hasattr(self, 'weight'):
return self.weight.new_ones([self.expanded_in_channel])
else:
return torch.ones([self.expanded_in_channel])

@property
def mutable_out_mask(self):
"""Return the mutable out mask."""
if self.out_mutable is not None:
return self.out_mutable.current_mask
else:
if hasattr(self, 'weight'):
return self.weight.new_ones([self.expanded_out_channel])
else:
return torch.ones([self.expanded_out_channel])

@property
def in_mutable(self) -> MutableChannelContainer:
"""In channel mask."""
return self.get_mutable_attr('in_channels') # type: ignore

@property
def out_mutable(self) -> MutableChannelContainer:
"""Out channel mask."""
return self.get_mutable_attr('out_channels') # type: ignore

def zero_weight_(self: nn.Module):
"""Zero all weights."""
for p in self.parameters():
p.data.zero_()

@torch.no_grad()
def expand_matrix(self, weight: torch.Tensor, old_weight: torch.Tensor):
"""Expand weight matrix."""
assert len(weight.shape) == 3 # out in c
assert len(old_weight.shape) == 3 # out in c
mask = self.mutable_out_mask.float().unsqueeze(
-1) * self.mutable_in_mask.float().unsqueeze(0)
mask = mask.unsqueeze(-1).expand(*weight.shape)
weight.data.masked_scatter_(mask.bool(), old_weight)
return weight

@torch.no_grad()
def expand_vector(self, weight: torch.Tensor, old_weight: torch.Tensor):
"""Expand weight vector."""
assert len(weight.shape) == 2 # out c
assert len(old_weight.shape) == 2 # out c
mask = self.mutable_out_mask
mask = mask.unsqueeze(-1).expand(*weight.shape)
weight.data.masked_scatter_(mask.bool(), old_weight)
return weight

@torch.no_grad()
def expand_bias(self, bias: torch.Tensor, old_bias: torch.Tensor):
"""Expand bias."""
assert len(bias.shape) == 1 # out c
assert len(old_bias.shape) == 1 # out c
return self.expand_vector(bias.unsqueeze(-1),
old_bias.unsqueeze(-1)).squeeze(1)


class ExpandableConv2d(dynamic_ops.DynamicConv2d, ExpandableMixin):

@property
def _original_in_channel(self):
return self.in_channels

@property
def _original_out_channel(self):
return self.out_channels

def get_expand_op(self, in_c, out_c, zero=False):
module = nn.Conv2d(in_c, out_c, self.kernel_size, self.stride,
self.padding, self.dilation, self.groups, self.bias
is not None, self.padding_mode)
if zero:
ExpandableMixin.zero_weight_(module)

weight = self.expand_matrix(
module.weight.flatten(2), self.weight.flatten(2))
module.weight.data = weight.reshape(module.weight.shape)
if module.bias is not None and self.bias is not None:
bias = self.expand_vector(
module.bias.unsqueeze(-1), self.bias.unsqueeze(-1))
module.bias.data = bias.reshape(module.bias.shape)
return module


class ExpandLinear(dynamic_ops.DynamicLinear, ExpandableMixin):

@property
def _original_in_channel(self):
return self.in_features

@property
def _original_out_channel(self):
return self.out_features

def get_expand_op(self, in_c, out_c, zero=False):
module = nn.Linear(in_c, out_c, self.bias is not None)
if zero:
ExpandableMixin.zero_weight_(module)

weight = self.expand_matrix(
module.weight.unsqueeze(-1), self.weight.unsqueeze(-1))
module.weight.data = weight.reshape(module.weight.shape)
if module.bias is not None:
bias = self.expand_vector(
module.bias.unsqueeze(-1), self.bias.unsqueeze(-1))
module.bias.data = bias.reshape(module.bias.shape)
return module


class ExpandableBatchNorm2d(dynamic_ops.DynamicBatchNorm2d, ExpandableMixin):

@property
def _original_in_channel(self):
return self.num_features

@property
def _original_out_channel(self):
return self.num_features

def get_expand_op(self, in_c, out_c, zero=False):
assert in_c == out_c
module = nn.BatchNorm2d(in_c, self.eps, self.momentum, self.affine,
self.track_running_stats)
if zero:
ExpandableMixin.zero_weight_(module)

if module.running_mean is not None:
module.running_mean.data = self.expand_bias(
module.running_mean, self.running_mean)

if module.running_var is not None:
module.running_var.data = self.expand_bias(module.running_var,
self.running_var)
module.weight.data = self.expand_bias(module.weight, self.weight)
module.bias.data = self.expand_bias(module.bias, self.bias)
return module
Loading