Skip to content

Commit

Permalink
add color target type test
Browse files Browse the repository at this point in the history
  • Loading branch information
lijm1358 committed Jan 28, 2023
1 parent 964501b commit 67debd9
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions tests/datamodules/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,35 @@ def _create_synth_Cityscapes_dataset(path_dir):
image_name = f"{base_name}_leftImg8bit.png"
instance_target_name = f"{base_name}_gtFine_instanceIds.png"
semantic_target_name = f"{base_name}_gtFine_labelIds.png"
color_target_name = f"{base_name}_gtFine_color.png"
Image.new("RGB", (2048, 1024)).save(images_dir / split / city / image_name)
Image.new("L", (2048, 1024)).save(fine_labels_dir / split / city / instance_target_name)
Image.new("L", (2048, 1024)).save(fine_labels_dir / split / city / semantic_target_name)
Image.new("RGBA", (2048, 1024)).save(fine_labels_dir / split / city / color_target_name)


def test_cityscapes_datamodule(datadir):
def test_cityscapes_datamodule(datadir, catch_warnings):
_create_synth_Cityscapes_dataset(datadir)

batch_size = 1
target_types = ["semantic", "instance"]
for target_type in target_types:
target_types = ["semantic", "instance", "color"]
target_sizes = [(1024, 2048), (1024, 2048), (4, 1024, 2048)]
for target_type, target_size in zip(target_types, target_sizes):
dm = CityscapesDataModule(datadir, num_workers=0, batch_size=batch_size, target_type=target_type)
loader = dm.train_dataloader()
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, 1024, 2048])

loader = dm.val_dataloader()
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, 1024, 2048])

loader = dm.test_dataloader()
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, 1024, 2048])
loader = dm.train_dataloader()
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, *target_size])

loader = dm.val_dataloader()
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, *target_size])

loader = dm.test_dataloader()
img, mask = next(iter(loader))
assert img.size() == torch.Size([batch_size, 3, 1024, 2048])
assert mask.size() == torch.Size([batch_size, *target_size])


@pytest.mark.parametrize("val_split, train_len", [(0.2, 48_000), (5_000, 55_000)])
Expand Down

0 comments on commit 67debd9

Please sign in to comment.