Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samples after training only contain white noise. #347

Open
Valerie9696 opened this issue May 14, 2023 · 7 comments
Open

Samples after training only contain white noise. #347

Valerie9696 opened this issue May 14, 2023 · 7 comments

Comments

@Valerie9696
Copy link

I wanted to try out training imagen and generating some samples. Therefore, I ran this part of the script from the readme:

import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer

unet for imagen

unet1 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
)

unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)

imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
unets = (unet1, unet2),
text_encoder_name = 't5-large',
image_sizes = (64, 256),
timesteps = 1000,
cond_drop_prob = 0.1
).cuda()

wrap imagen with the trainer class

trainer = ImagenTrainer(imagen)

mock images (get a lot of this) and text encodings from large T5

text_embeds = my text embeddings#torch.randn(64, 256, 1024).cuda()
images = my sample images (3k for a first try) #torch.randn(64, 3, 256, 256).cuda()

feed images into imagen, training each unet in the cascade

loss = trainer(
images,
text_embeds = text_embeds,
unet_number = 1, # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
max_batch_size = 4 # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)

trainer.update(unet_number = 1)

do the above for many many many many steps
now you can sample an image based on the text embeddings from the cascading ddpm

images = trainer.sample(texts = [
'a puppy looking anxiously at a giant donut on the table',
'the milky way galaxy in the style of monet'
], cond_scale = 3.)

images.shape # (2, 3, 256, 256)

Thereafter, I run the following in order to make my results visible:

for img in images:
transform = transforms.ToPILImage()
i = transform(img)
i.show()

The training is running without an error, but at the end, the only thing that shows are images full of white noise. Does anyone know where I am making the mistake here? Or is 3k simply not enough to get anything out of it? (I know its not much, but unfortunately the maximum that my device can handle).
This is one such result.
image

I am really thankfull for any advice on the matter.

@kirilllzaitsev
Copy link

One reason could be an insufficient number of steps. Have you tried increasing it to the maximum feasible for your case amount and see if your training loop works?

Try to "remember" a single image first, i.e., optimize for N steps and observe that you can sample this image with high quality.

@TheFusion21
Copy link
Contributor

The script in the readme only contains mock images (random noise).
You need to load a sufficient dataset first and train on it.

@Valerie9696
Copy link
Author

The script in the readme only contains mock images (random noise). You need to load a sufficient dataset first and train on it.

I know, that is why I marked the part with the mock data with my text embeddings and my sample images. This training set currently consists of 3k images and embedded captions.

@Valerie9696
Copy link
Author

One reason could be an insufficient number of steps. Have you tried increasing it to the maximum feasible for your case amount and see if your training loop works?

Try to "remember" a single image first, i.e., optimize for N steps and observe that you can sample this image with high quality.

Oh, I think this might be where I am wrong. How exactly do I increase the amount steps? Currently I am basically running the script above with my own images and embedded captions.

@kirilllzaitsev
Copy link

Feel free to put a huge number of steps in the beginning and plot your samples regularly. Notice when your samples start looking like your inputs which will give you intuition about how much time it takes to "overfit" your network.

@stepanovD
Copy link

Hey @Valerie9696! I have same problem. Do you have a solution of this problem?

@asher-lab
Copy link

Hello @stepanovD and @Valerie9696 can you share how did you create a training data here in imagen? I'm new to pytorch and still trying to connect the dots.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants