Skip to content

Commit

Permalink
Update KD Notebook for classification (#1595)
Browse files Browse the repository at this point in the history
* Update what_are_recipes_and_how_to_use notebook

* Improve notebook version check script and added what_are_recipes_and_how_to_use to the list of checked notebooks

* Move import of nbformat inside get_first_cell_content method

* Update KD notebook, make it using predict()
Fixed Cifar/Imagenet datasets that using torchvision transforms to support HasPreprocessingParams

* Added Resize op to dataset params

* Added notebook

* Added test_get_torchvision_transforms_equivalent_processing & Fixed Resize implementation to match the implementation from Torchvision

* Update notebook
  • Loading branch information
BloodAxe committed Nov 5, 2023
1 parent 24102a8 commit 885e2f0
Show file tree
Hide file tree
Showing 14 changed files with 793 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ sweeper_test:

# Here you define a list of notebooks we want to execute and convert to markdown files
# NOTEBOOKS = hellomake.ipynb hellofunc.ipynb helloclass.ipynb
NOTEBOOKS = src/super_gradients/examples/model_export/models_export.ipynb notebooks/what_are_recipes_and_how_to_use.ipynb notebooks/transfer_learning_classification.ipynb
NOTEBOOKS = src/super_gradients/examples/model_export/models_export.ipynb notebooks/what_are_recipes_and_how_to_use.ipynb notebooks/transfer_learning_classification.ipynb notebooks/how_to_use_knowledge_distillation_for_classification.ipynb

# This Makefile target runs notebooks listed below and converts them to markdown files in documentation/source/
run_and_convert_notebooks_to_docs: $(NOTEBOOKS)
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ Knowledge Distillation is a training technique that uses a large model, teacher
Learn more about SuperGradients knowledge distillation training with our pre-trained BEiT base teacher model and Resnet18 student model on CIFAR10 example notebook on Google Colab for an easy to use tutorial using free GPU hardware
<table class="tfo-notebook-buttons" align="left">
<td width="500">
<a target="_blank" href="https://bit.ly/3BLA5oR"><img src="./documentation/assets/SG_img/colab_logo.png" /> Knowledge Distillation Training</a>
<a target="_blank" href="https://colab.research.google.com/github/Deci-AI/super-gradients/blob/master/notebooks/how_to_use_knowledge_distillation_for_classification.ipynb">
<img src="./documentation/assets/SG_img/colab_logo.png" /> Knowledge Distillation Training
</a>
</td>
</table>
</br></br>
Expand Down
612 changes: 612 additions & 0 deletions notebooks/how_to_use_knowledge_distillation_for_classification.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ val_dataset_params:
root: /data/cifar100
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ val_dataset_params:
root: ./data/cifar10
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/recipes/imagenet_resnet50_kd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ training_hyperparams:
criterion_params:
distillation_loss_coeff: 0.8
task_loss_fn:
_target_: super_gradients.training.losses.label_smoothing_cross_entropy_loss.LabelSmoothingCrossEntropyLoss
_target_: super_gradients.training.losses.label_smoothing_cross_entropy_loss.CrossEntropyLoss

arch_params:
teacher_input_adapter:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Optional, Callable, Union
from typing import Optional, Callable, Union, Dict

from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import Compose

from super_gradients.common.object_names import Datasets
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.datasets.classification_datasets.torchvision_utils import get_torchvision_transforms_equivalent_processing


@register_dataset(Datasets.CIFAR_10)
class Cifar10(CIFAR10):
class Cifar10(CIFAR10, HasPreprocessingParams):
"""
CIFAR10 Dataset
Expand Down Expand Up @@ -43,9 +45,23 @@ def __init__(
download=download,
)

def get_dataset_preprocessing_params(self) -> Dict:
"""
Get the preprocessing params for the dataset.
It infers preprocessing params from transforms used in the dataset & class names
:return: (dict) Preprocessing params
"""

pipeline = get_torchvision_transforms_equivalent_processing(self.transforms)
params = dict(
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
class_names=self.classes,
)
return params


@register_dataset(Datasets.CIFAR_100)
class Cifar100(CIFAR100):
class Cifar100(CIFAR100, HasPreprocessingParams):
@resolve_param("transforms", TransformsFactory())
def __init__(
self,
Expand Down Expand Up @@ -76,3 +92,17 @@ def __init__(
target_transform=target_transform,
download=download,
)

def get_dataset_preprocessing_params(self) -> Dict:
"""
Get the preprocessing params for the dataset.
It infers preprocessing params from transforms used in the dataset & class names
:return: (dict) Preprocessing params
"""

pipeline = get_torchvision_transforms_equivalent_processing(self.transforms)
params = dict(
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
class_names=self.classes,
)
return params
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Union
from typing import Union, Dict

import torchvision.datasets as torch_datasets
from torchvision.transforms import Compose

from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.object_names import Datasets
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.datasets.classification_datasets.torchvision_utils import get_torchvision_transforms_equivalent_processing


@register_dataset(Datasets.IMAGENET_DATASET)
class ImageNetDataset(torch_datasets.ImageFolder):
class ImageNetDataset(torch_datasets.ImageFolder, HasPreprocessingParams):
"""ImageNetDataset dataset.
To use this Dataset you need to:
Expand Down Expand Up @@ -41,3 +43,17 @@ def __init__(self, root: str, transforms: Union[list, dict] = [], *args, **kwarg
if isinstance(transforms, list):
transforms = Compose(transforms)
super(ImageNetDataset, self).__init__(root, transform=transforms, *args, **kwargs)

def get_dataset_preprocessing_params(self) -> Dict:
"""
Get the preprocessing params for the dataset.
It infers preprocessing params from transforms used in the dataset & class names
:return: (dict) Preprocessing params
"""

pipeline = get_torchvision_transforms_equivalent_processing(self.transforms)
params = dict(
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
class_names=self.classes,
)
return params
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List, Any, Dict

from torchvision.datasets.vision import StandardTransform
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop, Compose

from super_gradients.common.object_names import Processings


def get_torchvision_transforms_equivalent_processing(transforms: List[Any]) -> List[Dict[str, Any]]:
"""
Get the equivalent processing pipeline for torchvision transforms.
:return: List of Processings operations
"""
# Since we are using cv2.imread to read images, our model in fact is trained on BGR images.
# In our pipelines the convention that input images are RGB, so we need to reverse the channels to get BGR
# to match with the expected input of the model.
pipeline = []

if isinstance(transforms, StandardTransform):
transforms = transforms.transform

if isinstance(transforms, Compose):
transforms = transforms.transforms

for transform in transforms:
if isinstance(transform, ToTensor):
pipeline.append({Processings.StandardizeImage: {"max_value": 255}})
elif isinstance(transform, Normalize):
pipeline.append({Processings.NormalizeImage: {"mean": tuple(map(float, transform.mean)), "std": tuple(map(float, transform.std))}})
elif isinstance(transform, Resize):
pipeline.append({Processings.Resize: {"size": int(transform.size)}})
elif isinstance(transform, CenterCrop):
pipeline.append({Processings.CenterCrop: {"size": int(transform.size)}})
else:
raise ValueError(f"Unsupported transform: {transform}")

pipeline.append({Processings.ImagePermute: {"permutation": (2, 0, 1)}})
return pipeline
20 changes: 19 additions & 1 deletion src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Dict, Mapping, Any
from typing import Union, Dict, Mapping, Any, Optional

import hydra
import torch.nn
Expand All @@ -7,6 +7,7 @@

from super_gradients.common import MultiGPUMode, StrictLoad
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.module_interfaces import HasPredict, HasPreprocessingParams
from super_gradients.training import utils as core_utils, models
from super_gradients.training.dataloaders import dataloaders
from super_gradients.common.exceptions.kd_trainer_exceptions import (
Expand Down Expand Up @@ -330,3 +331,20 @@ def train(
valid_loader=valid_loader,
additional_configs_to_log=additional_configs_to_log,
)

def _get_preprocessing_from_valid_loader(self) -> Optional[dict]:
valid_loader = self.valid_loader

if isinstance(unwrap_model(self.net).student, HasPredict) and isinstance(valid_loader.dataset, HasPreprocessingParams):
try:
return valid_loader.dataset.get_dataset_preprocessing_params()
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the validation dataset:\n {e}.\n Before calling"
"predict make sure to call set_dataset_processing_params."
)

def _maybe_set_preprocessing_params_for_model_from_dataset(self):
processing_params = self._get_preprocessing_from_valid_loader()
if processing_params is not None:
unwrap_model(self.net).student.set_dataset_processing_params(**processing_params)
25 changes: 12 additions & 13 deletions src/super_gradients/training/processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Tuple, List, Union, Optional

import numpy as np
from PIL import Image
from torch import nn

from super_gradients.common.object_names import Processings
Expand All @@ -19,6 +18,7 @@
PaddingCoordinates,
_rescale_keypoints,
_shift_keypoints,
_rescale_image_with_pil,
)
from super_gradients.training.utils.predict import Prediction, DetectionPrediction, PoseEstimationPrediction

Expand Down Expand Up @@ -451,22 +451,21 @@ def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
:param image: Image, in (H, W, C) format.
:return: The resized image.
"""
image = Image.fromarray(image)
resized_image = image.resize((self.size, self.size))
resized_image = np.array(resized_image)
height, width = image.shape[:2]
output_shape = self.size, self.size
scale_factor = max(output_shape[0] / height, output_shape[1] / width)

return resized_image, None
if scale_factor != 1.0:
new_height, new_width = int(height * scale_factor), int(width * scale_factor)
image = _rescale_image_with_pil(image, target_shape=(new_height, new_width))

def get_equivalent_photometric_module(self) -> Optional[nn.Module]:
return None
return image, RescaleMetadata(original_shape=(height, width), scale_factor_h=scale_factor, scale_factor_w=scale_factor)

def infer_image_input_shape(self) -> Optional[Tuple[int, int]]:
"""
Infer the output image shape from the processing.
def get_equivalent_photometric_module(self) -> None:
return None

:return: (rows, cols) Returns the last known output shape for all the processings.
"""
return (self.size, self.size)
def infer_image_input_shape(self) -> None:
return None


@register_processing(Processings.CenterCrop)
Expand Down
9 changes: 6 additions & 3 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,9 +1429,7 @@ def forward(self, inputs, targets):
max_train_batches=self.max_train_batches,
)

processing_params = self._get_preprocessing_from_valid_loader()
if processing_params is not None:
unwrap_model(self.net).set_dataset_processing_params(**processing_params)
self._maybe_set_preprocessing_params_for_model_from_dataset()

try:
# HEADERS OF THE TRAINING PROGRESS
Expand Down Expand Up @@ -1578,6 +1576,11 @@ def forward(self, inputs, targets):
if not self.ddp_silent_mode:
self.sg_logger.close()

def _maybe_set_preprocessing_params_for_model_from_dataset(self):
processing_params = self._get_preprocessing_from_valid_loader()
if processing_params is not None:
unwrap_model(self.net).set_dataset_processing_params(**processing_params)

def _get_preprocessing_from_valid_loader(self) -> Optional[dict]:
valid_loader = self.valid_loader

Expand Down
17 changes: 17 additions & 0 deletions src/super_gradients/training/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ def _rescale_image(image: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarr
return cv2.resize(image, dsize=(width, height), interpolation=cv2.INTER_LINEAR)


def _rescale_image_with_pil(image: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray:
"""Rescale image to target_shape, without preserving aspect ratio using PIL.
OpenCV and PIL has slightly different implementations of interpolation methods.
OpenCV has faster resizing, however PIL is more accurate (not introducing aliasing artifacts).
We use this method in some preprocessing transforms where we want to keep the compatibility with
torchvision transforms.
:param image: Image to rescale. (H, W, C) or (H, W).
:param target_shape: Target shape to rescale to (H, W).
:return: Rescaled image.
"""
height, width = target_shape[:2]
from PIL import Image

return np.array(Image.fromarray(image).resize((width, height), Image.BILINEAR))


def _rescale_bboxes(targets: np.ndarray, scale_factors: Tuple[float, float]) -> np.ndarray:
"""Rescale bboxes to given scale factors, without preserving aspect ratio.
Expand Down
26 changes: 26 additions & 0 deletions tests/unit_tests/preprocessing_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
import torchvision as tv

from super_gradients import Trainer
from super_gradients.common.factories.list_factory import ListFactory
Expand All @@ -12,6 +13,7 @@
from super_gradients.training import dataloaders
from super_gradients.training import models
from super_gradients.training.datasets import COCODetectionDataset
from super_gradients.training.datasets.classification_datasets.torchvision_utils import get_torchvision_transforms_equivalent_processing
from super_gradients.training.metrics import DetectionMetrics
from super_gradients.training.models import YoloXPostPredictionCallback
from super_gradients.training.processing import (
Expand Down Expand Up @@ -211,6 +213,30 @@ def test_processings_from_dataset_params(self):
result = processing_pipeline.preprocess_image(np.zeros((480, 640, 3)))
print(result)

def test_get_torchvision_transforms_equivalent_processing(self):
from PIL import Image

tv_transforms = tv.transforms.Compose(
[
tv.transforms.Resize(512),
tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
]
)

input = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)

expected_output = tv_transforms(Image.fromarray(input)).numpy()

processing = get_torchvision_transforms_equivalent_processing(tv_transforms)

instantiated_processing = ListFactory(ProcessingFactory()).get(processing)
processing_pipeline = ComposeProcessing(instantiated_processing)
actual_output = processing_pipeline.preprocess_image(input)[0]

self.assertEqual(actual_output.shape, expected_output.shape)
np.testing.assert_allclose(actual_output, expected_output, atol=1e-5)


if __name__ == "__main__":
unittest.main()

0 comments on commit 885e2f0

Please sign in to comment.