Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there anyone success to train this model? #34

Open
Jihun999 opened this issue Feb 16, 2024 · 31 comments
Open

Is there anyone success to train this model? #34

Jihun999 opened this issue Feb 16, 2024 · 31 comments

Comments

@Jihun999
Copy link

I tried to train this model few days. However, the reconstruction results always abnormal. If there is anyone success to train this model, can you tell me some tips for training?

@Jihun999 Jihun999 closed this as not planned Won't fix, can't repro, duplicate, stale Feb 16, 2024
@Jihun999 Jihun999 reopened this Feb 16, 2024
@Jihun999
Copy link
Author

The reconstruction images are like solid image.

@RobertLuo1
Copy link

RobertLuo1 commented Feb 20, 2024

Can you show the reconstruction images after training?

@Jihun999
Copy link
Author

example_image It always looks like this image.

@RobertLuo1
Copy link

@bridenmj How much epochs do you use? Are you working on the ImageNet Pretrain?

@Jihun999
Copy link
Author

Yes I'm working on ImageNet pretraining, It passed 12000 steps. The output image looks always the same. So, I tried LFQ in my own autoencoder, the training works well. It looks like there is something wrong in magvit2 model architecture.

@RobertLuo1
Copy link

Actually I reimplement the model structure to align with the magvit2 paper. But I find that the LFQ Loss is negative and the recon loss will get converage easily with or without GAN. The reconstructed images are vague but not the solid color. What about you? @Jihun999

@Jihun999
Copy link
Author

Ok, I will reimplement the model first. Thank you for your comment.

@Jason3900
Copy link

Actually I reimplement the model structure to align with the magvit2 paper. But I find that the LFQ Loss is negative and the recon loss will get converage easily with or without GAN. The reconstructed images are vague but not the solid color. What about you? @Jihun999

Hey, is it possible to share the code modification for model architecture alignment? Thanks a lot!

@lucidrains
Copy link
Owner

someone i know has trained it successfully.

@Jiushanhuadao
Copy link

wow, could i know who did it.

@StarCycle
Copy link

@RobertLuo1 @Jihun999 @lucidrains If you successfully trained this model, would you like to share the pretrained weights and the modified model code?

@vinyesm
Copy link

vinyesm commented May 16, 2024

Hello there,
Thanks @lucidrains for your work! I have successful trainings on toy data (tried it on images and video) with code in this fork https://github.com/vinyesm/magvit2-pytorch/blob/trainvideo/examples/train-on-video.py and with this video data https://huggingface.co/datasets/mavi88/phys101_frames/tree/main. What seemed to fix the issue is to stop using accelerate (I only train on one GPU).

I tried with only MSE and then also the other losses, and also with/without attend_space layers. All work but I did not try to tune hyperparameters..

Screenshot 2024-05-16 at 21 40 57 Screenshot 2024-05-16 at 21 44 15

@lucidrains
Copy link
Owner

thank you for sharing this Marina! I'll see if I can find the bug, and worse comes to worse, can always rewrite the training code in pytorch lightning

@RobertLuo1
Copy link

RobertLuo1 commented Jun 17, 2024

Hi, recently we have devoted a lot to training the tokenizer in Magvit2, and now we have open source the tokenizer trained with imagenet. Feel free to use that. The project page is https://github.com/TencentARC/Open-MAGVIT2. Thanks @lucidrains so much for your reference code and discussions!

@ashah01
Copy link

ashah01 commented Jul 23, 2024

Hey @lucidrains, I trained a MAGVIT2 tokenizer without modifying your implementation of the accelerate framework. As others have experienced, I initially saw just a solid block in the results/sampled.x.gif files. However, upon loading the model weights from my most recent checkpoint, I was able to get pretty good reconstructions in a sample script that I wrote that performs inference without using the accelerate framework. Additionally, the reconstruction MSE scores were consistent with the ones observed in your training script. This means that whatever bug others are experiencing is not the result of flawed model training, but rather something going wrong with the gif rendering.
gif generation
inference generation

*Note: the first file is the saved gif in the results folder. The ground truth frames have a weird colour scheme because I normalized the frame pixels to be between [-1, 1]. The second file is a reconstructed frame from my inference script. MSE was ~0.011 after training on a v100 for 5 hours.

@vincentcartillier
Copy link

Hello everyone, I also have been struggling training the model.

My goal is to first try to overfit magvit2 on a single video.

I haven't made any modifications to the code base at all. I am using accelerate and training on a single GPU.

Here is the result I get:
overfit

Here are the corresponding training curves:
image
image
image
image
image
image

The reconstruction loss does seem to decrease and then plateau.

Let me know if you have any idea why this is happening?

Some additional info:
I am loading the video using the video loader from magvit (ie frames are normalized between [0,1]).
I'm on Ubuntu 20.04, pytorch 2.4 with cuda 12.1, training on a single A40 GPU.
The VideoTokenizer I used is the one provided in the readme ie:

tokenizer = VideoTokenizer(
    image_size = 128,
    init_dim = 64,
    max_dim = 512,
    codebook_size = 1024,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'linear_attend_space',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
        'compress_time',
        ('consecutive_residual', 2),
        'compress_time',
        ('consecutive_residual', 2),
        'attend_time',
    )
)

@StarCycle
Copy link

Hi @vincentcartillier,

Please check Tencent's https://github.com/TencentARC/Open-MAGVIT2, which is based on this implementation but modify some parts

@vincentcartillier
Copy link

Thanks for your reply. I did came across the Open-MAGVIT2 repo, but correct me if I'm wrong, I don't think they've implemented the video tokenizer yet?
I believe there is only an image tokenizer that has been evaluated on ImageNet

@StarCycle
Copy link

@vincentcartillier Oh yes but they are developing a video tokenizer...

They had the same problem during training the image tokenizer and finally fixed it. I guess @RobertLuo1 can answer your question ^^

@Jason3900
Copy link

typically, if the recon loss is below 0.03, you will see an outline of the video. What you encountered may indicate the architecture is difficult to converge as I manually re-implement magvit2, it will quickly produce the reconstruction within a few hundred steps. In order to debug, you can first skip quantization and only use encoder's output as decoder's input to adjust the whole model's structure. When it's done, then you add quantization as it hard to train. BTW, the repo uses a 2d gan which takes samples frames as input, which is not aligned with the paper. You can use a 3d vqgan instead. But from my point of view, discriminator's training is not the most import part. The encoder and decoder's structures matters.

@vincentcartillier
Copy link

Got it. Thanks a lot for all the tips. I'll try these out. In the meantime, do you think you could share your re-implementation of magvit2? I'm assuming this is based of this repo.

@Jason3900
Copy link

Sorry, I can't because it's an internal project and is still under development > <. The implementation is not based on this project. I followed the google's magvit-v1 jax repo and modified it. The adjustments between v1 and v2 are minimal.

@Jason3900
Copy link

But I use the vector-quantize-pytorch's LFQ as it used in here.

@vincentcartillier
Copy link

Got it. Totally understandable. Thanks a lot for all the tips I'll give it a try!

@JingwWu
Copy link

JingwWu commented Sep 4, 2024

@vincentcartillier I encountered the same difficult convergence problem as described by @Jason3900 , but before that, I found that the learning rate is not set correctly. When I checked lr in the optimizer, it was extremely low (~1e-10). And when I train with a higher learning rate (e.g., 1e-5), I get better results.
5k iters
sampled 50 rk0
40k iters
sampled 400 rk0
train with only 1 video and LFQ is not used.

@vincentcartillier
Copy link

Got it thanks so much for the pointer. I think the reason why you're seeing such a low learning rate is because of the use of warmup. You can set warmup_step=1 in the VideoTokenizerTrainer() setup parameters to bypass this.

If you haven't used LFQ, are you using FSQ instead? (finite scalar quantization) or another thing?

Could you try running the same thing (same learning rate) with LFQ and see if it works? - it would be great to see if that's the source of the problem we're facing.
Also could you share your setup? (OS, cuda, pytorch version, linux kernel version, GPU type etc..)

@Jason3900
Copy link

@vincentcartillier I encountered the same difficult convergence problem as described by @Jason3900 , but before that, I found that the learning rate is not set correctly. When I checked lr in the optimizer, it was extremely low (~1e-10). And when I train with a higher learning rate (e.g., 1e-5), I get better results. 5k iters sampled 50 rk0 sampled 50 rk0 40k iters sampled 400 rk0 sampled 400 rk0 train with only 1 video and LFQ is not used.

Yep, but still it shouldn't take that many steps to get the reconstructed result only with 1 Video. It may indicate that the model is too hard to converge.

@JingwWu
Copy link

JingwWu commented Sep 5, 2024

@vincentcartillier sure!

If you haven't used LFQ, are you using FSQ instead? (finite scalar quantization) or another thing?

to test the convergence of this enc / dec arch, I just drop the quantizer and use continuous representation with dim 512

Could you try running the same thing (same learning rate) with LFQ and see if it works? - it would be great to see if that's the source of the problem we're facing.

yes I run the same setting with LFQ, but can't get a converged result in 40k steps.

So I re-implemented the enc / dec arch (using code in this repo with little modification) according to the paper. a surprising result I got this time. Still follow @Jason3900 's suggestion, skip quantization

2k iters, 1 video, no quant with dim 18
sampled 20 rk0

@vincentcartillier
Copy link

vincentcartillier commented Sep 5, 2024

Amazing! Would you be comfortable sharing the code modifications you've made? ( maybe via a PR or just sharing your fork).
It's really promising, then I guess the challenge is to include the LFQ part, which, based of Jason's comment should be a working piece already.

@vincentcartillier
Copy link

I also got to something kinda working. This is the same code, ie no modifications, same settings as my initial post above. Except I've changed the learning rate, or rather I've turned of the warmup (as @JingwWu pointed out). Here is the overfitting results I have after ~5k steps.
overfit
This is with using the LFQ + original enc dec implementation.
It is much better than before, but I would have expected to have an even cleaner output given that this is an overfitting experiment.

@JingwWu
Copy link

JingwWu commented Sep 6, 2024

Yes, this is the modified CausalConv3d module. The key is to allow stride in the spatial dim. This module can get correct results under kernel size and stride settings mentioned in the paper, and I've dropped other features for simplicity

class CausalConv3d(nn.Module):
    def __init__(
        self,
        chan_in,
        chan_out,
        kernel_size,
        pad_mode = 'constant',
        s_stride = 1,
        t_stride = 1,
        **kwargs
    ):
        super().__init__()
        kernel_size = cast_tuple(kernel_size, 3)

        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size

        assert is_odd(height_kernel_size) and is_odd(width_kernel_size)

        self.pad_mode = pad_mode
        time_pad = time_kernel_size - 1
        height_pad = height_kernel_size // 2
        width_pad = width_kernel_size // 2

        self.time_pad = time_pad
        self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)

        stride = (t_stride, s_stride, s_stride)
        self.conv = nn.Conv3d(
            chan_in,
            chan_out,
            kernel_size,
            stride = stride,
            **kwargs
        )

    def forward(self, x):
        pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'

        x = F.pad(x, self.time_causal_padding, mode = pad_mode)
        return self.conv(x)

Next is to implement ResBlock X, ResBlock X->Y, ResBlockDown X->Y as described in paper, using the CausalConv3d above and other common Pytorch nn module (eg, nn.Conv3d, nn.GroupNorm, etc) and Blur in this repo. They are then assembled into encoder and decoder

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants