Skip to content

Commit

Permalink
add version and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Dupont committed Aug 27, 2023
1 parent 4723a64 commit a210e0a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 9 deletions.
20 changes: 15 additions & 5 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
import warnings
from functools import wraps
from typing import Optional
from pkg_resources import parse_version

import super_gradients

def deprecate_call(deprecated_in_v: str, removed_in_v: str, target: Optional[callable] = None, reason: str = ""):

def deprecate_call(deprecated_in_v: str, remove_in_v: str, target: Optional[callable] = None, reason: str = ""):
"""
Decorator to mark a callable as deprecated. It provides a clear and actionable warning message informing
the user about the version in which the function was deprecated, the version in which it will be removed,
and guidance on how to replace it.
:param deprecated_in_v: Version number when the function was deprecated.
:param removed_in_v: Version number when the function will be removed.
:param remove_in_v: Version number when the function will be removed.
:param target: (Optional) The new function that should be used as a replacement. If provided, it will guide the user to the updated function.
:param reason: (Optional) Additional information or reason for the deprecation.
Example usage:
If a direct replacement function exists:
>> from new.module.path import new_get_local_rank
>> @deprecate_call(deprecated_in_v='3.2.0', removed_in_v='4.0.0', target=new_get_local_rank, reason="Replaced for optimization")
>> @deprecate_call(deprecated_in_v='3.2.0', remove_in_v='4.0.0', target=new_get_local_rank, reason="Replaced for optimization")
>> def get_local_rank():
>> return new_get_local_rank()
If there's no direct replacement:
>> @deprecate_call(deprecated_in_v='3.2.0', removed_in_v='4.0.0', reason="Function is no longer needed due to XYZ reason")
>> @deprecate_call(deprecated_in_v='3.2.0', remove_in_v='4.0.0', reason="Function is no longer needed due to XYZ reason")
>> def some_old_function():
>> # ... function logic ...
Expand All @@ -38,12 +41,19 @@ def deprecate_call(deprecated_in_v: str, removed_in_v: str, target: Optional[cal
"""

def decorator(old_func: callable) -> callable:

if parse_version(super_gradients.__version__) >= parse_version(remove_in_v):
raise ValueError(
f"`super_gradients.__version__={super_gradients.__version__}` >= `remove_in_v={remove_in_v}`. "
f"Please remove {old_func.__module__}.{old_func.__name__} from your code base."
)

@wraps(old_func)
def wrapper(*args, **kwargs):
if not wrapper._warned:
message = (
f"Function `{old_func.__module__}.{old_func.__name__}` is deprecated since version `{deprecated_in_v}` "
f"and will be removed in version `{removed_in_v}`.\n"
f"and will be removed in version `{remove_in_v}`.\n"
)
if reason:
message += f"Reason: {reason}.\n"
Expand Down
60 changes: 56 additions & 4 deletions tests/unit_tests/test_deprecate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest
import warnings
import unittest
from unittest.mock import patch

from super_gradients.common.deprecate import deprecate_call


Expand All @@ -11,18 +13,30 @@ def setUp(self):
def new_func():
return self.new_function_message

@deprecate_call(deprecated_in_v="3.2.0", removed_in_v="4.0.0", target=new_func, reason="Replaced for optimization")
@deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.0.0", target=new_func, reason="Replaced for optimization")
def fully_configured_deprecated_func():
return new_func()

@deprecate_call(deprecated_in_v="3.2.0", removed_in_v="4.0.0")
@deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.0.0")
def basic_deprecated_func():
return new_func()

self.new_func = new_func
self.fully_configured_deprecated_func = fully_configured_deprecated_func
self.basic_deprecated_func = basic_deprecated_func

class NewClass:
def __init__(self):
pass

@deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.0.0", target=NewClass, reason="Replaced for optimization")
class DeprecatedClass:
def __init__(self):
pass

self.NewClass = NewClass
self.DeprecatedClass = DeprecatedClass

def test_emits_warning(self):
"""Ensure that the deprecated function emits a warning when called."""
with warnings.catch_warnings(record=True) as w:
Expand All @@ -42,7 +56,7 @@ def test_displays_removed_version(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.fully_configured_deprecated_func()
self.assertTrue(any("4.0.0" in str(warning.message) for warning in w))
self.assertTrue(any("10.0.0" in str(warning.message) for warning in w))

def test_guidance_on_replacement(self):
"""Ensure that if a replacement target is provided, guidance on using the new function is included in the warning."""
Expand Down Expand Up @@ -74,6 +88,44 @@ def test_basic_deprecation_emits_warning(self):
self.basic_deprecated_func()
self.assertEqual(len(w), 1)

def test_class_deprecation_warning(self):
"""Ensure that creating an instance of a deprecated class emits a warning."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_ = self.DeprecatedClass() # Instantiate the deprecated class
self.assertEqual(len(w), 1)

def test_class_deprecation_message_content(self):
"""Ensure that the emitted warning for a deprecated class contains relevant information including target class."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_ = self.DeprecatedClass()
self.assertTrue(any("3.2.0" in str(warning.message) for warning in w))
self.assertTrue(any("10.0.0" in str(warning.message) for warning in w))
self.assertTrue(any("DeprecatedClass" in str(warning.message) for warning in w))
self.assertTrue(any("Replaced for optimization" in str(warning.message) for warning in w))
self.assertTrue(any("NewClass" in str(warning.message) for warning in w))

def test_raise_error_when_library_version_equals_removal_version(self):
"""Ensure that an error is raised when the library's version equals the function's removal version."""
with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be equal to removal version
with self.assertRaises(ValueError):

@deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.1.0", target=self.new_func)
def deprecated_func_version_equal():
return

def test_no_error_when_library_version_below_removal_version(self):
"""Ensure that no error is raised when the library's version is below the function's removal version."""
with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be below removal version

@deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.2.0", target=self.new_func)
def deprecated_func_version_below():
return

# Actually call the function to check no exception is raised
deprecated_func_version_below()


if __name__ == "__main__":
unittest.main()

0 comments on commit a210e0a

Please sign in to comment.