Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed ListConfig in pose estimation dataset classes #1602

Merged
merged 9 commits into from
Nov 6, 2023
30 changes: 29 additions & 1 deletion src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import os
import signal
import time
from typing import Union, Any
import typing
from typing import Union, Any, Mapping

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -11,6 +12,8 @@
from PIL import Image
import shutil

from omegaconf import ListConfig, DictConfig

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
from super_gradients.common.auto_logging.console_logging import ConsoleSink
Expand Down Expand Up @@ -312,6 +315,7 @@ def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = None) ->
name += ".pth"
path = os.path.join(self._local_dir, name)

self._validate_checkpoint(state_dict)
self._save_checkpoint(path=path, state_dict=state_dict)

@multi_process_safe
Expand Down Expand Up @@ -348,3 +352,27 @@ def _save_code(self):
self.add_file(name)
code = "\t" + code
self.add_text(name, code.replace("\n", " \n \t")) # this replacement makes tb format the code as code

def _validate_checkpoint(self, state_dict: Mapping, path=None) -> None:
"""
Validate the checkpoint state_dict to make sure it contains only primitive types and tensors.
This function raises ValueError if the checkpoint contains ListConfig or DictConfig.

:param state_dict: Checkpoint state_dict.
:param path: Indicates the current path of the checkpoint
(Used to print meaningful path if problematic key is detected)
"""
if isinstance(state_dict, (ListConfig, DictConfig)):
raise ValueError(
f"Checkpoint state_dict element {path} contain ListCongfig and DictConfig."
f"Only types and Tensors are supported."
f"Most likely, you forgot to convert those to container using OmegaConf.to_container()."
)
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(state_dict, typing.Mapping):
for k, v in state_dict.items():
self._validate_checkpoint(v, path=f"{path}.{k}" if path is not None else str(k))
elif isinstance(state_dict, typing.Iterable):
for index, value in enumerate(state_dict):
self._validate_checkpoint(value, path=f"{path}[{index}]" if path is not None else f"[{index}]")
elif isinstance(state_dict, torch.Tensor):
pass
2 changes: 2 additions & 0 deletions src/super_gradients/common/sg_loggers/clearml_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def upload(self):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
self._validate_checkpoint(state_dict)

name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def upload(self):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
self._validate_checkpoint(state_dict)
name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
2 changes: 2 additions & 0 deletions src/super_gradients/common/sg_loggers/wandb_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def _save_wandb_artifact(self, path):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
self._validate_checkpoint(state_dict)

name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple, List, Union

import numpy as np
from omegaconf import ListConfig
from torch.utils.data.dataloader import Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
Expand Down Expand Up @@ -32,9 +33,9 @@ def __init__(
self,
transforms: List[AbstractKeypointTransform],
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray],
edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
):
"""

Expand All @@ -50,6 +51,18 @@ def __init__(
load_sample_fn=self.load_random_sample,
)
self.num_joints = num_joints

# Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples
# This is necessary to ensure ListConfig objects do not leak to these properties
# and from there - to checkpoint's state_dict.
# Otherwise, through ListConfig instances a whole configuration file will leak to state_dict
# and torch.load will attempt to unpickle lot of unnecessary classes.
edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links]
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
if edge_colors is not None:
edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors]
if keypoint_colors is not None:
keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors]

self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from omegaconf import ListConfig
from torch.utils.data.dataloader import default_collate, Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
Expand All @@ -28,9 +29,9 @@ def __init__(
transforms: List[KeypointTransform],
min_instance_area: float,
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray],
edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
):
"""

Expand All @@ -48,6 +49,18 @@ def __init__(
self.transforms = KeypointsCompose(transforms)
self.min_instance_area = min_instance_area
self.num_joints = num_joints

# Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples
# This is necessary to ensure ListConfig objects do not leak to these properties
# and from there - to checkpoint's state_dict.
# Otherwise, through ListConfig instances a whole configuration file will leak to state_dict
# and torch.load will attempt to unpickle lot of unnecessary classes.
edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links]
if edge_colors is not None:
edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors]
if keypoint_colors is not None:
keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors]

self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)
Expand Down