-
Notifications
You must be signed in to change notification settings - Fork 488
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/sg 849 add replace in channels (#1557)
* wip * first draft * wip * remove self.in_channels * add get_input_channels to all SgModel * ADD TESTS * SupportsReplaceInChannels -> SupportsReplaceInputChannels * remove unwanted comment * replace_in_channels -> replace_input_channels * remove unwanted comment * add docstring * rename replace_input_channels_with_random_weights -> replace_conv2d_input_channels + adding docstings * fix * update test to also run foward * add pretrained test * add minor docstring * use existing channels when replacing * set self.in_channels in replace_input_channels * add num_input_channels in models.get * automatically set self.input_channels when calling replace_input_channels * remove #TODO * add to train_from_recipe * add to * add to kd * make inherit from
- Loading branch information
1 parent
41c97e7
commit c0594a3
Showing
41 changed files
with
847 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import Optional, Callable | ||
|
||
import torch | ||
from torch import nn | ||
|
||
__all__ = ["replace_conv2d_input_channels", "replace_conv2d_input_channels_with_random_weights"] | ||
|
||
|
||
def replace_conv2d_input_channels(conv: nn.Conv2d, in_channels: int, fn: Optional[Callable[[nn.Conv2d, int], nn.Conv2d]] = None) -> nn.Module: | ||
"""Instantiate a new Conv2d module with same attributes as input Conv2d, except for the input channels. | ||
:param conv: Conv2d to replace the input channels in. | ||
:param in_channels: New number of input channels. | ||
:param fn: (Optional) Function to instantiate the new Conv2d. | ||
By default, it will initialize the new weights with the same mean and std as the original weights. | ||
:return: Conv2d with new number of input channels. | ||
""" | ||
if fn: | ||
return fn(conv, in_channels) | ||
else: | ||
return replace_conv2d_input_channels_with_random_weights(conv=conv, in_channels=in_channels) | ||
|
||
|
||
def replace_conv2d_input_channels_with_random_weights(conv: nn.Conv2d, in_channels: int) -> nn.Conv2d: | ||
""" | ||
Replace the input channels in the input Conv2d with random weights. | ||
Returned module will have the same device and dtype as the original module. | ||
Random weights are initialized with the same mean and std as the original weights. | ||
:param conv: Conv2d to replace the input channels in. | ||
:param in_channels: New number of input channels. | ||
:return: Conv2d with new number of input channels. | ||
""" | ||
|
||
if in_channels % conv.groups != 0: | ||
raise ValueError( | ||
f"Incompatible number of input channels ({in_channels}) with the number of groups ({conv.groups})." | ||
f"The number of input channels must be divisible by the number of groups." | ||
) | ||
|
||
new_conv = nn.Conv2d( | ||
in_channels, | ||
conv.out_channels, | ||
kernel_size=conv.kernel_size, | ||
stride=conv.stride, | ||
padding=conv.padding, | ||
dilation=conv.dilation, | ||
groups=conv.groups, | ||
bias=conv.bias is not None, | ||
device=conv.weight.device, | ||
dtype=conv.weight.dtype, | ||
) | ||
|
||
if in_channels <= conv.in_channels: | ||
new_conv.weight.data = conv.weight.data[:, :in_channels, ...] | ||
else: | ||
new_conv.weight.data[:, : conv.in_channels, ...] = conv.weight.data | ||
|
||
# Pad the remaining channels with random weights | ||
torch.nn.init.normal_(new_conv.weight.data[:, conv.in_channels :, ...], mean=conv.weight.mean().item(), std=conv.weight.std().item()) | ||
|
||
if conv.bias is not None: | ||
torch.nn.init.normal_(new_conv.bias, mean=conv.bias.mean().item(), std=conv.bias.std().item()) | ||
|
||
return new_conv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.