Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a deprecation option into registeries #1421

Merged
merged 2 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/super_gradients/common/factories/base_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union, Mapping, Dict

from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
from super_gradients.common.registry.registry import warn_if_deprecated
from super_gradients.training.utils.utils import fuzzy_str, fuzzy_keys, get_fuzzy_mapping_param


Expand Down Expand Up @@ -43,6 +44,7 @@ def get(self, conf: Union[str, dict]):
If provided value is not one of the three above, the value will be returned as is
"""
if isinstance(conf, str):
warn_if_deprecated(name=conf, registry=self.type_dict)
if conf in self.type_dict:
return self.type_dict[conf]()
elif fuzzy_str(conf) in fuzzy_keys(self.type_dict):
Expand All @@ -60,6 +62,7 @@ def get(self, conf: Union[str, dict]):
_type = list(conf.keys())[0] # THE TYPE NAME
_params = list(conf.values())[0] # A DICT CONTAINING THE PARAMETERS FOR INIT
if _type in self.type_dict:
warn_if_deprecated(name=_type, registry=self.type_dict)
return self.type_dict[_type](**_params)
elif fuzzy_str(_type) in fuzzy_keys(self.type_dict):
return get_fuzzy_mapping_param(_type, self.type_dict)(**_params)
Expand Down
5 changes: 4 additions & 1 deletion src/super_gradients/common/factories/type_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib

from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
from super_gradients.common.factories.base_factory import AbstractFactory
from super_gradients.common.factories.base_factory import AbstractFactory, warn_if_deprecated
from super_gradients.training.utils import get_param


Expand Down Expand Up @@ -32,6 +32,9 @@ def get(self, conf: Union[str, type]):
If provided value is already a class type, the value will be returned as is.
"""
if isinstance(conf, str) or isinstance(conf, bool):
if isinstance(conf, str):
warn_if_deprecated(name=conf, registry=self.type_dict)

if conf in self.type_dict:
return self.type_dict[conf]
elif isinstance(conf, str) and get_param(self.type_dict, conf) is not None:
Expand Down
47 changes: 39 additions & 8 deletions src/super_gradients/common/registry/registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import inspect
from typing import Callable, Dict, Optional
import warnings

import torch
from torch import nn, optim
import torchvision

from super_gradients.common.object_names import Losses, Transforms, Samplers, Optimizers

_DEPRECATED_KEY = "_deprecated_objects"


def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
"""
Expand All @@ -16,30 +19,58 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
:return: Register function
"""

def register(name: Optional[str] = None) -> Callable:
def register(name: Optional[str] = None, deprecated_name: Optional[str] = None) -> Callable:
"""
Set up a register decorator.

:param name: If specified, the decorated object will be registered with this name.
:return: Decorator that registers the callable.
:param name: If specified, the decorated object will be registered with this name. Otherwise, the class name will be used to register.
:param deprecated_name: If specified, the decorated object will be registered with this name.
This is done on top of the `official` registration which is done by setting the `name` argument.
:return: Decorator that registers the callable.
"""

def decorator(cls: Callable) -> Callable:
"""Register the decorated callable"""
cls_name = name if name is not None else cls.__name__

if cls_name in registry:
ref = registry[cls_name]
raise Exception(f"`{cls_name}` is already registered and points to `{inspect.getmodule(ref).__name__}.{ref.__name__}")
def _registered_cls(registration_name: str):
if registration_name in registry:
registered_cls = registry[registration_name]
if registered_cls != cls:
raise Exception(
f"`{registration_name}` is already registered and points to `{inspect.getmodule(registered_cls).__name__}.{registered_cls.__name__}"
)
registry[registration_name] = cls

registration_name = name or cls.__name__
_registered_cls(registration_name=registration_name)

if deprecated_name:
# Deprecated objects like other objects - This is meant to avoid any breaking change.
_registered_cls(registration_name=deprecated_name)

# But deprecated objects are also listed in the _deprecated_objects key.
# This can later be used in the factories to know if a name is deprecated and how it should be named instead.
deprecated_registered_objects = registry.get(_DEPRECATED_KEY, {})
deprecated_registered_objects[deprecated_name] = registration_name # Keep the information about how it should be named.
registry[_DEPRECATED_KEY] = deprecated_registered_objects

registry[cls_name] = cls
return cls

return decorator

return register


def warn_if_deprecated(name: str, registry: dict):
"""If the name is deprecated, warn the user about it.
:param name: The name of the object that we want to check if it is deprecated.
:param registry: The registry that may or may not include deprecated objects.
"""
deprecated_names = registry.get(_DEPRECATED_KEY, {})
if name in deprecated_names:
warnings.warn(f"Using `{name}` in the recipe has been deprecated. Please use `{deprecated_names[name]}`", DeprecationWarning)


ARCHITECTURES = {}
register_model = create_register_decorator(registry=ARCHITECTURES)

Expand Down
47 changes: 47 additions & 0 deletions tests/unit_tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import unittest
from typing import List

from super_gradients.common.registry.registry import create_register_decorator
from super_gradients.common.factories.base_factory import BaseFactory, UnknownTypeException


class RegistryTest(unittest.TestCase):
def setUp(self) -> None:
# We do all the registration in `setUp` to avoid having registration ran on import
_DUMMY_REGISTRY = {}
register_class = create_register_decorator(registry=_DUMMY_REGISTRY)

@register_class("good_object_name")
class Class1:
def __init__(self, values: List[float]):
self.values = values

@register_class(deprecated_name="deprecated_object_name")
class Class2:
def __init__(self, values: List[float]):
self.values = values

self.Class1 = Class1 # Save classes, not instances
self.Class2 = Class2
self.factory = BaseFactory(type_dict=_DUMMY_REGISTRY)

def test_instantiate_from_name(self):
instance = self.factory.get({"good_object_name": {"values": [1.0, 2.0]}})
self.assertIsInstance(instance, self.Class1)

def test_instantiate_from_classname_when_name_set(self):
with self.assertRaises(UnknownTypeException):
self.factory.get({"Class1": {"values": [1.0, 2.0]}})

def test_instantiate_from_classname_when_no_name_set(self):
instance = self.factory.get({"Class2": {"values": [1.0, 2.0]}})
self.assertIsInstance(instance, self.Class2)

def test_instantiate_from_deprecated_name(self):
with self.assertWarns(DeprecationWarning):
instance = self.factory.get({"deprecated_object_name": {"values": [1.0, 2.0]}})
self.assertIsInstance(instance, self.Class2)


if __name__ == "__main__":
unittest.main()