From f514ab64baae4862488e068c38829286ae9296fc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Aug 2023 17:24:30 +0200 Subject: [PATCH] test output_format in video datasets (#7879) --- test/datasets_utils.py | 22 +++++++++++++++++----- torchvision/datasets/video_utils.py | 6 +++--- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 8afc6ddb369..f7a1b8dd3de 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -662,27 +662,39 @@ class VideoDatasetTestCase(DatasetTestCase): FEATURE_TYPES = (torch.Tensor, torch.Tensor, int) REQUIRED_PACKAGES = ("av",) - DEFAULT_FRAMES_PER_CLIP = 1 + FRAMES_PER_CLIP = 1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dataset_args = self._set_default_frames_per_clip(self.dataset_args) - def _set_default_frames_per_clip(self, inject_fake_data): + def _set_default_frames_per_clip(self, dataset_args): argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__) args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)] frames_per_clip_last = args_without_default[-1] == "frames_per_clip" - @functools.wraps(inject_fake_data) + @functools.wraps(dataset_args) def wrapper(tmpdir, config): - args = inject_fake_data(tmpdir, config) + args = dataset_args(tmpdir, config) if frames_per_clip_last and len(args) == len(args_without_default) - 1: - args = (*args, self.DEFAULT_FRAMES_PER_CLIP) + args = (*args, self.FRAMES_PER_CLIP) return args return wrapper + def test_output_format(self): + for output_format in ["TCHW", "THWC"]: + with self.create_dataset(output_format=output_format) as (dataset, _): + for video, *_ in dataset: + if output_format == "TCHW": + num_frames, num_channels, *_ = video.shape + else: # output_format == "THWC": + num_frames, *_, num_channels = video.shape + + assert num_frames == self.FRAMES_PER_CLIP + assert num_channels == 3 + @test_all_configs def test_transforms_v2_wrapper(self, config): # `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index bb1974b7a4f..df55518de37 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -187,9 +187,9 @@ def subset(self, indices: List[int]) -> "VideoClips": } return type(self)( video_paths, - self.num_frames, - self.step, - self.frame_rate, + clip_length_in_frames=self.num_frames, + frames_between_clips=self.step, + frame_rate=self.frame_rate, _precomputed_metadata=metadata, num_workers=self.num_workers, _video_width=self._video_width,