Skip to content

Commit

Permalink
Add support for fine-tuning CLIP-like models using contrastive-image-…
Browse files Browse the repository at this point in the history
…text example (#29070)

* add support for siglip and chinese-clip model training with contrastive-image-text example

* codebase fixups
  • Loading branch information
tjs-intel committed Feb 20, 2024
1 parent 0996a10 commit ee3af60
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
("camembert", "CamembertConfig"),
("canine", "CanineConfig"),
("chinese_clip", "ChineseCLIPConfig"),
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
("clap", "ClapConfig"),
("clip", "CLIPConfig"),
("clip_vision_model", "CLIPVisionConfig"),
Expand Down Expand Up @@ -512,6 +513,7 @@
("camembert", "CamemBERT"),
("canine", "CANINE"),
("chinese_clip", "Chinese-CLIP"),
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "CLAP"),
("clip", "CLIP"),
("clip_vision_model", "CLIPVisionModel"),
Expand Down Expand Up @@ -773,6 +775,7 @@
("xclip", "x_clip"),
("clip_vision_model", "clip"),
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
]
)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
("camembert", "CamembertModel"),
("canine", "CanineModel"),
("chinese_clip", "ChineseCLIPModel"),
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "ClapModel"),
("clip", "CLIPModel"),
("clip_vision_model", "CLIPVisionModel"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ class ChineseCLIPVisionConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`ChineseCLIPModel`]. It is used to instantiate an
ChineseCLIP model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the ChineseCLIP
[OFA-Sys/chinese-clip-vit-base-patch16](https:
//huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
[OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..chinese_clip.configuration_chinese_clip import ChineseCLIPVisionConfig
from ..clip.configuration_clip import CLIPVisionConfig
from ..siglip.configuration_siglip import SiglipVisionConfig


logger = logging.get_logger(__name__)

VISION_MODEL_CONFIGS = {
"clip_vision_model": CLIPVisionConfig,
"chinese_clip_vision_model": ChineseCLIPVisionConfig,
"siglip_vision_model": SiglipVisionConfig,
}


class VisionTextDualEncoderConfig(PretrainedConfig):
r"""
Expand Down Expand Up @@ -85,12 +93,13 @@ def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs):
vision_model_type = vision_config.pop("model_type")
text_model_type = text_config.pop("model_type")

if vision_model_type == "clip":
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
elif vision_model_type == "clip_vision_model":
self.vision_config = CLIPVisionConfig(**vision_config)
vision_config_class = VISION_MODEL_CONFIGS.get(vision_model_type)
if vision_config_class is not None:
self.vision_config = vision_config_class(**vision_config)
else:
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
if hasattr(self.vision_config, "vision_config"):
self.vision_config = self.vision_config.vision_config

self.text_config = AutoConfig.for_model(text_model_type, **text_config)

Expand Down
1 change: 1 addition & 0 deletions utils/check_copies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,7 @@ def check_model_list_copy(overwrite: bool = False):
"VisionTextDualEncoder",
"CLIPVisionModel",
"SiglipVisionModel",
"ChineseCLIPVisionModel",
]

# Template for new entries to add in the main README when we have missing models.
Expand Down
2 changes: 1 addition & 1 deletion utils/check_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _center_text(text: str, width: int) -> str:
"XLS-R": "Wav2Vec2",
"XLSR-Wav2Vec2": "Wav2Vec2",
}
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel"]
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel", "ChineseCLIPVisionModel"]


def get_model_table_from_auto_modules() -> str:
Expand Down

0 comments on commit ee3af60

Please sign in to comment.