Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update KD Notebook for classification #1595

Merged
merged 15 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ sweeper_test:

# Here you define a list of notebooks we want to execute and convert to markdown files
# NOTEBOOKS = hellomake.ipynb hellofunc.ipynb helloclass.ipynb
NOTEBOOKS = src/super_gradients/examples/model_export/models_export.ipynb notebooks/what_are_recipes_and_how_to_use.ipynb
NOTEBOOKS = src/super_gradients/examples/model_export/models_export.ipynb notebooks/what_are_recipes_and_how_to_use.ipynb notebooks/how_to_use_knowledge_distillation_for_classification.ipynb

# This Makefile target runs notebooks listed below and converts them to markdown files in documentation/source/
run_and_convert_notebooks_to_docs: $(NOTEBOOKS)
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ Knowledge Distillation is a training technique that uses a large model, teacher
Learn more about SuperGradients knowledge distillation training with our pre-trained BEiT base teacher model and Resnet18 student model on CIFAR10 example notebook on Google Colab for an easy to use tutorial using free GPU hardware
<table class="tfo-notebook-buttons" align="left">
<td width="500">
<a target="_blank" href="https://bit.ly/3BLA5oR"><img src="./documentation/assets/SG_img/colab_logo.png" /> Knowledge Distillation Training</a>
<a target="_blank" href="https://colab.research.google.com/github/Deci-AI/super-gradients/blob/master/notebooks/how_to_use_knowledge_distillation_for_classification.ipynb">
<img src="./documentation/assets/SG_img/colab_logo.png" /> Knowledge Distillation Training
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
</a>
</td>
</table>
</br></br>
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/recipes/imagenet_resnet50_kd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ training_hyperparams:
criterion_params:
distillation_loss_coeff: 0.8
task_loss_fn:
_target_: super_gradients.training.losses.label_smoothing_cross_entropy_loss.LabelSmoothingCrossEntropyLoss
_target_: super_gradients.training.losses.label_smoothing_cross_entropy_loss.CrossEntropyLoss

arch_params:
teacher_input_adapter:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Optional, Callable, Union
from typing import Optional, Callable, Union, Dict

from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import Compose

from super_gradients.common.object_names import Datasets
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.datasets.classification_datasets.torchvision_utils import get_torchvision_transforms_equivalent_processing


@register_dataset(Datasets.CIFAR_10)
class Cifar10(CIFAR10):
class Cifar10(CIFAR10, HasPreprocessingParams):
"""
CIFAR10 Dataset

Expand Down Expand Up @@ -43,9 +45,23 @@ def __init__(
download=download,
)

def get_dataset_preprocessing_params(self) -> Dict:
"""
Get the preprocessing params for the dataset.
It infers preprocessing params from transforms used in the dataset & class names
:return: (dict) Preprocessing params
"""

pipeline = get_torchvision_transforms_equivalent_processing(self.transforms)
params = dict(
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
class_names=self.classes,
)
return params


@register_dataset(Datasets.CIFAR_100)
class Cifar100(CIFAR100):
class Cifar100(CIFAR100, HasPreprocessingParams):
@resolve_param("transforms", TransformsFactory())
def __init__(
self,
Expand Down Expand Up @@ -76,3 +92,17 @@ def __init__(
target_transform=target_transform,
download=download,
)

def get_dataset_preprocessing_params(self) -> Dict:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the preprocessing params for the dataset.
It infers preprocessing params from transforms used in the dataset & class names
:return: (dict) Preprocessing params
"""

pipeline = get_torchvision_transforms_equivalent_processing(self.transforms)
params = dict(
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
class_names=self.classes,
)
return params
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Union
from typing import Union, Dict

import torchvision.datasets as torch_datasets
from torchvision.transforms import Compose

from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.object_names import Datasets
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.datasets.classification_datasets.torchvision_utils import get_torchvision_transforms_equivalent_processing


@register_dataset(Datasets.IMAGENET_DATASET)
class ImageNetDataset(torch_datasets.ImageFolder):
class ImageNetDataset(torch_datasets.ImageFolder, HasPreprocessingParams):
"""ImageNetDataset dataset.
To use this Dataset you need to:
Expand Down Expand Up @@ -41,3 +43,17 @@ def __init__(self, root: str, transforms: Union[list, dict] = [], *args, **kwarg
if isinstance(transforms, list):
transforms = Compose(transforms)
super(ImageNetDataset, self).__init__(root, transform=transforms, *args, **kwargs)

def get_dataset_preprocessing_params(self) -> Dict:
"""
Get the preprocessing params for the dataset.
It infers preprocessing params from transforms used in the dataset & class names
:return: (dict) Preprocessing params
"""

pipeline = get_torchvision_transforms_equivalent_processing(self.transforms)
params = dict(
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
class_names=self.classes,
)
return params
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List, Any, Dict

from torchvision.datasets.vision import StandardTransform
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop, Compose

from super_gradients.common.object_names import Processings


def get_torchvision_transforms_equivalent_processing(transforms: List[Any]) -> List[Dict[str, Any]]:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the equivalent processing pipeline for torchvision transforms.
:return: List of Processings operations
"""
# Since we are using cv2.imread to read images, our model in fact is trained on BGR images.
# In our pipelines the convention that input images are RGB, so we need to reverse the channels to get BGR
# to match with the expected input of the model.
pipeline = []

if isinstance(transforms, StandardTransform):
transforms = transforms.transform

if isinstance(transforms, Compose):
transforms = transforms.transforms

for transform in transforms:
if isinstance(transform, ToTensor):
pipeline.append({Processings.StandardizeImage: {"max_value": 255}})
elif isinstance(transform, Normalize):
pipeline.append({Processings.NormalizeImage: {"mean": tuple(map(float, transform.mean)), "std": tuple(map(float, transform.std))}})
elif isinstance(transform, Resize):
pipeline.append({Processings.Resize: {"size": int(transform.size)}})
elif isinstance(transform, CenterCrop):
pipeline.append({Processings.CenterCrop: {"size": int(transform.size)}})
else:
raise ValueError(f"Unsupported transform: {transform}")

pipeline.append({Processings.ImagePermute: {"permutation": (2, 0, 1)}})
return pipeline
20 changes: 19 additions & 1 deletion src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Dict, Mapping, Any
from typing import Union, Dict, Mapping, Any, Optional

import hydra
import torch.nn
Expand All @@ -7,6 +7,7 @@

from super_gradients.common import MultiGPUMode, StrictLoad
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.module_interfaces import HasPredict, HasPreprocessingParams
from super_gradients.training import utils as core_utils, models
from super_gradients.training.dataloaders import dataloaders
from super_gradients.common.exceptions.kd_trainer_exceptions import (
Expand Down Expand Up @@ -330,3 +331,20 @@ def train(
valid_loader=valid_loader,
additional_configs_to_log=additional_configs_to_log,
)

def _get_preprocessing_from_valid_loader(self) -> Optional[dict]:
valid_loader = self.valid_loader

if isinstance(unwrap_model(self.net).student, HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams):
try:
return valid_loader.dataset.get_dataset_preprocessing_params()
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the validation dataset:\n {e}.\n Before calling"
"predict make sure to call set_dataset_processing_params."
)

def _maybe_set_preprocessing_params_for_model_from_dataset(self):
processing_params = self._get_preprocessing_from_valid_loader()
if processing_params is not None:
unwrap_model(self.net).student.set_dataset_processing_params(**processing_params)
9 changes: 6 additions & 3 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,9 +1429,7 @@ def forward(self, inputs, targets):
max_train_batches=self.max_train_batches,
)

processing_params = self._get_preprocessing_from_valid_loader()
if processing_params is not None:
unwrap_model(self.net).set_dataset_processing_params(**processing_params)
self._maybe_set_preprocessing_params_for_model_from_dataset()

try:
# HEADERS OF THE TRAINING PROGRESS
Expand Down Expand Up @@ -1578,6 +1576,11 @@ def forward(self, inputs, targets):
if not self.ddp_silent_mode:
self.sg_logger.close()

def _maybe_set_preprocessing_params_for_model_from_dataset(self):
processing_params = self._get_preprocessing_from_valid_loader()
if processing_params is not None:
unwrap_model(self.net).set_dataset_processing_params(**processing_params)

def _get_preprocessing_from_valid_loader(self) -> Optional[dict]:
valid_loader = self.valid_loader

Expand Down