Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
support prelu to speedup model
Browse files Browse the repository at this point in the history
  • Loading branch information
twmht committed Jun 16, 2021
1 parent 71fc4da commit 1062411
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
27 changes: 27 additions & 0 deletions nni/compression/pytorch/speedup/compress_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask),
'PReLU': lambda module, mask: replace_prelu(module, mask),
'ReLU6': lambda module, mask: no_replace(module, mask),
'Sigmoid': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask),
Expand All @@ -31,6 +32,32 @@ def no_replace(module, mask):
_logger.debug("no need to replace")
return module

def replace_prelu(norm, mask):
"""
Parameters
----------
norm : torch.nn.BatchNorm2d
The batchnorm module to be replace
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.BatchNorm2d
The new batchnorm module
"""
assert isinstance(mask, ModuleMasks)
assert 'weight' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0]
_logger.debug("replace prelu with num_features: %d", num_features)
if num_features == 0:
return torch.nn.Identity()
new_norm = torch.nn.PReLU(num_features)
# assign weights
new_norm.weight.data = torch.index_select(norm.weight.data, 0, index)
# print (f'replace prelu {new_norm.weight.data}')
return new_norm

def replace_linear(linear, mask):
"""
Expand Down
58 changes: 58 additions & 0 deletions nni/compression/pytorch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __repr__(self):
infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'PReLU': lambda module_masks, mask: prelu_inshape(module_masks, mask),
'Sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
Expand Down Expand Up @@ -293,6 +294,7 @@ def __repr__(self):
'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),

'ReLU': lambda module_masks, mask: relu_outshape(module_masks, mask),
'PReLU': lambda module_masks, mask: prelu_outshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_outshape(module_masks, mask),
Expand Down Expand Up @@ -735,6 +737,62 @@ def maxpool2d_outshape(module_masks, mask):
module_masks.set_output_mask(mask)
return mask

def prelu_inshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
weight_cmask = CoarseMask(num_dim=1)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
return mask

def prelu_outshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None

weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)

return mask


def relu_inshape(module_masks, mask):
"""
Expand Down

0 comments on commit 1062411

Please sign in to comment.