Skip to content

Commit

Permalink
[Features]Support MethodInputsRecorder and FunctionInputsRecorder (
Browse files Browse the repository at this point in the history
…#320)

* support MethodInputsRecorder and FunctionInputsRecorder

* fix bugs that the model can not be pickled

* WIP: add pytest for ema model

* fix bugs in recorder and delivery when ema_hook is used

* don't register the DummyDataset

* fix pytest
  • Loading branch information
HIT-cwh authored Oct 24, 2022
1 parent 31052ea commit 972fd8e
Show file tree
Hide file tree
Showing 13 changed files with 535 additions and 68 deletions.
4 changes: 2 additions & 2 deletions mmrazor/models/task_modules/delivery/distill_delivery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from queue import Queue
from collections import deque
from typing import Callable


Expand Down Expand Up @@ -33,7 +33,7 @@ class DistillDelivery(metaclass=ABCMeta):
def __init__(self, max_keep_data: int = 1) -> None:

self._override_data = False
self.data_queue: Queue = Queue(maxsize=max_keep_data)
self.data_queue: deque = deque([], maxlen=max_keep_data)
self.max_keep_data = max_keep_data

@property
Expand Down
50 changes: 29 additions & 21 deletions mmrazor/models/task_modules/delivery/function_outputs_delivery.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,7 @@ def __init__(self, func_path: str, max_keep_data: int):
super().__init__(max_keep_data)

self._check_valid_path(func_path)
module_path = self._get_module_path(func_path)
try:
module = import_modules_from_strings(module_path)
except ImportError:
raise ImportError(f'{module_path} is not imported correctly.')
self.module = module

func_name = self._get_func_name(func_path)
assert hasattr(module, func_name), \
f'{func_name} is not in {module_path}.'
self.func_name = func_name

origin_func = getattr(module, func_name)
if not isinstance(origin_func, FunctionType):
raise TypeError(f'{func_name} should be a FunctionType '
f'instance, but got {type(origin_func)}')
self.origin_func = origin_func
self.func_path = func_path

@staticmethod
def _check_valid_path(func_path: str) -> None:
Expand All @@ -121,6 +105,24 @@ def __enter__(self) -> None:
Wrap the origin function.
"""
module_path = self._get_module_path(self.func_path)
try:
module = import_modules_from_strings(module_path)
except ImportError:
raise ImportError(f'{module_path} is not imported correctly.')
self.module = module

func_name = self._get_func_name(self.func_path)
assert hasattr(module, func_name), \
f'{func_name} is not in {module_path}.'
self.func_name = func_name

origin_func = getattr(module, func_name)
if not isinstance(origin_func, FunctionType):
raise TypeError(f'{func_name} should be a FunctionType '
f'instance, but got {type(origin_func)}')
self.origin_func = origin_func

wrapped_func = self.deliver_wrapper(self.origin_func)
setattr(self.module, self.func_name, wrapped_func)

Expand All @@ -131,6 +133,11 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
setattr(self.module, self.func_name, self.origin_func)

# self.module and self.origin_func can not be pickled.
# Delete these two attributes to avoid errors when ema model is used.
del self.module
del self.origin_func

def deliver_wrapper(self, origin_func: Callable) -> Callable:
"""Wrap the specific function to make the intermediate results of the
model can be delivered."""
Expand All @@ -139,12 +146,13 @@ def deliver_wrapper(self, origin_func: Callable) -> Callable:
def wrap_func(*args, **kwargs):

if self.override_data:
assert not self.data_queue.empty(), 'pop from an empty queue'
outputs = self.data_queue.get()
assert len(self.data_queue) > 0, 'pop from an empty queue'
outputs = self.data_queue.popleft()
else:
assert not self.data_queue.full(), 'push into an full queue'
assert len(self.data_queue) < self.data_queue.maxlen,\
'push into an full queue'
outputs = origin_func(*args, **kwargs)
self.data_queue.put(outputs)
self.data_queue.append(outputs)
return outputs

return wrap_func
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,13 @@ def deliver_wrapper(self, origin_method: Callable) -> Callable:
def wrap_method(*args, **kwargs):

if self.override_data:
assert not self.data_queue.empty(), 'pop from an empty queue'
outputs = self.data_queue.get()
assert len(self.data_queue) > 0, 'pop from an empty queue'
outputs = self.data_queue.popleft()
else:
assert not self.data_queue.full(), 'push into an full queue'
assert len(self.data_queue) < self.data_queue.maxlen,\
'push into an full queue'
outputs = origin_method(*args, **kwargs)
self.data_queue.put(outputs)
self.data_queue.append(outputs)
return outputs

return wrap_method
4 changes: 3 additions & 1 deletion mmrazor/models/task_modules/recorder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .function_inputs_recorder import FunctionInputsRecorder
from .function_outputs_recorder import FunctionOutputsRecorder
from .method_inputs_recorder import MethodInputsRecorder
from .method_outputs_recorder import MethodOutputsRecorder
from .module_inputs_recorder import ModuleInputsRecorder
from .module_outputs_recorder import ModuleOutputsRecorder
Expand All @@ -9,5 +11,5 @@
__all__ = [
'FunctionOutputsRecorder', 'MethodOutputsRecorder',
'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager',
'ModuleInputsRecorder'
'ModuleInputsRecorder', 'MethodInputsRecorder', 'FunctionInputsRecorder'
]
71 changes: 71 additions & 0 deletions mmrazor/models/task_modules/recorder/function_inputs_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from inspect import signature
from typing import Callable, List

from mmrazor.registry import TASK_UTILS
from .function_outputs_recorder import FunctionOutputsRecorder


@TASK_UTILS.register_module()
class FunctionInputsRecorder(FunctionOutputsRecorder):
"""Recorder for intermediate results which are ``FunctionType``'s inputs.
Notes:
The form of `source` needs special attention. For example,
`anchor_inside_flags` is a function in mmdetection to check whether the
anchors are inside the border. This function is in
`mmdet/core/anchor/utils.py` and used in
`mmdet/models/dense_heads/anchor_head.py`. Then the source should be
`mmdet.models.dense_heads.anchor_head.anchor_inside_flags` but not
`mmdet.core.anchor.utils.anchor_inside_flags`.
Examples:
>>> # Below code in toy_module.py
>>> import random
>>> def toy_func(a, b):
... return a, b
>>> def execute_toy_func(a, b):
... toy_func(a, b)
>>> # Below code in main.py
>>> # Now, we want to get teacher's inputs by recorder.
>>> from toy_module import execute_toy_func
>>> r1 = FunctionInputsRecorder('toy_module.toy_func')
>>> r1.initialize()
>>> with r1:
... execute_toy_func(1, 2)
... execute_toy_func(1, b=2)
... execute_toy_func(b=2, a=1)
>>> r1.data_buffer
[[1, 2], [1, 2], [1, 2]]
"""

def func_record_wrapper(self, origin_func: Callable,
data_buffer: List) -> Callable:
"""Save the function's inputs.
Args:
origin_func (FunctionType): The method whose inputs need to be
recorded.
data_buffer (list): A list of data.
"""

func_input_params = signature(origin_func).parameters.keys()

@functools.wraps(origin_func)
def wrap_func(*args, **kwargs):
outputs = origin_func(*args, **kwargs)
inputs = list(args)
for keyword in func_input_params:
if keyword in kwargs:
inputs.append(kwargs[keyword])
# assume a func execute N times, there will be N inputs need to
# save.
data_buffer.append(inputs)
return outputs

return wrap_func
49 changes: 25 additions & 24 deletions mmrazor/models/task_modules/recorder/function_outputs_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,8 @@ class FunctionOutputsRecorder(BaseRecorder):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self._check_valid_source(self.source)

# import the function corrosponding module
try:
mod = import_modules_from_strings(self.module_string)
except ImportError:
raise ImportError(
f'{self.module_string} is not imported correctly.')

self.imported_module: ModuleType = mod

assert hasattr(mod, self.func_name), \
f'{self.func_name} is not in {self.module_string}.'

origin_func = getattr(mod, self.func_name)
if not isinstance(origin_func, FunctionType):
raise TypeError(f'{self.func_name} should be a FunctionType '
f'instance, but got {type(origin_func)}')

self.origin_func: Callable = origin_func

@staticmethod
def _check_valid_source(source):
"""Check if the source's format is valid."""
Expand Down Expand Up @@ -118,8 +98,7 @@ def func_record_wrapper(self, origin_func: Callable,
Args:
origin_func (FunctionType): The method whose outputs need to be
recorded.
buffer_key (str): The key of the function's outputs saved in
``data_buffer``.
data_buffer (list): A list of data.
"""

@functools.wraps(origin_func)
Expand All @@ -136,8 +115,25 @@ def __enter__(self):
"""Enter the context manager."""
super().__enter__()

mod = self.imported_module
origin_func = self.origin_func
# import the function corrosponding module
try:
mod = import_modules_from_strings(self.module_string)
except ImportError:
raise ImportError(
f'{self.module_string} is not imported correctly.')

self.imported_module: ModuleType = mod

assert hasattr(mod, self.func_name), \
f'{self.func_name} is not in {self.module_string}.'

origin_func = getattr(mod, self.func_name)
if not isinstance(origin_func, FunctionType):
raise TypeError(f'{self.func_name} should be a FunctionType '
f'instance, but got {type(origin_func)}')

self.origin_func: Callable = origin_func

# add record wrapper to origin function.
record_func = self.func_record_wrapper(origin_func, self.data_buffer)

Expand All @@ -159,3 +155,8 @@ def __exit__(self, exc_type, exc_value, traceback):

# restore the origin function
setattr(mod, self.func_name, origin_func)

# self.imported_module and self.origin_func can not be pickled.
# Delete these two attributes to avoid errors when ema model is used.
del self.imported_module
del self.origin_func
83 changes: 83 additions & 0 deletions mmrazor/models/task_modules/recorder/method_inputs_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from inspect import signature
from typing import Callable, List

from mmrazor.registry import TASK_UTILS
from .method_outputs_recorder import MethodOutputsRecorder


@TASK_UTILS.register_module()
class MethodInputsRecorder(MethodOutputsRecorder):
"""Recorder for intermediate results which are ``MethodType``'s inputs.
Note:
Different from ``FunctionType``, ``MethodType`` is the type of methods
of class instances.
Examples:
>>> # Below code in toy_module.py
>>> import random
>>> class Toy():
... def toy_func(self, x, y=0):
... return x + y
>>> # Below code in main.py
>>> # Now, we want to get teacher's inputs by recorder.
>>> from toy_module import Toy
>>> toy = Toy()
>>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func')
>>> r1.initialize()
>>> with r1:
... _ = toy.toy_func(1, 2)
>>> r1.data_buffer
[[1, 2]]
>>> r1.get_record_data(record_idx=0, data_idx=0)
1
>>> r1.get_record_data(record_idx=0, data_idx=1)
2
>>> from toy_module import Toy
>>> toy = Toy()
>>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func')
>>> r1.initialize()
>>> with r1:
... _ = toy.toy_func(1, 2)
... _ = toy.toy_func(y=2, x=1)
>>> r1.data_buffer
[[1, 2], [1, 2]]
>>> r1.get_record_data(record_idx=1, data_idx=0)
1
>>> r1.get_record_data(record_idx=1, data_idx=1)
2
"""

def method_record_wrapper(self, orgin_method: Callable,
data_buffer: List) -> Callable:
"""Save the method's inputs.
Args:
origin_method (MethodType): The method whose inputs need to be
recorded.
data_buffer (list): A list of data.
"""

method_input_params = signature(orgin_method).parameters.keys()

@functools.wraps(orgin_method)
def wrap_method(*args, **kwargs):
outputs = orgin_method(*args, **kwargs)
# the first element of a class method is the class itself
inputs = list(args[1:])
for keyword in method_input_params:
if keyword in kwargs:
inputs.append(kwargs[keyword])
# Assume a func execute N times, there will be N inputs need to
# save.
data_buffer.append(inputs)
return outputs

return wrap_method
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def method_record_wrapper(self, orgin_method: Callable,
Args:
origin_method (MethodType): The method whose outputs need to be
recorded.
buffer_key (str): The key of the method's outputs saved in
``data_buffer``.
data_buffer (list): A list of data.
"""

@functools.wraps(orgin_method)
Expand Down
Loading

0 comments on commit 972fd8e

Please sign in to comment.