Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deprecate module #1416

Merged
merged 17 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/super_gradients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__version__ = "3.2.0"

shaydeci marked this conversation as resolved.
Show resolved Hide resolved
from super_gradients.common import init_trainer, is_distributed, object_names
from super_gradients.training import losses, utils, datasets_utils, DataAugmentation, Trainer, KDTrainer, QATTrainer
from super_gradients.common.registry.registry import ARCHITECTURES
Expand All @@ -23,6 +25,4 @@
"AutoTrainBatchSizeSelectionCallback",
]

__version__ = "3.2.0"

env_sanity_check()
78 changes: 78 additions & 0 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import warnings
from functools import wraps
from typing import Optional
from pkg_resources import parse_version


def deprecated(deprecated_in_v: str, remove_in_v: str, target: Optional[callable] = None, reason: str = ""):
"""
Decorator to mark a callable as deprecated. Works on functions and classes.
It provides a clear and actionable warning message informing
the user about the version in which the function was deprecated, the version in which it will be removed,
and guidance on how to replace it.

:param deprecated_in_v: Version number when the function was deprecated.
:param remove_in_v: Version number when the function will be removed.
:param target: (Optional) The new function that should be used as a replacement. If provided, it will guide the user to the updated function.
:param reason: (Optional) Additional information or reason for the deprecation.

Example usage:
If a direct replacement function exists:
>> from new.module.path import new_get_local_rank

>> @deprecated(deprecated_in_v='3.2.0', remove_in_v='4.0.0', target=new_get_local_rank, reason="Replaced for optimization")
>> def get_local_rank():
>> return new_get_local_rank()

If there's no direct replacement:
>> @deprecated(deprecated_in_v='3.2.0', remove_in_v='4.0.0', reason="Function is no longer needed due to XYZ reason")
>> def some_old_function():
>> # ... function logic ...

When calling a deprecated function:
>> from some_module import get_local_rank
>> get_local_rank()
DeprecationWarning: Function `some_module.get_local_rank` is deprecated. Deprecated since version `3.2.0`
and will be removed in version `4.0.0`. Reason: `Replaced for optimization`.
Please update your code:
[-] from `some_module` import `get_local_rank`
[+] from `new.module.path` import `new_get_local_rank`.
"""

def decorator(old_func: callable) -> callable:
@wraps(old_func)
def wrapper(*args, **kwargs):
if not wrapper._warned:
import super_gradients

is_still_supported = parse_version(super_gradients.__version__) < parse_version(remove_in_v)
status_msg = "is deprecated" if is_still_supported else "was deprecated and has been removed"
message = (
f"Callable `{old_func.__module__}.{old_func.__name__}` {status_msg} since version `{deprecated_in_v}` "
f"and will be removed in version `{remove_in_v}`.\n"
)
if reason:
message += f"Reason: {reason}.\n"

if target is not None:
message += (
f"Please update your code:\n"
f" [-] from `{old_func.__module__}` import `{old_func.__name__}`\n"
f" [+] from `{target.__module__}` import `{target.__name__}`"
)

if is_still_supported:
warnings.simplefilter("once", DeprecationWarning) # Required, otherwise the warning may never be displayed.
warnings.warn(message, DeprecationWarning, stacklevel=2)
wrapper._warned = True
else:
raise ImportError(message)

return old_func(*args, **kwargs)

# Each decorated object will have its own _warned state
# This state ensures that the warning will appear only once, to avoid polluting the console in case the function is called too often.
wrapper._warned = False
return wrapper

return decorator
57 changes: 17 additions & 40 deletions src/super_gradients/training/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import warnings
from super_gradients.common.deprecate import deprecated

from .sg_module import SgModule
from .classification_models.base_classifer import BaseClassifier
Expand Down Expand Up @@ -135,51 +135,28 @@
from super_gradients.training.utils import make_divisible as _make_divisible_current_version, HpmStruct as CurrVersionHpmStruct


def make_deprecated(func, reason):
def inner(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("once", DeprecationWarning)
warnings.warn(reason, category=DeprecationWarning, stacklevel=2)
warnings.warn(reason, DeprecationWarning)
return func(*args, **kwargs)
@deprecated(deprecated_in_v="3.1.0", remove_in_v="3.4.0", target=_make_divisible_current_version)
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
def make_divisible(x: int, divisor: int, ceil: bool = True) -> int:
"""
Returns x evenly divisible by divisor.
If ceil=True it will return the closest larger number to the original x, and ceil=False the closest smaller number.
"""
return _make_divisible_current_version(x=x, divisor=divisor, ceil=ceil)

return inner

@deprecated(deprecated_in_v="3.1.0", remove_in_v="3.4.0", target=BasicResNetBlock, reason="This block was renamed to BasicResNetBlock for better clarity.")
class BasicBlock(BasicResNetBlock):
...

make_divisible = make_deprecated(
func=_make_divisible_current_version,
reason="You're importing `make_divisible` from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n"
"Please update your code to import it as follows:\n"
"[-] from super_gradients.training.models import make_divisible\n"
"[+] from super_gradients.training.utils import make_divisible\n",
)

@deprecated(deprecated_in_v="3.1.0", remove_in_v="3.4.0", target=NewBottleneck, reason="This block was renamed to BasicResNetBlock for better clarity.")
class Bottleneck(NewBottleneck):
...

BasicBlock = make_deprecated(
func=BasicResNetBlock,
reason="You're importing `BasicBlock` class from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n"
"This block was renamed to BasicResNetBlock for better clarity.\n"
"Please update your code to import it as follows:\n"
"[-] from super_gradients.training.models import BasicBlock\n"
"[+] from super_gradients.training.models import BasicResNetBlock\n",
)

Bottleneck = make_deprecated(
func=NewBottleneck,
reason="You're importing `Bottleneck` class from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n"
"This block was renamed to BasicResNetBlock for better clarity.\n"
"Please update your code to import it as follows:\n"
"[-] from super_gradients.training.models import Bottleneck\n"
"[+] from super_gradients.training.models.classification_models.resnet import Bottleneck\n",
)

HpmStruct = make_deprecated(
func=CurrVersionHpmStruct,
reason="You're importing `HpmStruct` class from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n"
"Please update your code to import it as follows:\n"
"[-] from super_gradients.training.models import HpmStruct\n"
"[+] from super_gradients.training.utils import HpmStruct\n",
)
@deprecated(deprecated_in_v="3.1.0", remove_in_v="3.4.0", target=CurrVersionHpmStruct)
class HpmStruct(CurrVersionHpmStruct):
...


__all__ = [
Expand Down
2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TestTransforms,
TestPostPredictionCallback,
TestModelPredict,
TestDeprecationDecorator,
)
from tests.end_to_end_tests import TestTrainer
from tests.unit_tests.detection_utils_test import TestDetectionUtils
Expand Down Expand Up @@ -153,6 +154,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelPredict))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionModelExport))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(SlidingWindowTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDeprecationDecorator))

def _add_modules_to_end_to_end_tests_suite(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tests.unit_tests.transforms_test import TestTransforms
from tests.unit_tests.post_prediction_callback_test import TestPostPredictionCallback
from tests.unit_tests.test_predict import TestModelPredict
from tests.unit_tests.test_deprecate import TestDeprecationDecorator

__all__ = [
"CrashTipTest",
Expand Down Expand Up @@ -53,4 +54,5 @@
"TestTransforms",
"TestPostPredictionCallback",
"TestModelPredict",
"TestDeprecationDecorator",
]
132 changes: 132 additions & 0 deletions tests/unit_tests/test_deprecate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import warnings
import unittest
from unittest.mock import patch

from super_gradients.common.deprecate import deprecated


class TestDeprecationDecorator(unittest.TestCase):
def setUp(self):
"""Prepare required functions before each test."""
self.new_function_message = "This is the new function!"

def new_func():
return self.new_function_message

@deprecated(deprecated_in_v="3.2.0", remove_in_v="10.0.0", target=new_func, reason="Replaced for optimization")
def fully_configured_deprecated_func():
return new_func()

@deprecated(deprecated_in_v="3.2.0", remove_in_v="10.0.0")
def basic_deprecated_func():
return new_func()

self.new_func = new_func
self.fully_configured_deprecated_func = fully_configured_deprecated_func
self.basic_deprecated_func = basic_deprecated_func

class NewClass:
def __init__(self):
pass

@deprecated(deprecated_in_v="3.2.0", remove_in_v="10.0.0", target=NewClass, reason="Replaced for optimization")
class DeprecatedClass:
def __init__(self):
pass

self.NewClass = NewClass
self.DeprecatedClass = DeprecatedClass

def test_emits_warning(self):
"""Ensure that the deprecated function emits a warning when called."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertEqual(len(w), 1)

def test_displays_deprecated_version(self):
"""Ensure that the warning contains the version in which the function was deprecated."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertTrue(any("3.2.0" in str(warning.message) for warning in w))

def test_displays_removed_version(self):
"""Ensure that the warning contains the version in which the function will be removed."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertTrue(any("10.0.0" in str(warning.message) for warning in w))

def test_guidance_on_replacement(self):
"""Ensure that if a replacement target is provided, guidance on using the new function is included in the warning."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertTrue(any("new_func" in str(warning.message) for warning in w))

def test_displays_reason(self):
"""Ensure that if provided, the reason for deprecation is included in the warning."""
reason_str = "Replaced for optimization"
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertTrue(any(reason_str in str(warning.message) for warning in w))

def test_triggered_only_once(self):
"""Ensure that the deprecation warning is triggered only once even if the deprecated function is called multiple times."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
for _ in range(10):
self.fully_configured_deprecated_func()
self.assertEqual(len(w), 1, "Only one warning should be emitted")

def test_basic_deprecation_emits_warning(self):
"""Ensure that a function with minimal deprecation configuration emits a warning."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.basic_deprecated_func()
self.assertEqual(len(w), 1)

def test_class_deprecation_warning(self):
"""Ensure that creating an instance of a deprecated class emits a warning."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_ = self.DeprecatedClass() # Instantiate the deprecated class
self.assertEqual(len(w), 1)

def test_class_deprecation_message_content(self):
"""Ensure that the emitted warning for a deprecated class contains relevant information including target class."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_ = self.DeprecatedClass()
self.assertTrue(any("3.2.0" in str(warning.message) for warning in w))
self.assertTrue(any("10.0.0" in str(warning.message) for warning in w))
self.assertTrue(any("DeprecatedClass" in str(warning.message) for warning in w))
self.assertTrue(any("Replaced for optimization" in str(warning.message) for warning in w))
self.assertTrue(any("NewClass" in str(warning.message) for warning in w))

def test_raise_error_when_library_version_equals_removal_version(self):
"""Ensure that an error is raised when the library's version equals the function's removal version."""
with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be equal to removal version
with self.assertRaises(ImportError):

@deprecated(deprecated_in_v="3.2.0", remove_in_v="10.1.0", target=self.new_func)
def deprecated_func_version_equal():
return

deprecated_func_version_equal()

def test_no_error_when_library_version_below_removal_version(self):
"""Ensure that no error is raised when the library's version is below the function's removal version."""
with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be below removal version

@deprecated(deprecated_in_v="3.2.0", remove_in_v="10.2.0", target=self.new_func)
def deprecated_func_version_below():
return

deprecated_func_version_below()


if __name__ == "__main__":
unittest.main()