Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about Text-to-Image training #380

Closed
xiaoxiaodadada opened this issue Jun 26, 2024 · 2 comments
Closed

Questions about Text-to-Image training #380

xiaoxiaodadada opened this issue Jun 26, 2024 · 2 comments

Comments

@xiaoxiaodadada
Copy link

I now have my own image text pair dataset, how can I modify this code to train it? If you can help me, I would be very grateful.

@Siddharth-Latthe-07
Copy link

To train a model using your own image-text pair dataset with the given setup, you'll need to modify the code to load your dataset, preprocess the data, and then feed it into the model for training. Here's a steps on how to do this:

  1. Prepare Your Dataset: Ensure your dataset is in a structured format, such as a CSV file, where each row contains an image file path and its corresponding text description.
  2. Modify the Dataset Loader: Update the data loading part of your code to read your dataset and preprocess the images and text.
  3. Update the Training Loop: Adjust the training loop to use your dataset and fine-tune the model accordingly.

sample code snippets:-
Modify the Dataset Loader

Copy code
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPProcessor

class CustomDataset(Dataset):
    def __init__(self, csv_file, processor, max_length=77):
        self.data = pd.read_csv(csv_file)
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 0]
        text = self.data.iloc[idx, 1]

        image = Image.open(img_path).convert("RGB")
        inputs = self.processor(text=text, images=image, return_tensors="pt", max_length=self.max_length, padding="max_length", truncation=True)

        return inputs

# Initialize processor (modify as per your model)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
dataset = CustomDataset(csv_file="dataset.csv", processor=processor)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

Update the Training Loop

import torch
from transformers import CLIPModel, CLIPProcessor

# Initialize the model (modify as per your model)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss()

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in dataloader:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].squeeze(1).to(device)
        pixel_values = batch['pixel_values'].squeeze(1).to(device)
        attention_mask = batch['attention_mask'].squeeze(1).to(device)

        outputs = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text

        labels = torch.arange(logits_per_image.size(0), device=device)
        loss = (loss_fn(logits_per_image, labels) + loss_fn(logits_per_text, labels)) / 2
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

# Save the model
model.save_pretrained("path_to_save_your_model")
processor.save_pretrained("path_to_save_your_processor")

Hope this helps
Thanks

@xiaoxiaodadada
Copy link
Author

Thank you very much, this helped me a lot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants