Skip to content

Commit

Permalink
update tests to run unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Dupont committed Nov 22, 2023
1 parent fef6347 commit b17bbd3
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 101 deletions.
6 changes: 6 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
DynamicModelTests,
TestExportRecipe,
TestMixedPrecisionDisabled,
TestClassificationAdapter,
TestDetectionAdapter,
TestSegmentationAdapter,
)
from tests.end_to_end_tests import TestTrainer
from tests.unit_tests.test_convert_recipe_to_code import TestConvertRecipeToCode
Expand Down Expand Up @@ -172,6 +175,9 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestConvertRecipeToCode))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestVersionCheck))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelWeightAveraging))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestClassificationAdapter))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionAdapter))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationAdapter))

def _add_modules_to_end_to_end_tests_suite(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions tests/unit_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tests.unit_tests.test_deprecate import TestDeprecationDecorator
from tests.unit_tests.test_models_factory import DynamicModelTests
from tests.unit_tests.test_mixed_precision_cpu import TestMixedPrecisionDisabled
from tests.unit_tests.test_data_adapters import TestClassificationAdapter, TestDetectionAdapter, TestSegmentationAdapter

__all__ = [
"CrashTipTest",
Expand Down Expand Up @@ -61,4 +62,7 @@
"TestMixedPrecisionDisabled",
"DynamicModelTests",
"TestExportRecipe",
"TestClassificationAdapter",
"TestDetectionAdapter",
"TestSegmentationAdapter",
]
214 changes: 113 additions & 101 deletions tests/unit_tests/test_data_adapters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import tempfile
import numpy as np
import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -70,48 +71,51 @@ def setUp(self) -> None:
]

def test_adapt_dataset_detection(self):

analyzer_ds = DetectionAnalysisManager(
report_title="test_adapt_dataset_detection",
train_data=self.dataset,
val_data=self.dataset,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
use_cache=True,
is_label_first=False,
bbox_format="xywh",
)
analyzer_ds.run() # Run the analysis. This will create the cache.

loader = DetectionDataloaderAdapterFactory.from_dataset(dataset=self.dataset, config_path=analyzer_ds.data_config.cache_path, batch_size=2)

for expected_images_shape, expected_targets, (images, targets) in zip(self.expected_image_shapes_batches, self.expected_targets_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue(((0 <= images) & (images <= 255)).all()) # Should be 0-255
self.assertTrue(torch.equal(targets, expected_targets))
with tempfile.TemporaryDirectory() as tmpdirname:
analyzer_ds = DetectionAnalysisManager(
log_dir=tmpdirname,
report_title="test_adapt_dataset_detection",
train_data=self.dataset,
val_data=self.dataset,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
use_cache=True,
is_label_first=False,
bbox_format="xywh",
)
analyzer_ds.run() # Run the analysis. This will create the cache.

loader = DetectionDataloaderAdapterFactory.from_dataset(dataset=self.dataset, config_path=analyzer_ds.data_config.cache_path, batch_size=2)

for expected_images_shape, expected_targets, (images, targets) in zip(self.expected_image_shapes_batches, self.expected_targets_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue(((0 <= images) & (images <= 255)).all()) # Should be 0-255
self.assertTrue(torch.equal(targets, expected_targets))

def test_adapt_dataloader_detection(self):

loader = DataLoader(self.dataset, batch_size=2)

analyzer_ds = DetectionAnalysisManager(
report_title="test_adapt_dataloader_detection",
train_data=loader,
val_data=loader,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
use_cache=True,
is_label_first=False,
bbox_format="xywh",
)
analyzer_ds.run()
with tempfile.TemporaryDirectory() as tmpdirname:
analyzer_ds = DetectionAnalysisManager(
log_dir=tmpdirname,
report_title="test_adapt_dataloader_detection",
train_data=loader,
val_data=loader,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
use_cache=True,
is_label_first=False,
bbox_format="xywh",
)
analyzer_ds.run()

loader = DetectionDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer_ds.data_config.cache_path)
loader = DetectionDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer_ds.data_config.cache_path)

for expected_images_shape, expected_targets, (images, targets) in zip(self.expected_image_shapes_batches, self.expected_targets_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue(((0 <= images) & (images <= 255)).all()) # Should be 0-255
self.assertTrue(torch.equal(targets, expected_targets))
for expected_images_shape, expected_targets, (images, targets) in zip(self.expected_image_shapes_batches, self.expected_targets_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue(((0 <= images) & (images <= 255)).all()) # Should be 0-255
self.assertTrue(torch.equal(targets, expected_targets))


class TestSegmentationAdapter(unittest.TestCase):
Expand All @@ -132,43 +136,47 @@ def setUp(self) -> None:

def test_adapt_dataset_segmentation(self):

analyzer_ds = SegmentationAnalysisManager(
report_title="test_adapt_dataset_segmentation",
train_data=self.dataset,
val_data=self.dataset,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
use_cache=True,
is_batch=False,
)
analyzer_ds.run()

loader = SegmentationDataloaderAdapterFactory.from_dataset(dataset=self.dataset, config_path=analyzer_ds.data_config.cache_path, batch_size=2)

for expected_images_shape, expected_masks, (images, masks) in zip(self.expected_image_shapes_batches, self.expected_masks_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue((masks == expected_masks).all()) # Checking that the masks are as expected
with tempfile.TemporaryDirectory() as tmpdirname:
analyzer_ds = SegmentationAnalysisManager(
log_dir=tmpdirname,
report_title="test_adapt_dataset_segmentation",
train_data=self.dataset,
val_data=self.dataset,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
use_cache=True,
is_batch=False,
)
analyzer_ds.run()

loader = SegmentationDataloaderAdapterFactory.from_dataset(dataset=self.dataset, config_path=analyzer_ds.data_config.cache_path, batch_size=2)

for expected_images_shape, expected_masks, (images, masks) in zip(self.expected_image_shapes_batches, self.expected_masks_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue((masks == expected_masks).all()) # Checking that the masks are as expected

def test_adapt_dataloader_segmentation(self):

loader = DataLoader(self.dataset, batch_size=2)

analyzer_ds = SegmentationAnalysisManager(
report_title="test_adapt_dataloader_segmentation",
train_data=loader,
val_data=loader,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
use_cache=True,
is_batch=True,
)
analyzer_ds.run()
with tempfile.TemporaryDirectory() as tmpdirname:
analyzer_ds = SegmentationAnalysisManager(
log_dir=tmpdirname,
report_title="test_adapt_dataloader_segmentation",
train_data=loader,
val_data=loader,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
use_cache=True,
is_batch=True,
)
analyzer_ds.run()

loader = DetectionDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer_ds.data_config.cache_path)
loader = SegmentationDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer_ds.data_config.cache_path)

for expected_images_shape, expected_masks, (images, masks) in zip(self.expected_image_shapes_batches, self.expected_masks_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue((masks == expected_masks).all()) # Checking that the masks are as expected
for expected_images_shape, expected_masks, (images, masks) in zip(self.expected_image_shapes_batches, self.expected_masks_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue((masks == expected_masks).all()) # Checking that the masks are as expected


class TestClassificationAdapter(unittest.TestCase):
Expand All @@ -187,47 +195,51 @@ def setUp(self) -> None:

def test_adapt_dataset_classification(self):

analyzer_ds = ClassificationAnalysisManager(
report_title="test_adapt_dataset_classification",
train_data=self.dataset,
val_data=self.dataset,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
images_extractor="[0]",
labels_extractor="[1]",
use_cache=True,
is_batch=False,
)
analyzer_ds.run()

loader = ClassificationDataloaderAdapterFactory.from_dataset(dataset=self.dataset, config_path=analyzer_ds.data_config.cache_path, batch_size=2)

for expected_images_shape, expected_labels, (images, labels) in zip(self.expected_image_shapes_batches, self.expected_labels_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue(torch.equal(labels, expected_labels))
with tempfile.TemporaryDirectory() as tmpdirname:
analyzer_ds = ClassificationAnalysisManager(
log_dir=tmpdirname,
report_title="test_adapt_dataset_classification",
train_data=self.dataset,
val_data=self.dataset,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
images_extractor="[0]",
labels_extractor="[1]",
use_cache=True,
is_batch=False,
)
analyzer_ds.run()

loader = ClassificationDataloaderAdapterFactory.from_dataset(dataset=self.dataset, config_path=analyzer_ds.data_config.cache_path, batch_size=2)

for expected_images_shape, expected_labels, (images, labels) in zip(self.expected_image_shapes_batches, self.expected_labels_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue(torch.equal(labels, expected_labels))

def test_adapt_dataloader_classification(self):

loader = DataLoader(self.dataset, batch_size=2)

analyzer_ds = ClassificationAnalysisManager(
report_title="test_adapt_dataloader_classification",
train_data=loader,
val_data=loader,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
images_extractor="[0]",
labels_extractor="[1]",
use_cache=True,
is_batch=True,
)
analyzer_ds.run()

loader = ClassificationDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer_ds.data_config.cache_path)

for expected_images_shape, expected_labels, (images, labels) in zip(self.expected_image_shapes_batches, self.expected_labels_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue(torch.equal(labels, expected_labels))
with tempfile.TemporaryDirectory() as tmpdirname:
analyzer_ds = ClassificationAnalysisManager(
log_dir=tmpdirname,
report_title="test_adapt_dataloader_classification",
train_data=loader,
val_data=loader,
class_names=list(map(str, range(6))),
image_channels=ImageChannels.from_str("RGB"),
images_extractor="[0]",
labels_extractor="[1]",
use_cache=True,
is_batch=True,
)
analyzer_ds.run()

loader = ClassificationDataloaderAdapterFactory.from_dataloader(dataloader=loader, config_path=analyzer_ds.data_config.cache_path)

for expected_images_shape, expected_labels, (images, labels) in zip(self.expected_image_shapes_batches, self.expected_labels_batches, loader):
self.assertEqual(images.shape, expected_images_shape)
self.assertTrue(torch.equal(labels, expected_labels))


if __name__ == "__main__":
Expand Down

0 comments on commit b17bbd3

Please sign in to comment.