Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove videos from test for DatasetFolder #7216

Merged
merged 1 commit into from
Feb 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,27 +1528,16 @@ def test_split(self, config):
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DatasetFolder

# The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader
# that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method.
FEATURE_TYPES = (str, int)

_IMAGE_EXTENSIONS = ("jpg", "png")
_VIDEO_EXTENSIONS = ("avi", "mp4")
_EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS)
_EXTENSIONS = ("jpg", "png")

# DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required.
# We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the
# 'test_is_valid_file()' method.
DEFAULT_CONFIG = dict(extensions=_EXTENSIONS)
ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]),
dict(extensions=_IMAGE_EXTENSIONS),
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]),
dict(extensions=_VIDEO_EXTENSIONS),
)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(extensions=[(ext,) for ext in _EXTENSIONS])

def dataset_args(self, tmpdir, config):
return tmpdir, lambda x: x
return tmpdir, datasets.folder.pil_loader

def inject_fake_data(self, tmpdir, config):
extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"])
Expand All @@ -1559,14 +1548,8 @@ def inject_fake_data(self, tmpdir, config):
if ext not in extensions:
continue

create_example_folder = (
datasets_utils.create_image_folder
if ext in self._IMAGE_EXTENSIONS
else datasets_utils.create_video_folder
)

num_examples = torch.randint(1, 3, size=()).item()
create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples)
datasets_utils.create_image_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples)

num_examples_total += num_examples
classes.append(cls)
Expand Down