Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseKeypointsDataset now inherits from HasPreprocessingParams #1380

Merged
merged 2 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/super_gradients/module_interfaces/module_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Callable, Optional
from typing import Callable, Optional, TYPE_CHECKING

from torch import nn
from typing_extensions import Protocol, runtime_checkable

from super_gradients.training.processing.processing import Processing
if TYPE_CHECKING:
# This is a hack to avoid circular imports while still having type hints.
from super_gradients.training.processing.processing import Processing


@runtime_checkable
class HasPreprocessingParams(Protocol):
class HasPreprocessingParams:
"""
Protocol interface for torch datasets that support getting preprocessing params, later to be passed to a model
that obeys NeedsPreprocessingParams. This interface class serves a purpose of explicitly indicating whether a torch dataset has
Expand All @@ -16,7 +16,7 @@ class HasPreprocessingParams(Protocol):
"""

def get_dataset_preprocessing_params(self):
...
raise NotImplementedError(f"get_dataset_preprocessing_params is not implemented in the derived class {self.__class__.__name__}")


class HasPredict:
Expand All @@ -43,12 +43,11 @@ def get_input_channels(self) -> int:
"""
raise NotImplementedError(f"get_input_channels is not implemented in the derived class {self.__class__.__name__}")

def get_processing_params(self) -> Optional[Processing]:
def get_processing_params(self) -> Optional["Processing"]:
raise NotImplementedError(f"get_processing_params is not implemented in the derived class {self.__class__.__name__}")


@runtime_checkable
class SupportsReplaceNumClasses(Protocol):
class SupportsReplaceNumClasses:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"""
Protocol interface for modules that support replacing the number of classes.
Derived classes should implement the `replace_num_classes` method.
Expand All @@ -69,4 +68,4 @@ def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable
It takes existing nn.Module and returns a new one.
:return: None
"""
...
raise NotImplementedError(f"replace_num_classes is not implemented in the derived class {self.__class__.__name__}")
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from super_gradients.common.object_names import Datasets, Processings
from super_gradients.common.registry.registry import register_dataset
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.utils.detection_utils import get_class_index_in_target
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.transforms.transforms import DetectionTransform, DetectionTargetsFormatTransform, DetectionTargetsFormat
Expand All @@ -30,7 +31,7 @@


@register_dataset(Datasets.DETECTION_DATASET)
class DetectionDataset(Dataset):
class DetectionDataset(Dataset, HasPreprocessingParams):
"""Detection dataset.
This is a boilerplate class to facilitate the implementation of datasets.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.object_names import Processings
from super_gradients.common.registry.registry import register_collate_function
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.datasets.pose_estimation_datasets.target_generators import KeypointsTargetsGenerator
from super_gradients.training.transforms.keypoint_transforms import KeypointsCompose, KeypointTransform
from super_gradients.training.utils.visualization.utils import generate_color_mapping

logger = get_logger(__name__)


class BaseKeypointsDataset(Dataset):
class BaseKeypointsDataset(Dataset, HasPreprocessingParams):
"""
Base class for pose estimation datasets.
Descendants should implement the load_sample method to read a sample from the disk and return (image, mask, joints, extras) tuple.
Expand Down Expand Up @@ -116,7 +117,7 @@ def get_dataset_preprocessing_params(self):
"""
pipeline = self.transforms.get_equivalent_preprocessing()
params = dict(
conf=0.25,
conf=0.05,
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
edge_links=self.edge_links,
edge_colors=self.edge_colors,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def get_dataset_preprocessing_params(self):
# to match with the expected input of the model.
pipeline = [Processings.ReverseImageChannels] + self.transforms.get_equivalent_preprocessing()
params = dict(
conf=0.25,
conf=0.05,
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
edge_links=self.edge_links,
edge_colors=self.edge_colors,
Expand Down
18 changes: 15 additions & 3 deletions tests/unit_tests/preprocessing_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from pathlib import Path

import numpy as np
import torch

from super_gradients import Trainer
from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.processing_factory import ProcessingFactory
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training import dataloaders
from super_gradients.training import models
from super_gradients.training.datasets import COCODetectionDataset
from super_gradients.training.metrics import DetectionMetrics
Expand All @@ -20,7 +23,6 @@
)
from super_gradients.training.transforms import DetectionPaddedRescale, DetectionRGB2BGR
from super_gradients.training.utils.detection_utils import DetectionCollateFN, CrowdDetectionCollateFN
from super_gradients.training import dataloaders


class PreprocessingUnitTest(unittest.TestCase):
Expand Down Expand Up @@ -84,10 +86,12 @@ def test_setting_preprocessing_params_from_validation_set(self):
],
}
trainset = COCODetectionDataset(**train_dataset_params)
train_loader = dataloaders.get(dataset=trainset, dataloader_params={"collate_fn": DetectionCollateFN()})
self.assertIsInstance(trainset, HasPreprocessingParams)
train_loader = dataloaders.get(dataset=trainset, dataloader_params={"collate_fn": DetectionCollateFN(), "num_workers": 0})

valset = COCODetectionDataset(**val_dataset_params)
valid_loader = dataloaders.get(dataset=valset, dataloader_params={"collate_fn": CrowdDetectionCollateFN()})
self.assertIsInstance(valset, HasPreprocessingParams)
valid_loader = dataloaders.get(dataset=valset, dataloader_params={"collate_fn": CrowdDetectionCollateFN(), "num_workers": 0})

trainer = Trainer("test_setting_preprocessing_params_from_validation_set")

Expand Down Expand Up @@ -119,6 +123,10 @@ def test_setting_preprocessing_params_from_validation_set(self):
self.assertEqual(model._default_nms_iou, 0.65)
self.assertEqual(model._default_nms_conf, 0.5)

checkpoint_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
self.assertTrue("processing_params" in checkpoint)

def test_setting_preprocessing_params_from_checkpoint(self):
model = models.get("yolox_s", num_classes=80)
self.assertTrue(model._image_processor is None)
Expand Down Expand Up @@ -187,6 +195,10 @@ def test_setting_preprocessing_params_from_checkpoint(self):
self.assertEqual(model._default_nms_iou, 0.65)
self.assertEqual(model._default_nms_conf, 0.5)

checkpoint_path = os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
self.assertTrue("processing_params" in checkpoint)

def test_processings_from_dataset_params(self):
transforms = [DetectionRGB2BGR(prob=1), DetectionPaddedRescale(input_dim=(512, 512))]

Expand Down