Skip to content

Commit

Permalink
Merge branch 'master' into featuer/SG-000-improve-notebook-checks
Browse files Browse the repository at this point in the history
# Conflicts:
#	Makefile
  • Loading branch information
BloodAxe committed Nov 6, 2023
2 parents 1202e1c + bcc026c commit e8d783f
Show file tree
Hide file tree
Showing 23 changed files with 954 additions and 44 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ sweeper_test:
NOTEBOOKS_TO_RUN := src/super_gradients/examples/model_export/models_export.ipynb
NOTEBOOKS_TO_RUN += notebooks/what_are_recipes_and_how_to_use.ipynb
NOTEBOOKS_TO_RUN += notebooks/transfer_learning_classification.ipynb
NOTEBOOKS_TO_RUN += notebooks/how_to_use_knowledge_distillation_for_classification.ipynb

# If there are additional notebooks that must not be executed, but still should be checked for version match, add them here
NOTEBOOKS_TO_CHECK := $(NOTEBOOKS_TO_RUN)
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ Knowledge Distillation is a training technique that uses a large model, teacher
Learn more about SuperGradients knowledge distillation training with our pre-trained BEiT base teacher model and Resnet18 student model on CIFAR10 example notebook on Google Colab for an easy to use tutorial using free GPU hardware
<table class="tfo-notebook-buttons" align="left">
<td width="500">
<a target="_blank" href="https://bit.ly/3BLA5oR"><img src="./documentation/assets/SG_img/colab_logo.png" /> Knowledge Distillation Training</a>
<a target="_blank" href="https://colab.research.google.com/github/Deci-AI/super-gradients/blob/master/notebooks/how_to_use_knowledge_distillation_for_classification.ipynb">
<img src="./documentation/assets/SG_img/colab_logo.png" /> Knowledge Distillation Training
</a>
</td>
</table>
</br></br>
Expand Down
612 changes: 612 additions & 0 deletions notebooks/how_to_use_knowledge_distillation_for_classification.ipynb

Large diffs are not rendered by default.

33 changes: 31 additions & 2 deletions src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import json
import os
import shutil
import signal
import time
from typing import Union, Any
Expand All @@ -9,21 +11,21 @@
import psutil
import torch
from PIL import Image
import shutil
from omegaconf import ListConfig, DictConfig, OmegaConf

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
from super_gradients.common.auto_logging.console_logging import ConsoleSink
from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
from super_gradients.common.decorators.code_save_decorator import saved_codes
from super_gradients.common.environment.checkpoints_dir_utils import is_run_dir
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.environment.monitoring import SystemMonitor
from super_gradients.common.registry.registry import register_sg_logger
from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
from super_gradients.common.sg_loggers.time_units import TimeUnit
from super_gradients.training.params import TrainingParams
from super_gradients.training.utils import sg_trainer_utils, get_param
from super_gradients.common.environment.checkpoints_dir_utils import is_run_dir

logger = get_logger(__name__)

Expand Down Expand Up @@ -312,6 +314,7 @@ def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = None) ->
name += ".pth"
path = os.path.join(self._local_dir, name)

state_dict = self._sanitize_checkpoint(state_dict)
self._save_checkpoint(path=path, state_dict=state_dict)

@multi_process_safe
Expand Down Expand Up @@ -348,3 +351,29 @@ def _save_code(self):
self.add_file(name)
code = "\t" + code
self.add_text(name, code.replace("\n", " \n \t")) # this replacement makes tb format the code as code

def _sanitize_checkpoint(self, state_dict: dict) -> dict:
"""
Sanitize state dictionary to be saved in a checkpoint. Iterates recursively over the state_dict and converts
all instances of ListConfig and DictConfig to their native python counterparts.
:param state_dict: Checkpoint state_dict.
:return: Sanitized checkpoint state_dict.
"""
if isinstance(state_dict, (ListConfig, DictConfig)):
state_dict = OmegaConf.to_container(state_dict, resolve=True)

if isinstance(state_dict, torch.Tensor):
pass
elif isinstance(state_dict, collections.OrderedDict):
state_dict = collections.OrderedDict((k, self._sanitize_checkpoint(v)) for k, v in state_dict.items())
elif isinstance(state_dict, dict):
state_dict = dict((k, self._sanitize_checkpoint(v)) for k, v in state_dict.items())
elif isinstance(state_dict, list):
state_dict = [self._sanitize_checkpoint(v) for v in state_dict]
elif isinstance(state_dict, tuple):
state_dict = tuple(self._sanitize_checkpoint(v) for v in state_dict)
else:
pass

return state_dict
2 changes: 2 additions & 0 deletions src/super_gradients/common/sg_loggers/clearml_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def upload(self):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
state_dict = self._sanitize_checkpoint(state_dict)

name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
1 change: 1 addition & 0 deletions src/super_gradients/common/sg_loggers/dagshub_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def upload(self):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
state_dict = self._sanitize_checkpoint(state_dict)
name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
1 change: 1 addition & 0 deletions src/super_gradients/common/sg_loggers/wandb_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def _save_wandb_artifact(self, path):

@multi_process_safe
def add_checkpoint(self, tag: str, state_dict: dict, global_step: int = 0):
state_dict = self._sanitize_checkpoint(state_dict)
name = f"ckpt_{global_step}.pth" if tag is None else tag
if not name.endswith(".pth"):
name += ".pth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ val_dataset_params:
root: /data/cifar100
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ val_dataset_params:
root: ./data/cifar10
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/recipes/imagenet_resnet50_kd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ training_hyperparams:
criterion_params:
distillation_loss_coeff: 0.8
task_loss_fn:
_target_: super_gradients.training.losses.label_smoothing_cross_entropy_loss.LabelSmoothingCrossEntropyLoss
_target_: super_gradients.training.losses.label_smoothing_cross_entropy_loss.CrossEntropyLoss

arch_params:
teacher_input_adapter:
Expand Down
3 changes: 0 additions & 3 deletions src/super_gradients/training/dataloaders/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
CoCoSegmentationDataSet,
PascalVOC2012SegmentationDataSet,
PascalVOCAndAUGUnifiedDataset,
SuperviselyPersonsDataset,
MapillaryDataset,
)
from super_gradients.training.utils import get_param
Expand Down Expand Up @@ -755,7 +754,6 @@ def pascal_voc_segmentation_val(dataset_params: Dict = None, dataloader_params:
def supervisely_persons_train(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
return get_data_loader(
config_name="supervisely_persons_dataset_params",
dataset_cls=SuperviselyPersonsDataset,
train=True,
dataset_params=dataset_params,
dataloader_params=dataloader_params,
Expand All @@ -766,7 +764,6 @@ def supervisely_persons_train(dataset_params: Dict = None, dataloader_params: Di
def supervisely_persons_val(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
return get_data_loader(
config_name="supervisely_persons_dataset_params",
dataset_cls=SuperviselyPersonsDataset,
train=False,
dataset_params=dataset_params,
dataloader_params=dataloader_params,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Optional, Callable, Union
from typing import Optional, Callable, Union, Dict

from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import Compose

from super_gradients.common.object_names import Datasets
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.datasets.classification_datasets.torchvision_utils import get_torchvision_transforms_equivalent_processing


@register_dataset(Datasets.CIFAR_10)
class Cifar10(CIFAR10):
class Cifar10(CIFAR10, HasPreprocessingParams):
"""
CIFAR10 Dataset
Expand Down Expand Up @@ -43,9 +45,23 @@ def __init__(
download=download,
)

def get_dataset_preprocessing_params(self) -> Dict:
"""
Get the preprocessing params for the dataset.
It infers preprocessing params from transforms used in the dataset & class names
:return: (dict) Preprocessing params
"""

pipeline = get_torchvision_transforms_equivalent_processing(self.transforms)
params = dict(
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
class_names=self.classes,
)
return params


@register_dataset(Datasets.CIFAR_100)
class Cifar100(CIFAR100):
class Cifar100(CIFAR100, HasPreprocessingParams):
@resolve_param("transforms", TransformsFactory())
def __init__(
self,
Expand Down Expand Up @@ -76,3 +92,17 @@ def __init__(
target_transform=target_transform,
download=download,
)

def get_dataset_preprocessing_params(self) -> Dict:
"""
Get the preprocessing params for the dataset.
It infers preprocessing params from transforms used in the dataset & class names
:return: (dict) Preprocessing params
"""

pipeline = get_torchvision_transforms_equivalent_processing(self.transforms)
params = dict(
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
class_names=self.classes,
)
return params
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Union
from typing import Union, Dict

import torchvision.datasets as torch_datasets
from torchvision.transforms import Compose

from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.object_names import Datasets
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.datasets.classification_datasets.torchvision_utils import get_torchvision_transforms_equivalent_processing


@register_dataset(Datasets.IMAGENET_DATASET)
class ImageNetDataset(torch_datasets.ImageFolder):
class ImageNetDataset(torch_datasets.ImageFolder, HasPreprocessingParams):
"""ImageNetDataset dataset.
To use this Dataset you need to:
Expand Down Expand Up @@ -41,3 +43,17 @@ def __init__(self, root: str, transforms: Union[list, dict] = [], *args, **kwarg
if isinstance(transforms, list):
transforms = Compose(transforms)
super(ImageNetDataset, self).__init__(root, transform=transforms, *args, **kwargs)

def get_dataset_preprocessing_params(self) -> Dict:
"""
Get the preprocessing params for the dataset.
It infers preprocessing params from transforms used in the dataset & class names
:return: (dict) Preprocessing params
"""

pipeline = get_torchvision_transforms_equivalent_processing(self.transforms)
params = dict(
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
class_names=self.classes,
)
return params
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List, Any, Dict

from torchvision.datasets.vision import StandardTransform
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop, Compose

from super_gradients.common.object_names import Processings


def get_torchvision_transforms_equivalent_processing(transforms: List[Any]) -> List[Dict[str, Any]]:
"""
Get the equivalent processing pipeline for torchvision transforms.
:return: List of Processings operations
"""
# Since we are using cv2.imread to read images, our model in fact is trained on BGR images.
# In our pipelines the convention that input images are RGB, so we need to reverse the channels to get BGR
# to match with the expected input of the model.
pipeline = []

if isinstance(transforms, StandardTransform):
transforms = transforms.transform

if isinstance(transforms, Compose):
transforms = transforms.transforms

for transform in transforms:
if isinstance(transform, ToTensor):
pipeline.append({Processings.StandardizeImage: {"max_value": 255}})
elif isinstance(transform, Normalize):
pipeline.append({Processings.NormalizeImage: {"mean": tuple(map(float, transform.mean)), "std": tuple(map(float, transform.std))}})
elif isinstance(transform, Resize):
pipeline.append({Processings.Resize: {"size": int(transform.size)}})
elif isinstance(transform, CenterCrop):
pipeline.append({Processings.CenterCrop: {"size": int(transform.size)}})
else:
raise ValueError(f"Unsupported transform: {transform}")

pipeline.append({Processings.ImagePermute: {"permutation": (2, 0, 1)}})
return pipeline
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple, List, Union

import numpy as np
from omegaconf import ListConfig
from torch.utils.data.dataloader import Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
Expand Down Expand Up @@ -32,9 +33,9 @@ def __init__(
self,
transforms: List[AbstractKeypointTransform],
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray],
edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
):
"""
Expand All @@ -50,6 +51,18 @@ def __init__(
load_sample_fn=self.load_random_sample,
)
self.num_joints = num_joints

# Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples
# This is necessary to ensure ListConfig objects do not leak to these properties
# and from there - to checkpoint's state_dict.
# Otherwise, through ListConfig instances a whole configuration file will leak to state_dict
# and torch.load will attempt to unpickle lot of unnecessary classes.
edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links]
if edge_colors is not None:
edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors]
if keypoint_colors is not None:
keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors]

self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from omegaconf import ListConfig
from torch.utils.data.dataloader import default_collate, Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
Expand All @@ -28,9 +29,9 @@ def __init__(
transforms: List[KeypointTransform],
min_instance_area: float,
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
edge_links: Union[ListConfig, List[Tuple[int, int]], np.ndarray],
edge_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[ListConfig, List[Tuple[int, int, int]], np.ndarray, None],
):
"""
Expand All @@ -48,6 +49,18 @@ def __init__(
self.transforms = KeypointsCompose(transforms)
self.min_instance_area = min_instance_area
self.num_joints = num_joints

# Explicitly convert edge_links, keypoint_colors and edge_colors to lists of tuples
# This is necessary to ensure ListConfig objects do not leak to these properties
# and from there - to checkpoint's state_dict.
# Otherwise, through ListConfig instances a whole configuration file will leak to state_dict
# and torch.load will attempt to unpickle lot of unnecessary classes.
edge_links = [(int(from_idx), int(to_idx)) for from_idx, to_idx in edge_links]
if edge_colors is not None:
edge_colors = [(int(r), int(g), int(b)) for r, g, b in edge_colors]
if keypoint_colors is not None:
keypoint_colors = [(int(r), int(g), int(b)) for r, g, b in keypoint_colors]

self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)
Expand Down
Loading

0 comments on commit e8d783f

Please sign in to comment.