Skip to content

Commit

Permalink
Feature/sg 814 support yoloformat loader (#847)
Browse files Browse the repository at this point in the history
* first draft

* improve naming

* fix name

* remove comments

* wip

* add comment
  • Loading branch information
Louis-Dupont committed Apr 19, 2023
1 parent 876ee76 commit 4d06a20
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ class Dataloaders:
COCO2017_VAL_SSD_LITE_MOBILENET_V2 = "coco2017_val_ssd_lite_mobilenet_v2"
COCO2017_POSE_TRAIN = "coco2017_pose_train"
COCO2017_POSE_VAL = "coco2017_pose_val"
COCO_DETECTION_YOLO_FORMAT_TRAIN = "coco_detection_yolo_format_train"
COCO_DETECTION_YOLO_FORMAT_VAL = "coco_detection_yolo_format_val"
IMAGENET_TRAIN = "imagenet_train"
IMAGENET_VAL = "imagenet_val"
IMAGENET_EFFICIENTNET_TRAIN = "imagenet_efficientnet_train"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

train_dataset_params:
data_dir: /data/coco # TO FILL: Where the data is stored.
images_dir: images/train2017 # TO FILL: Local path to directory that includes all the images. Path relative to `data_dir`. Can be the same as `labels_dir`.
labels_dir: labels/train2017 # TO FILL: Local path to directory that includes all the labels. Path relative to `data_dir`. Can be the same as `images_dir`.
classes: [ person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic light, fire hydrant, stop sign,
parking meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag,
tie, suitcase, frisbee, skis, snowboard, sports ball, kite, baseball bat, baseball glove, skateboard, surfboard,
tennis racket, bottle, wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot,
hot dog, pizza, donut, cake, chair, couch, potted plant, bed, dining table, toilet, tv, laptop, mouse, remote,
keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear,
hair drier, toothbrush] # TO FILL: List of classes used in your dataset.
input_dim: [640, 640]
cache_dir:
cache: False
transforms:
- DetectionMosaic:
input_dim: ${dataset_params.train_dataset_params.input_dim}
prob: 1.
- DetectionRandomAffine:
degrees: 10. # rotation degrees, randomly sampled from [-degrees, degrees]
translate: 0.1 # image translation fraction
scales: [ 0.1, 2 ] # random rescale range (keeps size by padding/cropping) after mosaic transform.
shear: 2.0 # shear degrees, randomly sampled from [-degrees, degrees]
target_size: ${dataset_params.train_dataset_params.input_dim}
filter_box_candidates: True # whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.
wh_thr: 2 # edge size threshold when filter_box_candidates = True (pixels)
area_thr: 0.1 # threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True
ar_thr: 20 # aspect ratio threshold when filter_box_candidates = True
- DetectionMixup:
input_dim: ${dataset_params.train_dataset_params.input_dim}
mixup_scale: [ 0.5, 1.5 ] # random rescale range for the additional sample in mixup
prob: 1.0 # probability to apply per-sample mixup
flip_prob: 0.5 # probability to apply horizontal flip
- DetectionHSV:
prob: 1.0 # probability to apply HSV transform
hgain: 5 # HSV transform hue gain (randomly sampled from [-hgain, hgain])
sgain: 30 # HSV transform saturation gain (randomly sampled from [-sgain, sgain])
vgain: 30 # HSV transform value gain (randomly sampled from [-vgain, vgain])
- DetectionHorizontalFlip:
prob: 0.5 # probability to apply horizontal flip
- DetectionPaddedRescale:
input_dim: ${dataset_params.train_dataset_params.input_dim}
max_targets: 120
- DetectionTargetsFormatTransform:
input_dim: ${dataset_params.train_dataset_params.input_dim}
output_format: LABEL_CXCYWH
class_inclusion_list:
max_num_samples:

train_dataloader_params:
batch_size: 25
num_workers: 8
shuffle: True
drop_last: True
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN

val_dataset_params:
data_dir: /data/coco # TO FILL: Where the data is stored.
images_dir: images/val2017 # TO FILL: Local path to directory that includes all the images. Path relative to `data_dir`. Can be the same as `labels_dir`.
labels_dir: labels/val2017 # TO FILL: Local path to directory that includes all the labels. Path relative to `data_dir`. Can be the same as `images_dir`.
classes: [ person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic light, fire hydrant, stop sign,
parking meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag,
tie, suitcase, frisbee, skis, snowboard, sports ball, kite, baseball bat, baseball glove, skateboard, surfboard,
tennis racket, bottle, wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot,
hot dog, pizza, donut, cake, chair, couch, potted plant, bed, dining table, toilet, tv, laptop, mouse, remote,
keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear,
hair drier, toothbrush] # TO FILL: List of classes used in your dataset.
input_dim: [640, 640]
cache_dir:
cache: False
transforms:
- DetectionPaddedRescale:
input_dim: ${dataset_params.val_dataset_params.input_dim}
- DetectionTargetsFormatTransform:
max_targets: 50
input_dim: ${dataset_params.val_dataset_params.input_dim}
output_format: LABEL_CXCYWH
class_inclusion_list:
max_num_samples:

val_dataloader_params:
batch_size: 25
num_workers: 8
drop_last: False
pin_memory: True
collate_fn:
_target_: super_gradients.training.utils.detection_utils.DetectionCollateFN

_convert_: all
24 changes: 23 additions & 1 deletion src/super_gradients/training/dataloaders/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Cifar10,
Cifar100,
)
from super_gradients.training.datasets.detection_datasets import COCODetectionDataset, RoboflowDetectionDataset
from super_gradients.training.datasets.detection_datasets import COCODetectionDataset, RoboflowDetectionDataset, YoloDarknetFormatDetectionDataset
from super_gradients.training.datasets.detection_datasets.pascal_voc_detection import (
PascalVOCUnifiedDetectionTrainDataset,
PascalVOCDetectionDataset,
Expand Down Expand Up @@ -270,6 +270,28 @@ def roboflow_val_yolox(dataset_params: Dict = None, dataloader_params: Dict = No
)


@register_dataloader(Dataloaders.COCO_DETECTION_YOLO_FORMAT_TRAIN)
def coco_detection_yolo_format_train(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
return get_data_loader(
config_name="coco_detection_yolo_format_base_dataset_params",
dataset_cls=YoloDarknetFormatDetectionDataset,
train=True,
dataset_params=dataset_params,
dataloader_params=dataloader_params,
)


@register_dataloader(Dataloaders.COCO_DETECTION_YOLO_FORMAT_VAL)
def coco_detection_yolo_format_val(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
return get_data_loader(
config_name="coco_detection_yolo_format_base_dataset_params",
dataset_cls=YoloDarknetFormatDetectionDataset,
train=False,
dataset_params=dataset_params,
dataloader_params=dataloader_params,
)


@register_dataloader(Dataloaders.IMAGENET_TRAIN)
def imagenet_train(dataset_params: Dict = None, dataloader_params: Dict = None, config_name="imagenet_dataset_params"):
return get_data_loader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(

self.data_dir = data_dir
if not Path(data_dir).exists():
raise FileNotFoundError(f"data_dir={data_dir} not found. Please make sure that data_dir points toward your dataset.")
raise RuntimeError(f"data_dir={data_dir} not found. Please make sure that data_dir points toward your dataset.")

# Number of images that are available (regardless of ignored images)
self.n_available_samples = self._setup_data_source()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,21 @@ def _setup_data_source(self) -> int:
logger.warning(f"{len(labels_not_in_images)} label files are not associated to any image.")

# Only keep names that are in both the images and the labels
valid_base_names = list(unique_image_file_base_names & unique_label_file_base_names)
valid_base_names = unique_image_file_base_names & unique_label_file_base_names
if len(valid_base_names) != len(all_images_file_names):
logger.warning(
f"As a consequence, "
f"{len(valid_base_names)}/{len(all_images_file_names)} images and "
f"{len(valid_base_names)}/{len(all_labels_file_names)} label files will be used."
)

self.images_file_names = list(
sorted(image_full_name for image_full_name in all_images_file_names if remove_file_extension(image_full_name) in valid_base_names)
)
self.labels_file_names = list(
sorted(label_full_name for label_full_name in all_labels_file_names if remove_file_extension(label_full_name) in valid_base_names)
)
self.images_file_names = []
self.labels_file_names = []
for image_full_name in all_images_file_names:
base_name = remove_file_extension(image_full_name)
if base_name in valid_base_names:
self.images_file_names.append(image_full_name)
self.labels_file_names.append(base_name + ".txt")
return len(self.images_file_names)

def _load_annotation(self, sample_id: int) -> dict:
Expand Down

0 comments on commit 4d06a20

Please sign in to comment.