Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting stuck on generate_images #417

Open
m0nologuer opened this issue Feb 21, 2022 · 0 comments
Open

Getting stuck on generate_images #417

m0nologuer opened this issue Feb 21, 2022 · 0 comments

Comments

@m0nologuer
Copy link

m0nologuer commented Feb 21, 2022

Here's my code -- no idea what's happening

import torch
from dalle_pytorch import DiscreteVAE, DALLE

from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

BATCH_SIZE = 4
IMAGE_SIZE = 64
IMAGE_PATH = "."
EPOCHS = 1

vae = DiscreteVAE(
    image_size = IMAGE_SIZE,
    num_layers = 2,           # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 1024,        # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 256,       # codebook dimension
    hidden_dim = 32,          # hidden dimension
    num_resnet_blocks = 1,    # number of resnet blocks
    temperature = 0.9,        # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
)

##Train on images
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)
dataset = ImageFolder(
    IMAGE_PATH,
    T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

##Run training for several epochs
count = 0
for epoch in range(EPOCHS):
    for (images, labels) in iter(dataloader):
        loss = vae(images, return_loss = True)
        loss.backward()
        print(count)
        count = count + 1


#Train on text to images
dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 1000,    # vocab size for text
    text_seq_len = 16,         # text sequence length
    depth = 12,                 # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 1000, (BATCH_SIZE, 16))
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)

loss = dalle(text, images, return_loss = True)
loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(text)
img1 = images[0]
save_image(img1, 'img1.png')

print(images.shape) # (4, 3, 256, 256)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant