Skip to content

Commit

Permalink
BaseKeypointsDataset now inherits from HasPreprocessingParams (#1380)
Browse files Browse the repository at this point in the history
* Remove Protocol inheritance and replace with Mixins

* assert isinstance -> self.assertIsInstance
  • Loading branch information
BloodAxe committed Aug 21, 2023
1 parent 15802c5 commit 8c7dc64
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 17 deletions.
19 changes: 9 additions & 10 deletions src/super_gradients/module_interfaces/module_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Callable, Optional
from typing import Callable, Optional, TYPE_CHECKING

from torch import nn
from typing_extensions import Protocol, runtime_checkable

from super_gradients.training.processing.processing import Processing
if TYPE_CHECKING:
# This is a hack to avoid circular imports while still having type hints.
from super_gradients.training.processing.processing import Processing


@runtime_checkable
class HasPreprocessingParams(Protocol):
class HasPreprocessingParams:
"""
Protocol interface for torch datasets that support getting preprocessing params, later to be passed to a model
that obeys NeedsPreprocessingParams. This interface class serves a purpose of explicitly indicating whether a torch dataset has
Expand All @@ -16,7 +16,7 @@ class HasPreprocessingParams(Protocol):
"""

def get_dataset_preprocessing_params(self):
...
raise NotImplementedError(f"get_dataset_preprocessing_params is not implemented in the derived class {self.__class__.__name__}")


class HasPredict:
Expand All @@ -43,12 +43,11 @@ def get_input_channels(self) -> int:
"""
raise NotImplementedError(f"get_input_channels is not implemented in the derived class {self.__class__.__name__}")

def get_processing_params(self) -> Optional[Processing]:
def get_processing_params(self) -> Optional["Processing"]:
raise NotImplementedError(f"get_processing_params is not implemented in the derived class {self.__class__.__name__}")


@runtime_checkable
class SupportsReplaceNumClasses(Protocol):
class SupportsReplaceNumClasses:
"""
Protocol interface for modules that support replacing the number of classes.
Derived classes should implement the `replace_num_classes` method.
Expand All @@ -69,4 +68,4 @@ def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable
It takes existing nn.Module and returns a new one.
:return: None
"""
...
raise NotImplementedError(f"replace_num_classes is not implemented in the derived class {self.__class__.__name__}")
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.utils.detection_utils import get_class_index_in_target
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.transforms.transforms import DetectionTransform, DetectionTargetsFormatTransform, DetectionTargetsFormat
Expand All @@ -30,7 +31,7 @@


@register_dataset(Datasets.DETECTION_DATASET)
class DetectionDataset(Dataset):
class DetectionDataset(Dataset, HasPreprocessingParams):
"""Detection dataset.
This is a boilerplate class to facilitate the implementation of datasets.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.object_names import Processings
from super_gradients.common.registry.registry import register_collate_function
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.datasets.pose_estimation_datasets.target_generators import KeypointsTargetsGenerator
from super_gradients.training.transforms.keypoint_transforms import KeypointsCompose, KeypointTransform
from super_gradients.training.utils.visualization.utils import generate_color_mapping

logger = get_logger(__name__)


class BaseKeypointsDataset(Dataset):
class BaseKeypointsDataset(Dataset, HasPreprocessingParams):
"""
Base class for pose estimation datasets.
Descendants should implement the load_sample method to read a sample from the disk and return (image, mask, joints, extras) tuple.
Expand Down Expand Up @@ -116,7 +117,7 @@ def get_dataset_preprocessing_params(self):
"""
pipeline = self.transforms.get_equivalent_preprocessing()
params = dict(
conf=0.25,
conf=0.05,
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
edge_links=self.edge_links,
edge_colors=self.edge_colors,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def get_dataset_preprocessing_params(self):
# to match with the expected input of the model.
pipeline = [Processings.ReverseImageChannels] + self.transforms.get_equivalent_preprocessing()
params = dict(
conf=0.25,
conf=0.05,
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
edge_links=self.edge_links,
edge_colors=self.edge_colors,
Expand Down
18 changes: 15 additions & 3 deletions tests/unit_tests/preprocessing_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from pathlib import Path

import numpy as np
import torch

from super_gradients import Trainer
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.processing_factory import ProcessingFactory
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training import dataloaders
from super_gradients.training import models
from super_gradients.training.datasets import COCODetectionDataset
from super_gradients.training.metrics import DetectionMetrics
Expand All @@ -20,7 +23,6 @@
)
from super_gradients.training.transforms import DetectionPaddedRescale, DetectionRGB2BGR
from super_gradients.training.utils.detection_utils import DetectionCollateFN, CrowdDetectionCollateFN
from super_gradients.training import dataloaders


class PreprocessingUnitTest(unittest.TestCase):
Expand Down Expand Up @@ -84,10 +86,12 @@ def test_setting_preprocessing_params_from_validation_set(self):
],
}
trainset = COCODetectionDataset(**train_dataset_params)
train_loader = dataloaders.get(dataset=trainset, dataloader_params={"collate_fn": DetectionCollateFN()})
self.assertIsInstance(trainset, HasPreprocessingParams)
train_loader = dataloaders.get(dataset=trainset, dataloader_params={"collate_fn": DetectionCollateFN(), "num_workers": 0})

valset = COCODetectionDataset(**val_dataset_params)
valid_loader = dataloaders.get(dataset=valset, dataloader_params={"collate_fn": CrowdDetectionCollateFN()})
self.assertIsInstance(valset, HasPreprocessingParams)
valid_loader = dataloaders.get(dataset=valset, dataloader_params={"collate_fn": CrowdDetectionCollateFN(), "num_workers": 0})

trainer = Trainer("test_setting_preprocessing_params_from_validation_set")

Expand Down Expand Up @@ -119,6 +123,10 @@ def test_setting_preprocessing_params_from_validation_set(self):
self.assertEqual(model._default_nms_iou, 0.65)
self.assertEqual(model._default_nms_conf, 0.5)

checkpoint_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
self.assertTrue("processing_params" in checkpoint)

def test_setting_preprocessing_params_from_checkpoint(self):
model = models.get("yolox_s", num_classes=80)
self.assertTrue(model._image_processor is None)
Expand Down Expand Up @@ -187,6 +195,10 @@ def test_setting_preprocessing_params_from_checkpoint(self):
self.assertEqual(model._default_nms_iou, 0.65)
self.assertEqual(model._default_nms_conf, 0.5)

checkpoint_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
self.assertTrue("processing_params" in checkpoint)

def test_processings_from_dataset_params(self):
transforms = [DetectionRGB2BGR(prob=1), DetectionPaddedRescale(input_dim=(512, 512))]

Expand Down

0 comments on commit 8c7dc64

Please sign in to comment.