Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 573 Integrate new EMA decay schedules #647

Merged
merged 18 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ema: True
ema_params:
decay: 0.9999
beta: 15
exp_activation: True
decay_type: exp

train_metrics_list:
- PixelAccuracy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ criterion_params: {} # when `loss` is one of SuperGradient's built in options, i
ema: False # whether to use Model Exponential Moving Average
ema_params: # parameters for the ema model.
decay: 0.9999
decay_type: exp
beta: 15
exp_activation: True


train_metrics_list: [] # Metrics to log during training. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ optimizer_params:

ema: True
ema_params:
exp_activation: False
decay: 0.9999
decay_type: constant

loss: cross_entropy
criterion_params:
Expand All @@ -42,4 +42,3 @@ valid_metrics_list: # metrics for evaluation
- Top5

_convert_: all

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ optimizer_params:

ema: True
ema_params:
exp_activation: False
decay_type: constant
decay: 0.9999

loss: cross_entropy
Expand Down
64 changes: 54 additions & 10 deletions src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hydra
import torch.nn
from omegaconf import DictConfig, OmegaConf
from super_gradients.training.utils.ema_decay_schedules import EMA_DECAY_FUNCTIONS
from torch.utils.data import DataLoader

from super_gradients.training.utils.distributed_training_utils import setup_device
Expand Down Expand Up @@ -255,17 +256,60 @@ def _get_hyper_param_config(self):
)
return hyper_param_config

def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
"""Instantiate KD ema model for KDModule.

If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
:param decay: the maximum decay value. as the training process advances, the decay will climb towards
this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
:param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will
saturate to its final value. beta=15 is ~40% of the training process.
:param exp_activation:
def _instantiate_ema_model(self, decay_type: str = None, decay: float = None, **kwargs) -> KDModelEMA:
"""Instantiate ema model for standard SgModule.
:param decay_type: (str) The decay climb schedule. See EMA_DECAY_FUNCTIONS for more details.
:param decay: The maximum decay value. As the training process advances, the decay will climb towards this value
according to decay_type schedule. See EMA_DECAY_FUNCTIONS for more details.
:param kwargs: Additional parameters for the decay function. See EMA_DECAY_FUNCTIONS for more details.
"""
return KDModelEMA(self.net, decay, beta, exp_activation)
if decay is None:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"Parameter `decay` is not specified for EMA model. Please specify `decay` parameter explicitly in your config:\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: exp\n"
" beta: 15\n"
"In the next major release of SG this warning will become an error."
)

if "exp_activation" in kwargs:
logger.warning(
"Parameter `exp_activation` is deprecated for EMA model. Please update your config to use decay_type: str (constant|exp|threshold) instead:\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: exp\n"
" beta: 15\n"
"\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: constant\n"
"\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: threshold\n"
"In the next major release of SG this warning will become an error."
)
decay_type = "exp" if bool(kwargs.pop("exp_activation")) else "constant"

if decay_type is None:
logger.warning(
"Parameter decay_type is not specified for EMA model. Please specify decay_type parameter explicitly in your config:\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: exp\n"
" beta: 15\n"
"In the next major release of SG this warning will become an error."
)
decay_type = "exp"

decay_function = EMA_DECAY_FUNCTIONS[decay_type](**kwargs)
return KDModelEMA(self.net, decay, decay_function)

def _save_best_checkpoint(self, epoch, state):
"""
Expand Down
67 changes: 58 additions & 9 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import hydra
from omegaconf import DictConfig
from super_gradients.training.utils.ema_decay_schedules import EMA_DECAY_FUNCTIONS
from torch import nn
from torch.utils.data import DataLoader, SequentialSampler
from torch.cuda.amp import GradScaler, autocast
Expand Down Expand Up @@ -551,9 +552,11 @@ def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context
torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.training_params.clip_grad_norm)

# ACCUMULATE GRADIENT FOR X BATCHES BEFORE OPTIMIZING
integrated_batches_num = batch_idx + len(self.train_loader) * epoch + 1
local_step = batch_idx + 1
global_step = local_step + len(self.train_loader) * epoch
total_steps = len(self.train_loader) * self.max_epochs

if integrated_batches_num % self.batch_accumulate == 0:
if global_step % self.batch_accumulate == 0:
self.phase_callback_handler.on_train_batch_gradient_step_start(context)

# SCALER IS ENABLED ONLY IF self.training_params.mixed_precision=True
Expand All @@ -562,7 +565,7 @@ def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context

self.optimizer.zero_grad()
if self.ema:
self.ema_model.update(self.net, integrated_batches_num / (len(self.train_loader) * self.max_epochs))
self.ema_model.update(self.net, step=global_step, total_steps=total_steps)

# RUN PHASE CALLBACKS
self.phase_callback_handler.on_train_batch_gradient_step_end(context)
Expand Down Expand Up @@ -1902,14 +1905,60 @@ def _instantiate_net(

return net

def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> ModelEMA:
def _instantiate_ema_model(self, decay_type: str = None, decay: float = None, **kwargs) -> ModelEMA:
"""Instantiate ema model for standard SgModule.
:param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
:param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
its final value. beta=15 is ~40% of the training process.
:param decay_type: (str) The decay climb schedule. See EMA_DECAY_FUNCTIONS for more details.
:param decay: The maximum decay value. As the training process advances, the decay will climb towards this value
according to decay_type schedule. See EMA_DECAY_FUNCTIONS for more details.
:param kwargs: Additional parameters for the decay function. See EMA_DECAY_FUNCTIONS for more details.
"""
return ModelEMA(self.net, decay, beta, exp_activation)
if decay is None:
logger.warning(
"Parameter `decay` is not specified for EMA model. Please specify `decay` parameter explicitly in your config:\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: exp\n"
" beta: 15\n"
"In the next major release of SG this warning will become an error."
)

if "exp_activation" in kwargs:
logger.warning(
"Parameter `exp_activation` is deprecated for EMA model. Please update your config to use decay_type: str (constant|exp|threshold) instead:\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: exp\n"
" beta: 15\n"
"\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: constant\n"
"\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: threshold\n"
"In the next major release of SG this warning will become an error."
)
decay_type = "exp" if bool(kwargs.pop("exp_activation")) else "constant"

if decay_type is None:
logger.warning(
"Parameter decay_type is not specified for EMA model. Please specify decay_type parameter explicitly in your config:\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: exp\n"
" beta: 15\n"
"In the next major release of SG this warning will become an error."
)
decay_type = "exp"

decay_function = EMA_DECAY_FUNCTIONS[decay_type](**kwargs)
return ModelEMA(self.net, decay, decay_function)

@property
def get_net(self):
Expand Down
24 changes: 12 additions & 12 deletions src/super_gradients/training/utils/ema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import warnings
from copy import deepcopy
from typing import Union
Expand All @@ -9,6 +8,7 @@
from super_gradients.training import utils as core_utils
from super_gradients.training.models import SgModule
from super_gradients.training.models.kd_modules.kd_module import KDModule
from super_gradients.training.utils.ema_decay_schedules import IDecayFunction


def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
Expand All @@ -30,7 +30,7 @@ class ModelEMA:
GPU assignment and distributed training wrappers.
"""

def __init__(self, model, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
def __init__(self, model, decay: float, decay_function: IDecayFunction):
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""
Init the EMA
:param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
Expand All @@ -44,10 +44,8 @@ def __init__(self, model, decay: float = 0.9999, beta: float = 15, exp_activatio
# Create EMA
self.ema = deepcopy(model)
self.ema.eval()
if exp_activation:
self.decay_function = lambda x: decay * (1 - math.exp(-x * beta)) # decay exponential ramp (to help early epochs)
else:
self.decay_function = lambda x: decay # always return the same decay factor
self.decay = decay
self.decay_function = decay_function

""""
we hold a list of model attributes (not wights and biases) which we would like to include in each
Expand All @@ -65,15 +63,17 @@ def __init__(self, model, decay: float = 0.9999, beta: float = 15, exp_activatio
for p in self.ema.module.parameters():
p.requires_grad_(False)

def update(self, model, training_percent: float):
def update(self, model, step: int, total_steps: int):
"""
Update the state of the EMA model.
:param model: current training model
:param training_percent: the percentage of the training process [0,1]. i.e 0.4 means 40% of the training have passed

:param model: Current training model
:param step: Current training step
:param total_steps: Total training steps
"""
# Update EMA parameters
with torch.no_grad():
decay = self.decay_function(training_percent)
decay = self.decay_function(self.decay, step, total_steps)

for ema_v, model_v in zip(self.ema.module.state_dict().values(), model.state_dict().values()):
if ema_v.dtype.is_floating_point:
Expand Down Expand Up @@ -101,7 +101,7 @@ class KDModelEMA(ModelEMA):
GPU assignment and distributed training wrappers.
"""

def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunction):
"""
Init the EMA
:param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
Expand All @@ -113,7 +113,7 @@ def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15,
its final value. beta=15 is ~40% of the training process.
"""
# Only work on the student (we don't want to update and to have a duplicate of the teacher)
super().__init__(model=core_utils.WrappedModel(kd_model.module.student), decay=decay, beta=beta, exp_activation=exp_activation)
super().__init__(model=core_utils.WrappedModel(kd_model.module.student), decay=decay, decay_function=decay_function)

# Overwrite current ema attribute with combination of the student model EMA (current self.ema)
# with already the instantiated teacher, to have the final KD EMA
Expand Down
38 changes: 38 additions & 0 deletions src/super_gradients/training/utils/ema_decay_schedules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import math
from abc import abstractmethod

__all__ = ["IDecayFunction", "ConstantDecay", "ThresholdDecay", "ExpDecay", "EMA_DECAY_FUNCTIONS"]


class IDecayFunction:
@abstractmethod
def __call__(self, decay: float, step: int, total_steps: int):
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
pass


class ConstantDecay(IDecayFunction):
def __init__(self, **kwargs):
pass

def __call__(self, decay: float, step: int, total_steps: int):
return decay


class ThresholdDecay(IDecayFunction):
def __init__(self, **kwargs):
pass

def __call__(self, decay: float, step, total_steps: int):
return min(decay, (1 + step) / (10 + step))


class ExpDecay(IDecayFunction):
def __init__(self, beta: float, **kwargs):
self.beta = beta

def __call__(self, decay: float, step, total_steps: int):
x = step / total_steps
return decay * (1 - math.exp(-x * self.beta))


EMA_DECAY_FUNCTIONS = {"constant": ConstantDecay, "threshold": ThresholdDecay, "exp": ExpDecay}
16 changes: 13 additions & 3 deletions tests/integration_tests/ema_train_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,21 @@ def _init_model(self) -> None:
def tearDownClass(cls) -> None:
pass

def test_train(self):
def test_train_exp_decay(self):
self._init_model()
self._train({})
self._train({"decay_type": "exp", "beta": 15, "decay": 0.9999})

def test_train_threshold_decay(self):
self._init_model()
self._train({"decay_type": "threshold", "decay": 0.9999})

def test_train_constant_decay(self):
self._init_model()
self._train({"decay_type": "constant", "decay": 0.9999})

def test_train_with_old_ema_params(self):
self._init_model()
self._train({"exp_activation": False})
self._train({"decay": 0.9999, "exp_activation": True, "beta": 10})

def _train(self, ema_params):
training_params = {
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/kd_ema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def setUp(cls):
"greater_metric_to_watch_is_better": True,
"average_best_models": False,
"ema": True,
"ema_params": {"decay_type": "constant", "decay": 0.999},
}

def test_teacher_ema_not_duplicated(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/kd_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def test_load_ckpt_best_for_student_with_ema(self):
train_params = self.kd_train_params.copy()
train_params["max_epochs"] = 1
train_params["ema"] = True
train_params["ema_params"] = {"decay_type": "constant", "decay": 0.999}

kd_trainer.train(
training_params=train_params,
student=student,
Expand Down