Skip to content

Commit

Permalink
VideoLLaVa: fix chat format in docs (huggingface#32083)
Browse files Browse the repository at this point in the history
fix chat format
  • Loading branch information
zucchini-nlp authored and amyeroberts committed Jul 19, 2024
1 parent ac2cb9d commit 33a6cbf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions docs/source/en/model_doc/video_llava.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ indices = np.arange(0, total_frames, total_frames / 8).astype(int)
video = read_video_pyav(container, indices)

# For better results, we recommend to prompt the model in the following format
prompt = "USER: <video>Why is this funny? ASSISTANT:"
prompt = "USER: <video>\nWhy is this funny? ASSISTANT:"
inputs = processor(text=prompt, videos=video, return_tensors="pt")

out = model.generate(**inputs, max_new_tokens=60)
Expand All @@ -108,7 +108,7 @@ processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spac
For multiple turns conversation change the prompt format to:

```bash
"USER: <video>What do you see in this video? ASSISTANT: A baby reading a book. USER: Why is the it funny? ASSISTANT:"
"USER: <video>\nWhat do you see in this video? ASSISTANT: A baby reading a book. USER: Why is the it funny? ASSISTANT:"
```

### Mixed Media Mode
Expand All @@ -123,7 +123,7 @@ import requests
# Load and image and write a new prompt
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "USER: <image> How many cats are there in the image? ASSISTANT: There are two cats. USER: <video>Why is this video funny? ASSISTANT:"
prompt = "USER: <image>\nHow many cats are there in the image? ASSISTANT: There are two cats. USER: <video>\nWhy is this video funny? ASSISTANT:"

inputs = processor(text=prompt, images=image, videos=clip, padding=True, return_tensors="pt")

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def forward(
>>> model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
>>> processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
>>> prompt = "USER: <video>Why is this video funny? ASSISTANT:"
>>> prompt = "USER: <video>\nWhy is this video funny? ASSISTANT:"
>>> video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
>>> container = av.open(video_path)
Expand All @@ -476,8 +476,8 @@ def forward(
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = [
... "USER: <image> How many cats do you see? ASSISTANT:",
... "USER: <video>Why is this video funny? ASSISTANT:"
... "USER: <image>\nHow many cats do you see? ASSISTANT:",
... "USER: <video>\nWhy is this video funny? ASSISTANT:"
... ]
>>> inputs = processor(text=prompt, images=image, videos=clip, padding=True, return_tensors="pt")
Expand Down

0 comments on commit 33a6cbf

Please sign in to comment.