Skip to content

Commit

Permalink
unwrap when deprecating decorated methods (#1942)
Browse files Browse the repository at this point in the history
* unwrap when deprecating decorated methods

* lint

* remove breaking changes. Will deal with that later

* less verbosity

---------

Co-authored-by: Shay Aharon <80472096+shaydeci@users.noreply.github.com>
  • Loading branch information
NatanBagrov and shaydeci committed Apr 1, 2024
1 parent b3e698c commit 11326d9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ def wrapper(*args, **kwargs):
message += f"Reason: {reason}.\n"

if target is not None:
if hasattr(target, "__func__"): # unwraps `__func__` for @classmethod and @staticmethod decorators
target_func = target.__func__
else:
target_func = target
message += (
f"Please update your code:\n"
f" [-] from `{old_func.__module__}` import `{old_func.__name__}`\n"
f" [+] from `{target.__module__}` import `{target.__name__}`"
f" [+] from `{target_func.__module__}` import `{target_func.__name__}`"
)

if is_still_supported:
Expand Down
30 changes: 30 additions & 0 deletions tests/unit_tests/test_deprecate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ def __init__(self):
self.NewClass = NewClass
self.DeprecatedClass = DeprecatedClass

@classmethod
def new_class_func(cls):
return None

@classmethod
@deprecated(deprecated_since="3.2.0", removed_from="10.0.0", target=new_class_func)
def deprecated_class_func(cls):
return cls.new_class_func()

@staticmethod
def new_static_func():
return None

@staticmethod
@deprecated(deprecated_since="3.2.0", removed_from="10.0.0", target=new_static_func)
def deprecated_static_func():
return TestDeprecationDecorator.new_static_func

def test_emits_warning(self):
"""Ensure that the deprecated function emits a warning when called."""
with warnings.catch_warnings(record=True) as w:
Expand Down Expand Up @@ -88,6 +106,18 @@ def test_basic_deprecation_emits_warning(self):
self.basic_deprecated_func()
self.assertEqual(len(w), 1)

def test_class_method_deprecation_emits_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_ = TestDeprecationDecorator.deprecated_class_func()
self.assertEqual(len(w), 1)

def test_static_method_deprecation_emits_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_ = TestDeprecationDecorator.deprecated_static_func()
self.assertEqual(len(w), 1)

def test_class_deprecation_warning(self):
"""Ensure that creating an instance of a deprecated class emits a warning."""
with warnings.catch_warnings(record=True) as w:
Expand Down

0 comments on commit 11326d9

Please sign in to comment.