Skip to content

Commit

Permalink
Merge branch 'master' into feature/saving_metrics_to_yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
hakuryuu96 committed Oct 30, 2023
2 parents e7b8d83 + 827c4f8 commit 2ce24a7
Show file tree
Hide file tree
Showing 42 changed files with 855 additions and 95 deletions.
33 changes: 33 additions & 0 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,36 @@ def wrapper(*args, **training_params):
return wrapper

return decorator


def deprecate_param(
deprecated_param_name: str,
new_param_name: str = "",
deprecated_since: str = "",
removed_from: str = "",
reason: str = "",
):
"""
Utility function to warn about a deprecated parameter (or dictionary key).
:param deprecated_param_name: Name of the deprecated parameter.
:param new_param_name: Name of the new parameter/key that should replace the deprecated one.
:param deprecated_since: Version number when the parameter was deprecated.
:param removed_from: Version number when the parameter will be removed.
:param reason: Additional information or reason for the deprecation.
"""
is_still_supported = deprecated_since < removed_from
status_msg = "is deprecated" if is_still_supported else "was deprecated and has been removed"
message = f"Parameter `{deprecated_param_name}` {status_msg} " f"since version `{deprecated_since}` and will be removed in version `{removed_from}`.\n"

if reason:
message += f"Reason: {reason}.\n"

if new_param_name:
message += f"Please update your code to use the `{new_param_name}` instead of `{deprecated_param_name}`."

if is_still_supported:
warnings.simplefilter("once", DeprecationWarning) # Required, otherwise the warning may never be displayed.
warnings.warn(message, DeprecationWarning)
else:
raise ValueError(message)
3 changes: 2 additions & 1 deletion src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses, SupportsReplaceInputChannels
from .exceptions import ModelHasNoPreprocessingParamsException
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from .exportable_pose_estimation import ExportablePoseEstimationModel, PoseEstimationModelExportResult, AbstractPoseEstimationDecodingModule
Expand All @@ -8,6 +8,7 @@
"HasPredict",
"HasPreprocessingParams",
"SupportsReplaceNumClasses",
"SupportsReplaceInputChannels",
"ExportableObjectDetectionModel",
"AbstractObjectDetectionDecodingModule",
"ModelHasNoPreprocessingParamsException",
Expand Down
32 changes: 32 additions & 0 deletions src/super_gradients/module_interfaces/module_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC
from typing import Callable, Optional, TYPE_CHECKING

from torch import nn
Expand Down Expand Up @@ -69,3 +70,34 @@ def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable
:return: None
"""
raise NotImplementedError(f"replace_num_classes is not implemented in the derived class {self.__class__.__name__}")


class SupportsReplaceInputChannels(ABC):
"""
Protocol interface for modules that support replacing the number of input channels.
Derived classes should implement the `replace_input_channels` method.
This interface class serves the purpose of explicitly indicating whether a class supports optimized input channel replacement:
>>> class InputLayer(nn_Module, SupportsReplaceInputChannels):
>>> def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module] = None):
>>> ...
"""

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]]):
"""
Replace the number of input channels in the module.
:param in_channels: New number of input channels.
:param compute_new_weights_fn: (Optional) function that computes the new weights for the new input channels.
It takes the existing nn_Module and returns a new one.
"""
raise NotImplementedError(f"`replace_input_channels` is not implemented in the derived class `{self.__class__.__name__}`")

def get_input_channels(self) -> int:
"""Get the number of input channels for the model.
:return: Number of input channels.
"""
raise NotImplementedError(f"`get_input_channels` is not implemented in the derived class `{self.__class__.__name__}`")
23 changes: 20 additions & 3 deletions src/super_gradients/modules/conv_bn_act_block.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Union, Tuple, Type
from typing import Union, Tuple, Type, Callable, Optional

from torch import nn

from super_gradients.modules.utils import autopad
from super_gradients.module_interfaces import SupportsReplaceInputChannels


class ConvBNAct(nn.Module):
class ConvBNAct(nn.Module, SupportsReplaceInputChannels):
"""
Class for Convolution2d-Batchnorm2d-Activation layer.
Default behaviour is Conv-BN-Act. To exclude Batchnorm module use
Expand Down Expand Up @@ -67,8 +68,16 @@ def __init__(
def forward(self, x):
return self.seq(x)

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
from super_gradients.modules.weight_replacement_utils import replace_conv2d_input_channels

class Conv(nn.Module):
self.seq[0] = replace_conv2d_input_channels(conv=self.seq[0], in_channels=in_channels, fn=compute_new_weights_fn)

def get_input_channels(self) -> int:
return self.seq[0].in_channels


class Conv(nn.Module, SupportsReplaceInputChannels):
# STANDARD CONVOLUTION
# TODO: This class is illegaly similar to ConvBNAct, and the only reason it exists is due to fact that some models were using it
# previosly and one have to find a bulletproof way drop this class but still be able to load models that were using it. Perhaps
Expand All @@ -85,3 +94,11 @@ def forward(self, x):

def fuseforward(self, x):
return self.act(self.conv(x))

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
from super_gradients.modules.weight_replacement_utils import replace_conv2d_input_channels

self.conv = replace_conv2d_input_channels(conv=self.conv, in_channels=in_channels, fn=compute_new_weights_fn)

def get_input_channels(self) -> int:
return self.conv.in_channels
31 changes: 26 additions & 5 deletions src/super_gradients/modules/detection_modules.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from abc import ABC, abstractmethod
from typing import Union, List
from typing import Union, List, Optional, Callable

import torch
from torch import nn
from omegaconf import DictConfig
from omegaconf.listconfig import ListConfig

from super_gradients.common.registry.registry import register_detection_module
from super_gradients.modules.base_modules import BaseDetectionModule
from super_gradients.modules.multi_output_modules import MultiOutputModule
from super_gradients.training.models import MobileNet, MobileNetV2
from super_gradients.training.models.classification_models.mobilenetv2 import InvertedResidual
from super_gradients.training.utils.utils import HpmStruct
from torch import nn
from super_gradients.module_interfaces import SupportsReplaceInputChannels


__all__ = [
"PANNeck",
Expand All @@ -28,7 +31,7 @@


@register_detection_module()
class NStageBackbone(BaseDetectionModule):
class NStageBackbone(BaseDetectionModule, SupportsReplaceInputChannels):
"""
A backbone with a stem -> N stages -> context module
Returns outputs of the layers listed in out_layers
Expand Down Expand Up @@ -83,6 +86,18 @@ def forward(self, x):

return outputs

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
if isinstance(self.stem, SupportsReplaceInputChannels):
self.stem.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)
else:
raise NotImplementedError(f"`{self.stem.__class__.__name__}` does not support `replace_input_channels`")

def get_input_channels(self) -> int:
if isinstance(self.stem, SupportsReplaceInputChannels):
return self.stem.get_input_channels()
else:
raise NotImplementedError(f"`{self.stem.__class__.__name__}` does not support `get_input_channels`")


@register_detection_module()
class PANNeck(BaseDetectionModule):
Expand Down Expand Up @@ -176,14 +191,14 @@ def combine_preds(self, preds):
return outputs if self.training else (torch.cat(outputs, 1), outputs_logits)


class MultiOutputBackbone(BaseDetectionModule):
class MultiOutputBackbone(BaseDetectionModule, SupportsReplaceInputChannels):
"""
Defines a backbone using MultiOutputModule with the interface of BaseDetectionModule
"""

def __init__(self, in_channels: int, backbone: nn.Module, out_layers: List):
super().__init__(in_channels)
self.multi_output_backbone = MultiOutputModule(backbone, out_layers)
self.multi_output_backbone = MultiOutputModule(module=backbone, output_paths=out_layers)
self._out_channels = [x.shape[1] for x in self.forward(torch.empty((1, in_channels, 64, 64)))]

@property
Expand All @@ -193,6 +208,12 @@ def out_channels(self) -> Union[List[int], int]:
def forward(self, x):
return self.multi_output_backbone(x)

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
self.multi_output_backbone.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)

def get_input_channels(self) -> int:
return self.multi_output_backbone.get_input_channels()


@register_detection_module()
class MobileNetV1Backbone(MultiOutputBackbone):
Expand Down
19 changes: 18 additions & 1 deletion src/super_gradients/modules/multi_output_modules.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import OrderedDict
from typing import Optional, Callable
from torch import nn
from omegaconf.listconfig import ListConfig

from super_gradients.module_interfaces import SupportsReplaceInputChannels

class MultiOutputModule(nn.Module):

class MultiOutputModule(nn.Module, SupportsReplaceInputChannels):
"""
This module wraps around a container nn.Module (such as Module, Sequential and ModuleList) and allows to extract
multiple output from its inner modules on each forward call() (as a list of output tensors)
Expand Down Expand Up @@ -99,3 +102,17 @@ def _prune(self, module: nn.Module, output_paths: list):
def _slice_odict(self, odict: OrderedDict, start: int, end: int):
"""Slice an OrderedDict in the same logic list,tuple... are sliced"""
return OrderedDict([(k, v) for (k, v) in odict.items() if k in list(odict.keys())[start:end]])

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
module = self._modules["0"]
if isinstance(module, SupportsReplaceInputChannels):
module.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)
else:
raise NotImplementedError(f"`{module.__class__.__name__}` does not support `replace_input_channels`")

def get_input_channels(self) -> int:
module = self._modules["0"]
if isinstance(module, SupportsReplaceInputChannels):
return module.get_input_channels()
else:
raise NotImplementedError(f"`{module.__class__.__name__}` does not support `get_input_channels`")
65 changes: 65 additions & 0 deletions src/super_gradients/modules/weight_replacement_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Optional, Callable

import torch
from torch import nn

__all__ = ["replace_conv2d_input_channels", "replace_conv2d_input_channels_with_random_weights"]


def replace_conv2d_input_channels(conv: nn.Conv2d, in_channels: int, fn: Optional[Callable[[nn.Conv2d, int], nn.Conv2d]] = None) -> nn.Module:
"""Instantiate a new Conv2d module with same attributes as input Conv2d, except for the input channels.
:param conv: Conv2d to replace the input channels in.
:param in_channels: New number of input channels.
:param fn: (Optional) Function to instantiate the new Conv2d.
By default, it will initialize the new weights with the same mean and std as the original weights.
:return: Conv2d with new number of input channels.
"""
if fn:
return fn(conv, in_channels)
else:
return replace_conv2d_input_channels_with_random_weights(conv=conv, in_channels=in_channels)


def replace_conv2d_input_channels_with_random_weights(conv: nn.Conv2d, in_channels: int) -> nn.Conv2d:
"""
Replace the input channels in the input Conv2d with random weights.
Returned module will have the same device and dtype as the original module.
Random weights are initialized with the same mean and std as the original weights.
:param conv: Conv2d to replace the input channels in.
:param in_channels: New number of input channels.
:return: Conv2d with new number of input channels.
"""

if in_channels % conv.groups != 0:
raise ValueError(
f"Incompatible number of input channels ({in_channels}) with the number of groups ({conv.groups})."
f"The number of input channels must be divisible by the number of groups."
)

new_conv = nn.Conv2d(
in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=conv.bias is not None,
device=conv.weight.device,
dtype=conv.weight.dtype,
)

if in_channels <= conv.in_channels:
new_conv.weight.data = conv.weight.data[:, :in_channels, ...]
else:
new_conv.weight.data[:, : conv.in_channels, ...] = conv.weight.data

# Pad the remaining channels with random weights
torch.nn.init.normal_(new_conv.weight.data[:, conv.in_channels :, ...], mean=conv.weight.mean().item(), std=conv.weight.std().item())

if conv.bias is not None:
torch.nn.init.normal_(new_conv.bias, mean=conv.bias.mean().item(), std=conv.bias.std().item())

return new_conv
2 changes: 2 additions & 0 deletions src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
checkpoint_path=cfg.student_checkpoint_params.checkpoint_path,
load_backbone=cfg.student_checkpoint_params.load_backbone,
checkpoint_num_classes=get_param(cfg.student_checkpoint_params, "checkpoint_num_classes"),
num_input_channels=get_param(cfg.student_arch_params, "num_input_channels"),
)

teacher = models.get(
Expand All @@ -87,6 +88,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
checkpoint_path=cfg.teacher_checkpoint_params.checkpoint_path,
load_backbone=cfg.teacher_checkpoint_params.load_backbone,
checkpoint_num_classes=get_param(cfg.teacher_checkpoint_params, "checkpoint_num_classes"),
num_input_channels=get_param(cfg.teacher_arch_params, "num_input_channels"),
)

recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
Expand Down
12 changes: 10 additions & 2 deletions src/super_gradients/training/models/classification_models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# --------------------------------------------------------'
import math
from functools import partial
from typing import Optional, Tuple
from typing import Optional, Tuple, Callable

import torch
import torch.nn as nn
Expand Down Expand Up @@ -322,7 +322,9 @@ def __init__(
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.grad_checkpointing = False

self.patch_embed = PatchEmbed(img_size=image_size, patch_size=patch_size, in_channels=in_chans, hidden_dim=embed_dim)
self.image_size = image_size
self.patch_size = patch_size
self.patch_embed = PatchEmbed(img_size=self.image_size, patch_size=self.patch_size, in_channels=in_chans, hidden_dim=self.embed_dim)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
Expand Down Expand Up @@ -453,6 +455,12 @@ def replace_head(self, new_num_classes=None, new_head=None):
else:
self.head = nn.Linear(self.head.in_features, new_num_classes)

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
self.patch_embed = PatchEmbed(img_size=self.image_size, patch_size=self.patch_size, in_channels=in_channels, hidden_dim=self.embed_dim)

def get_input_channels(self) -> int:
return self.patch_embed.get_input_channels()


@register_model(Models.BEIT_BASE_PATCH16_224)
class BeitBasePatch16_224(Beit):
Expand Down
Loading

0 comments on commit 2ce24a7

Please sign in to comment.