Skip to content

Commit

Permalink
Bug/sg 000 merge failure for datasetparams (#1140)
Browse files Browse the repository at this point in the history
* failure verification

* added set_struct before merge

* added set_struct before merge in val too

* release tag filter added back

* return type fixed

* used merge_with with open_dict

* added test
  • Loading branch information
shaydeci authored Jun 7, 2023
1 parent a3191c4 commit 83ce129
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
15 changes: 11 additions & 4 deletions src/super_gradients/training/dataloaders/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf, UnsupportedValueType
from omegaconf import OmegaConf, UnsupportedValueType, DictConfig, open_dict
from torch.utils.data import BatchSampler, DataLoader, TensorDataset, RandomSampler

import super_gradients
Expand Down Expand Up @@ -101,12 +101,19 @@ def _process_dataset_params(cfg, dataset_params, train: bool):
# >>> dataset_params = OmegaConf.merge(default_dataset_params, dataset_params)
# >>> return hydra.utils.instantiate(dataset_params)
# For some reason this breaks interpolation :shrug:

if not isinstance(dataset_params, DictConfig):
dataset_params = OmegaConf.create(dataset_params)
if train:
cfg.train_dataset_params = OmegaConf.merge(cfg.train_dataset_params, dataset_params)
train_dataset_params = cfg.train_dataset_params
with open_dict(train_dataset_params):
train_dataset_params.merge_with(dataset_params)
cfg.train_dataset_params = train_dataset_params
return hydra.utils.instantiate(cfg.train_dataset_params)
else:
cfg.val_dataset_params = OmegaConf.merge(cfg.val_dataset_params, dataset_params)
val_dataset_params = cfg.val_dataset_params
with open_dict(val_dataset_params):
val_dataset_params.merge_with(dataset_params)
cfg.val_dataset_params = val_dataset_params
return hydra.utils.instantiate(cfg.val_dataset_params)

except UnsupportedValueType:
Expand Down
40 changes: 38 additions & 2 deletions tests/unit_tests/detection_dataset_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
import unittest
from pathlib import Path
from typing import Dict

from super_gradients.training.dataloaders import coco2017_train_yolo_nas
from torch.utils.data import DataLoader
from super_gradients.training.dataloaders import coco2017_train_yolo_nas, get_data_loader
from super_gradients.training.datasets import COCODetectionDataset
from super_gradients.training.datasets.data_formats.default_formats import LABEL_CXCYWH
from super_gradients.training.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
from super_gradients.training.transforms import DetectionMosaic, DetectionTargetsFormatTransform, DetectionPaddedRescale


class DummyCOCODetectionDatasetInheritor(COCODetectionDataset):
def __init__(self, json_file: str, subdir: str, dummy_field: int, *args, **kwargs):
super(DummyCOCODetectionDatasetInheritor, self).__init__(json_file=json_file, subdir=subdir, *args, **kwargs)
self.dummy_field = dummy_field


def dummy_coco2017_inheritor_train_yolo_nas(dataset_params: Dict = None, dataloader_params: Dict = None) -> DataLoader:
return get_data_loader(
config_name="coco_detection_yolo_nas_dataset_params",
dataset_cls=DummyCOCODetectionDatasetInheritor,
train=True,
dataset_params=dataset_params,
dataloader_params=dataloader_params,
)


class DetectionDatasetTest(unittest.TestCase):
def setUp(self) -> None:
self.mini_coco_data_dir = str(Path(__file__).parent.parent / "data" / "tinycoco")
Expand All @@ -23,7 +41,6 @@ def test_normal_coco_dataset_creation(self):
COCODetectionDataset(**train_dataset_params)

def test_coco_dataset_creation_with_wrong_classes(self):

train_dataset_params = {
"data_dir": self.mini_coco_data_dir,
"subdir": "images/train2017",
Expand Down Expand Up @@ -88,6 +105,25 @@ def test_coco_detection_dataset_override_with_objects(self):
self.assertEqual(batch[0].shape[2], 384)
self.assertEqual(batch[0].shape[3], 384)

def test_coco_detection_dataset_override_with_new_entries(self):
train_dataset_params = {
"data_dir": self.mini_coco_data_dir,
"input_dim": 384,
"transforms": [
DetectionMosaic(input_dim=384),
DetectionPaddedRescale(input_dim=384, max_targets=10),
DetectionTargetsFormatTransform(max_targets=10, output_format=LABEL_CXCYWH),
],
"dummy_field": 10,
}
train_dataloader_params = {"num_workers": 0}
dataloader = dummy_coco2017_inheritor_train_yolo_nas(dataset_params=train_dataset_params, dataloader_params=train_dataloader_params)
batch = next(iter(dataloader))
print(batch[0].shape)
self.assertEqual(batch[0].shape[2], 384)
self.assertEqual(batch[0].shape[3], 384)
self.assertEqual(dataloader.dataset.dummy_field, 10)


if __name__ == "__main__":
unittest.main()

0 comments on commit 83ce129

Please sign in to comment.