diff --git a/src/super_gradients/module_interfaces/module_interfaces.py b/src/super_gradients/module_interfaces/module_interfaces.py index 6b9544dab4..a2d00ab1ba 100644 --- a/src/super_gradients/module_interfaces/module_interfaces.py +++ b/src/super_gradients/module_interfaces/module_interfaces.py @@ -1,13 +1,13 @@ -from typing import Callable, Optional +from typing import Callable, Optional, TYPE_CHECKING from torch import nn -from typing_extensions import Protocol, runtime_checkable -from super_gradients.training.processing.processing import Processing +if TYPE_CHECKING: + # This is a hack to avoid circular imports while still having type hints. + from super_gradients.training.processing.processing import Processing -@runtime_checkable -class HasPreprocessingParams(Protocol): +class HasPreprocessingParams: """ Protocol interface for torch datasets that support getting preprocessing params, later to be passed to a model that obeys NeedsPreprocessingParams. This interface class serves a purpose of explicitly indicating whether a torch dataset has @@ -16,7 +16,7 @@ class HasPreprocessingParams(Protocol): """ def get_dataset_preprocessing_params(self): - ... + raise NotImplementedError(f"get_dataset_preprocessing_params is not implemented in the derived class {self.__class__.__name__}") class HasPredict: @@ -43,12 +43,11 @@ def get_input_channels(self) -> int: """ raise NotImplementedError(f"get_input_channels is not implemented in the derived class {self.__class__.__name__}") - def get_processing_params(self) -> Optional[Processing]: + def get_processing_params(self) -> Optional["Processing"]: raise NotImplementedError(f"get_processing_params is not implemented in the derived class {self.__class__.__name__}") -@runtime_checkable -class SupportsReplaceNumClasses(Protocol): +class SupportsReplaceNumClasses: """ Protocol interface for modules that support replacing the number of classes. Derived classes should implement the `replace_num_classes` method. @@ -69,4 +68,4 @@ def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable It takes existing nn.Module and returns a new one. :return: None """ - ... + raise NotImplementedError(f"replace_num_classes is not implemented in the derived class {self.__class__.__name__}") diff --git a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py index 75f2764b50..352a34247b 100644 --- a/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py +++ b/src/super_gradients/training/datasets/detection_datasets/detection_dataset.py @@ -16,6 +16,7 @@ from super_gradients.common.object_names import Datasets, Processings from super_gradients.common.registry.registry import register_dataset from super_gradients.common.decorators.factory_decorator import resolve_param +from super_gradients.module_interfaces import HasPreprocessingParams from super_gradients.training.utils.detection_utils import get_class_index_in_target from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.training.transforms.transforms import DetectionTransform, DetectionTargetsFormatTransform, DetectionTargetsFormat @@ -30,7 +31,7 @@ @register_dataset(Datasets.DETECTION_DATASET) -class DetectionDataset(Dataset): +class DetectionDataset(Dataset, HasPreprocessingParams): """Detection dataset. This is a boilerplate class to facilitate the implementation of datasets. diff --git a/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py b/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py index 062f0b57d4..8054641428 100644 --- a/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py +++ b/src/super_gradients/training/datasets/pose_estimation_datasets/base_keypoints.py @@ -8,6 +8,7 @@ from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.object_names import Processings from super_gradients.common.registry.registry import register_collate_function +from super_gradients.module_interfaces import HasPreprocessingParams from super_gradients.training.datasets.pose_estimation_datasets.target_generators import KeypointsTargetsGenerator from super_gradients.training.transforms.keypoint_transforms import KeypointsCompose, KeypointTransform from super_gradients.training.utils.visualization.utils import generate_color_mapping @@ -15,7 +16,7 @@ logger = get_logger(__name__) -class BaseKeypointsDataset(Dataset): +class BaseKeypointsDataset(Dataset, HasPreprocessingParams): """ Base class for pose estimation datasets. Descendants should implement the load_sample method to read a sample from the disk and return (image, mask, joints, extras) tuple. @@ -116,7 +117,7 @@ def get_dataset_preprocessing_params(self): """ pipeline = self.transforms.get_equivalent_preprocessing() params = dict( - conf=0.25, + conf=0.05, image_processor={Processings.ComposeProcessing: {"processings": pipeline}}, edge_links=self.edge_links, edge_colors=self.edge_colors, diff --git a/src/super_gradients/training/datasets/pose_estimation_datasets/coco_keypoints.py b/src/super_gradients/training/datasets/pose_estimation_datasets/coco_keypoints.py index 77ee826dc8..7f71fe07af 100644 --- a/src/super_gradients/training/datasets/pose_estimation_datasets/coco_keypoints.py +++ b/src/super_gradients/training/datasets/pose_estimation_datasets/coco_keypoints.py @@ -216,7 +216,7 @@ def get_dataset_preprocessing_params(self): # to match with the expected input of the model. pipeline = [Processings.ReverseImageChannels] + self.transforms.get_equivalent_preprocessing() params = dict( - conf=0.25, + conf=0.05, image_processor={Processings.ComposeProcessing: {"processings": pipeline}}, edge_links=self.edge_links, edge_colors=self.edge_colors, diff --git a/tests/unit_tests/preprocessing_unit_test.py b/tests/unit_tests/preprocessing_unit_test.py index 0360c761b5..9416f309ba 100644 --- a/tests/unit_tests/preprocessing_unit_test.py +++ b/tests/unit_tests/preprocessing_unit_test.py @@ -3,10 +3,13 @@ from pathlib import Path import numpy as np +import torch from super_gradients import Trainer from super_gradients.common.factories.list_factory import ListFactory from super_gradients.common.factories.processing_factory import ProcessingFactory +from super_gradients.module_interfaces import HasPreprocessingParams +from super_gradients.training import dataloaders from super_gradients.training import models from super_gradients.training.datasets import COCODetectionDataset from super_gradients.training.metrics import DetectionMetrics @@ -20,7 +23,6 @@ ) from super_gradients.training.transforms import DetectionPaddedRescale, DetectionRGB2BGR from super_gradients.training.utils.detection_utils import DetectionCollateFN, CrowdDetectionCollateFN -from super_gradients.training import dataloaders class PreprocessingUnitTest(unittest.TestCase): @@ -84,10 +86,12 @@ def test_setting_preprocessing_params_from_validation_set(self): ], } trainset = COCODetectionDataset(**train_dataset_params) - train_loader = dataloaders.get(dataset=trainset, dataloader_params={"collate_fn": DetectionCollateFN()}) + self.assertIsInstance(trainset, HasPreprocessingParams) + train_loader = dataloaders.get(dataset=trainset, dataloader_params={"collate_fn": DetectionCollateFN(), "num_workers": 0}) valset = COCODetectionDataset(**val_dataset_params) - valid_loader = dataloaders.get(dataset=valset, dataloader_params={"collate_fn": CrowdDetectionCollateFN()}) + self.assertIsInstance(valset, HasPreprocessingParams) + valid_loader = dataloaders.get(dataset=valset, dataloader_params={"collate_fn": CrowdDetectionCollateFN(), "num_workers": 0}) trainer = Trainer("test_setting_preprocessing_params_from_validation_set") @@ -119,6 +123,10 @@ def test_setting_preprocessing_params_from_validation_set(self): self.assertEqual(model._default_nms_iou, 0.65) self.assertEqual(model._default_nms_conf, 0.5) + checkpoint_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth") + checkpoint = torch.load(checkpoint_path, map_location="cpu") + self.assertTrue("processing_params" in checkpoint) + def test_setting_preprocessing_params_from_checkpoint(self): model = models.get("yolox_s", num_classes=80) self.assertTrue(model._image_processor is None) @@ -187,6 +195,10 @@ def test_setting_preprocessing_params_from_checkpoint(self): self.assertEqual(model._default_nms_iou, 0.65) self.assertEqual(model._default_nms_conf, 0.5) + checkpoint_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth") + checkpoint = torch.load(checkpoint_path, map_location="cpu") + self.assertTrue("processing_params" in checkpoint) + def test_processings_from_dataset_params(self): transforms = [DetectionRGB2BGR(prob=1), DetectionPaddedRescale(input_dim=(512, 512))]