Skip to content

Commit

Permalink
[Fix] fix the bug in Recorders when nn.Module contains the 'inplace' …
Browse files Browse the repository at this point in the history
…attribute
  • Loading branch information
cape-zck committed Feb 6, 2023
1 parent a27952d commit 00bb24a
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 6 deletions.
49 changes: 47 additions & 2 deletions mmrazor/models/algorithms/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, OrderedDict, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -39,6 +39,8 @@ class BaseAlgorithm(BaseModel):
config of :class:`BaseDataPreprocessor`. Defaults to None.
init_cfg (dict): The weight initialized config for
:class:`BaseModule`.
module_inplace(bool): Whether to allow module inplace attribute True.
Defaults to False.
Note:
If `data_preprocessor` is None, :obj:`BaseAlgorithm` will set
Expand All @@ -56,7 +58,8 @@ class BaseAlgorithm(BaseModel):
def __init__(self,
architecture: Union[BaseModel, Dict],
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
init_cfg: Optional[Dict] = None) -> None:
init_cfg: Optional[Dict] = None,
module_inplace: bool = False) -> None:

# super().__init__() needs built data_preprocessor, so
# build model first.
Expand All @@ -79,6 +82,13 @@ def __init__(self,
# Cannot assign module before Module.__init__()
self.architecture = architecture

# Find all nn.Modules in the model that contain the 'inplace' attribute
# and set them to False
self.module_inplace = module_inplace
if not self.module_inplace:
self.set_module_inplace_false(architecture, 'self.architecture')
pass

def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
Expand Down Expand Up @@ -162,3 +172,38 @@ def _predict(
"""Predict results from a batch of inputs and data samples with post-
processing."""
return self.architecture(inputs, data_samples, mode='predict')

def set_module_inplace_false(self, architecture: Union[OrderedDict,
nn.Module],
varstr: str) -> None:
"""Find all nn.Modules in the model that contain the 'inplace'
attribute and set them to False in order to prevent occur error in
Recorders using recursion algorithm.
This function will disassemble the Args architecture .If type
'nn.Module' is detected, determine if it contains an 'inplace'
attribute and set False if it does. If none, get the OrderedDict
and then iterate through the dictionary to continue the recursive
search.
Args:
architecture (OrderedDict | nn.Module): The config OrderedDict
for model or built model.
varstr (str): Records the call-level string containing the
'inplace' attribute.
Returns:
None
"""

if isinstance(architecture, nn.Module):
if hasattr(eval(varstr), 'inplace'):
eval(varstr).inplace = False
else:
self.set_module_inplace_false(architecture._modules,
varstr + '._modules')
elif isinstance(architecture, OrderedDict):
for key, value in architecture.items():
self.set_module_inplace_false(value, varstr + f"['{key}']")
else:
return
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class SingleTeacherDistill(BaseAlgorithm):
to True.
calculate_student_loss (bool): Whether to calculate student loss
(original task loss) to update student model. Defaults to True.
teacher_module_inplace(bool): Whether to allow teacher module inplace
attribute True. Defaults to False.
"""

def __init__(self,
Expand All @@ -42,6 +44,7 @@ def __init__(self,
teacher_norm_eval: bool = True,
student_trainable: bool = True,
calculate_student_loss: bool = True,
teacher_module_inplace: bool = False,
**kwargs) -> None:
super().__init__(**kwargs)

Expand All @@ -56,6 +59,13 @@ def __init__(self,
f'{type(teacher)}')

self.teacher = teacher

# Find all nn.Modules in the model that contain the 'inplace' attribute
# and set them to False.
self.teacher_module_inplace = teacher_module_inplace
if not self.teacher_module_inplace:
self.set_module_inplace_false(teacher, 'self.teacher')

if teacher_ckpt:
# avoid loaded parameters be overwritten
self.teacher.init_weights()
Expand Down
30 changes: 26 additions & 4 deletions tests/test_models/test_algorithms/test_base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mmengine.model import BaseDataPreprocessor, BaseModel

from mmrazor.models import BaseAlgorithm
from mmrazor.models.task_modules import ModuleOutputsRecorder
from mmrazor.registry import MODELS


Expand All @@ -25,16 +26,18 @@ class ToyModel(BaseModel):
def __init__(self, data_preprocessor=None):
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
self.conv = nn.Conv2d(3, 1, 1)
self.bn = nn.BatchNorm2d(1)
self.relu = nn.ReLU(inplace=True)

def forward(self, inputs, data_samples=None, mode='tensor'):
def forward(self, batch_inputs, data_samples=None, mode='tensor'):
if mode == 'loss':
out = self.conv(inputs)
out = self.relu(self.bn(self.conv(batch_inputs)))
return dict(loss=out)
elif mode == 'predict':
out = self.conv(inputs) + 1
out = self.relu(self.bn(self.conv(batch_inputs) + 1))
return out
elif mode == 'tensor':
out = self.conv(inputs) + 2
out = self.relu(self.bn(self.conv(batch_inputs) + 2))
return out


Expand Down Expand Up @@ -100,3 +103,22 @@ def test_forward(self):

with self.assertRaisesRegex(RuntimeError, 'Invalid mode "A"'):
alg(inputs, mode='A')

def test_set_module_inplace_false(self):
inputs = torch.randn(1, 3, 8, 8)

model = ToyModel()
res_before = model(inputs)
_ = BaseAlgorithm(model)

r1 = ModuleOutputsRecorder('bn')
r1.initialize(model)
with r1:
res_after = model(inputs)
self.assertIs(torch.equal(res_before, res_after), True)

self.assertIs(model.relu.inplace, False)

self.assertIs(
torch.equal(r1.data_buffer[0], model.bn(model.conv(inputs) + 2)),
True)

0 comments on commit 00bb24a

Please sign in to comment.