Skip to content

Commit

Permalink
Adding a deprecation option into registeries (#1421)
Browse files Browse the repository at this point in the history
* first draft of registries with deprecate - still need to change how we register all the classes one at a time///

* docstring
  • Loading branch information
Louis-Dupont committed Sep 4, 2023
1 parent 492aef7 commit 4c3fea4
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/super_gradients/common/factories/base_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union, Mapping, Dict

from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
from super_gradients.common.registry.registry import warn_if_deprecated
from super_gradients.training.utils.utils import fuzzy_str, fuzzy_keys, get_fuzzy_mapping_param


Expand Down Expand Up @@ -43,6 +44,7 @@ def get(self, conf: Union[str, dict]):
If provided value is not one of the three above, the value will be returned as is
"""
if isinstance(conf, str):
warn_if_deprecated(name=conf, registry=self.type_dict)
if conf in self.type_dict:
return self.type_dict[conf]()
elif fuzzy_str(conf) in fuzzy_keys(self.type_dict):
Expand All @@ -60,6 +62,7 @@ def get(self, conf: Union[str, dict]):
_type = list(conf.keys())[0] # THE TYPE NAME
_params = list(conf.values())[0] # A DICT CONTAINING THE PARAMETERS FOR INIT
if _type in self.type_dict:
warn_if_deprecated(name=_type, registry=self.type_dict)
return self.type_dict[_type](**_params)
elif fuzzy_str(_type) in fuzzy_keys(self.type_dict):
return get_fuzzy_mapping_param(_type, self.type_dict)(**_params)
Expand Down
5 changes: 4 additions & 1 deletion src/super_gradients/common/factories/type_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib

from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
from super_gradients.common.factories.base_factory import AbstractFactory
from super_gradients.common.factories.base_factory import AbstractFactory, warn_if_deprecated
from super_gradients.training.utils import get_param


Expand Down Expand Up @@ -32,6 +32,9 @@ def get(self, conf: Union[str, type]):
If provided value is already a class type, the value will be returned as is.
"""
if isinstance(conf, str) or isinstance(conf, bool):
if isinstance(conf, str):
warn_if_deprecated(name=conf, registry=self.type_dict)

if conf in self.type_dict:
return self.type_dict[conf]
elif isinstance(conf, str) and get_param(self.type_dict, conf) is not None:
Expand Down
47 changes: 39 additions & 8 deletions src/super_gradients/common/registry/registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import inspect
from typing import Callable, Dict, Optional
import warnings

import torch
from torch import nn, optim
import torchvision

from super_gradients.common.object_names import Losses, Transforms, Samplers, Optimizers

_DEPRECATED_KEY = "_deprecated_objects"


def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
"""
Expand All @@ -16,30 +19,58 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
:return: Register function
"""

def register(name: Optional[str] = None) -> Callable:
def register(name: Optional[str] = None, deprecated_name: Optional[str] = None) -> Callable:
"""
Set up a register decorator.
:param name: If specified, the decorated object will be registered with this name.
:return: Decorator that registers the callable.
:param name: If specified, the decorated object will be registered with this name. Otherwise, the class name will be used to register.
:param deprecated_name: If specified, the decorated object will be registered with this name.
This is done on top of the `official` registration which is done by setting the `name` argument.
:return: Decorator that registers the callable.
"""

def decorator(cls: Callable) -> Callable:
"""Register the decorated callable"""
cls_name = name if name is not None else cls.__name__

if cls_name in registry:
ref = registry[cls_name]
raise Exception(f"`{cls_name}` is already registered and points to `{inspect.getmodule(ref).__name__}.{ref.__name__}")
def _registered_cls(registration_name: str):
if registration_name in registry:
registered_cls = registry[registration_name]
if registered_cls != cls:
raise Exception(
f"`{registration_name}` is already registered and points to `{inspect.getmodule(registered_cls).__name__}.{registered_cls.__name__}"
)
registry[registration_name] = cls

registration_name = name or cls.__name__
_registered_cls(registration_name=registration_name)

if deprecated_name:
# Deprecated objects like other objects - This is meant to avoid any breaking change.
_registered_cls(registration_name=deprecated_name)

# But deprecated objects are also listed in the _deprecated_objects key.
# This can later be used in the factories to know if a name is deprecated and how it should be named instead.
deprecated_registered_objects = registry.get(_DEPRECATED_KEY, {})
deprecated_registered_objects[deprecated_name] = registration_name # Keep the information about how it should be named.
registry[_DEPRECATED_KEY] = deprecated_registered_objects

registry[cls_name] = cls
return cls

return decorator

return register


def warn_if_deprecated(name: str, registry: dict):
"""If the name is deprecated, warn the user about it.
:param name: The name of the object that we want to check if it is deprecated.
:param registry: The registry that may or may not include deprecated objects.
"""
deprecated_names = registry.get(_DEPRECATED_KEY, {})
if name in deprecated_names:
warnings.warn(f"Using `{name}` in the recipe has been deprecated. Please use `{deprecated_names[name]}`", DeprecationWarning)


ARCHITECTURES = {}
register_model = create_register_decorator(registry=ARCHITECTURES)

Expand Down
47 changes: 47 additions & 0 deletions tests/unit_tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import unittest
from typing import List

from super_gradients.common.registry.registry import create_register_decorator
from super_gradients.common.factories.base_factory import BaseFactory, UnknownTypeException


class RegistryTest(unittest.TestCase):
def setUp(self) -> None:
# We do all the registration in `setUp` to avoid having registration ran on import
_DUMMY_REGISTRY = {}
register_class = create_register_decorator(registry=_DUMMY_REGISTRY)

@register_class("good_object_name")
class Class1:
def __init__(self, values: List[float]):
self.values = values

@register_class(deprecated_name="deprecated_object_name")
class Class2:
def __init__(self, values: List[float]):
self.values = values

self.Class1 = Class1 # Save classes, not instances
self.Class2 = Class2
self.factory = BaseFactory(type_dict=_DUMMY_REGISTRY)

def test_instantiate_from_name(self):
instance = self.factory.get({"good_object_name": {"values": [1.0, 2.0]}})
self.assertIsInstance(instance, self.Class1)

def test_instantiate_from_classname_when_name_set(self):
with self.assertRaises(UnknownTypeException):
self.factory.get({"Class1": {"values": [1.0, 2.0]}})

def test_instantiate_from_classname_when_no_name_set(self):
instance = self.factory.get({"Class2": {"values": [1.0, 2.0]}})
self.assertIsInstance(instance, self.Class2)

def test_instantiate_from_deprecated_name(self):
with self.assertWarns(DeprecationWarning):
instance = self.factory.get({"deprecated_object_name": {"values": [1.0, 2.0]}})
self.assertIsInstance(instance, self.Class2)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4c3fea4

Please sign in to comment.