Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why using meta-epoch training paradigm #18

Open
DonaldRR opened this issue Dec 27, 2021 · 8 comments
Open

Why using meta-epoch training paradigm #18

DonaldRR opened this issue Dec 27, 2021 · 8 comments

Comments

@DonaldRR
Copy link

DonaldRR commented Dec 27, 2021

Hi,
Thanks to your code, it helps my way on researching alot.
One question comes to my mind when I try to implement cFlow under my code that usually a model is trained by a batch of data where the loss is reduced from the whole batch and backpropagated.
And I found that the loss is backpropagated for each sub-iteration -- only part of the batch is sampled.
This Training paradigm somehow confuses me, does it work better than the normal way ?
Here are points that I surmise why that works:

  • Only seeing part of the batch randomly makes the gradients move more stochastically which could produce a more robust model
  • It saves GPU memory for each forward/backward propagation with high batch size

Thanks!

@gudovskiy
Copy link
Owner

@DonaldRR could you point me to a line of code where "the loss is backpropagated for each sub-iteration -- only part of the batch is sampled"? I don't think, I implemented anything like you described: meta/sub epochs are just to introduce flexibility for train/test phases

@DonaldRR
Copy link
Author

@DonaldRR could you point me to a line of code where "the loss is backpropagated for each sub-iteration -- only part of the batch is sampled"? I don't think, I implemented anything like you described: meta/sub epochs are just to introduce flexibility for train/test phases

Ops, I mean the FIBER iteratoin. During each fiber iteration, N(=256) features are sampled for loss computation and backpropagation.

lines starting at 65th line train.py

                for f in range(FIB):  # per-fiber processing
                    idx = torch.arange(f*N, (f+1)*N)
                    c_p = c_r[perm[idx]]  # NxP
                    e_p = e_r[perm[idx]]  # NxC
                    if 'cflow' in c.dec_arch:
                        z, log_jac_det = decoder(e_p, [c_p,])
                    else:
                        z, log_jac_det = decoder(e_p)
                    #
                    decoder_log_prob = get_logp(C, z, log_jac_det)
                    log_prob = decoder_log_prob / C  # likelihood per dim
                    loss = -log_theta(log_prob)
                    optimizer.zero_grad()
                    loss.mean().backward()

@gudovskiy
Copy link
Owner

@DonaldRR I see. The number of feature vectors (fibers) can be quite large in a feature map (tensor) to fill all memory for a flow model. So, it is better to sample random feature vectors from a number of feature maps. Hence, your original post is on point :)

@PSZehnder
Copy link

@gudovskiy Thank you for providing this excellent repo. Does training on subbatches of fibers serve any purpose other than conserving memory? Could I remove this loop and process all the fibers in one shot if I have sufficient gpu memory to do so?

@gudovskiy
Copy link
Owner

@PSZehnder yes

@Howie86
Copy link

Howie86 commented Jun 1, 2023

Hi,
Need the value of N in train and test be the same ? When I change the value of N only in test phase, I found the inference results will slightly change.
Thanks a lot.

@gudovskiy
Copy link
Owner

@Howie86 N should not change test results

@Howie86
Copy link

Howie86 commented Jun 2, 2023

@gudovskiy Thanks for your reply, I understand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants