Skip to content

Commit

Permalink
added docs
Browse files Browse the repository at this point in the history
  • Loading branch information
shaydeci committed Oct 11, 2023
1 parent 9c24d6c commit f85d1fb
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
63 changes: 63 additions & 0 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ def wrapper(*args, **kwargs):


def deprecated_training_param(deprecated_tparam_name: str, deprecated_since: str, removed_from: str, new_arg_assigner: Callable, message: str = ""):
"""
Decorator for deprecating training hyperparameters.
Recommended tp be used as a decorator on top of super_gradients.training.params.TrainingParams's override method:
class TrainingParams(HpmStruct):
def __init__(self, **entries):
# WE initialize by the default training params, overridden by the provided params
default_training_params = deepcopy(DEFAULT_TRAINING_PARAMS)
super().__init__(**default_training_params)
self.set_schema(TRAINING_PARAM_SCHEMA)
if len(entries) > 0:
self.override(**entries)
@deprecated_training_param(
"criterion_params", "3.2.1", "3.3.0", new_arg_assigner=get_deprecated_nested_params_to_factory_format_assigner("loss", "criterion_params")
)
def override(self, **entries):
super().override(**entries)
self.validate()
:param deprecated_tparam_name: str, the name of the deprecated hyperparameter.
:param deprecated_since: str, SG version of deprecation.
:param removed_from: str, SG version of removal.
:param new_arg_assigner: Callable, a handler to assign the deprecated parameter value to the updated
hyperparameter entry.
:param message: str, message to append to the deprecation warning (default="")
:return:
"""

def decorator(func):
def wrapper(*args, **training_params):
if deprecated_tparam_name in training_params:
Expand Down Expand Up @@ -107,6 +138,38 @@ def wrapper(*args, **training_params):


def get_deprecated_nested_params_to_factory_format_assigner(param_name: str, nested_params_name: str) -> Callable:
"""
Returns an assigner to be used by deprecated_training_param decorator.
The assigner takes a deprecated parameter name, and its __init___ arguments that previously were passed
through nested_params_name entry in training_params and manipulates the training_params so they are in 'Factory' format.
For example:
class TrainingParams(HpmStruct):
def __init__(self, **entries):
# WE initialize by the default training params, overridden by the provided params
default_training_params = deepcopy(DEFAULT_TRAINING_PARAMS)
super().__init__(**default_training_params)
self.set_schema(TRAINING_PARAM_SCHEMA)
if len(entries) > 0:
self.override(**entries)
@deprecated_training_param(
"criterion_params", "3.2.1", "3.3.0", new_arg_assigner=get_deprecated_nested_params_to_factory_format_assigner("loss", "criterion_params")
)
def override(self, **entries):
super().override(**entries)
self.validate()
then under the hood, training_params.loss will be set to
{training_params.loss: training_params.criterion_params}
:param param_name: str, parameter name (for example, 'loss').
:param nested_params_name: str, nested_params_name (for example, 'criterion_params')
:return: Callable as described above.
"""

def deprecated_nested_params_to_factory_format_assigner(**params):
nested_params = params.get(nested_params_name)
param_val = params.get(param_name)
Expand Down
4 changes: 1 addition & 3 deletions tests/unit_tests/test_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from omegaconf import DictConfig
from torch import nn

from super_gradients import setup_device, Trainer
from super_gradients import Trainer
from super_gradients.common.registry import register_model
from super_gradients.training import models
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
Expand Down Expand Up @@ -126,7 +126,6 @@ def test_deprecated_criterion_params(self):
train_params.override(criterion_params={"ignore_index": 0})

def test_train_with_deprecated_criterion_params(self):
setup_device(device="cpu")
trainer = Trainer("test_train_with_precise_bn_explicit_size")
net = ResNet18(num_classes=5, arch_params={})
train_params = {
Expand All @@ -145,7 +144,6 @@ def test_train_with_deprecated_criterion_params(self):
"metric_to_watch": "Accuracy",
"greater_metric_to_watch_is_better": True,
"precise_bn": True,
"precise_bn_batch_size": 100,
}
trainer.train(
model=net,
Expand Down

0 comments on commit f85d1fb

Please sign in to comment.