Skip to content

Commit

Permalink
Remove __target__ for Detection CollateFN (#1470)
Browse files Browse the repository at this point in the history
* remove __target__

* undo

* fix

* add __init__ __all__

* fix __call__ in CrowdDetectionCollateFN + use main import

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
Louis-Dupont and BloodAxe committed Oct 3, 2023
1 parent 98f3226 commit 94c0f9e
Show file tree
Hide file tree
Showing 20 changed files with 289 additions and 188 deletions.
6 changes: 2 additions & 4 deletions documentation/source/ObjectDetection.md
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,7 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: ${dataset_params.root_dir}
Expand All @@ -521,8 +520,7 @@ val_dataloader_params:
num_workers: 8
drop_last: True
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN
```

In your training recipe add/change the following lines to:
Expand Down
43 changes: 43 additions & 0 deletions src/super_gradients/common/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from .dataset_exceptions import (
EmptyDatasetException,
DatasetItemsException,
DatasetValidationException,
IllegalDatasetParameterException,
ParameterMismatchException,
UnsupportedBatchItemsFormat,
)
from .loss_exceptions import RequiredLossComponentReductionException, IllegalRangeForLossAttributeException
from .factory_exceptions import UnknownTypeException
from .kd_trainer_exceptions import (
KDModelException,
UnsupportedKDModelArgException,
UnsupportedKDArchitectureException,
ArchitectureKwargsException,
InconsistentParamsException,
TeacherKnowledgeException,
UndefinedNumClassesException,
)
from .sg_trainer_exceptions import IllegalDataloaderInitialization, UnsupportedOptimizerFormat, UnsupportedTrainingParameterFormat, GPUModeNotSetupError

__all__ = [
"EmptyDatasetException",
"DatasetItemsException",
"DatasetValidationException",
"IllegalDatasetParameterException",
"ParameterMismatchException",
"UnsupportedBatchItemsFormat",
"RequiredLossComponentReductionException",
"IllegalRangeForLossAttributeException",
"UnknownTypeException",
"KDModelException",
"UnsupportedKDModelArgException",
"UnsupportedKDArchitectureException",
"ArchitectureKwargsException",
"InconsistentParamsException",
"TeacherKnowledgeException",
"UndefinedNumClassesException",
"IllegalDataloaderInitialization",
"UnsupportedOptimizerFormat",
"UnsupportedTrainingParameterFormat",
"GPUModeNotSetupError",
]
16 changes: 16 additions & 0 deletions src/super_gradients/common/exceptions/dataset_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Tuple, Type


class DatasetValidationException(Exception):
pass

Expand Down Expand Up @@ -46,3 +49,16 @@ def __init__(self, batch_items: tuple):
"To fix this, please change the implementation of your dataset __getitem__ method, so that it would return the items defined above.\n"
)
super().__init__(self.message)


class DatasetItemsException(Exception):
def __init__(self, data_sample: Tuple, collate_type: Type, expected_item_names: Tuple):
"""
:param data_sample: item(s) returned by a dataset
:param collate_type: type of the collate that caused the exception
:param expected_item_names: tuple of names of items that are expected by the collate to be returned from the dataset
"""
collate_type_name = collate_type.__name__
num_sample_items = len(data_sample) if isinstance(data_sample, tuple) else 1
error_msg = f"`{collate_type_name}` only supports Datasets that return a tuple {expected_item_names}, but got a tuple of len={num_sample_items}"
super().__init__(error_msg)
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: /data/coco # root path to coco data
Expand All @@ -80,8 +79,7 @@ val_dataloader_params:
num_workers: 8
drop_last: False
pin_memory: True
collate_fn: # collate function for valset
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
collate_fn: CrowdDetectionCollateFN


_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.PPYoloECollateFN
random_resize_sizes: [ 320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768 ]
random_resize_modes:
- 0 # cv::INTER_NEAREST
- 1 # cv::INTER_LINEAR
- 2 # cv::INTER_CUBIC
- 3 # cv::INTER_AREA
- 4 # cv::INTER_LANCZOS4
collate_fn:
PPYoloECollateFN:
random_resize_sizes: [ 320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768 ]
random_resize_modes:
- 0 # cv::INTER_NEAREST
- 1 # cv::INTER_LINEAR
- 2 # cv::INTER_CUBIC
- 3 # cv::INTER_AREA
- 4 # cv::INTER_LANCZOS4

val_dataset_params:
data_dir: /data/coco # root path to coco data
Expand Down Expand Up @@ -93,7 +93,6 @@ val_dataloader_params:
drop_last: False
shuffle: False
pin_memory: False
collate_fn: # collate function for valset
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionPPYoloECollateFN
collate_fn: CrowdDetectionPPYoloECollateFN

_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: /data/coco # root path to coco data
Expand All @@ -76,7 +75,6 @@ val_dataloader_params:
num_workers: 8
drop_last: False
pin_memory: True
collate_fn: # collate function for valset
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
collate_fn: CrowdDetectionCollateFN

_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ train_dataloader_params:
shuffle: True
drop_last: True
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: /data/coco # TO FILL: Where the data is stored.
Expand Down Expand Up @@ -88,7 +87,6 @@ val_dataloader_params:
num_workers: 8
drop_last: False
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ train_dataloader_params:
shuffle: True
drop_last: True
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: /data/coco # root path to coco data
Expand Down Expand Up @@ -86,7 +85,6 @@ val_dataloader_params:
drop_last: False
shuffle: False
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
collate_fn: CrowdDetectionCollateFN

_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataloader_params:
batch_size: 64
num_workers: 8
drop_last: False
pin_memory: True
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN


_convert_: all
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ train_dataloader_params:
worker_init_fn:
_target_: super_gradients.training.utils.utils.load_func
dotpath: super_gradients.training.datasets.datasets_utils.worker_init_reset_seed
collate_fn: # collate function for trainset
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN
collate_fn: DetectionCollateFN

val_dataset_params:
data_dir: ${..data_dir} # root path to Robflow datasets
Expand Down Expand Up @@ -94,8 +93,7 @@ val_dataloader_params:
drop_last: False
shuffle: False
pin_memory: True
collate_fn: # collate function for valset
_target_: super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN
collate_fn: CrowdDetectionCollateFN


_convert_: all
6 changes: 6 additions & 0 deletions src/super_gradients/training/utils/collate_fn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .detection_collate_fn import DetectionCollateFN
from .ppyoloe_collate_fn import PPYoloECollateFN
from .crowd_detection_collate_fn import CrowdDetectionCollateFN
from .crowd_detection_ppyoloe_collate_fn import CrowdDetectionPPYoloECollateFN

__all__ = ["DetectionCollateFN", "PPYoloECollateFN", "CrowdDetectionCollateFN", "CrowdDetectionPPYoloECollateFN"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Tuple, Dict

import torch

from super_gradients.common.registry import register_collate_function
from super_gradients.common.exceptions.dataset_exceptions import DatasetItemsException
from super_gradients.training.utils.collate_fn.detection_collate_fn import DetectionCollateFN


@register_collate_function()
class CrowdDetectionCollateFN(DetectionCollateFN):
"""
Collate function for Yolox training with additional_batch_items that includes crowd targets
"""

def __init__(self):
super().__init__()
self.expected_item_names = ("image", "targets", "crowd_targets")

def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
try:
images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
except (ValueError, TypeError):
raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)

return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Union, List, Tuple, Dict

import torch

from super_gradients.common.registry import register_collate_function
from super_gradients.common.exceptions.dataset_exceptions import DatasetItemsException
from super_gradients.training.utils.collate_fn.ppyoloe_collate_fn import PPYoloECollateFN


@register_collate_function()
class CrowdDetectionPPYoloECollateFN(PPYoloECollateFN):
"""
Collate function for Yolox training with additional_batch_items that includes crowd targets
"""

def __init__(self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None):
super().__init__(random_resize_sizes, random_resize_modes)
self.expected_item_names = ("image", "targets", "crowd_targets")

def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:

if self.random_resize_sizes is not None:
data = self.random_resize(data)

try:
images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
except (ValueError, TypeError):
raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)

return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Tuple, List, Union

import numpy as np
import torch

from super_gradients.common.registry import register_collate_function
from super_gradients.common.exceptions.dataset_exceptions import DatasetItemsException


@register_collate_function()
class DetectionCollateFN:
"""
Collate function for Yolox training
"""

def __init__(self):
self.expected_item_names = ("image", "targets")

def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
try:
images_batch, labels_batch = list(zip(*data))
except (ValueError, TypeError):
raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)

return self._format_images(images_batch), self._format_targets(labels_batch)

def _format_images(self, images_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
images_batch = [torch.tensor(img) for img in images_batch]
images_batch_stack = torch.stack(images_batch, 0)
if images_batch_stack.shape[3] == 3:
images_batch_stack = torch.moveaxis(images_batch_stack, -1, 1).float()
return images_batch_stack

def _format_targets(self, labels_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
"""
Stack a batch id column to targets and concatenate
:param labels_batch: a list of targets per image (each of arbitrary length)
:return: one tensor of targets of all imahes of shape [N, 6], where N is the total number of targets in a batch
and the 1st column is batch item index
"""
labels_batch = [torch.tensor(labels) for labels in labels_batch]
labels_batch_indexed = []
for i, labels in enumerate(labels_batch):
batch_column = labels.new_ones((labels.shape[0], 1)) * i
labels = torch.cat((batch_column, labels), dim=-1)
labels_batch_indexed.append(labels)
return torch.cat(labels_batch_indexed, 0)
Loading

0 comments on commit 94c0f9e

Please sign in to comment.