Skip to content

Commit

Permalink
YoloNAS_Pose_Fine_Tuning_Animals_Pose_Dataset (#1876)
Browse files Browse the repository at this point in the history
returned FunctionLRScheduler (looks like it was dropped by mistake)
  • Loading branch information
ofrimasad committed Feb 29, 2024
1 parent 6c8fa3a commit 12cd550
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,36 @@ def compute_learning_rate(cls, step: Union[float, np.ndarray], total_steps: floa
return lr * (1 - final_lr_ratio) + (initial_lr * final_lr_ratio)


@register_lr_scheduler(LRSchedulers.FUNCTION, deprecated_name="function")
class FunctionLRScheduler(LRCallbackBase):
"""
Hard coded rate scheduling for user defined lr scheduling function.
"""

def __init__(self, max_epochs, lr_schedule_function, **kwargs):
super().__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
assert callable(lr_schedule_function), "self.lr_function must be callable"
self.lr_schedule_function = lr_schedule_function
self.max_epochs = max_epochs

def is_lr_scheduling_enabled(self, context):
post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs

def perform_scheduling(self, context):
effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
for group_name in self.lr.keys():
self.lr[group_name] = self.lr_schedule_function(
initial_lr=self.initial_lr[group_name],
epoch=effective_epoch,
iter=context.batch_idx,
max_epoch=effective_max_epochs,
iters_per_epoch=self.train_loader_len,
)
self.update_lr(context.optimizer, context.epoch, context.batch_idx)


class IllegalLRSchedulerMetric(Exception):
"""Exception raised illegal combination of training parameters.
Expand Down

0 comments on commit 12cd550

Please sign in to comment.