Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Video Llava #29733

Merged
merged 57 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
dce6678
add model draft
zucchini-nlp Mar 19, 2024
72626df
update docstring
zucchini-nlp Mar 20, 2024
8cca731
add tests
zucchini-nlp Mar 20, 2024
4ea4f70
support image and video as input
zucchini-nlp Mar 20, 2024
c36819d
update for better handling of mixed input and clean-up a bit
zucchini-nlp Mar 21, 2024
c1a8fd5
bug when mixed inputs & add tests
zucchini-nlp Apr 8, 2024
c591c75
Update README.md
zucchini-nlp Apr 8, 2024
5ff8d18
Merge remote-tracking branch 'upstream/main' into video_llava
zucchini-nlp Apr 8, 2024
a6bc68d
link to abstract of paper in README
zucchini-nlp Apr 8, 2024
eb309ed
fix test
zucchini-nlp Apr 8, 2024
2f46f6c
fix-copies
zucchini-nlp Apr 8, 2024
6b51b7e
Merge branch 'main' into video_llava
zucchini-nlp Apr 8, 2024
e112958
make tests happy
zucchini-nlp Apr 8, 2024
5cb6163
skip docstest for now
zucchini-nlp Apr 10, 2024
930147d
do not run doctest for now
zucchini-nlp Apr 18, 2024
24ec2b3
Merge remote-tracking branch 'upstream/main' into video_llava
zucchini-nlp Apr 18, 2024
142bfc0
Update src/transformers/models/video_llava/processing_video_llava.py
zucchini-nlp Apr 22, 2024
fdec895
Update src/transformers/models/video_llava/image_processing_video_lla…
zucchini-nlp Apr 22, 2024
e83251c
Update src/transformers/models/video_llava/image_processing_video_lla…
zucchini-nlp Apr 22, 2024
4fcfe72
Update src/transformers/models/video_llava/image_processing_video_lla…
zucchini-nlp Apr 22, 2024
327030d
Update src/transformers/models/video_llava/image_processing_video_lla…
zucchini-nlp Apr 22, 2024
33289a5
Update tests/models/video_llava/test_modeling_video_llava.py
zucchini-nlp Apr 22, 2024
dfef75a
Update src/transformers/models/video_llava/image_processing_video_lla…
zucchini-nlp Apr 22, 2024
ebf1042
address review comments
zucchini-nlp Apr 22, 2024
aa1b278
failing tests
zucchini-nlp Apr 22, 2024
7802922
Fix vocab_size in common tests for VLMs
zucchini-nlp Apr 23, 2024
9fce414
codestyle
zucchini-nlp Apr 23, 2024
e8b4569
Merge branch 'huggingface:main' into video_llava
zucchini-nlp Apr 23, 2024
bb1cc26
Update src/transformers/models/video_llava/configuration_video_llava.py
zucchini-nlp Apr 29, 2024
e2e92b2
Update src/transformers/models/video_llava/configuration_video_llava.py
zucchini-nlp Apr 29, 2024
5c77fff
Update src/transformers/models/video_llava/modeling_video_llava.py
zucchini-nlp Apr 29, 2024
99518cb
Update src/transformers/models/video_llava/modeling_video_llava.py
zucchini-nlp Apr 29, 2024
451fd72
Update docs/source/en/model_doc/video_llava.md
zucchini-nlp Apr 30, 2024
95a9a01
Update docs/source/en/model_doc/video_llava.md
zucchini-nlp Apr 30, 2024
347fa8c
Update src/transformers/models/video_llava/image_processing_video_lla…
zucchini-nlp Apr 30, 2024
3e2f1b4
Update docs/source/en/model_doc/video_llava.md
zucchini-nlp Apr 30, 2024
3cd1222
Update src/transformers/models/video_llava/processing_video_llava.py
zucchini-nlp Apr 30, 2024
242703a
Update tests/models/video_llava/test_modeling_video_llava.py
zucchini-nlp Apr 30, 2024
9c1a10d
Update tests/models/video_llava/test_modeling_video_llava.py
zucchini-nlp Apr 30, 2024
b4145e1
Update tests/models/video_llava/test_modeling_video_llava.py
zucchini-nlp Apr 30, 2024
5803d5a
PR suggestions
zucchini-nlp Apr 30, 2024
975d959
fix-copies
zucchini-nlp Apr 30, 2024
7f30e3b
Merge branch 'main' into video_llava
zucchini-nlp Apr 30, 2024
6bdad81
Merge branch 'huggingface:main' into video_llava
zucchini-nlp May 1, 2024
a817f31
Update src/transformers/models/video_llava/configuration_video_llava.py
zucchini-nlp May 8, 2024
dba80e2
Update src/transformers/models/video_llava/configuration_video_llava.py
zucchini-nlp May 8, 2024
6b3eafb
Merge remote-tracking branch 'upstream/main' into video_llava
zucchini-nlp May 8, 2024
ba4e125
add full example in docs
zucchini-nlp May 8, 2024
6cc8af1
clean-up with new model-id
zucchini-nlp May 10, 2024
885a5ae
[run-slow] video_llava
zucchini-nlp May 10, 2024
377aafe
update docstring
zucchini-nlp May 10, 2024
637b197
Merge branch 'main' into video_llava
zucchini-nlp May 10, 2024
a411347
[run-slow] video_llava
zucchini-nlp May 10, 2024
0d83eaf
Merge branch 'huggingface:main' into video_llava
zucchini-nlp May 14, 2024
8134039
remove all achive maps
zucchini-nlp May 15, 2024
8e15514
fix some tests
zucchini-nlp May 15, 2024
5d1e976
test was supposed to be skipped for llava :)
zucchini-nlp May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README_fr.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ Suivez les pages d'installation de Flax, PyTorch ou TensorFlow pour voir comment

Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen)


🤗 Transformers fournit actuellement les architectures suivantes: consultez [ici](https://huggingface.co/docs/transformers/model_summary) pour un résumé global de chacune d'entre elles.

Pour vérifier si chaque modèle a une implémentation en Flax, PyTorch ou TensorFlow, ou s'il a un tokenizer associé pris en charge par la bibliothèque 🤗 Tokenizers, consultez [ce tableau](https://huggingface.co/docs/transformers/index#supported-frameworks).
Expand Down
1 change: 0 additions & 1 deletion README_te.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ Flax, PyTorch లేదా TensorFlow యొక్క ఇన్‌స్టా

🤗 ట్రాన్స్‌ఫార్మర్లు ప్రస్తుతం కింది ఆర్కిటెక్చర్‌లను అందజేస్తున్నాయి: వాటిలో ప్రతి ఒక్కటి ఉన్నత స్థాయి సారాంశం కోసం [ఇక్కడ](https://huggingface.co/docs/transformers/model_summary) చూడండి.


ఈ అమలులు అనేక డేటాసెట్‌లలో పరీక్షించబడ్డాయి (ఉదాహరణ స్క్రిప్ట్‌లను చూడండి) మరియు అసలైన అమలుల పనితీరుతో సరిపోలాలి. మీరు [డాక్యుమెంటేషన్](https://github.com/huggingface/transformers/tree/main/examples) యొక్క ఉదాహరణల విభాగంలో పనితీరుపై మరిన్ని వివరాలను కనుగొనవచ్చు.

## ఇంకా నేర్చుకో
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,8 @@
title: TVP
- local: model_doc/udop
title: UDOP
- local: model_doc/video_llava
title: VideoLlava
- local: model_doc/vilt
title: ViLT
- local: model_doc/vipllava
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ Flax), PyTorch, and/or TensorFlow.
| [UnivNet](model_doc/univnet) | ✅ | ❌ | ❌ |
| [UPerNet](model_doc/upernet) | ✅ | ❌ | ❌ |
| [VAN](model_doc/van) | ✅ | ❌ | ❌ |
| [VideoLlava](model_doc/video_llava) | ✅ | ❌ | ❌ |
| [VideoMAE](model_doc/videomae) | ✅ | ❌ | ❌ |
| [ViLT](model_doc/vilt) | ✅ | ❌ | ❌ |
| [VipLlava](model_doc/vipllava) | ✅ | ❌ | ❌ |
Expand Down
129 changes: 129 additions & 0 deletions docs/source/en/model_doc/video_llava.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Video-LLaVA

## Overview

Video-LLaVa is an open-source multimodal LLM trained by fine-tuning LlamA/Vicuna on multimodal instruction-following data generated by Llava1.5 and VideChat. It is an auto-regressive language model, based on the transformer architecture. Video-LLaVa unifies visual representations to the language feature space, and enables an LLM to perform visual reasoning capabilities on both images and videos simultaneously.


The Video-LLaVA model was proposed in [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/abs/2311.10122) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan.

The abstract from the paper is the following:

*The Large Vision-Language Model (LVLM) has enhanced the performance of various downstream tasks in
visual-language understanding. Most existing approaches
encode images and videos into separate feature spaces,
which are then fed as inputs to large language models.
However, due to the lack of unified tokenization for images and videos, namely misalignment before projection, it
becomes challenging for a Large Language Model (LLM)
to learn multi-modal interactions from several poor projection layers. In this work, we unify visual representation into the language feature space to advance the foundational LLM towards a unified LVLM. As a result, we establish a simple but robust LVLM baseline, Video-LLaVA,
which learns from a mixed dataset of images and videos,
mutually enhancing each other. Video-LLaVA achieves superior performances on a broad range of 9 image benchmarks across 5 image question-answering datasets and 4
image benchmark toolkits. Additionally, our Video-LLaVA
also outperforms Video-ChatGPT by 5.8%, 9.9%, 18.6%,
and 10.1% on MSRVTT, MSVD, TGIF, and ActivityNet, respectively. Notably, extensive experiments demonstrate that
Video-LLaVA mutually benefits images and videos within
a unified visual representation, outperforming models designed specifically for images or videos. We aim for this
work to provide modest insights into the multi-modal inputs
for the LLM*

Tips:

- We advise users to use padding_side="left" when computing batched generation as it leads to more accurate results. Simply make sure to call processor.tokenizer.padding_side = "left" before generating.

- Note the model has not been explicitly trained to process multiple images/videos in the same prompt, although this is technically possible, you may experience inaccurate results.

- For better results, we recommend users prompt the model with the correct prompt format:


```python
import av
import torch
import numpy as np
import requests
from PIL import Image
from transformers import VideoLlavaForConditionalGeneration, VideoLlavaProcessor

def read_video_pyav(container, indices):
'''
Decode the video with PyAV decoder.
Args:
container (`av.container.input.InputContainer`): PyAV container.
indices (`List[int]`): List of frame indices to decode.
Returns:
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
'''
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])


model = VideoLlavaForConditionalGeneration.from_pretrained("RaushanTurganbay/video-llava-7b-hf", device_map="auto")
processor = VideoLlavaProcessor.from_pretrained("RaushanTurganbay/video-llava-7b-hf")

video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")

container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
video = read_video_pyav(container, indices)

prompt = "USER: <video>Why is this funny? ASSISTANT:"
inputs = processor(text=prompt, videos=video, return_tensors="pt")

out = model.generate(**inputs, max_new_tokens=40)
print(processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True))
```

For multiple turns conversation change the prompt to:

```bash
"USER: <video>What do you see in this video? ASSISTANT: A baby reading a book. USER: Why is the it funny? ASSISTANT:"
```

- Note that the video inputs should have exactly 8 frames at the input, since the models were trained in that setting.



This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
The original code can be found [here](https://github.com/PKU-YuanGroup/Video-LLaVA).


## VideoLlavaConfig

[[autodoc]] VideoLlavaConfig

## VideoLlavaImageProcessor

[[autodoc]] VideoLlavaImageProcessor

## VideoLlavaProcessor

[[autodoc]] VideoLlavaProcessor

## VideoLlavaForConditionalGeneration

[[autodoc]] VideoLlavaForConditionalGeneration
- forward
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@
"UnivNetFeatureExtractor",
],
"models.upernet": ["UperNetConfig"],
"models.video_llava": ["VideoLlavaConfig"],
"models.videomae": ["VideoMAEConfig"],
"models.vilt": [
"ViltConfig",
Expand Down Expand Up @@ -3215,6 +3216,14 @@
"UperNetPreTrainedModel",
]
)
_import_structure["models.video_llava"].extend(
[
"VideoLlavaForConditionalGeneration",
"VideoLlavaImageProcessor",
"VideoLlavaPreTrainedModel",
"VideoLlavaProcessor",
]
)
_import_structure["models.videomae"].extend(
[
"VideoMAEForPreTraining",
Expand Down Expand Up @@ -5281,6 +5290,7 @@
UnivNetFeatureExtractor,
)
from .models.upernet import UperNetConfig
from .models.video_llava import VideoLlavaConfig
from .models.videomae import VideoMAEConfig
from .models.vilt import (
ViltConfig,
Expand Down Expand Up @@ -7384,6 +7394,12 @@
UperNetForSemanticSegmentation,
UperNetPreTrainedModel,
)
from .models.video_llava import (
VideoLlavaForConditionalGeneration,
VideoLlavaImageProcessor,
VideoLlavaPreTrainedModel,
VideoLlavaProcessor,
)
from .models.videomae import (
VideoMAEForPreTraining,
VideoMAEForVideoClassification,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
] # noqa


VideoInput = Union[np.ndarray, "torch.Tensor", List[np.ndarray], List["torch.Tensor"]] # noqa


class ChannelDimension(ExplicitEnum):
FIRST = "channels_first"
LAST = "channels_last"
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@
unispeech_sat,
univnet,
upernet,
video_llava,
videomae,
vilt,
vipllava,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@
("univnet", "UnivNetConfig"),
("upernet", "UperNetConfig"),
("van", "VanConfig"),
("video_llava", "VideoLlavaConfig"),
("videomae", "VideoMAEConfig"),
("vilt", "ViltConfig"),
("vipllava", "VipLlavaConfig"),
Expand Down Expand Up @@ -538,6 +539,7 @@
("univnet", "UnivNet"),
("upernet", "UPerNet"),
("van", "VAN"),
("video_llava", "VideoLlava"),
("videomae", "VideoMAE"),
("vilt", "ViLT"),
("vipllava", "VipLlava"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
("udop", "LayoutLMv3ImageProcessor"),
("upernet", "SegformerImageProcessor"),
("van", "ConvNextImageProcessor"),
("video_llava", "VideoLlavaImageProcessor"),
("videomae", "VideoMAEImageProcessor"),
("vilt", "ViltImageProcessor"),
("vipllava", "CLIPImageProcessor"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@
("tvlt", "TvltForPreTraining"),
("unispeech", "UniSpeechForPreTraining"),
("unispeech-sat", "UniSpeechSatForPreTraining"),
("video_llava", "VideoLlavaForConditionalGeneration"),
("videomae", "VideoMAEForPreTraining"),
("vipllava", "VipLlavaForConditionalGeneration"),
("visual_bert", "VisualBertForPreTraining"),
Expand Down Expand Up @@ -696,6 +697,7 @@
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("video_llava", "VideoLlavaForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
("tvp", "TvpProcessor"),
("unispeech", "Wav2Vec2Processor"),
("unispeech-sat", "Wav2Vec2Processor"),
("video_llava", "VideoLlavaProcessor"),
("vilt", "ViltProcessor"),
("vipllava", "LlavaProcessor"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
Expand Down
71 changes: 71 additions & 0 deletions src/transformers/models/video_llava/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available


_import_structure = {
"configuration_video_llava": ["VideoLlavaConfig"],
"processing_video_llava": ["VideoLlavaProcessor"],
}

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_video_llava"] = ["VideoLlavaImageProcessor"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_video_llava"] = [
"VideoLlavaPreTrainedModel",
"VideoLlavaForConditionalGeneration",
]

if TYPE_CHECKING:
from .configuration_video_llava import (
VideoLlavaConfig,
)
from .image_processing_video_llava import VideoLlavaProcessor

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_video_llava import VideoLlavaImageProcessor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_video_llava import (
VideoLlavaForConditionalGeneration,
VideoLlavaPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading
Loading