Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed ListConfig in pose estimation dataset classes #1602

Merged
merged 9 commits into from
Nov 6, 2023
33 changes: 31 additions & 2 deletions src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import json
import os
import shutil
import signal
import time
from typing import Union, Any
Expand All @@ -9,21 +11,21 @@
import psutil
import torch
from PIL import Image
import shutil
from omegaconf import ListConfig, DictConfig, OmegaConf

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
from super_gradients.common.auto_logging.console_logging import ConsoleSink
from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
from super_gradients.common.decorators.code_save_decorator import saved_codes
from super_gradients.common.environment.checkpoints_dir_utils import is_run_dir
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.environment.monitoring import SystemMonitor
from super_gradients.common.registry.registry import register_sg_logger
from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
from super_gradients.common.sg_loggers.time_units import TimeUnit
from super_gradients.training.params import TrainingParams
from super_gradients.training.utils import sg_trainer_utils, get_param
from super_gradients.common.environment.checkpoints_dir_utils import is_run_dir

logger = get_logger(__name__)

Expand Down Expand Up @@ -312,6 +314,7 @@ def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = None) ->
name += ".pth"
path = os.path.join(self._local_dir, name)

state_dict = self._sanitize_checkpoint(state_dict)
self._save_checkpoint(path=path, state_dict=state_dict)

@multi_process_safe
Expand Down Expand Up @@ -348,3 +351,29 @@ def _save_code(self):
self.add_file(name)
code = "\t" + code
self.add_text(name, code.replace("\n", " \n \t")) # this replacement makes tb format the code as code

def _sanitize_checkpoint(self, state_dict: dict) -> dict:
ofrimasad marked this conversation as resolved.
Show resolved Hide resolved
"""
Sanitize state dictionary to be saved in a checkpoint. Iterates recursively over the state_dict and converts
all instances of ListConfig and DictConfig to their native python counterparts.

:param state_dict: Checkpoint state_dict.
:return: Sanitized checkpoint state_dict.
"""
if isinstance(state_dict, (ListConfig, DictConfig)):
state_dict = OmegaConf.to_container(state_dict, resolve=True)

if isinstance(state_dict, torch.Tensor):
pass
elif isinstance(state_dict, collections.OrderedDict):
state_dict = collections.OrderedDict((k, self._sanitize_checkpoint(v)) for k, v in state_dict.items())
elif isinstance(state_dict, dict):
state_dict = dict((k, self._sanitize_checkpoint(v)) for k, v in state_dict.items())
elif isinstance(state_dict, list):
state_dict = [self._sanitize_checkpoint(v) for v in state_dict]
elif isinstance(state_dict, tuple):
state_dict = tuple(self._sanitize_checkpoint(v) for v in state_dict)
else:
pass

return state_dict
2 changes: 2 additions & 0 deletions src/super_gradients/common/sg_loggers/clearml_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def upload(self):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
state_dict = self._sanitize_checkpoint(state_dict)

name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def upload(self):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
state_dict = self._sanitize_checkpoint(state_dict)
name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
1 change: 1 addition & 0 deletions src/super_gradients/common/sg_loggers/wandb_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def _save_wandb_artifact(self, path):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
state_dict = self._sanitize_checkpoint(state_dict)
name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple, List, Union

import numpy as np
from omegaconf import ListConfig
from torch.utils.data.dataloader import Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
Expand Down Expand Up @@ -32,9 +33,9 @@ def __init__(
self,
transforms: List[AbstractKeypointTransform],
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray],
edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
):
"""

Expand All @@ -50,6 +51,18 @@ def __init__(
load_sample_fn=self.load_random_sample,
)
self.num_joints = num_joints

# Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples
# This is necessary to ensure ListConfig objects do not leak to these properties
# and from there - to checkpoint's state_dict.
# Otherwise, through ListConfig instances a whole configuration file will leak to state_dict
# and torch.load will attempt to unpickle lot of unnecessary classes.
edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links]
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
if edge_colors is not None:
edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors]
if keypoint_colors is not None:
keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors]

self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from omegaconf import ListConfig
from torch.utils.data.dataloader import default_collate, Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
Expand All @@ -28,9 +29,9 @@ def __init__(
transforms: List[KeypointTransform],
min_instance_area: float,
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray],
edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
):
"""

Expand All @@ -48,6 +49,18 @@ def __init__(
self.transforms = KeypointsCompose(transforms)
self.min_instance_area = min_instance_area
self.num_joints = num_joints

# Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples
# This is necessary to ensure ListConfig objects do not leak to these properties
# and from there - to checkpoint's state_dict.
# Otherwise, through ListConfig instances a whole configuration file will leak to state_dict
# and torch.load will attempt to unpickle lot of unnecessary classes.
edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links]
if edge_colors is not None:
edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors]
if keypoint_colors is not None:
keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors]

self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)
Expand Down