Skip to content

Commit

Permalink
Check class_id validity in DetectionDataset (#1536)
Browse files Browse the repository at this point in the history
* first draft

* fix

* fix

* fix detectiondataset

* Ensure that `checkpoint_num_classes` is propagated from YAML to model (#1533)

* Ensure that `checkpoint_num_classes` is propagated from YAML to models.get()

* Red checkpoint_num_classes via get_params

* Fixed a bug in export() that prevented to export model for BS>1 (#1530)

* Remove unnecessary files (#1540)

* remove check on base class

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
Louis-Dupont and BloodAxe committed Oct 17, 2023
1 parent b3b37e5 commit 845318d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def _load_sample_annotation(self, sample_id: int) -> Dict[str, Union[np.ndarray,
# Filter out classes that are not in self.class_inclusion_list
if self.class_inclusion_list is not None:
sample_annotations = self._sub_class_annotation(annotation=sample_annotations)

return sample_annotations

def _load_all_annotations(self, n_samples: int) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, Dict[str, Any]]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _load_annotation(self, sample_id: int) -> dict:

yolo_format_target, invalid_labels = self._parse_yolo_label_file(
label_file_path=label_path,
num_classes=len(self.all_classes_list),
ignore_invalid_labels=self.ignore_invalid_labels,
show_warnings=self.show_all_warnings,
)
Expand All @@ -210,13 +211,20 @@ def _load_annotation(self, sample_id: int) -> dict:
return annotation

@staticmethod
def _parse_yolo_label_file(label_file_path: str, ignore_invalid_labels: bool = True, show_warnings: bool = True) -> Tuple[np.ndarray, List[str]]:
def _parse_yolo_label_file(
label_file_path: str,
ignore_invalid_labels: bool = True,
show_warnings: bool = True,
num_classes: Optional[int] = None,
) -> Tuple[np.ndarray, List[str]]:
"""Parse a single label file in yolo format.
#TODO: Add support for additional fields (with ConcatenatedTensorFormat)
:param label_file_path: Path to the label file in yolo format.
:param ignore_invalid_labels: Whether to ignore labels that fail to be parsed. If True ignores and logs a warning, otherwise raise an error.
:param show_warnings: Whether to show the warnings or not.
:param num_classes: Number of classes in the dataset. Used to ensure that class ids are within the range [0, num_classes - 1].
If None, ignore.
:return:
- labels: np.ndarray of shape (n_labels, 5) in yolo format (LABEL_NORMALIZED_CXCYWH)
Expand All @@ -229,12 +237,21 @@ def _parse_yolo_label_file(label_file_path: str, ignore_invalid_labels: bool = T
for line in filter(lambda x: x != "\n", lines):
try:
label_id, cx, cw, w, h = line.split()
labels_yolo_format.append([int(label_id), float(cx), float(cw), float(w), float(h)])
label_id, cx, cw, w, h = int(label_id), float(cx), float(cw), float(w), float(h)

if (num_classes is not None) and (label_id not in range(num_classes)):
raise ValueError(f"`class_id={label_id}` invalid. It should be between (0 - {num_classes - 1}).")

labels_yolo_format.append([label_id, cx, cw, w, h])
except Exception as e:
error_msg = (
f"Line `{line}` of file {label_file_path} will be ignored because not cannot be parsed to (label, cx, cy, w, h) format, "
f"with Exception:\n{e}"
)
if ignore_invalid_labels:
invalid_labels.append(line)
if show_warnings:
logger.warning(f"Line `{line}` of file {label_file_path} will be ignored because not in LABEL_NORMALIZED_CXCYWH format: {e}")
logger.warning(error_msg)
else:
raise e
raise RuntimeError(error_msg)
return np.array(labels_yolo_format) if labels_yolo_format else np.zeros((0, 5)), invalid_labels
3 changes: 2 additions & 1 deletion tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from tests.end_to_end_tests import TestTrainer
from tests.unit_tests.detection_utils_test import TestDetectionUtils
from tests.unit_tests.detection_dataset_test import DetectionDatasetTest
from tests.unit_tests.detection_dataset_test import DetectionDatasetTest, TestParseYoloLabelFile
from tests.unit_tests.export_detection_model_test import TestDetectionModelExport
from tests.unit_tests.export_onnx_test import TestModelsONNXExport
from tests.unit_tests.export_pose_estimation_model_test import TestPoseEstimationModelExport
Expand Down Expand Up @@ -136,6 +136,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestRepVGGBlock))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(LocalCkptHeadReplacementTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DetectionDatasetTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestParseYoloLabelFile))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelsONNXExport))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(MaxBatchesLoopBreakTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestTrainingUtils))
Expand Down
30 changes: 29 additions & 1 deletion tests/unit_tests/detection_dataset_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import unittest
from unittest.mock import patch, mock_open
from pathlib import Path
from typing import Dict
import numpy as np

from torch.utils.data import DataLoader

from super_gradients import Trainer
from super_gradients.training import models, dataloaders
from super_gradients.training.dataloaders import coco2017_train_yolo_nas, get_data_loader
from super_gradients.training.datasets import COCODetectionDataset
from super_gradients.training.datasets import COCODetectionDataset, YoloDarknetFormatDetectionDataset
from super_gradients.training.datasets.data_formats.default_formats import LABEL_CXCYWH
from super_gradients.training.datasets.datasets_conf import COCO_DETECTION_CLASSES_LIST
from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
Expand Down Expand Up @@ -194,5 +196,31 @@ def test_coco_detection_metrics_with_classwise_ap(self):
trainer.train(model=model, training_params=detection_train_params_yolox, train_loader=train_loader, valid_loader=valid_loader)


class TestParseYoloLabelFile(unittest.TestCase):
def setUp(self):
self.num_classes = 3
self.sample_data_valid = "0 0.5 0.5 0.1 0.1\n1 0.6 0.6 0.2 0.2"
self.sample_data_invalid_format = "0 0.5\n1 0.6 0.6 0.2 0.2"
self.sample_data_invalid_class = "-1 0.5 0.5 0.1 0.1\n3 0.6 0.6 0.2 0.2"

def test_valid_label(self):
with patch("builtins.open", mock_open(read_data=self.sample_data_valid)):
labels, invalid_labels = YoloDarknetFormatDetectionDataset._parse_yolo_label_file("mock_path", num_classes=3)
np.testing.assert_array_equal(labels, np.array([[0, 0.5, 0.5, 0.1, 0.1], [1, 0.6, 0.6, 0.2, 0.2]]))
self.assertEqual(invalid_labels, [])

def test_invalid_format(self):
with patch("builtins.open", mock_open(read_data=self.sample_data_invalid_format)):
labels, invalid_labels = YoloDarknetFormatDetectionDataset._parse_yolo_label_file("mock_path", num_classes=3)
np.testing.assert_array_equal(labels, np.array([[1, 0.6, 0.6, 0.2, 0.2]]))
self.assertEqual(invalid_labels, ["0 0.5\n"])

def test_invalid_class(self):
with patch("builtins.open", mock_open(read_data=self.sample_data_invalid_class)):
labels, invalid_labels = YoloDarknetFormatDetectionDataset._parse_yolo_label_file("mock_path", num_classes=3)
self.assertEqual(len(labels), 0)
self.assertEqual(invalid_labels, ["-1 0.5 0.5 0.1 0.1\n", "3 0.6 0.6 0.2 0.2"])


if __name__ == "__main__":
unittest.main()

0 comments on commit 845318d

Please sign in to comment.