Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enforce pickleability for v2 transforms and wrapped datasets #7860

Merged
merged 10 commits into from
Aug 24, 2023
Merged
28 changes: 27 additions & 1 deletion test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import os
import pathlib
import platform
import random
import shutil
import string
Expand Down Expand Up @@ -548,7 +549,7 @@ def test_feature_types(self, config):
@test_all_configs
def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info):
assert len(dataset) == info["num_examples"]
assert len(list(dataset)) == len(dataset) == info["num_examples"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never actually consumed the dataset before. Thus, any failures that happen not for the first sample are not detected. Fortunately, only one test was broken that I'll flag below.


@test_all_configs
def test_transforms(self, config):
Expand Down Expand Up @@ -692,6 +693,31 @@ def test_transforms_v2_wrapper(self, config):
super().test_transforms_v2_wrapper.__wrapped__(self, config)


def _no_collate(batch):
return batch


def check_transforms_v2_wrapper_spawn(dataset):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
if platform.system() != "Darwin":
pytest.skip("Multiprocessing spawning is only checked on macOS.")

from torch.utils.data import DataLoader
from torchvision import datapoints
from torchvision.datasets import wrap_dataset_for_transforms_v2

wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)

dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)

for wrapped_sample in dataloader:
assert tree_any(
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
)


def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.

Expand Down
57 changes: 56 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,18 @@ def test_combined_targets(self):
), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech256

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"

categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter"))
categories = ((1, "ak47"), (2, "american-flag"), (3, "backpack"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

datasets.Caltech relies on the fact that all categories are present. When actually consuming the dataset (see above), the old fake data setup falls flat. Our options are:

  1. Fix the dataset to account for gaps in the catgories.
  2. Create all categories as fakedata
  3. Create fakedata that starts at the first without any gaps, but not all categories.

Option 3. is by far the least amount of work, so I went for that here.

num_images_per_category = 2

for idx, category in categories:
Expand Down Expand Up @@ -258,6 +262,10 @@ def inject_fake_data(self, tmpdir, config):

return split_to_num_examples[config["split"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Cityscapes
Expand Down Expand Up @@ -382,6 +390,11 @@ def test_feature_types_target_polygon(self):
assert isinstance(polygon_img, PIL.Image.Image)
(polygon_target, info["expected_polygon_target"])

def test_transforms_v2_wrapper_spawn(self):
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageNet
Expand Down Expand Up @@ -413,6 +426,10 @@ def inject_fake_data(self, tmpdir, config):
torch.save((wnid_to_classes, None), tmpdir / "meta.bin")
return num_examples

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10
Expand Down Expand Up @@ -607,6 +624,11 @@ def test_images_names_split(self):

assert merged_imgs_names == all_imgs_names

def test_transforms_v2_wrapper_spawn(self):
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation
Expand Down Expand Up @@ -694,6 +716,10 @@ def add_bndbox(obj, bndbox=None):

return data

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class VOCDetectionTestCase(VOCSegmentationTestCase):
DATASET_CLASS = datasets.VOCDetection
Expand All @@ -714,6 +740,10 @@ def test_annotations(self):

assert object == info["annotation"]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CocoDetection
Expand Down Expand Up @@ -784,6 +814,10 @@ def _create_json(self, root, name, content):
json.dump(content, fh)
return file

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class CocoCaptionsTestCase(CocoDetectionTestCase):
DATASET_CLASS = datasets.CocoCaptions
Expand All @@ -800,6 +834,11 @@ def test_captions(self):
_, captions = dataset[0]
assert tuple(captions) == tuple(info["captions"])

def test_transforms_v2_wrapper_spawn(self):
# We need to define this method, because otherwise the test from the super class will
# be run
pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.")


class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.UCF101
Expand Down Expand Up @@ -966,6 +1005,10 @@ def inject_fake_data(self, tmpdir, config):
)
return num_videos_per_class * len(classes)

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51
Expand Down Expand Up @@ -1193,6 +1236,10 @@ def _create_segmentation(self, size):
def _file_stem(self, idx):
return f"2008_{idx:06d}"

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FakeData
Expand Down Expand Up @@ -1642,6 +1689,10 @@ def inject_fake_data(self, tmpdir, config):

return split_to_num_examples[config["train"]]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SVHN
Expand Down Expand Up @@ -2516,6 +2567,10 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
breed_id = "-1"
return (image_id, class_id, species, breed_id)

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.StanfordCars
Expand Down
6 changes: 5 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import pathlib
import pickle
import random
import warnings

Expand Down Expand Up @@ -169,8 +170,11 @@ class TestSmoke:
next(make_vanilla_tensor_images()),
],
)
@pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_common(self, transform, adapter, container_type, image_or_video, device):
def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
transform = de_serialize(transform)

canvas_size = F.get_size(image_or_video)
input = dict(
image_or_video=image_or_video,
Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import decimal
import inspect
import math
import pickle
import re
from pathlib import Path
from unittest import mock
Expand Down Expand Up @@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input):
def check_transform(transform_cls, input, *args, **kwargs):
transform = transform_cls(*args, **kwargs)

pickle.loads(pickle.dumps(transform))

output = transform(input)
assert isinstance(output, type(input))

Expand Down
4 changes: 4 additions & 0 deletions torchvision/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(self, dataset, target_keys):
raise TypeError(msg)

self._dataset = dataset
self._target_keys = target_keys
self._wrapper = wrapper_factory(dataset, target_keys)

# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
Expand Down Expand Up @@ -197,6 +198,9 @@ def __getitem__(self, idx):
def __len__(self):
return len(self._dataset)

def __reduce__(self):
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)


def raise_not_supported(description):
raise RuntimeError(
Expand Down
14 changes: 7 additions & 7 deletions torchvision/datasets/widerface.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ def parse_train_val_annotations_file(self) -> None:
{
"img_path": img_path,
"annotations": {
"bbox": labels_tensor[:, 0:4], # x, y, width, height
"blur": labels_tensor[:, 4],
"expression": labels_tensor[:, 5],
"illumination": labels_tensor[:, 6],
"occlusion": labels_tensor[:, 7],
"pose": labels_tensor[:, 8],
"invalid": labels_tensor[:, 9],
"bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Views on tensor cannot be pickled correctly. Meaning regardless of the v2 wrapper, datasets.Widerface has never worked in a spawn context.

"blur": labels_tensor[:, 4].clone(),
"expression": labels_tensor[:, 5].clone(),
"illumination": labels_tensor[:, 6].clone(),
"occlusion": labels_tensor[:, 7].clone(),
"pose": labels_tensor[:, 8].clone(),
"invalid": labels_tensor[:, 9].clone(),
},
}
)
Expand Down
Loading