Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Dupont committed Aug 27, 2023
1 parent 7240726 commit 1ff5418
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import nn

from super_gradients.common.factories.base_factory import AbstractFactory
from super_gradients.common.factories.base_factory import AbstractFactory, maybe_deprecate_registered_name
from super_gradients.training.utils.activations_utils import get_builtin_activation_type


Expand All @@ -26,11 +26,13 @@ def get(self, conf: Union[str, Mapping, Type[nn.Module]]) -> Type[nn.Module]:
If provided value is not one of the three above, the value will be returned as is
"""
if isinstance(conf, str):
maybe_deprecate_registered_name(conf)
return get_builtin_activation_type(conf)

if isinstance(conf, Mapping):
(type_name,) = list(conf.keys())
type_args = conf[type_name]
maybe_deprecate_registered_name(type_name)
return get_builtin_activation_type(type_name, **type_args)

if issubclass(conf, nn.Module):
Expand Down
9 changes: 9 additions & 0 deletions src/super_gradients/common/factories/base_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import warnings
from typing import Union, Mapping, Dict

from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
from super_gradients.training.utils.utils import fuzzy_str, fuzzy_keys, get_fuzzy_mapping_param
from super_gradients.common.registry.registry import _DEPRECATED_REGISTRIES


def maybe_deprecate_registered_name(conf: str):
if conf in _DEPRECATED_REGISTRIES:
warnings.warn(f"Using `{conf}` in the recipe has been deprecated. Please use `{_DEPRECATED_REGISTRIES[conf]}`", DeprecationWarning)


class AbstractFactory:
Expand Down Expand Up @@ -43,6 +50,7 @@ def get(self, conf: Union[str, dict]):
If provided value is not one of the three above, the value will be returned as is
"""
if isinstance(conf, str):
maybe_deprecate_registered_name(conf)
if conf in self.type_dict:
return self.type_dict[conf]()
elif fuzzy_str(conf) in fuzzy_keys(self.type_dict):
Expand All @@ -60,6 +68,7 @@ def get(self, conf: Union[str, dict]):
_type = list(conf.keys())[0] # THE TYPE NAME
_params = list(conf.values())[0] # A DICT CONTAINING THE PARAMETERS FOR INIT
if _type in self.type_dict:
maybe_deprecate_registered_name(_type)
return self.type_dict[_type](**_params)
elif fuzzy_str(_type) in fuzzy_keys(self.type_dict):
return get_fuzzy_mapping_param(_type, self.type_dict)(**_params)
Expand Down
5 changes: 4 additions & 1 deletion src/super_gradients/common/factories/type_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib

from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
from super_gradients.common.factories.base_factory import AbstractFactory
from super_gradients.common.factories.base_factory import AbstractFactory, maybe_deprecate_registered_name
from super_gradients.training.utils import get_param


Expand Down Expand Up @@ -32,6 +32,9 @@ def get(self, conf: Union[str, type]):
If provided value is already a class type, the value will be returned as is.
"""
if isinstance(conf, str) or isinstance(conf, bool):
if isinstance(conf, str):
maybe_deprecate_registered_name(conf)

if conf in self.type_dict:
return self.type_dict[conf]
elif isinstance(conf, str) and get_param(self.type_dict, conf) is not None:
Expand Down
74 changes: 45 additions & 29 deletions src/super_gradients/common/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
from super_gradients.common.object_names import Losses, Transforms, Samplers, Optimizers


def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
_DEPRECATED_REGISTRIES = dict()


def create_register_decorator(registry: Dict[str, Callable], deprecate_custom_name: bool) -> Callable:
"""
Create a decorator that registers object of specified type (model, metric, ...)
:param registry: Dict including registered objects (maps name to object that you register)
:return: Register function
:param registry: Dict including registered objects (maps name to object that you register)
:param deprecate_custom_name: If True, deprecate instantiating objects with the custom name.
Instead, it should be instantiated with the name of the class.
:return: Register function
"""

def register(name: Optional[str] = None) -> Callable:
Expand All @@ -26,13 +31,25 @@ def register(name: Optional[str] = None) -> Callable:

def decorator(cls: Callable) -> Callable:
"""Register the decorated callable"""
cls_name = name if name is not None else cls.__name__

if cls_name in registry:
ref = registry[cls_name]
raise Exception(f"`{cls_name}` is already registered and points to `{inspect.getmodule(ref).__name__}.{ref.__name__}")
def _register_class(class_name: str):
if class_name in registry:
registered_cls = registry[class_name]
if registered_cls != cls:
raise Exception(
f"`{class_name}` is already registered and points to `{inspect.getmodule(registered_cls).__name__}.{registered_cls.__name__}"
)
registry[class_name] = cls

if name is None:
_register_class(class_name=cls.__name__)
elif name == cls.__name__ or not deprecate_custom_name:
_register_class(class_name=name)
else:
_DEPRECATED_REGISTRIES[name] = cls.__name__
_register_class(class_name=name)
_register_class(class_name=cls.__name__)

registry[cls_name] = cls
return cls

return decorator
Expand All @@ -41,26 +58,25 @@ def decorator(cls: Callable) -> Callable:


ARCHITECTURES = {}
register_model = create_register_decorator(registry=ARCHITECTURES)
register_model = create_register_decorator(registry=ARCHITECTURES, deprecate_custom_name=True)

KD_ARCHITECTURES = {}
register_kd_model = create_register_decorator(registry=KD_ARCHITECTURES)
register_kd_model = create_register_decorator(registry=KD_ARCHITECTURES, deprecate_custom_name=True)

ALL_DETECTION_MODULES = {}
register_detection_module = create_register_decorator(registry=ALL_DETECTION_MODULES)
register_detection_module = create_register_decorator(registry=ALL_DETECTION_MODULES, deprecate_custom_name=True)

METRICS = {}
register_metric = create_register_decorator(registry=METRICS)
register_metric = create_register_decorator(registry=METRICS, deprecate_custom_name=True)

LOSSES = {Losses.MSE: nn.MSELoss}
register_loss = create_register_decorator(registry=LOSSES)

register_loss = create_register_decorator(registry=LOSSES, deprecate_custom_name=True)

ALL_DATALOADERS = {}
register_dataloader = create_register_decorator(registry=ALL_DATALOADERS)
register_dataloader = create_register_decorator(registry=ALL_DATALOADERS, deprecate_custom_name=False) # Dataloaders need to work with custom names.

CALLBACKS = {}
register_callback = create_register_decorator(registry=CALLBACKS)
register_callback = create_register_decorator(registry=CALLBACKS, deprecate_custom_name=True)

TRANSFORMS = {
Transforms.Compose: torchvision.transforms.Compose,
Expand Down Expand Up @@ -99,34 +115,34 @@ def decorator(cls: Callable) -> Callable:
Transforms.RandomAutocontrast: torchvision.transforms.RandomAutocontrast,
Transforms.RandomEqualize: torchvision.transforms.RandomEqualize,
}
register_transform = create_register_decorator(registry=TRANSFORMS)
register_transform = create_register_decorator(registry=TRANSFORMS, deprecate_custom_name=True)

ALL_DATASETS = {}
register_dataset = create_register_decorator(registry=ALL_DATASETS)
register_dataset = create_register_decorator(registry=ALL_DATASETS, deprecate_custom_name=True)

ALL_PRE_LAUNCH_CALLBACKS = {}
register_pre_launch_callback = create_register_decorator(registry=ALL_PRE_LAUNCH_CALLBACKS)
register_pre_launch_callback = create_register_decorator(registry=ALL_PRE_LAUNCH_CALLBACKS, deprecate_custom_name=True)

BACKBONE_STAGES = {}
register_unet_backbone_stage = create_register_decorator(registry=BACKBONE_STAGES)
register_unet_backbone_stage = create_register_decorator(registry=BACKBONE_STAGES, deprecate_custom_name=True)

UP_FUSE_BLOCKS = {}
register_unet_up_block = create_register_decorator(registry=UP_FUSE_BLOCKS)
register_unet_up_block = create_register_decorator(registry=UP_FUSE_BLOCKS, deprecate_custom_name=True)

ALL_TARGET_GENERATORS = {}
register_target_generator = create_register_decorator(registry=ALL_TARGET_GENERATORS)
register_target_generator = create_register_decorator(registry=ALL_TARGET_GENERATORS, deprecate_custom_name=True)

LR_SCHEDULERS_CLS_DICT = {}
register_lr_scheduler = create_register_decorator(registry=LR_SCHEDULERS_CLS_DICT)
register_lr_scheduler = create_register_decorator(registry=LR_SCHEDULERS_CLS_DICT, deprecate_custom_name=True)

LR_WARMUP_CLS_DICT = {}
register_lr_warmup = create_register_decorator(registry=LR_WARMUP_CLS_DICT)
register_lr_warmup = create_register_decorator(registry=LR_WARMUP_CLS_DICT, deprecate_custom_name=True)

SG_LOGGERS = {}
register_sg_logger = create_register_decorator(registry=SG_LOGGERS)
register_sg_logger = create_register_decorator(registry=SG_LOGGERS, deprecate_custom_name=True)

ALL_COLLATE_FUNCTIONS = {}
register_collate_function = create_register_decorator(registry=ALL_COLLATE_FUNCTIONS)
register_collate_function = create_register_decorator(registry=ALL_COLLATE_FUNCTIONS, deprecate_custom_name=True)

SAMPLERS = {
Samplers.DISTRIBUTED: torch.utils.data.DistributedSampler,
Expand All @@ -135,7 +151,7 @@ def decorator(cls: Callable) -> Callable:
Samplers.RANDOM: torch.utils.data.RandomSampler,
Samplers.WEIGHTED_RANDOM: torch.utils.data.WeightedRandomSampler,
}
register_sampler = create_register_decorator(registry=SAMPLERS)
register_sampler = create_register_decorator(registry=SAMPLERS, deprecate_custom_name=True)


OPTIMIZERS = {
Expand All @@ -158,7 +174,7 @@ def decorator(cls: Callable) -> Callable:
"LinearLR": torch.optim.lr_scheduler.LinearLR,
}

register_optimizer = create_register_decorator(registry=OPTIMIZERS)
register_optimizer = create_register_decorator(registry=OPTIMIZERS, deprecate_custom_name=True)

PROCESSINGS = {}
register_processing = create_register_decorator(registry=PROCESSINGS)
register_processing = create_register_decorator(registry=PROCESSINGS, deprecate_custom_name=True)

0 comments on commit 1ff5418

Please sign in to comment.