Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training support for SigLIP #31495

Merged
merged 10 commits into from
Jul 5, 2024
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/siglip.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The abstract from the paper is the following:
## Usage tips

- Usage of SigLIP is similar to [CLIP](clip). The main difference is the training loss, which does not require a global view of all the pairwise similarities of images and texts within a batch. One needs to apply the sigmoid activation function to the logits, rather than the softmax.
- Training is not yet supported. If you want to fine-tune SigLIP or train from scratch, refer to the loss function from [OpenCLIP](https://github.com/mlfoundations/open_clip/blob/73ad04ae7fb93ede1c02dc9040a828634cb1edf1/src/open_clip/loss.py#L307), which leverages various `torch.distributed` utilities.
- Training is supported but does not use `torch.distributed` utilities which may limit the scalability of batch size. However, DDP and FDSP works on single-node multi-gpu setup.
- When using the standalone [`SiglipTokenizer`] or [`SiglipProcessor`], make sure to pass `padding="max_length"` as that's how the model was trained.
- To get the same results as the pipeline, a prompt template of "This is a photo of {label}." should be used.

Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/siglip/modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,12 @@ def forward(

loss = None
if return_loss:
raise NotImplementedError("SigLIP loss to be implemented")
# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()

if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
Expand Down
32 changes: 4 additions & 28 deletions tests/models/siglip/test_modeling_siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,27 +335,19 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training(self):
pass

@unittest.skip
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant_false
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

Expand Down Expand Up @@ -481,22 +473,6 @@ def test_retain_grad_hidden_states_attentions(self):
def test_model_get_set_embeddings(self):
pass

@unittest.skip(reason="SiglipModel does not support training")
def test_training(self):
pass

@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing(self):
pass

@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass

@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
Expand Down