Skip to content

Commit

Permalink
[Bugfix][Model] Add base class for vision-language models (#4809)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed May 19, 2024
1 parent 2e9a222 commit f68470e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 29 deletions.
9 changes: 9 additions & 0 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

from vllm.model_executor.models import _MODELS, ModelRegistry


@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
# Ensure all model classes can be imported successfully
ModelRegistry.load_model_cls(model_cls)
13 changes: 7 additions & 6 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
download_weights_from_hf, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration

_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase

logger = init_logger(__name__)

Expand Down Expand Up @@ -73,7 +69,12 @@ def _get_model_initialization_kwargs(
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
elif model_class in _VISION_MODEL_CLASSES:
elif issubclass(model_class, VisionLanguageModelBase):
if vision_language_config is None:
raise ValueError("Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")

extra_kwargs["vision_language_config"] = vision_language_config
return extra_kwargs

Expand Down
48 changes: 25 additions & 23 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput

from .vlm_base import VisionLanguageModelBase

_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
Expand All @@ -40,7 +42,7 @@ def __init__(self, vision_hidden_size: int, text_hidden_size: int,
text_hidden_size,
bias=True)

def forward(self, image_features):
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
Expand All @@ -50,29 +52,31 @@ def forward(self, image_features):
def _merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int):
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
inputs_embeds[mask] = vision_embeddings.view(-1,

image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
if mask.sum() != image_feature_size:
raise ValueError(f"image_feature_size should be {image_feature_size}, "
f"but found: {mask.sum()}")

inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
vision_embeddings.shape[-1])

return inputs_embeds

class LlavaForConditionalGeneration(nn.Module):

class LlavaForConditionalGeneration(VisionLanguageModelBase):

def __init__(self,
config: "LlavaConfig",
config: LlavaConfig,
vision_language_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__()
self.config = config

self.vision_language_config = vision_language_config
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config)

assert self.vision_language_config, (
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")
self.config = config

if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
Expand All @@ -98,14 +102,12 @@ def __init__(self,
config.vocab_size, logit_scale)
self.sampler = Sampler()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None
) -> SamplerOutput: # noqa: E501
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None) -> SamplerOutput:
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
Expand Down Expand Up @@ -172,7 +174,7 @@ def forward(
image_features = image_input
vision_embeddings = self.multi_modal_projector(image_features)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
_merge_vision_embeddings(
inputs_embeds = _merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
input_ids = None
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/models/vlm_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch import nn

from vllm.config import VisionLanguageConfig


class VisionLanguageModelBase(nn.Module):
"""Base class for all vision language models (VLMs)."""

def __init__(self, vision_language_config: VisionLanguageConfig) -> None:
super().__init__()

self.vision_language_config = vision_language_config

0 comments on commit f68470e

Please sign in to comment.