diff --git a/src/super_gradients/common/factories/base_factory.py b/src/super_gradients/common/factories/base_factory.py index 6edab85fed..ab81ff97ee 100644 --- a/src/super_gradients/common/factories/base_factory.py +++ b/src/super_gradients/common/factories/base_factory.py @@ -1,6 +1,7 @@ from typing import Union, Mapping, Dict from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException +from super_gradients.common.registry.registry import warn_if_deprecated from super_gradients.training.utils.utils import fuzzy_str, fuzzy_keys, get_fuzzy_mapping_param @@ -43,6 +44,7 @@ def get(self, conf: Union[str, dict]): If provided value is not one of the three above, the value will be returned as is """ if isinstance(conf, str): + warn_if_deprecated(name=conf, registry=self.type_dict) if conf in self.type_dict: return self.type_dict[conf]() elif fuzzy_str(conf) in fuzzy_keys(self.type_dict): @@ -60,6 +62,7 @@ def get(self, conf: Union[str, dict]): _type = list(conf.keys())[0] # THE TYPE NAME _params = list(conf.values())[0] # A DICT CONTAINING THE PARAMETERS FOR INIT if _type in self.type_dict: + warn_if_deprecated(name=_type, registry=self.type_dict) return self.type_dict[_type](**_params) elif fuzzy_str(_type) in fuzzy_keys(self.type_dict): return get_fuzzy_mapping_param(_type, self.type_dict)(**_params) diff --git a/src/super_gradients/common/factories/type_factory.py b/src/super_gradients/common/factories/type_factory.py index aa4f368463..222509c080 100644 --- a/src/super_gradients/common/factories/type_factory.py +++ b/src/super_gradients/common/factories/type_factory.py @@ -3,7 +3,7 @@ import importlib from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException -from super_gradients.common.factories.base_factory import AbstractFactory +from super_gradients.common.factories.base_factory import AbstractFactory, warn_if_deprecated from super_gradients.training.utils import get_param @@ -32,6 +32,9 @@ def get(self, conf: Union[str, type]): If provided value is already a class type, the value will be returned as is. """ if isinstance(conf, str) or isinstance(conf, bool): + if isinstance(conf, str): + warn_if_deprecated(name=conf, registry=self.type_dict) + if conf in self.type_dict: return self.type_dict[conf] elif isinstance(conf, str) and get_param(self.type_dict, conf) is not None: diff --git a/src/super_gradients/common/registry/registry.py b/src/super_gradients/common/registry/registry.py index 95da751cb6..f00f0a3193 100644 --- a/src/super_gradients/common/registry/registry.py +++ b/src/super_gradients/common/registry/registry.py @@ -1,5 +1,6 @@ import inspect from typing import Callable, Dict, Optional +import warnings import torch from torch import nn, optim @@ -7,6 +8,8 @@ from super_gradients.common.object_names import Losses, Transforms, Samplers, Optimizers +_DEPRECATED_KEY = "_deprecated_objects" + def create_register_decorator(registry: Dict[str, Callable]) -> Callable: """ @@ -16,23 +19,41 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable: :return: Register function """ - def register(name: Optional[str] = None) -> Callable: + def register(name: Optional[str] = None, deprecated_name: Optional[str] = None) -> Callable: """ Set up a register decorator. - :param name: If specified, the decorated object will be registered with this name. - :return: Decorator that registers the callable. + :param name: If specified, the decorated object will be registered with this name. Otherwise, the class name will be used to register. + :param deprecated_name: If specified, the decorated object will be registered with this name. + This is done on top of the `official` registration which is done by setting the `name` argument. + :return: Decorator that registers the callable. """ def decorator(cls: Callable) -> Callable: """Register the decorated callable""" - cls_name = name if name is not None else cls.__name__ - if cls_name in registry: - ref = registry[cls_name] - raise Exception(f"`{cls_name}` is already registered and points to `{inspect.getmodule(ref).__name__}.{ref.__name__}") + def _registered_cls(registration_name: str): + if registration_name in registry: + registered_cls = registry[registration_name] + if registered_cls != cls: + raise Exception( + f"`{registration_name}` is already registered and points to `{inspect.getmodule(registered_cls).__name__}.{registered_cls.__name__}" + ) + registry[registration_name] = cls + + registration_name = name or cls.__name__ + _registered_cls(registration_name=registration_name) + + if deprecated_name: + # Deprecated objects like other objects - This is meant to avoid any breaking change. + _registered_cls(registration_name=deprecated_name) + + # But deprecated objects are also listed in the _deprecated_objects key. + # This can later be used in the factories to know if a name is deprecated and how it should be named instead. + deprecated_registered_objects = registry.get(_DEPRECATED_KEY, {}) + deprecated_registered_objects[deprecated_name] = registration_name # Keep the information about how it should be named. + registry[_DEPRECATED_KEY] = deprecated_registered_objects - registry[cls_name] = cls return cls return decorator @@ -40,6 +61,16 @@ def decorator(cls: Callable) -> Callable: return register +def warn_if_deprecated(name: str, registry: dict): + """If the name is deprecated, warn the user about it. + :param name: The name of the object that we want to check if it is deprecated. + :param registry: The registry that may or may not include deprecated objects. + """ + deprecated_names = registry.get(_DEPRECATED_KEY, {}) + if name in deprecated_names: + warnings.warn(f"Using `{name}` in the recipe has been deprecated. Please use `{deprecated_names[name]}`", DeprecationWarning) + + ARCHITECTURES = {} register_model = create_register_decorator(registry=ARCHITECTURES) diff --git a/tests/unit_tests/test_registry.py b/tests/unit_tests/test_registry.py new file mode 100644 index 0000000000..ac8f237d26 --- /dev/null +++ b/tests/unit_tests/test_registry.py @@ -0,0 +1,47 @@ +import unittest +from typing import List + +from super_gradients.common.registry.registry import create_register_decorator +from super_gradients.common.factories.base_factory import BaseFactory, UnknownTypeException + + +class RegistryTest(unittest.TestCase): + def setUp(self) -> None: + # We do all the registration in `setUp` to avoid having registration ran on import + _DUMMY_REGISTRY = {} + register_class = create_register_decorator(registry=_DUMMY_REGISTRY) + + @register_class("good_object_name") + class Class1: + def __init__(self, values: List[float]): + self.values = values + + @register_class(deprecated_name="deprecated_object_name") + class Class2: + def __init__(self, values: List[float]): + self.values = values + + self.Class1 = Class1 # Save classes, not instances + self.Class2 = Class2 + self.factory = BaseFactory(type_dict=_DUMMY_REGISTRY) + + def test_instantiate_from_name(self): + instance = self.factory.get({"good_object_name": {"values": [1.0, 2.0]}}) + self.assertIsInstance(instance, self.Class1) + + def test_instantiate_from_classname_when_name_set(self): + with self.assertRaises(UnknownTypeException): + self.factory.get({"Class1": {"values": [1.0, 2.0]}}) + + def test_instantiate_from_classname_when_no_name_set(self): + instance = self.factory.get({"Class2": {"values": [1.0, 2.0]}}) + self.assertIsInstance(instance, self.Class2) + + def test_instantiate_from_deprecated_name(self): + with self.assertWarns(DeprecationWarning): + instance = self.factory.get({"deprecated_object_name": {"values": [1.0, 2.0]}}) + self.assertIsInstance(instance, self.Class2) + + +if __name__ == "__main__": + unittest.main()