-
Notifications
You must be signed in to change notification settings - Fork 488
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add deprecate module * add * add tests * add tests * add version and tests * rename * add to tests stak * fix docstring typo * change error * fix import loop with version and some other minor fixes * move test to runtime * EnvironmentError to ImportError * update names --------- Co-authored-by: Ofri Masad <ofrimasad@users.noreply.github.com>
- Loading branch information
1 parent
e623981
commit 32fc041
Showing
6 changed files
with
233 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import warnings | ||
from functools import wraps | ||
from typing import Optional | ||
from pkg_resources import parse_version | ||
|
||
|
||
def deprecated(deprecated_since: str, removed_from: str, target: Optional[callable] = None, reason: str = ""): | ||
""" | ||
Decorator to mark a callable as deprecated. Works on functions and classes. | ||
It provides a clear and actionable warning message informing | ||
the user about the version in which the function was deprecated, the version in which it will be removed, | ||
and guidance on how to replace it. | ||
:param deprecated_since: Version number when the function was deprecated. | ||
:param removed_from: Version number when the function will be removed. | ||
:param target: (Optional) The new function that should be used as a replacement. If provided, it will guide the user to the updated function. | ||
:param reason: (Optional) Additional information or reason for the deprecation. | ||
Example usage: | ||
If a direct replacement function exists: | ||
>> from new.module.path import new_get_local_rank | ||
>> @deprecated(deprecated_since='3.2.0', removed_from='4.0.0', target=new_get_local_rank, reason="Replaced for optimization") | ||
>> def get_local_rank(): | ||
>> return new_get_local_rank() | ||
If there's no direct replacement: | ||
>> @deprecated(deprecated_since='3.2.0', removed_from='4.0.0', reason="Function is no longer needed due to XYZ reason") | ||
>> def some_old_function(): | ||
>> # ... function logic ... | ||
When calling a deprecated function: | ||
>> from some_module import get_local_rank | ||
>> get_local_rank() | ||
DeprecationWarning: Function `some_module.get_local_rank` is deprecated. Deprecated since version `3.2.0` | ||
and will be removed in version `4.0.0`. Reason: `Replaced for optimization`. | ||
Please update your code: | ||
[-] from `some_module` import `get_local_rank` | ||
[+] from `new.module.path` import `new_get_local_rank`. | ||
""" | ||
|
||
def decorator(old_func: callable) -> callable: | ||
@wraps(old_func) | ||
def wrapper(*args, **kwargs): | ||
if not wrapper._warned: | ||
import super_gradients | ||
|
||
is_still_supported = parse_version(super_gradients.__version__) < parse_version(removed_from) | ||
status_msg = "is deprecated" if is_still_supported else "was deprecated and has been removed" | ||
message = ( | ||
f"Callable `{old_func.__module__}.{old_func.__name__}` {status_msg} since version `{deprecated_since}` " | ||
f"and will be removed in version `{removed_from}`.\n" | ||
) | ||
if reason: | ||
message += f"Reason: {reason}.\n" | ||
|
||
if target is not None: | ||
message += ( | ||
f"Please update your code:\n" | ||
f" [-] from `{old_func.__module__}` import `{old_func.__name__}`\n" | ||
f" [+] from `{target.__module__}` import `{target.__name__}`" | ||
) | ||
|
||
if is_still_supported: | ||
warnings.simplefilter("once", DeprecationWarning) # Required, otherwise the warning may never be displayed. | ||
warnings.warn(message, DeprecationWarning, stacklevel=2) | ||
wrapper._warned = True | ||
else: | ||
raise ImportError(message) | ||
|
||
return old_func(*args, **kwargs) | ||
|
||
# Each decorated object will have its own _warned state | ||
# This state ensures that the warning will appear only once, to avoid polluting the console in case the function is called too often. | ||
wrapper._warned = False | ||
return wrapper | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import warnings | ||
import unittest | ||
from unittest.mock import patch | ||
|
||
from super_gradients.common.deprecate import deprecated | ||
|
||
|
||
class TestDeprecationDecorator(unittest.TestCase): | ||
def setUp(self): | ||
"""Prepare required functions before each test.""" | ||
self.new_function_message = "This is the new function!" | ||
|
||
def new_func(): | ||
return self.new_function_message | ||
|
||
@deprecated(deprecated_since="3.2.0", removed_from="10.0.0", target=new_func, reason="Replaced for optimization") | ||
def fully_configured_deprecated_func(): | ||
return new_func() | ||
|
||
@deprecated(deprecated_since="3.2.0", removed_from="10.0.0") | ||
def basic_deprecated_func(): | ||
return new_func() | ||
|
||
self.new_func = new_func | ||
self.fully_configured_deprecated_func = fully_configured_deprecated_func | ||
self.basic_deprecated_func = basic_deprecated_func | ||
|
||
class NewClass: | ||
def __init__(self): | ||
pass | ||
|
||
@deprecated(deprecated_since="3.2.0", removed_from="10.0.0", target=NewClass, reason="Replaced for optimization") | ||
class DeprecatedClass: | ||
def __init__(self): | ||
pass | ||
|
||
self.NewClass = NewClass | ||
self.DeprecatedClass = DeprecatedClass | ||
|
||
def test_emits_warning(self): | ||
"""Ensure that the deprecated function emits a warning when called.""" | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter("always") | ||
self.fully_configured_deprecated_func() | ||
self.assertEqual(len(w), 1) | ||
|
||
def test_displays_deprecated_version(self): | ||
"""Ensure that the warning contains the version in which the function was deprecated.""" | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter("always") | ||
self.fully_configured_deprecated_func() | ||
self.assertTrue(any("3.2.0" in str(warning.message) for warning in w)) | ||
|
||
def test_displays_removed_version(self): | ||
"""Ensure that the warning contains the version in which the function will be removed.""" | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter("always") | ||
self.fully_configured_deprecated_func() | ||
self.assertTrue(any("10.0.0" in str(warning.message) for warning in w)) | ||
|
||
def test_guidance_on_replacement(self): | ||
"""Ensure that if a replacement target is provided, guidance on using the new function is included in the warning.""" | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter("always") | ||
self.fully_configured_deprecated_func() | ||
self.assertTrue(any("new_func" in str(warning.message) for warning in w)) | ||
|
||
def test_displays_reason(self): | ||
"""Ensure that if provided, the reason for deprecation is included in the warning.""" | ||
reason_str = "Replaced for optimization" | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter("always") | ||
self.fully_configured_deprecated_func() | ||
self.assertTrue(any(reason_str in str(warning.message) for warning in w)) | ||
|
||
def test_triggered_only_once(self): | ||
"""Ensure that the deprecation warning is triggered only once even if the deprecated function is called multiple times.""" | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter("always") | ||
for _ in range(10): | ||
self.fully_configured_deprecated_func() | ||
self.assertEqual(len(w), 1, "Only one warning should be emitted") | ||
|
||
def test_basic_deprecation_emits_warning(self): | ||
"""Ensure that a function with minimal deprecation configuration emits a warning.""" | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter("always") | ||
self.basic_deprecated_func() | ||
self.assertEqual(len(w), 1) | ||
|
||
def test_class_deprecation_warning(self): | ||
"""Ensure that creating an instance of a deprecated class emits a warning.""" | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter("always") | ||
_ = self.DeprecatedClass() # Instantiate the deprecated class | ||
self.assertEqual(len(w), 1) | ||
|
||
def test_class_deprecation_message_content(self): | ||
"""Ensure that the emitted warning for a deprecated class contains relevant information including target class.""" | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.simplefilter("always") | ||
_ = self.DeprecatedClass() | ||
self.assertTrue(any("3.2.0" in str(warning.message) for warning in w)) | ||
self.assertTrue(any("10.0.0" in str(warning.message) for warning in w)) | ||
self.assertTrue(any("DeprecatedClass" in str(warning.message) for warning in w)) | ||
self.assertTrue(any("Replaced for optimization" in str(warning.message) for warning in w)) | ||
self.assertTrue(any("NewClass" in str(warning.message) for warning in w)) | ||
|
||
def test_raise_error_when_library_version_equals_removal_version(self): | ||
"""Ensure that an error is raised when the library's version equals the function's removal version.""" | ||
with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be equal to removal version | ||
with self.assertRaises(ImportError): | ||
|
||
@deprecated(deprecated_since="3.2.0", removed_from="10.1.0", target=self.new_func) | ||
def deprecated_func_version_equal(): | ||
return | ||
|
||
deprecated_func_version_equal() | ||
|
||
def test_no_error_when_library_version_below_removal_version(self): | ||
"""Ensure that no error is raised when the library's version is below the function's removal version.""" | ||
with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be below removal version | ||
|
||
@deprecated(deprecated_since="3.2.0", removed_from="10.2.0", target=self.new_func) | ||
def deprecated_func_version_below(): | ||
return | ||
|
||
deprecated_func_version_below() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |