Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deprecate module #1416

Merged
merged 17 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import warnings
from functools import wraps
from typing import Optional


def deprecate_call(deprecated_in_v: str, removed_in_v: str, target: Optional[callable] = None, reason: str = ""):
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
"""
Decorator to mark a callable as deprecated. It provides a clear and actionable warning message informing
the user about the version in which the function was deprecated, the version in which it will be removed,
and guidance on how to replace it.

:param deprecated_in_v: Version number when the function was deprecated.
:param removed_in_v: Version number when the function will be removed.
:param target: (Optional) The new function that should be used as a replacement. If provided, it will guide the user to the updated function.
:param reason: (Optional) Additional information or reason for the deprecation.

Example usage:
If a direct replacement function exists:
>> from new.module.path import new_get_local_rank

>> @deprecate_call(deprecated_in_v='3.2.0', removed_in_v='4.0.0', target=new_get_local_rank, reason="Replaced for optimization")
>> def get_local_rank():
>> return new_get_local_rank()

If there's no direct replacement:
>> @deprecate_call(deprecated_in_v='3.2.0', removed_in_v='4.0.0', reason="Function is no longer needed due to XYZ reason")
>> def some_old_function():
>> # ... function logic ...

When calling a deprecated function:
>> from some_module import get_local_rank
>> get_local_rank()
DeprecationWarning: Function `some_module.get_local_rank` is deprecated. Deprecated since version `3.2.0`
and will be removed in version `4.0.0`. Reason: `Replaced for optimization`.
Please update your code:
[-] from `some_module` import `get_local_rank`
[+] from `new.module.path` import `new_get_local_rank`.
"""

def decorator(old_func: callable) -> callable:
@wraps(old_func)
def wrapper(*args, **kwargs):
if not wrapper._warned:
message = (
f"Function `{old_func.__module__}.{old_func.__name__}` is deprecated since version `{deprecated_in_v}` "
f"and will be removed in version `{removed_in_v}`.\n"
)
if reason:
message += f"Reason: {reason}.\n"

if target is not None:
message += (
f"Please update your code:\n"
f" [-] from `{old_func.__module__}` import `{old_func.__name__}`\n"
f" [+] from `{target.__module__}` import `{target.__name__}`"
)

warnings.warn(message, DeprecationWarning, stacklevel=2)
wrapper._warned = True

return old_func(*args, **kwargs)

# Each decorated object will have its own _warned state
# This state ensures that the warning will appear only once, to avoid polluting the console in case the function is called too often.
wrapper._warned = False
return wrapper

return decorator


def make_deprecated(func, reason):
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
def inner(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("once", DeprecationWarning)
warnings.warn(reason, category=DeprecationWarning, stacklevel=2)
warnings.warn(reason, DeprecationWarning)
return func(*args, **kwargs)

return inner
13 changes: 1 addition & 12 deletions src/super_gradients/training/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import warnings
from super_gradients.common.deprecate import make_deprecated

from .sg_module import SgModule
from .classification_models.base_classifer import BaseClassifier
Expand Down Expand Up @@ -135,17 +135,6 @@
from super_gradients.training.utils import make_divisible as _make_divisible_current_version, HpmStruct as CurrVersionHpmStruct


def make_deprecated(func, reason):
def inner(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("once", DeprecationWarning)
warnings.warn(reason, category=DeprecationWarning, stacklevel=2)
warnings.warn(reason, DeprecationWarning)
return func(*args, **kwargs)

return inner


make_divisible = make_deprecated(
func=_make_divisible_current_version,
reason="You're importing `make_divisible` from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n"
Expand Down
79 changes: 79 additions & 0 deletions tests/unit_tests/test_deprecate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest
import warnings
from super_gradients.common.deprecate import deprecate_call


class TestDeprecationDecorator(unittest.TestCase):
def setUp(self):
"""Prepare required functions before each test."""
self.new_function_message = "This is the new function!"

def new_func():
return self.new_function_message

@deprecate_call(deprecated_in_v="3.2.0", removed_in_v="4.0.0", target=new_func, reason="Replaced for optimization")
def fully_configured_deprecated_func():
return new_func()

@deprecate_call(deprecated_in_v="3.2.0", removed_in_v="4.0.0")
def basic_deprecated_func():
return new_func()

self.new_func = new_func
self.fully_configured_deprecated_func = fully_configured_deprecated_func
self.basic_deprecated_func = basic_deprecated_func

def test_emits_warning(self):
"""Ensure that the deprecated function emits a warning when called."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertEqual(len(w), 1)

def test_displays_deprecated_version(self):
"""Ensure that the warning contains the version in which the function was deprecated."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertTrue(any("3.2.0" in str(warning.message) for warning in w))

def test_displays_removed_version(self):
"""Ensure that the warning contains the version in which the function will be removed."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertTrue(any("4.0.0" in str(warning.message) for warning in w))

def test_guidance_on_replacement(self):
"""Ensure that if a replacement target is provided, guidance on using the new function is included in the warning."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertTrue(any("new_func" in str(warning.message) for warning in w))

def test_displays_reason(self):
"""Ensure that if provided, the reason for deprecation is included in the warning."""
reason_str = "Replaced for optimization"
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertTrue(any(reason_str in str(warning.message) for warning in w))

def test_triggered_only_once(self):
"""Ensure that the deprecation warning is triggered only once even if the deprecated function is called multiple times."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
for _ in range(10):
self.fully_configured_deprecated_func()
self.assertEqual(len(w), 1, "Only one warning should be emitted")

def test_basic_deprecation_emits_warning(self):
"""Ensure that a function with minimal deprecation configuration emits a warning."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.basic_deprecated_func()
self.assertEqual(len(w), 1)


if __name__ == "__main__":
unittest.main()