Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noise on all image for training #315

Open
axel588 opened this issue Feb 15, 2023 · 3 comments
Open

Noise on all image for training #315

axel588 opened this issue Feb 15, 2023 · 3 comments

Comments

@axel588
Copy link

axel588 commented Feb 15, 2023

Hello,

I'm trying to train on 50 000 16x16 images with alpha channels (RGBA)
But training on multiple steps doesnt give me any decent result,
training on the cli is very very slow

After training on A100 for 12 hours I still get a completly noisy image.
How could I make smaller unet that are convenient for 16x16 images with alpha channel.
I know this architecture doesnt like small images. Still is it possible to make this architecture more efficient in training time and not giving noisy images.

Thanks by the way for the work you've done :)

import torch
from imagen_pytorch import Unet, Imagen

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    channels = 4,
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    channels = 4,
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    channels = 4,
    image_sizes = (16, 64),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

#My text embeds and images dataset here (I removed the code but I have a (50000, 4, 16, 16)

# feed images into imagen, training each unet in the cascade
for c in range(0,30000):
    for i in (1, 2):
        loss = imagen(images, text_embeds = text_embeds, unet_number = i)
        loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3.)

images.shape # (3, 4, 16, 16)
@alif-munim
Copy link

Hey! I was having the same issue for a while. Try wrapping your imagen with the ImagenTrainer module (mentioned in the README.md) and using trainer.train_step() for your gradient updates.

import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = 't5-large',
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# wrap imagen with the trainer class

trainer = ImagenTrainer(imagen)

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(64, 256, 1024).cuda()
images = torch.randn(64, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = trainer(
    images,
    text_embeds = text_embeds,
    unet_number = 1,            # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
    max_batch_size = 4          # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)

trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = trainer.sample(texts = [
    'a puppy looking anxiously at a giant donut on the table',
    'the milky way galaxy in the style of monet'
], cond_scale = 3.)

images.shape # (2, 3, 256, 256)

More info and an example training script is mentioned in this thread: #305

@axel578
Copy link

axel578 commented Feb 16, 2023

``Thanks it helped a lot @alif-munim ! But I Have a dataset on hugginface :
DatasetDict({
train: Dataset({
features: ['image', 'text'],
num_rows: 725277
})
})
the images are PngImageFile array and the text are strings, I can't manage to create a custom dataset, right now I have this for the dataset:

import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torchvision import transforms
from tqdm import tqdm

class MCDataset(Dataset):
    def __init__(self, dataset_dict, is_train=True, is_skip=False):
        super().__init__()

        
        self.dataset = dataset_dict["train"]
        self.is_train = is_train
        self.transform = transforms.Compose([
            transforms.Resize(16),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x if x.shape[0] == 4 else torch.cat([x, torch.full_like(x[:1], 255)]))
        ])
        # Convert to RGBA

        for i, sample in enumerate(self.dataset):
            if i % 10000 == 0:
              print("Images :"+str(i))
            image = sample["image"]
            if not isinstance(image, Image.Image):
                image = Image.open(image)
            if image.mode != "RGBA":
                image = image.convert("RGBA")
            self.dataset[i]["image"] = image

        # Split train and validation sets
        if is_train:
            self.dataset = self.dataset[:int(0.98 * len(self.dataset))]
        else:
            self.dataset = self.dataset[int(0.98 * len(self.dataset)):]
        print('Preparing text encoding')
        # Text encoding
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        self.tokenizer = T5Tokenizer.from_pretrained("t5-base", model_max_length=512)
        self.model = T5ForConditionalGeneration.from_pretrained("t5-base").to(self.device)
        self.model.eval()
        print('Finished preparing text encoding')
        self.texts = []
        u = 0

        for text in tqdm(self.dataset["text"]):
            self.texts.append(text)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset["image"][idx]
        image = self.transform(image)
        enc = self.tokenizer(self.texts[idx], return_tensors="pt", padding="max_length",
            max_length=512).to(device)

        # forward pass through encoder only
        output = self.model.encoder(
            input_ids=enc["input_ids"].to(self.device), 
            attention_mask=enc["attention_mask"].to(self.device), 
            return_dict=True
        )
        # get the final hidden states
        emb = output.last_hidden_state.cpu()
        
        return image, emb

But those are the dimension for the text :

print(train_db[0][0].shape)
print(train_db[1][0].shape)
print(train_db[2][0].shape)

print(train_db[0][1].shape)
print(train_db[1][1].shape)
print(train_db[2][1].shape)
#image
torch.Size([4, 16, 16])
torch.Size([4, 16, 16])
torch.Size([4, 16, 16])
#text
torch.Size([1, 1, 512, 768])
torch.Size([1, 1, 512, 768])
torch.Size([1, 1, 512, 768])
what should be the shape of the embeded text; I've seen somethings like (4, 256, 768) in the examples, but I can't manage to understand why using that kind of dimension.

@alif-munim
Copy link

@axel578 if you use the built-in t5 text encoding functions, you should get the correct dimensionality for your text embeddings. See https://github.com/lucidrains/imagen-pytorch#L26

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants