From 6d99b3f1a20be695a274a585792fa1f8664a3f80 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 18 Sep 2023 15:04:44 +0300 Subject: [PATCH 1/4] fixed cyclic lr state dict --- .../training/sg_trainer/sg_trainer.py | 3 ++- .../training/utils/checkpoint_utils.py | 19 ++++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 8e7943b15d..44f04891bb 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -83,6 +83,7 @@ read_ckpt_state_dict, load_checkpoint_to_model, load_pretrained_weights, + get_scheduler_state, ) from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger from super_gradients.training.utils.callbacks import ( @@ -660,7 +661,7 @@ def _save_checkpoint( state["processing_params"] = processing_params if self._torch_lr_scheduler is not None: - state["torch_scheduler_state_dict"] = self._torch_lr_scheduler.state_dict() + state["torch_scheduler_state_dict"] = get_scheduler_state(self._torch_lr_scheduler) # SAVES CURRENT MODEL AS ckpt_latest self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index a16de11d4e..4fbb602377 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -1,11 +1,12 @@ import collections import os import tempfile -from typing import Union, Mapping +from typing import Union, Mapping, Dict import pkg_resources import torch from torch import nn, Tensor +from torch.optim.lr_scheduler import CyclicLR from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces @@ -13,6 +14,7 @@ from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation from super_gradients.module_interfaces import HasPredict from super_gradients.training.pretrained_models import MODEL_URLS +from super_gradients.training.utils import torch_version_is_greater_or_equal from super_gradients.training.utils.distributed_training_utils import wait_for_the_master from super_gradients.common.environment.ddp_utils import get_local_rank from super_gradients.training.utils.utils import unwrap_model @@ -22,7 +24,6 @@ except (ModuleNotFoundError, ImportError, NameError): from torch.hub import _download_url_to_file as download_url_to_file - logger = get_logger(__name__) @@ -1585,7 +1586,6 @@ def _load_weights(architecture, model, pretrained_state_dict): def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str): - """ Loads pretrained weights from the MODEL_URLS dictionary to model :param architecture: name of the model's architecture @@ -1598,3 +1598,16 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location) _load_weights(architecture, model, pretrained_state_dict) + + +def get_scheduler_state(scheduler) -> Dict: + """ + Wrapper for getting a torch lr scheduler state dict, resolving some issues with CyclicLR + (see https://github.com/pytorch/pytorch/pull/91400) + :param scheduler: torch.optim.lr_scheduler._LRScheduler, the scheduler + :return: the scheduler's state_dict + """ + state = scheduler.state_dict() + if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0): + del state["_scale_fn_ref"] + return state From 5d42b9d3382e728a853c1db67531d4a2bb43dfaa Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 18 Sep 2023 16:24:20 +0300 Subject: [PATCH 2/4] fixed circular import --- src/super_gradients/training/utils/checkpoint_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index 4fbb602377..79a0709406 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -14,7 +14,6 @@ from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation from super_gradients.module_interfaces import HasPredict from super_gradients.training.pretrained_models import MODEL_URLS -from super_gradients.training.utils import torch_version_is_greater_or_equal from super_gradients.training.utils.distributed_training_utils import wait_for_the_master from super_gradients.common.environment.ddp_utils import get_local_rank from super_gradients.training.utils.utils import unwrap_model @@ -1608,6 +1607,6 @@ def get_scheduler_state(scheduler) -> Dict: :return: the scheduler's state_dict """ state = scheduler.state_dict() - if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0): + if isinstance(scheduler, CyclicLR) and int(torch.version.__version__.split(".")[0]) < 2: del state["_scale_fn_ref"] return state From b677a89ec8439484e6959d9bdc9b47003ee82c8b Mon Sep 17 00:00:00 2001 From: shayaharon Date: Wed, 20 Sep 2023 16:24:59 +0300 Subject: [PATCH 3/4] local import of torch_version_is_greater_or_equal --- src/super_gradients/training/utils/checkpoint_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index 79a0709406..8638752a3e 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -1606,7 +1606,9 @@ def get_scheduler_state(scheduler) -> Dict: :param scheduler: torch.optim.lr_scheduler._LRScheduler, the scheduler :return: the scheduler's state_dict """ + from super_gradients.training.utils import torch_version_is_greater_or_equal + state = scheduler.state_dict() - if isinstance(scheduler, CyclicLR) and int(torch.version.__version__.split(".")[0]) < 2: + if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0): del state["_scale_fn_ref"] return state From 00e7be941e65f373af8e13ae4dde9e613f51b4a2 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 11 Oct 2023 09:03:13 +0300 Subject: [PATCH 4/4] Fix missing function after merge --- .../training/utils/checkpoint_utils.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index 46de06fb58..c524462c9e 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -1,7 +1,7 @@ import collections import os import tempfile -from typing import Union, Mapping +from typing import Union, Mapping, Dict import pkg_resources import torch @@ -1629,3 +1629,19 @@ def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkp "predict make sure to call set_dataset_processing_params." ) return False + + +def get_scheduler_state(scheduler) -> Dict[str, Tensor]: + """ + Wrapper for getting a torch lr scheduler state dict, resolving some issues with CyclicLR + (see https://github.com/pytorch/pytorch/pull/91400) + :param scheduler: torch.optim.lr_scheduler._LRScheduler, the scheduler + :return: the scheduler's state_dict + """ + from super_gradients.training.utils import torch_version_is_greater_or_equal + from torch.optim.lr_scheduler import CyclicLR + + state = scheduler.state_dict() + if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0): + del state["_scale_fn_ref"] + return state