diff --git a/test/datasets_utils.py b/test/datasets_utils.py index b6f22d766df..8afc6ddb369 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -5,6 +5,7 @@ import itertools import os import pathlib +import platform import random import shutil import string @@ -548,7 +549,7 @@ def test_feature_types(self, config): @test_all_configs def test_num_examples(self, config): with self.create_dataset(config) as (dataset, info): - assert len(dataset) == info["num_examples"] + assert len(list(dataset)) == len(dataset) == info["num_examples"] @test_all_configs def test_transforms(self, config): @@ -692,6 +693,31 @@ def test_transforms_v2_wrapper(self, config): super().test_transforms_v2_wrapper.__wrapped__(self, config) +def _no_collate(batch): + return batch + + +def check_transforms_v2_wrapper_spawn(dataset): + # On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new + # subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what + # we are enforcing here. + if platform.system() != "Darwin": + pytest.skip("Multiprocessing spawning is only checked on macOS.") + + from torch.utils.data import DataLoader + from torchvision import datapoints + from torchvision.datasets import wrap_dataset_for_transforms_v2 + + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) + + dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate) + + for wrapped_sample in dataloader: + assert tree_any( + lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample + ) + + def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: r"""Create a random uint8 tensor. diff --git a/test/test_datasets.py b/test/test_datasets.py index ed6aa17d3f9..265316264f8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -183,6 +183,10 @@ def test_combined_targets(self): ), "Type of the combined target does not match the type of the corresponding individual target: " f"{actual} is not {expected}", + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset(target_type="category") as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Caltech256 @@ -190,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories" - categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter")) + categories = ((1, "ak47"), (2, "american-flag"), (3, "backpack")) num_images_per_category = 2 for idx, category in categories: @@ -258,6 +262,10 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["split"]] + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Cityscapes @@ -382,6 +390,11 @@ def test_feature_types_target_polygon(self): assert isinstance(polygon_img, PIL.Image.Image) (polygon_target, info["expected_polygon_target"]) + def test_transforms_v2_wrapper_spawn(self): + for target_type in ["instance", "semantic", ["instance", "semantic"]]: + with self.create_dataset(target_type=target_type) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.ImageNet @@ -413,6 +426,10 @@ def inject_fake_data(self, tmpdir, config): torch.save((wnid_to_classes, None), tmpdir / "meta.bin") return num_examples + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.CIFAR10 @@ -607,6 +624,11 @@ def test_images_names_split(self): assert merged_imgs_names == all_imgs_names + def test_transforms_v2_wrapper_spawn(self): + for target_type in ["identity", "bbox", ["identity", "bbox"]]: + with self.create_dataset(target_type=target_type) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.VOCSegmentation @@ -694,6 +716,10 @@ def add_bndbox(obj, bndbox=None): return data + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class VOCDetectionTestCase(VOCSegmentationTestCase): DATASET_CLASS = datasets.VOCDetection @@ -714,6 +740,10 @@ def test_annotations(self): assert object == info["annotation"] + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.CocoDetection @@ -784,6 +814,10 @@ def _create_json(self, root, name, content): json.dump(content, fh) return file + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class CocoCaptionsTestCase(CocoDetectionTestCase): DATASET_CLASS = datasets.CocoCaptions @@ -800,6 +834,11 @@ def test_captions(self): _, captions = dataset[0] assert tuple(captions) == tuple(info["captions"]) + def test_transforms_v2_wrapper_spawn(self): + # We need to define this method, because otherwise the test from the super class will + # be run + pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.") + class UCF101TestCase(datasets_utils.VideoDatasetTestCase): DATASET_CLASS = datasets.UCF101 @@ -966,6 +1005,10 @@ def inject_fake_data(self, tmpdir, config): ) return num_videos_per_class * len(classes) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset(output_format="TCHW") as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): DATASET_CLASS = datasets.HMDB51 @@ -1193,6 +1236,10 @@ def _create_segmentation(self, size): def _file_stem(self, idx): return f"2008_{idx:06d}" + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset(mode="segmentation") as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class FakeDataTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.FakeData @@ -1642,6 +1689,10 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["train"]] + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class SvhnTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.SVHN @@ -2516,6 +2567,10 @@ def _meta_to_split_and_classification_ann(self, meta, idx): breed_id = "-1" return (image_id, class_id, species, breed_id) + def test_transforms_v2_wrapper_spawn(self): + with self.create_dataset() as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.StanfordCars diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 26dde640788..325d864dc05 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1,5 +1,6 @@ import itertools import pathlib +import pickle import random import warnings @@ -169,8 +170,11 @@ class TestSmoke: next(make_vanilla_tensor_images()), ], ) + @pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_common(self, transform, adapter, container_type, image_or_video, device): + def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device): + transform = de_serialize(transform) + canvas_size = F.get_size(image_or_video) input = dict( image_or_video=image_or_video, diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index dce6229e84b..14842c85c4b 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2,6 +2,7 @@ import decimal import inspect import math +import pickle import re from pathlib import Path from unittest import mock @@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input): def check_transform(transform_cls, input, *args, **kwargs): transform = transform_cls(*args, **kwargs) + pickle.loads(pickle.dumps(transform)) + output = transform(input) assert isinstance(output, type(input)) diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index 2ed601fec21..07a3e0ff733 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -162,6 +162,7 @@ def __init__(self, dataset, target_keys): raise TypeError(msg) self._dataset = dataset + self._target_keys = target_keys self._wrapper = wrapper_factory(dataset, target_keys) # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them. @@ -197,6 +198,9 @@ def __getitem__(self, idx): def __len__(self): return len(self._dataset) + def __reduce__(self): + return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys) + def raise_not_supported(description): raise RuntimeError( diff --git a/torchvision/datasets/widerface.py b/torchvision/datasets/widerface.py index b46c7982d8b..aa520455ef1 100644 --- a/torchvision/datasets/widerface.py +++ b/torchvision/datasets/widerface.py @@ -137,13 +137,13 @@ def parse_train_val_annotations_file(self) -> None: { "img_path": img_path, "annotations": { - "bbox": labels_tensor[:, 0:4], # x, y, width, height - "blur": labels_tensor[:, 4], - "expression": labels_tensor[:, 5], - "illumination": labels_tensor[:, 6], - "occlusion": labels_tensor[:, 7], - "pose": labels_tensor[:, 8], - "invalid": labels_tensor[:, 9], + "bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height + "blur": labels_tensor[:, 4].clone(), + "expression": labels_tensor[:, 5].clone(), + "illumination": labels_tensor[:, 6].clone(), + "occlusion": labels_tensor[:, 7].clone(), + "pose": labels_tensor[:, 8].clone(), + "invalid": labels_tensor[:, 9].clone(), }, } )