From ca50d2cb2ed733d22d04a502a93cd45fd5074c59 Mon Sep 17 00:00:00 2001 From: Merve Noyan Date: Tue, 23 Jul 2024 13:23:23 +0300 Subject: [PATCH] Fix video batching to videollava (#32139) --------- Co-authored-by: Merve Noyan --- .../image_processing_video_llava.py | 7 ++- .../test_image_processing_video_llava.py | 43 ++++++++++++++----- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/video_llava/image_processing_video_llava.py b/src/transformers/models/video_llava/image_processing_video_llava.py index 82ac5869c01740..943c2fe51a0ef4 100644 --- a/src/transformers/models/video_llava/image_processing_video_llava.py +++ b/src/transformers/models/video_llava/image_processing_video_llava.py @@ -55,8 +55,11 @@ def make_batched_videos(videos) -> List[VideoInput]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): return videos - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]) and len(videos[0].shape) == 4: - return [list(video) for video in videos] + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + if isinstance(videos[0], PIL.Image.Image): + return [videos] + elif len(videos[0].shape) == 4: + return [list(video) for video in videos] elif is_valid_image(videos) and len(videos.shape) == 4: return [list(videos)] diff --git a/tests/models/video_llava/test_image_processing_video_llava.py b/tests/models/video_llava/test_image_processing_video_llava.py index 4a5c2516267e13..03cfb033ffb91f 100644 --- a/tests/models/video_llava/test_image_processing_video_llava.py +++ b/tests/models/video_llava/test_image_processing_video_llava.py @@ -97,8 +97,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F torchify=torchify, ) - def prepare_video_inputs(self, equal_resolution=False, torchify=False): - numpify = not torchify + def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=False): images = prepare_image_inputs( batch_size=self.batch_size, num_channels=self.num_channels, @@ -108,15 +107,19 @@ def prepare_video_inputs(self, equal_resolution=False, torchify=False): numpify=numpify, torchify=torchify, ) - # let's simply copy the frames to fake a long video-clip - videos = [] - for image in images: - if numpify: - video = image[None, ...].repeat(8, 0) - else: - video = image[None, ...].repeat(8, 1, 1, 1) - videos.append(video) + if numpify or torchify: + videos = [] + for image in images: + if numpify: + video = image[None, ...].repeat(8, 0) + else: + video = image[None, ...].repeat(8, 1, 1, 1) + videos.append(video) + else: + videos = [] + for pil_image in images: + videos.append([pil_image] * 8) return videos @@ -197,7 +200,7 @@ def test_call_numpy_videos(self): # Initialize image_processing image_processing = self.image_processing_class(**self.image_processor_dict) # create random numpy tensors - video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True) + video_inputs = self.image_processor_tester.prepare_video_inputs(numpify=True, equal_resolution=True) for video in video_inputs: self.assertIsInstance(video, np.ndarray) @@ -211,6 +214,24 @@ def test_call_numpy_videos(self): expected_output_video_shape = (5, 8, 3, 18, 18) self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + def test_call_pil_videos(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # the inputs come in list of lists batched format + video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True) + for video in video_inputs: + self.assertIsInstance(video[0], Image.Image) + + # Test not batched input + encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values_videos + expected_output_video_shape = (1, 8, 3, 18, 18) + self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + + # Test batched + encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values_videos + expected_output_video_shape = (5, 8, 3, 18, 18) + self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape) + def test_call_pytorch(self): # Initialize image_processing image_processing = self.image_processing_class(**self.image_processor_dict)