Skip to content

Commit

Permalink
streamline v2 check
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Aug 22, 2023
1 parent 1efe583 commit 5358620
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import importlib
import inspect
import itertools
import multiprocessing
import os
import pathlib
import random
Expand Down Expand Up @@ -180,27 +179,30 @@ def check_transforms_v2_wrapper(dataset_test_case, *, config=None, supports_targ
from torchvision import datapoints
from torchvision.datasets import wrap_dataset_for_transforms_v2

def check_wrapped_samples(dataset):
for wrapped_sample in dataset:
assert tree_any(
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
)

target_keyss = [None]
if supports_target_keys:
target_keyss.append("all")

for target_keys, multiprocessing_context in itertools.product(
target_keyss, multiprocessing.get_all_start_methods()
):
for target_keys in target_keyss:
with dataset_test_case.create_dataset(config) as (dataset, info):
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)

assert isinstance(wrapped_dataset, type(dataset))
assert len(wrapped_dataset) == info["num_examples"]

dataloader = DataLoader(
wrapped_dataset, num_workers=2, multiprocessing_context=multiprocessing_context, collate_fn=_no_collate
)
check_wrapped_samples(wrapped_dataset)

for wrapped_sample in dataloader:
assert tree_any(
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
)
with dataset_test_case.create_dataset(config) as (dataset, _):
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)

check_wrapped_samples(dataloader)


class DatasetTestCase(unittest.TestCase):
Expand Down

0 comments on commit 5358620

Please sign in to comment.