Skip to content

Commit

Permalink
Fix video batching to videollava (#32139)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Merve Noyan <mervenoyan@Merve-MacBook-Pro.local>
  • Loading branch information
2 people authored and itazap committed Jul 25, 2024
1 parent 72a6d4b commit 4352e10
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ def make_batched_videos(videos) -> List[VideoInput]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos

elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]) and len(videos[0].shape) == 4:
return [list(video) for video in videos]
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
if isinstance(videos[0], PIL.Image.Image):
return [videos]
elif len(videos[0].shape) == 4:
return [list(video) for video in videos]

elif is_valid_image(videos) and len(videos.shape) == 4:
return [list(videos)]
Expand Down
43 changes: 32 additions & 11 deletions tests/models/video_llava/test_image_processing_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F
torchify=torchify,
)

def prepare_video_inputs(self, equal_resolution=False, torchify=False):
numpify = not torchify
def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=False):
images = prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
Expand All @@ -108,15 +107,19 @@ def prepare_video_inputs(self, equal_resolution=False, torchify=False):
numpify=numpify,
torchify=torchify,
)

# let's simply copy the frames to fake a long video-clip
videos = []
for image in images:
if numpify:
video = image[None, ...].repeat(8, 0)
else:
video = image[None, ...].repeat(8, 1, 1, 1)
videos.append(video)
if numpify or torchify:
videos = []
for image in images:
if numpify:
video = image[None, ...].repeat(8, 0)
else:
video = image[None, ...].repeat(8, 1, 1, 1)
videos.append(video)
else:
videos = []
for pil_image in images:
videos.append([pil_image] * 8)

return videos

Expand Down Expand Up @@ -197,7 +200,7 @@ def test_call_numpy_videos(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
video_inputs = self.image_processor_tester.prepare_video_inputs(numpify=True, equal_resolution=True)
for video in video_inputs:
self.assertIsInstance(video, np.ndarray)

Expand All @@ -211,6 +214,24 @@ def test_call_numpy_videos(self):
expected_output_video_shape = (5, 8, 3, 18, 18)
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)

def test_call_pil_videos(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# the inputs come in list of lists batched format
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
for video in video_inputs:
self.assertIsInstance(video[0], Image.Image)

# Test not batched input
encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values_videos
expected_output_video_shape = (1, 8, 3, 18, 18)
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)

# Test batched
encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values_videos
expected_output_video_shape = (5, 8, 3, 18, 18)
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)

def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
Expand Down

0 comments on commit 4352e10

Please sign in to comment.