Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unpatch to dropconnect #198

Merged
merged 10 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions baal/bayesian/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Callable
import copy
import warnings
from typing import Callable, Optional

import torch
from torch import nn


def replace_layers_in_module(module: nn.Module, mapping_fn: Callable) -> bool:
def replace_layers_in_module(module: nn.Module, mapping_fn: Callable, *args, **kwargs) -> bool:
"""
Recursively iterate over the children of a module and replace them according to `mapping_fn`.

Expand All @@ -11,12 +15,49 @@ def replace_layers_in_module(module: nn.Module, mapping_fn: Callable) -> bool:
"""
changed = False
for name, child in module.named_children():
new_module = mapping_fn(child)
new_module = mapping_fn(child, *args, **kwargs)

if new_module is not None:
changed = True
module.add_module(name, new_module)

# recursively apply to child
changed |= replace_layers_in_module(child, mapping_fn)
changed |= replace_layers_in_module(child, mapping_fn, *args, **kwargs)
return changed


class BayesianModule(torch.nn.Module):
patching_function: Callable[..., torch.nn.Module]
unpatch_function: Callable[..., torch.nn.Module]

def __init__(self, module, *args, **kwargs):
super().__init__()
self.parent_module = self.__class__.patching_function(module, *args, **kwargs)

def forward(self, *args, **kwargs):
return self.parent_module(*args, **kwargs)

def unpatch(self) -> torch.nn.Module:
return self.__class__.unpatch_function(self.parent_module)

# Context Manager
def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.unpatch()


def _patching_wrapper(
module: nn.Module,
inplace: bool,
patching_fn: Callable[..., Optional[nn.Module]],
*args,
**kwargs
) -> nn.Module:
if not inplace:
module = copy.deepcopy(module)
changed = replace_layers_in_module(module, patching_fn, *args, **kwargs)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module
35 changes: 6 additions & 29 deletions baal/bayesian/consistent_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.nn import functional as F
from torch.nn.modules.dropout import _DropoutNd

from baal.bayesian.common import replace_layers_in_module
from baal.bayesian.common import replace_layers_in_module, _patching_wrapper, BayesianModule


class ConsistentDropout(_DropoutNd):
Expand Down Expand Up @@ -115,12 +115,7 @@ def patch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Modu
The modified module, which is either the same object as you passed in
(if inplace = True) or a copy of that object.
"""
if not inplace:
module = copy.deepcopy(module)
changed = replace_layers_in_module(module, _consistent_dropout_mapping_fn)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module
return _patching_wrapper(module, inplace=inplace, patching_fn=_consistent_dropout_mapping_fn)


def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module:
Expand All @@ -137,12 +132,7 @@ def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Mo
The modified module, which is either the same object as you passed in
(if inplace = True) or a copy of that object.
"""
if not inplace:
module = copy.deepcopy(module)
changed = replace_layers_in_module(module, _consistent_dropout_unmapping_fn)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module
return _patching_wrapper(module, inplace=inplace, patching_fn=_consistent_dropout_unmapping_fn)


def _consistent_dropout_mapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
Expand All @@ -163,19 +153,6 @@ def _consistent_dropout_unmapping_fn(module: torch.nn.Module) -> Optional[nn.Mod
return new_module


class MCConsistentDropoutModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module):
"""Create a module that with all dropout layers patched.

Args:
module (torch.nn.Module):
A fully specified neural network.
"""
super().__init__()
self.parent_module = patch_module(module)

def forward(self, *args, **kwargs):
return self.parent_module.forward(*args, **kwargs)

def unpatch(self) -> torch.nn.Module:
return unpatch_module(self.parent_module)
class MCConsistentDropoutModule(BayesianModule):
patching_function = patch_module
unpatch_function = unpatch_module
40 changes: 11 additions & 29 deletions baal/bayesian/dropout.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import copy
import warnings
from typing import Optional

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.dropout import _DropoutNd

from baal.bayesian.common import replace_layers_in_module
from baal.bayesian.common import BayesianModule, _patching_wrapper


class Dropout(_DropoutNd):
Expand Down Expand Up @@ -100,12 +98,7 @@ def patch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Modu
The modified module, which is either the same object as you passed in
(if inplace = True) or a copy of that object.
"""
if not inplace:
module = copy.deepcopy(module)
changed = replace_layers_in_module(module, _dropout_mapping_fn)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module
return _patching_wrapper(module, inplace=inplace, patching_fn=_dropout_mapping_fn)


def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module:
Expand All @@ -122,12 +115,7 @@ def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Mo
The modified module, which is either the same object as you passed in
(if inplace = True) or a copy of that object.
"""
if not inplace:
module = copy.deepcopy(module)
changed = replace_layers_in_module(module, _dropout_unmapping_fn)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module
return _patching_wrapper(module, inplace=inplace, patching_fn=_dropout_unmapping_fn)


def _dropout_mapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
Expand All @@ -148,19 +136,13 @@ def _dropout_unmapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
return new_module


class MCDropoutModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module):
"""Create a module that with all dropout layers patched.
class MCDropoutModule(BayesianModule):
"""Create a module that with all dropout layers patched.

Args:
module (torch.nn.Module):
A fully specified neural network.
"""
super().__init__()
self.parent_module = patch_module(module)

def forward(self, *args, **kwargs):
return self.parent_module(*args, **kwargs)
Args:
module (torch.nn.Module):
A fully specified neural network.
"""

def unpatch(self) -> torch.nn.Module:
return unpatch_module(self.parent_module)
patching_function = patch_module
unpatch_function = unpatch_module
106 changes: 64 additions & 42 deletions baal/bayesian/weight_drop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import warnings
from typing import List

from typing import List, Optional, cast, Dict
from baal.bayesian.common import replace_layers_in_module, _patching_wrapper, BayesianModule
import torch
from torch import nn
from packaging.version import parse as parse_version


Expand All @@ -13,7 +14,16 @@ def get_weight_drop_module(name: str, weight_dropout, **kwargs):
return {"Conv2d": WeightDropConv2d, "Linear": WeightDropLinear}[name](weight_dropout, **kwargs)


class WeightDropLinear(torch.nn.Linear):
class WeightDropMixin:
_kwargs: Dict

def unpatch(self):
new_module = self.__class__.__bases__[0](**self._kwargs)
new_module.load_state_dict(self.state_dict())
return new_module


class WeightDropLinear(torch.nn.Linear, WeightDropMixin):
"""
Thanks to PytorchNLP for the initial implementation
# code from https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/weight_drop.html
Expand All @@ -27,16 +37,16 @@ class WeightDropLinear(torch.nn.Linear):

def __init__(self, weight_dropout=0.0, **kwargs):
wanted = ["in_features", "out_features"]
kwargs = {k: v for k, v in kwargs.items() if k in wanted}
super().__init__(**kwargs)
self._kwargs = {k: v for k, v in kwargs.items() if k in wanted}
super().__init__(**self._kwargs)
self._weight_dropout = weight_dropout

def forward(self, input):
w = torch.nn.functional.dropout(self.weight, p=self._weight_dropout, training=True)
return torch.nn.functional.linear(input, w, self.bias)


class WeightDropConv2d(torch.nn.Conv2d):
class WeightDropConv2d(torch.nn.Conv2d, WeightDropMixin):
"""
Reimplemmentation of WeightDrop for Conv2D. Thanks to PytorchNLP for the initial implementation
of class WeightDropLinear. Their `License
Expand All @@ -49,8 +59,8 @@ class WeightDropConv2d(torch.nn.Conv2d):

def __init__(self, weight_dropout=0.0, **kwargs):
wanted = ["in_channels", "out_channels", "kernel_size", "dilation", "padding"]
kwargs = {k: v for k, v in kwargs.items() if k in wanted}
super().__init__(**kwargs)
self._kwargs = {k: v for k, v in kwargs.items() if k in wanted}
super().__init__(**self._kwargs)
self._weight_dropout = weight_dropout
self._torch_version = parse_version(torch.__version__)

Expand All @@ -70,7 +80,8 @@ def forward(self, input):
def patch_module(
module: torch.nn.Module, layers: Sequence, weight_dropout: float = 0.0, inplace: bool = True
) -> torch.nn.Module:
"""Replace given layers with weight_drop module of that layer.
"""
Replace given layers with weight_drop module of that layer.

Args:
module : torch.nn.Module
Expand All @@ -86,41 +97,57 @@ def patch_module(
The modified module, which is either the same object as you passed in
(if inplace = True) or a copy of that object.
"""
if not inplace:
module = copy.deepcopy(module)
changed = _patch_layers(module, layers, weight_dropout)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module
return _patching_wrapper(
module,
inplace=inplace,
patching_fn=_dropconnect_mapping_fn,
layers=layers,
weight_dropout=weight_dropout,
)


def _patch_layers(module: torch.nn.Module, layers: Sequence, weight_dropout: float) -> bool:
def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way that we can reduce the duplications for patch_module and unpatch_module ?
I understand that we don't want to expose the user to the unnecessary complexities hence not using the mapping_fn as an input so maybe there is not other way around

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, we could uniformize all of these.

"""
Recursively iterate over the children of a module and replace them if
they are in the layers list. This function operates in-place.
Unpatch Dropconnect module to recover initial module.

Args:
module (torch.nn.Module):
The module in which you would like to replace dropout layers.
inplace (bool, optional):
Whether to modify the module in place or return a copy of the module.

Returns:
torch.nn.Module
The modified module, which is either the same object as you passed in
(if inplace = True) or a copy of that object.
"""
changed = False
for name, child in module.named_children():
new_module = None
for layer in layers:
if isinstance(child, getattr(torch.nn, layer)):
new_module = get_weight_drop_module(layer, weight_dropout, **child.__dict__)
break
return _patching_wrapper(module, inplace=inplace, patching_fn=_droconnect_unmapping_fn)

if new_module is not None:
changed = True
module.add_module(name, new_module)

# The dropout layer should be deactivated to use DropConnect.
if isinstance(child, torch.nn.Dropout):
child.p = 0
def _dropconnect_mapping_fn(module: torch.nn.Module, layers, weight_dropout) -> Optional[nn.Module]:
new_module: Optional[nn.Module] = None
for layer in layers:
if isinstance(module, getattr(torch.nn, layer)):
new_module = get_weight_drop_module(layer, weight_dropout, **module.__dict__)
break
if isinstance(module, nn.Dropout):
module._baal_p: float = module.p # type: ignore
module.p = 0.0
return new_module

# Recursively apply to child.
changed |= _patch_layers(child, layers, weight_dropout)
return changed

def _droconnect_unmapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
new_module: Optional[nn.Module] = None
if isinstance(module, WeightDropMixin):
new_module = module.unpatch()

class MCDropoutConnectModule(torch.nn.Module):
if isinstance(module, nn.Dropout):
module.p = module._baal_p # type: ignore

return new_module


class MCDropoutConnectModule(BayesianModule):
"""Create a module that with all dropout layers patched.
With MCDropoutConnectModule, it could be decided which type of modules to be
replaced.
Expand All @@ -133,10 +160,5 @@ class MCDropoutConnectModule(torch.nn.Module):
weight_dropout (float): The probability a weight will be dropped.
"""

def __init__(self, module: torch.nn.Module, layers: Sequence, weight_dropout=0.0):
super().__init__()
self.parent_module = module
_patch_layers(self.parent_module, layers, weight_dropout)

def forward(self, x):
return self.parent_module(x)
patching_function = patch_module
unpatch_function = unpatch_module
10 changes: 3 additions & 7 deletions tests/bayesian/consistent_dropout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def a_model_with_dropout():
torch.nn.Linear(5, 5),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(5, 2),
))
)).eval()

def test_1d_eval_is_not_stochastic(a_model_with_dropout):
dummy_input = torch.randn(8, 10)
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_patch_module_replaces_all_dropout_layers(inplace, a_model_with_dropout)
)


def test_module_class_replaces_dropout_layers(a_model_with_dropout):
def test_module_class_replaces_dropout_layers(a_model_with_dropout, is_deterministic):
dummy_input = torch.randn(8, 10)
test_module = a_model_with_dropout
test_mc_module = baal.bayesian.consistent_dropout.MCConsistentDropoutModule(test_module)
Expand Down Expand Up @@ -99,11 +99,7 @@ def test_module_class_replaces_dropout_layers(a_model_with_dropout):
# Check that unpatch works
module = test_mc_module.unpatch()
module.eval()
with torch.no_grad():
assert all(
torch.allclose(module(dummy_input), module(dummy_input))
for _ in range(10)
)
assert is_deterministic(module, (8, 10))
assert not any(isinstance(mod, baal.bayesian.consistent_dropout.ConsistentDropout) for mod in module.modules())

@pytest.mark.parametrize("inplace", (True, False))
Expand Down
Loading