From 07c64ccea63e9fd5f716ebb6031bad4c852cd95c Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 12 Oct 2023 09:53:50 +0300 Subject: [PATCH] fixed cyclic lr state dict (#1469) (#1518) * fixed cyclic lr state dict * fixed circular import * local import of torch_version_is_greater_or_equal * Fix missing function after merge --------- Co-authored-by: Eugene Khvedchenya (cherry picked from commit b56fad8ebc5a1842969dcebc237dbd6b426ba6d4) Co-authored-by: Shay Aharon <80472096+shaydeci@users.noreply.github.com> --- .../training/sg_trainer/sg_trainer.py | 3 ++- .../training/utils/checkpoint_utils.py | 18 +++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 70ec6172b7..9d1648fafa 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -77,6 +77,7 @@ read_ckpt_state_dict, load_checkpoint_to_model, load_pretrained_weights, + get_scheduler_state, ) from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger from super_gradients.training.utils.callbacks import ( @@ -617,7 +618,7 @@ def _save_checkpoint( state["processing_params"] = processing_params if self._torch_lr_scheduler is not None: - state["torch_scheduler_state_dict"] = self._torch_lr_scheduler.state_dict() + state["torch_scheduler_state_dict"] = get_scheduler_state(self._torch_lr_scheduler) # SAVES CURRENT MODEL AS ckpt_latest self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index 35432569b8..6bae61d720 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -1,7 +1,7 @@ import collections import os import tempfile -from typing import Union, Mapping +from typing import Union, Mapping, Dict import pkg_resources import torch @@ -1628,3 +1628,19 @@ def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkp "predict make sure to call set_dataset_processing_params." ) return False + + +def get_scheduler_state(scheduler) -> Dict[str, Tensor]: + """ + Wrapper for getting a torch lr scheduler state dict, resolving some issues with CyclicLR + (see https://github.com/pytorch/pytorch/pull/91400) + :param scheduler: torch.optim.lr_scheduler._LRScheduler, the scheduler + :return: the scheduler's state_dict + """ + from super_gradients.training.utils import torch_version_is_greater_or_equal + from torch.optim.lr_scheduler import CyclicLR + + state = scheduler.state_dict() + if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0): + del state["_scale_fn_ref"] + return state