Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 849 add replace in channels #1557

Merged
merged 37 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
1085517
wip
Louis-Dupont Oct 19, 2023
906b351
first draft
Louis-Dupont Oct 21, 2023
356fb18
wip
Louis-Dupont Oct 21, 2023
0c9c4d4
remove self.in_channels
Louis-Dupont Oct 21, 2023
2f2c43b
add get_input_channels to all SgModel
Louis-Dupont Oct 21, 2023
0cb05a9
ADD TESTS
Louis-Dupont Oct 21, 2023
3be5c64
SupportsReplaceInChannels -> SupportsReplaceInputChannels
Louis-Dupont Oct 21, 2023
8d654eb
remove unwanted comment
Louis-Dupont Oct 21, 2023
79e861e
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 21, 2023
975b0ed
replace_in_channels -> replace_input_channels
Louis-Dupont Oct 21, 2023
6d43aae
remove unwanted comment
Louis-Dupont Oct 21, 2023
87ce5a2
add docstring
Louis-Dupont Oct 22, 2023
622c61e
rename replace_input_channels_with_random_weights -> replace_conv2d_i…
Louis-Dupont Oct 23, 2023
6964249
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 23, 2023
db9cec1
fix
Louis-Dupont Oct 23, 2023
cd0b67e
Merge branch 'master' into hotfix/SG-000-fix_csp_darknet53_forward
Louis-Dupont Oct 23, 2023
4afccd0
update test to also run foward
Louis-Dupont Oct 23, 2023
b5144b9
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 23, 2023
e74b8be
Merge branch 'hotfix/SG-000-fix_csp_darknet53_forward' into feature/S…
Louis-Dupont Oct 23, 2023
d54f268
add pretrained test
Louis-Dupont Oct 23, 2023
37ab77b
add minor docstring
Louis-Dupont Oct 23, 2023
c386a30
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 23, 2023
3d114af
use existing channels when replacing
Louis-Dupont Oct 23, 2023
ab760e8
set self.in_channels in replace_input_channels
Louis-Dupont Oct 24, 2023
287b28f
add num_input_channels in models.get
Louis-Dupont Oct 24, 2023
12170e7
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 24, 2023
c56d553
automatically set self.input_channels when calling replace_input_chan…
Louis-Dupont Oct 24, 2023
b075062
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 25, 2023
cccecb6
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 25, 2023
9318806
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 26, 2023
bf64cad
remove #TODO
Louis-Dupont Oct 26, 2023
9d86c3e
add to train_from_recipe
Louis-Dupont Oct 26, 2023
0a38118
add to
Louis-Dupont Oct 26, 2023
a5e3c86
add to kd
Louis-Dupont Oct 26, 2023
01ed612
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 29, 2023
503129f
make inherit from
Louis-Dupont Oct 29, 2023
df11873
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,36 @@ def deprecated_nested_params_to_factory_format_assigner(**params):
return params

return deprecated_nested_params_to_factory_format_assigner


def deprecate_param(
deprecated_param_name: str,
new_param_name: str = "",
deprecated_since: str = "",
removed_from: str = "",
reason: str = "",
):
"""
Utility function to warn about a deprecated parameter (or dictionary key).

:param deprecated_param_name: Name of the deprecated parameter.
:param new_param_name: Name of the new parameter/key that should replace the deprecated one.
:param deprecated_since: Version number when the parameter was deprecated.
:param removed_from: Version number when the parameter will be removed.
:param reason: Additional information or reason for the deprecation.
"""
is_still_supported = deprecated_since < removed_from
status_msg = "is deprecated" if is_still_supported else "was deprecated and has been removed"
message = f"Parameter `{deprecated_param_name}` {status_msg} " f"since version `{deprecated_since}` and will be removed in version `{removed_from}`.\n"

if reason:
message += f"Reason: {reason}.\n"

if new_param_name:
message += f"Please update your code to use the `{new_param_name}` instead of `{deprecated_param_name}`."

if is_still_supported:
warnings.simplefilter("once", DeprecationWarning) # Required, otherwise the warning may never be displayed.
warnings.warn(message, DeprecationWarning)
else:
raise ValueError(message)
2 changes: 1 addition & 1 deletion src/super_gradients/conversion/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
(torch.int16, np.int16),
(torch.int8, np.int8),
(torch.uint8, np.uint8),
(torch.bool, np.bool),
# (torch.bool, np.bool),
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
]


Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses, SupportsReplaceInChannels
from .exceptions import ModelHasNoPreprocessingParamsException
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from .exportable_pose_estimation import ExportablePoseEstimationModel, PoseEstimationModelExportResult, AbstractPoseEstimationDecodingModule
Expand All @@ -8,6 +8,7 @@
"HasPredict",
"HasPreprocessingParams",
"SupportsReplaceNumClasses",
"SupportsReplaceInChannels",
"ExportableObjectDetectionModel",
"AbstractObjectDetectionDecodingModule",
"ModelHasNoPreprocessingParamsException",
Expand Down
27 changes: 27 additions & 0 deletions src/super_gradients/module_interfaces/module_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC
from typing import Callable, Optional, TYPE_CHECKING

from torch import nn
Expand Down Expand Up @@ -69,3 +70,29 @@ def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable
:return: None
"""
raise NotImplementedError(f"replace_num_classes is not implemented in the derived class {self.__class__.__name__}")


class SupportsReplaceInChannels(ABC):
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
"""
Protocol interface for modules that support replacing the number of input channels.
Derived classes should implement the `replace_in_channels` method.

This interface class serves the purpose of explicitly indicating whether a class supports optimized input channel replacement:

>>> class InputLayer(nn_Module, SupportsReplaceInChannels):
>>> def replace_in_channels(self, in_channels: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module] = None):
>>> ...

"""

# @abstractmethod
def replace_in_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]]):
"""
Replace the number of input channels in the module.

:param in_channels: New number of input channels.
:param compute_new_weights_fn: An optional function that computes the new weights for the new input channels.
It takes the existing nn_Module and returns a new one.
:return: None
"""
raise NotImplementedError(f"`replace_in_channels` is not implemented in the derived class `{self.__class__.__name__}`")
60 changes: 60 additions & 0 deletions src/super_gradients/modules/backbone_replacement_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Union, Optional, Callable

import torch
from torch import nn

__all__ = ["replace_in_channels_with_random_weights"]


def compute_new_weights(
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
module: Union[nn.Conv2d, nn.Linear, nn.Module],
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
in_channels: int,
fn: Optional[Callable[[nn.Module, int], nn.Module]] = None,
) -> nn.Module:
fn = fn or replace_in_channels_with_random_weights
return fn(module=module, in_channels=in_channels)


def replace_in_channels_with_random_weights(module: Union[nn.Conv2d, nn.Linear, nn.Module], in_channels: int) -> nn.Module:
"""
Replace the input channels in the module with random weights.
This is useful for replacing the input layer of a model.
This implementation supports Conv2d layers.
Returned module will have the same device and dtype as the original module.
Random weights are initialized with the same mean and std as the original weights.

:param module: (nn.Module) Module to replace the input channels in.
:param in_channels: New number of input channels.
:return: nn.Module
"""
if isinstance(module, nn.Conv2d):

if in_channels % module.groups != 0:
raise ValueError(
f"Incompatible number of input channels ({in_channels}) with the number of groups ({module.groups})."
f"The number of input channels must be divisible by the number of groups."
)

new_module = nn.Conv2d(
in_channels,
module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups, # Be cautious: if in_channels % groups != 0, it will raise an error.
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
)
torch.nn.init.normal_(new_module.weight, mean=module.weight.mean().item(), std=module.weight.std().item())
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
if module.bias is not None:
torch.nn.init.normal_(new_module.bias, mean=module.bias.mean().item(), std=module.bias.std().item())

return new_module
elif isinstance(module, nn.Sequential):
# TODO: check - is it as safe as it looks ? Any possible side effect?
module[0] = replace_in_channels_with_random_weights(module=module[0], in_channels=in_channels)
return module
else:
raise ValueError(f"Module {module} does not support replacing the input channels")
24 changes: 20 additions & 4 deletions src/super_gradients/modules/conv_bn_act_block.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Union, Tuple, Type
from typing import Union, Tuple, Type, Callable, Optional

from torch import nn

from super_gradients.modules.utils import autopad
from super_gradients.module_interfaces import SupportsReplaceInChannels


class ConvBNAct(nn.Module):
class ConvBNAct(nn.Module, SupportsReplaceInChannels):
"""
Class for Convolution2d-Batchnorm2d-Activation layer.
Default behaviour is Conv-BN-Act. To exclude Batchnorm module use
Expand Down Expand Up @@ -40,11 +41,13 @@ def __init__(
if activation_kwargs is None:
activation_kwargs = {}

self.in_channels = in_channels

self.seq = nn.Sequential()
self.seq.add_module(
"conv",
nn.Conv2d(
in_channels,
self.in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
Expand All @@ -67,6 +70,12 @@ def __init__(
def forward(self, x):
return self.seq(x)

def replace_in_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
from super_gradients.modules.backbone_replacement_utils import compute_new_weights

self.in_channels = in_channels
self.seq[0] = compute_new_weights(module=self.seq[0], in_channels=self.in_channels, fn=compute_new_weights_fn)


class Conv(nn.Module):
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
# STANDARD CONVOLUTION
Expand All @@ -76,7 +85,8 @@ class Conv(nn.Module):
def __init__(self, input_channels, output_channels, kernel, stride, activation_type: Type[nn.Module], padding: int = None, groups: int = None):
super().__init__()

self.conv = nn.Conv2d(input_channels, output_channels, kernel, stride, autopad(kernel, padding), groups=groups or 1, bias=False)
self.in_channels = input_channels
self.conv = nn.Conv2d(self.in_channels, output_channels, kernel, stride, autopad(kernel, padding), groups=groups or 1, bias=False)
self.bn = nn.BatchNorm2d(output_channels)
self.act = activation_type()

Expand All @@ -85,3 +95,9 @@ def forward(self, x):

def fuseforward(self, x):
return self.act(self.conv(x))

def replace_in_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
from super_gradients.modules.backbone_replacement_utils import compute_new_weights

self.in_channels = in_channels
self.conv = compute_new_weights(module=self.conv, in_channels=self.in_channels, fn=compute_new_weights_fn)
25 changes: 20 additions & 5 deletions src/super_gradients/modules/detection_modules.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from abc import ABC, abstractmethod
from typing import Union, List
from typing import Union, List, Optional, Callable

import torch
from torch import nn
from omegaconf import DictConfig
from omegaconf.listconfig import ListConfig

from super_gradients.common.registry.registry import register_detection_module
from super_gradients.modules.base_modules import BaseDetectionModule
from super_gradients.modules.multi_output_modules import MultiOutputModule
from super_gradients.training.models import MobileNet, MobileNetV2
from super_gradients.training.models.classification_models.mobilenetv2 import InvertedResidual
from super_gradients.training.utils.utils import HpmStruct
from torch import nn
from super_gradients.module_interfaces import SupportsReplaceInChannels


__all__ = [
"PANNeck",
Expand All @@ -28,7 +31,7 @@


@register_detection_module()
class NStageBackbone(BaseDetectionModule):
class NStageBackbone(BaseDetectionModule, SupportsReplaceInChannels):
"""
A backbone with a stem -> N stages -> context module
Returns outputs of the layers listed in out_layers
Expand Down Expand Up @@ -83,6 +86,13 @@ def forward(self, x):

return outputs

def replace_in_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
self.in_channels = in_channels
if isinstance(self.stem, SupportsReplaceInChannels):
self.stem.replace_in_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)
else:
raise NotImplementedError(f"`{self.stem.__class__.__name__}` does not support `replace_in_channels`")


@register_detection_module()
class PANNeck(BaseDetectionModule):
Expand Down Expand Up @@ -176,14 +186,15 @@ def combine_preds(self, preds):
return outputs if self.training else (torch.cat(outputs, 1), outputs_logits)


class MultiOutputBackbone(BaseDetectionModule):
class MultiOutputBackbone(BaseDetectionModule, SupportsReplaceInChannels):
"""
Defines a backbone using MultiOutputModule with the interface of BaseDetectionModule
"""

def __init__(self, in_channels: int, backbone: nn.Module, out_layers: List):
super().__init__(in_channels)
self.multi_output_backbone = MultiOutputModule(backbone, out_layers)
self.in_channels = in_channels
self.multi_output_backbone = MultiOutputModule(module=backbone, output_paths=out_layers)
self._out_channels = [x.shape[1] for x in self.forward(torch.empty((1, in_channels, 64, 64)))]

@property
Expand All @@ -193,6 +204,10 @@ def out_channels(self) -> Union[List[int], int]:
def forward(self, x):
return self.multi_output_backbone(x)

def replace_in_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
self.in_channels = in_channels
self.multi_output_backbone.replace_in_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)


@register_detection_module()
class MobileNetV1Backbone(MultiOutputBackbone):
Expand Down
12 changes: 11 additions & 1 deletion src/super_gradients/modules/multi_output_modules.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import OrderedDict
from typing import Optional, Callable
from torch import nn
from omegaconf.listconfig import ListConfig

from super_gradients.module_interfaces import SupportsReplaceInChannels

class MultiOutputModule(nn.Module):

class MultiOutputModule(nn.Module, SupportsReplaceInChannels):
"""
This module wraps around a container nn.Module (such as Module, Sequential and ModuleList) and allows to extract
multiple output from its inner modules on each forward call() (as a list of output tensors)
Expand Down Expand Up @@ -99,3 +102,10 @@ def _prune(self, module: nn.Module, output_paths: list):
def _slice_odict(self, odict: OrderedDict, start: int, end: int):
"""Slice an OrderedDict in the same logic list,tuple... are sliced"""
return OrderedDict([(k, v) for (k, v) in odict.items() if k in list(odict.keys())[start:end]])

def replace_in_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
module = self._modules["0"]
if isinstance(module, SupportsReplaceInChannels):
module.replace_in_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)
else:
raise NotImplementedError(f"`{module.__class__.__name__}` does not support `replace_in_channels`")
14 changes: 14 additions & 0 deletions src/super_gradients/modules/qarepvgg_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,17 @@ def prep_model_for_conversion(self, input_size: Optional[Union[tuple, list]] = N
self.full_fusion()
else:
self.partial_fusion()

# def replace_in_channels(self, new_in_channels: int):
# """
# Replace the in_channels of the block and initialize the new channels with random weights.
# """
# # Replace the 3x3 branch
# new_branch_3x3_conv = replace_in_channels_with_random_weights(self.branch_3x3.conv, new_in_channels)
# self.branch_3x3.conv = new_branch_3x3_conv
#
# # Replace the 1x1 branch
# new_branch_1x1_conv = replace_in_channels_with_random_weights(self.branch_1x1, new_in_channels)
# self.branch_1x1 = new_branch_1x1_conv
#
# self.in_channels = new_in_channels
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# --------------------------------------------------------'
import math
from functools import partial
from typing import Optional, Tuple
from typing import Optional, Tuple, Callable

import torch
import torch.nn as nn
Expand Down Expand Up @@ -322,7 +322,10 @@ def __init__(
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.grad_checkpointing = False

self.patch_embed = PatchEmbed(img_size=image_size, patch_size=patch_size, in_channels=in_chans, hidden_dim=embed_dim)
self.image_size = image_size
self.patch_size = patch_size
self.in_channels = in_chans
self.patch_embed = PatchEmbed(img_size=self.image_size, patch_size=self.patch_size, in_channels=self.in_channels, hidden_dim=self.embed_dim)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
Expand Down Expand Up @@ -453,6 +456,10 @@ def replace_head(self, new_num_classes=None, new_head=None):
else:
self.head = nn.Linear(self.head.in_features, new_num_classes)

def replace_in_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
self.in_channels = in_channels
self.patch_embed = PatchEmbed(img_size=self.image_size, patch_size=self.patch_size, in_channels=self.in_channels, hidden_dim=self.embed_dim)


@register_model(Models.BEIT_BASE_PATCH16_224)
class BeitBasePatch16_224(Beit):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -132,6 +134,12 @@ def forward(self, x):
out = self.classifier(out)
return out

def replace_in_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
from super_gradients.modules.backbone_replacement_utils import compute_new_weights

self.in_channels = in_channels
self.features[0] = compute_new_weights(module=self.features[0], in_channels=self.in_channels, fn=compute_new_weights_fn)


@register_model(Models.CUSTOM_DENSENET)
class CustomizedDensnet(DenseNet):
Expand Down
Loading