diff --git a/docs/source/en/model_doc/siglip.md b/docs/source/en/model_doc/siglip.md index 6ee237942ab99d..2795cfb711767c 100644 --- a/docs/source/en/model_doc/siglip.md +++ b/docs/source/en/model_doc/siglip.md @@ -27,7 +27,7 @@ The abstract from the paper is the following: ## Usage tips - Usage of SigLIP is similar to [CLIP](clip). The main difference is the training loss, which does not require a global view of all the pairwise similarities of images and texts within a batch. One needs to apply the sigmoid activation function to the logits, rather than the softmax. -- Training is not yet supported. If you want to fine-tune SigLIP or train from scratch, refer to the loss function from [OpenCLIP](https://github.com/mlfoundations/open_clip/blob/73ad04ae7fb93ede1c02dc9040a828634cb1edf1/src/open_clip/loss.py#L307), which leverages various `torch.distributed` utilities. +- Training is supported but does not use `torch.distributed` utilities which may limit the scalability of batch size. However, DDP and FDSP works on single-node multi-gpu setup. - When using the standalone [`SiglipTokenizer`] or [`SiglipProcessor`], make sure to pass `padding="max_length"` as that's how the model was trained. - To get the same results as the pipeline, a prompt template of "This is a photo of {label}." should be used. diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 4c534bbce6ce8a..068f2173d7d679 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -1234,7 +1234,12 @@ def forward( loss = None if return_loss: - raise NotImplementedError("SigLIP loss to be implemented") + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() if not return_dict: output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index a8e1bb7b0f1264..8bdc995e51906b 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -335,27 +335,19 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip - # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training + @unittest.skip(reason="SiglipTextModel does not support standalone training") def test_training(self): pass - @unittest.skip - # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing + @unittest.skip(reason="SiglipTextModel does not support standalone training") def test_training_gradient_checkpointing(self): pass - @unittest.skip( - reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant + @unittest.skip(reason="SiglipTextModel does not support standalone training") def test_training_gradient_checkpointing_use_reentrant(self): pass - @unittest.skip( - reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - ) - # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant_false + @unittest.skip(reason="SiglipTextModel does not support standalone training") def test_training_gradient_checkpointing_use_reentrant_false(self): pass @@ -481,22 +473,6 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="SiglipModel does not support training") - def test_training(self): - pass - - @unittest.skip(reason="SiglipModel does not support training") - def test_training_gradient_checkpointing(self): - pass - - @unittest.skip(reason="SiglipModel does not support training") - def test_training_gradient_checkpointing_use_reentrant(self): - pass - - @unittest.skip(reason="SiglipModel does not support training") - def test_training_gradient_checkpointing_use_reentrant_false(self): - pass - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") def test_initialization(self): pass